From 697655bd6f5498229c207ae8c3b477cba0a5b041 Mon Sep 17 00:00:00 2001 From: Aziz Zayed Date: Tue, 15 Jun 2021 08:38:22 -0700 Subject: [PATCH] Squashed commit of the following: commit 0092f8e30971b6c1ea01925fdb68a34fcff356cf Author: Aziz Zayed Date: Tue Jun 15 08:22:51 2021 -0700 Fixed truncated-normal bug commit a6ded8c776b0f454328a85aa6d88f4e798519c50 Author: Aziz Zayed Date: Mon Jun 14 13:33:30 2021 -0700 [pytorch] Add BigGAN demo commit f14561449bb046ebd4d51fed7d2e5387a414cf5d Merge: a8a1a9b68 ec8405b9e Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com> Date: Fri Jun 11 20:45:34 2021 -0700 Merge branch 'deepjavalibrary:master' into master commit ec8405b9e52ca4474ee17f53d98832204fda1dd3 Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com> Date: Fri Jun 11 14:53:59 2021 -0700 [pytorch] Add oneHot operator (#1014) [tensoflow] Add truncated normal operation commit 50600fd3162cc648dfc6bc07cf3baa36404a05f8 Author: Frank Liu Date: Fri Jun 11 14:53:43 2021 -0700 upgrade dependencies version (#1012) Change-Id: I709938f69f21096bc5cd29a24191f0f282dcbc97 commit 3379fd2eac9da0e1a1f0b9253353a32c89061891 Author: Frank Liu Date: Fri Jun 11 14:53:29 2021 -0700 [serving] Fix flaky test (#1013) Change-Id: I13b89e04516c59a3d28ecafd49f4f808630b22fb commit 23157fd9de3dd70c63fbe1ecd859f978f1d796c6 Author: Frank Liu Date: Thu Jun 10 16:31:03 2021 -0700 Enable spotbugs for java 11+ (#1010) Change-Id: I74effbf45492a5cf50e09ba8af0223d2b1bcb5a5 commit 4f3870853ef11c93875a300462f0389c460895b0 Author: Frank Liu Date: Thu Jun 10 16:30:50 2021 -0700 Fix model zoo test typo (#1009) Change-Id: I7c0109c6e5fc0ece16288082fd830718f20ad489 commit a8a1a9b6822d3403b07b8be6130682e770c9ce71 Merge: 77809f49e 30b03f4fe Author: Aziz Zayed Date: Thu Jun 10 15:16:05 2021 -0700 Merge Truncated-Normal branch commit 77809f49ed05c4f6822613be044c8c604ad8b134 Author: Frank Liu Date: Thu Jun 10 14:07:43 2021 -0700 Make model zoo test weekly (#1004) Change-Id: I1c73df17cb077b9ce8905fcc2fc8bbb37b9688d8 commit 0aec8cad5aa5623a58217b2495c7e5d72205f582 Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com> Date: Thu Jun 10 12:46:16 2021 -0700 [tensoflow] Add truncated normal operation (#1005) commit 30b03f4fee78f0558c5a1577c00d00dfbee85425 Author: Aziz Zayed Date: Wed Jun 9 01:40:33 2021 -0700 [tensoflow] Add truncated normal operation commit d8e7e1d368ff6d32430bf7ff1bff0f21af3e50aa Author: Frank Liu Date: Wed Jun 9 07:55:15 2021 -0700 Fixes #999, hanlde UTF16 surrogate charactors properly. (#1003) Change-Id: I19e77cf5a8282bea901434041806eb102549ec0f commit b0fe73a97ea4310238c0d9a4d3db618b39b93a36 Author: Frank Liu Date: Tue Jun 8 17:56:19 2021 -0700 [pytorch] Update load model jupyter notebook (#1002) Change-Id: I1889aa93d2002e6ce02c740d2d1d3517bf586760 commit 828693029c0463828d9fae714c70514f9c3bf50b Author: Frank Liu Date: Tue Jun 8 15:29:27 2021 -0700 [tensorflow] fix optOption usage document (#1001) Change-Id: Ie044839cf082d63010a5c26d3f2f8833447919c6 commit a26f5b2e4bd24cbda1ab2e6aba8435364cd139b9 Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com> Date: Tue Jun 8 15:29:10 2021 -0700 Updated PyTorch Docs (#1000) * Added auto softmax metadata for action_recognition * Update PyTorch Docs commit e6890f9e9f63b986dc527a9d1a7175354106b5ce Author: Lanking Date: Mon Jun 7 18:25:19 2021 -0700 upgrade xgboost (#993) commit a0dcf3ac01da15ef3e540b6a0c6047c899ddf21c Author: Lanking Date: Mon Jun 7 18:25:12 2021 -0700 bump up onnx runtime version (#992) --- .../java/ai/djl/ndarray/BaseNDManager.java | 19 + api/src/main/java/ai/djl/ndarray/NDArray.java | 51 + .../main/java/ai/djl/ndarray/NDManager.java | 59 + api/src/main/native/djl/utils.h | 34 +- .../how_to_import_tensorflow_models_in_DJL.md | 2 +- examples/build.gradle | 2 +- .../inference/biggan/BigGANCategory.java | 98 ++ .../inference/biggan/BigGANInput.java | 76 ++ .../inference/biggan/BigGANTranslator.java | 93 ++ .../examples/inference/biggan/Generator.java | 97 ++ .../inference/biggan/package-info.java | 15 + examples/src/main/resources/categories.txt | 1000 +++++++++++++++++ .../ai/djl/sentencepiece/SpTokenizerTest.java | 15 + gradle.properties | 12 +- .../tests/model_zoo/ModelZooTest.java | 6 +- .../src/test/translator/MyTranslator.java | 2 +- jupyter/load_pytorch_model.ipynb | 33 +- .../ai/djl/onnxruntime/engine/OrtEngine.java | 2 +- paddlepaddle/paddlepaddle-native/build.gradle | 4 - .../java/ai/djl/pytorch/engine/PtNDArray.java | 12 + .../java/ai/djl/pytorch/jni/JniUtils.java | 8 + .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 + ...jl_pytorch_jni_PyTorchLibrary_inference.cc | 8 +- ...ytorch_jni_PyTorchLibrary_nn_functional.cc | 9 + .../java/ai/djl/serving/ModelServerTest.java | 32 +- .../ai/djl/tensorflow/engine/TfNDManager.java | 25 + .../djl/tensorflow/engine/TfSymbolBlock.java | 2 +- tensorflow/tensorflow-native/build.gradle | 4 - tflite/tflite-native/build.gradle | 4 - 29 files changed, 1657 insertions(+), 69 deletions(-) create mode 100644 examples/src/main/java/ai/djl/examples/inference/biggan/BigGANCategory.java create mode 100644 examples/src/main/java/ai/djl/examples/inference/biggan/BigGANInput.java create mode 100644 examples/src/main/java/ai/djl/examples/inference/biggan/BigGANTranslator.java create mode 100644 examples/src/main/java/ai/djl/examples/inference/biggan/Generator.java create mode 100644 examples/src/main/java/ai/djl/examples/inference/biggan/package-info.java create mode 100644 examples/src/main/resources/categories.txt diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 64b52355710..8012cecb05c 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.util.PairList; +import ai.djl.util.RandomUtils; import java.nio.Buffer; import java.nio.file.Path; import java.util.UUID; @@ -153,6 +154,24 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + int sampleSize = (int) shape.size(); + double[] dist = new double[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + double sample = RandomUtils.nextGaussian(); + while (sample < -2 || sample > 2) { + sample = RandomUtils.nextGaussian(); + } + + dist[i] = sample; + } + + return create(dist).muli(scale).addi(loc).reshape(shape).toType(dataType, false); + } + /** {@inheritDoc} */ @Override public NDArray randomMultinomial(int n, NDArray pValues) { diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index ebb3148b285..57cabbd2da5 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4592,6 +4592,57 @@ default NDArray oneHot(int depth) { return oneHot(depth, 1f, 0f, DataType.FLOAT32); } + /** + * Returns a one-hot {@code NDArray}. + * + *
    + *
  • The locations represented by indices take value 1, while all other locations take value + * 0. + *
  • If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is + * appended at the end. + *
  • If {@code NDArray} is a scalar the output shape will be a vector of length depth. + *
  • If {@code NDArray} is a vector of length features, the output shape will be features x + * depth. + *
  • If {@code NDArray} is a matrix with shape [batch, features], the output shape will be + * batch x features x depth. + *
+ * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
+     * jshell> array.oneHot(3);
+     * ND: (4, 3) cpu() float32
+     * [[0., 1., 0.],
+     *  [1., 0., 0.],
+     *  [0., 0., 1.],
+     *  [1., 0., 0.],
+     * ]
+     * jshell> NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
+     * jshell> array.oneHot(3);
+     * ND: (3, 2, 3) cpu() float32
+     * [[[0., 1., 0.],
+     *   [1., 0., 0.],
+     *  ],
+     *  [[0., 1., 0.],
+     *   [1., 0., 0.],
+     *  ],
+     *  [[0., 0., 1.],
+     *   [1., 0., 0.],
+     *  ],
+     * ]
+     * 
+ * + * @param depth Depth of the one hot dimension. + * @param dataType dataType of the output. + * @return one-hot encoding of this {@code NDArray} + * @see Classification-problems + */ + default NDArray oneHot(int depth, DataType dataType) { + return oneHot(depth, 0f, 1f, dataType); + } + /** * Returns a one-hot {@code NDArray}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index dae10ec8c99..826ee66d17f 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -1232,6 +1232,65 @@ default NDArray randomNormal( return newSubManager(device).randomNormal(loc, scale, shape, dataType); } + /** + * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation + * 1, discarding and re-drawing any samples that are more than two standard deviations from the + * mean. + * + *

Samples are distributed according to a normal distribution parametrized by mean = 0 and + * standard deviation = 1. + * + * @param shape the output {@link Shape} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal(Shape shape) { + return truncatedNormal(0f, 1f, shape, DataType.FLOAT32); + } + + /** + * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation + * 1, discarding and re-drawing any samples that are more than two standard deviations from the + * mean. + * + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal(Shape shape, DataType dataType) { + return truncatedNormal(0.0f, 1.0f, shape, dataType); + } + + /** + * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any + * samples that are more than two standard deviations from the mean. + * + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType); + + /** + * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any + * samples that are more than two standard deviations from the mean. + * + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal( + float loc, float scale, Shape shape, DataType dataType, Device device) { + if (device == null || device.equals(getDevice())) { + return truncatedNormal(loc, scale, shape, dataType); + } + return newSubManager(device).truncatedNormal(loc, scale, shape, dataType); + } + /** * Draw samples from a multinomial distribution. * diff --git a/api/src/main/native/djl/utils.h b/api/src/main/native/djl/utils.h index edce9e54a9c..ba78b63843c 100644 --- a/api/src/main/native/djl/utils.h +++ b/api/src/main/native/djl/utils.h @@ -29,9 +29,21 @@ inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) { if (jstr == nullptr) { return std::string(); } - const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE); - std::string str = std::string(c_str); - env->ReleaseStringUTFChars(jstr, c_str); + + // TODO: cache reflection to improve performance + const jclass string_class = env->GetObjectClass(jstr); + const jmethodID getbytes_method = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); + + const jstring charset = env->NewStringUTF("UTF-8"); + const jbyteArray jbytes = (jbyteArray) env->CallObjectMethod(jstr, getbytes_method, charset); + env->DeleteLocalRef(charset); + + const jsize length = env->GetArrayLength(jbytes); + jbyte* c_str = env->GetByteArrayElements(jbytes, NULL); + std::string str = std::string(reinterpret_cast(c_str), length); + + env->ReleaseByteArrayElements(jbytes, c_str, RELEASE_MODE); + env->DeleteLocalRef(jbytes); return str; } @@ -100,9 +112,23 @@ inline std::vector GetVecFromJStringArray(JNIEnv* env, jobjectArray // String[] inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector &vec) { jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr); + + // TODO: cache reflection to improve performance + const jclass string_class = env->FindClass("java/lang/String"); + const jmethodID ctor = env->GetMethodID(string_class, "", "([BLjava/lang/String;)V"); + const jstring charset = env->NewStringUTF("UTF-8"); + for (int i = 0; i < vec.size(); ++i) { - env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str())); + const char* c_str = vec[i].c_str(); + int len = vec[i].length(); + auto jbytes = env->NewByteArray(len); + env->SetByteArrayRegion(jbytes, 0, len, reinterpret_cast(c_str)); + jobject jstr = env->NewObject(string_class, ctor, jbytes, charset); + env->DeleteLocalRef(jbytes); + env->SetObjectArrayElement(array, i, jstr); } + + env->DeleteLocalRef(charset); return array; } diff --git a/docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md b/docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md index a1e968e568d..8c373dfd4b4 100644 --- a/docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md +++ b/docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md @@ -121,7 +121,7 @@ Criteria criteria = .setTypes(Image.class, DetectedObjects.class) .optFilter("backbone", "mobilenet_v2") .optEngine("TensorFlow") - .optOption("Tags", new String[] {}) + .optOption("Tags", "") .optOption("SignatureDefKey", "default") .optProgress(new ProgressBar()) .build(); diff --git a/examples/build.gradle b/examples/build.gradle index 223e538b00e..a37b02691e6 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -44,7 +44,7 @@ dependencies { } application { - mainClassName = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection") + mainClassName = System.getProperty("main", "ai.djl.examples.inference.biggan.Generator") } run { diff --git a/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANCategory.java b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANCategory.java new file mode 100644 index 00000000000..6e3823510ff --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANCategory.java @@ -0,0 +1,98 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.biggan; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class BigGANCategory { + private static final Logger logger = LoggerFactory.getLogger(BigGANCategory.class); + + public static final int NUMBER_OF_CATEGORIES = 1000; + private static final Map CATEGORIES_BY_NAME = + new ConcurrentHashMap<>(NUMBER_OF_CATEGORIES); + private static String[] categoriesById; + + private int id; + private String[] names; + + static { + try { + parseCategories(); + } catch (IOException e) { + logger.error("Error parsing the ImageNet categories: {}", e); + } + createCategoriesByName(); + } + + private BigGANCategory(int id, String[] names) { + this.id = id; + this.names = names; + } + + public int getId() { + return id; + } + + public String[] getNames() { + return names.clone(); + } + + public static BigGANCategory id(int id) { + String names = categoriesById[id]; + int index = names.indexOf(','); + if (index < 0) { + return of(names); + } else { + return of(names.substring(0, index)); + } + } + + public static BigGANCategory of(String name) { + if (!CATEGORIES_BY_NAME.containsKey(name)) { + throw new IllegalArgumentException(name + " is not a valid category."); + } + return CATEGORIES_BY_NAME.get(name); + } + + private static void createCategoriesByName() { + for (int i = 0; i < NUMBER_OF_CATEGORIES; i++) { + String[] categoryNames = categoriesById[i].split(", "); + BigGANCategory category = new BigGANCategory(i, categoryNames); + + for (String name : categoryNames) { + CATEGORIES_BY_NAME.put(name, category); + } + } + } + + private static void parseCategories() throws IOException { + String filePath = "src/main/resources/categories.txt"; + + List fileLines = Files.readAllLines(Paths.get(filePath)); + List categories = new ArrayList<>(NUMBER_OF_CATEGORIES); + for (String line : fileLines) { + int nameIndex = line.indexOf(':') + 2; + categories.add(line.substring(nameIndex)); + } + + categoriesById = categories.toArray(new String[] {}); + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANInput.java b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANInput.java new file mode 100644 index 00000000000..790143b8e5f --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANInput.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.biggan; + +public final class BigGANInput { + private int sampleSize; + private float truncation; + private BigGANCategory category; + + public BigGANInput(int sampleSize, float truncation, BigGANCategory category) { + this.sampleSize = sampleSize; + this.truncation = truncation; + this.category = category; + } + + BigGANInput(Builder builder) { + this.sampleSize = builder.sampleSize; + this.truncation = builder.truncation; + this.category = builder.category; + } + + public int getSampleSize() { + return sampleSize; + } + + public float getTruncation() { + return truncation; + } + + public BigGANCategory getCategory() { + return category; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private int sampleSize = 5; + private float truncation = 0.5f; + private BigGANCategory category; + + Builder() { + category = BigGANCategory.of("Egyptian cat"); + } + + public Builder optSampleSize(int sampleSize) { + this.sampleSize = sampleSize; + return this; + } + + public Builder optTruncation(float truncation) { + this.truncation = truncation; + return this; + } + + public Builder setCategory(BigGANCategory category) { + this.category = category; + return this; + } + + public BigGANInput build() { + return new BigGANInput(this); + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANTranslator.java b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANTranslator.java new file mode 100644 index 00000000000..cc872705b12 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/biggan/BigGANTranslator.java @@ -0,0 +1,93 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.biggan; + +import ai.djl.engine.Engine; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class BigGANTranslator implements Translator { + private static final Logger logger = LoggerFactory.getLogger(BigGANTranslator.class); + private static final int SEED_COLUMN_SIZE = 128; + + @Override + public Image[] processOutput(TranslatorContext ctx, NDList list) throws Exception { + logOutputList(list); + + NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false); + + int sampleSize = (int) output.getShape().get(0); + Image[] images = new Image[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + images[i] = ImageFactory.getInstance().fromNDArray(output.get(i)); + } + + return images; + } + + private void logOutputList(NDList list) { + logger.info(""); + logger.info("MY OUTPUT:"); + list.forEach(array -> logger.info(" out: {}", array.getShape())); + } + + @Override + public NDList processInput(TranslatorContext ctx, BigGANInput input) throws Exception { + Engine.getInstance().setRandomSeed(0); + NDManager manager = ctx.getNDManager(); + + NDArray categoryArray = createCategoryArray(manager, input); + NDArray seed = + manager.truncatedNormal(new Shape(input.getSampleSize(), SEED_COLUMN_SIZE)) + .muli(input.getTruncation()); + NDArray truncation = manager.create(input.getTruncation()); + + logInputArrays(categoryArray, seed, truncation); + return new NDList(seed, categoryArray, truncation); + } + + private NDArray createCategoryArray(NDManager manager, BigGANInput input) { + int categoryId = input.getCategory().getId(); + int sampleSize = input.getSampleSize(); + + int[] indices = new int[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + indices[i] = categoryId; + } + return manager.create(indices).oneHot(BigGANCategory.NUMBER_OF_CATEGORIES); + } + + private void logInputArrays(NDArray categoryArray, NDArray seed, NDArray truncation) { + logger.info(""); + logger.info("MY INPUTS: "); + logger.info(" y: {}", categoryArray.getShape()); + logger.info(" z: {}", seed.get(":, :10")); + logger.info(" truncation: {}", truncation.getShape()); + } + + @Override + public Batchifier getBatchifier() { + return null; + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/biggan/Generator.java b/examples/src/main/java/ai/djl/examples/inference/biggan/Generator.java new file mode 100644 index 00000000000..589b93d9389 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/biggan/Generator.java @@ -0,0 +1,97 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.biggan; + +import ai.djl.ModelException; +import ai.djl.engine.Engine; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.DownloadUtils; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class Generator { + + private static final Logger logger = LoggerFactory.getLogger(Generator.class); + + private Generator() {} + + public static void main(String[] args) throws ModelException, TranslateException, IOException { + Image[] generatedImages = Generator.generate(); + + if (generatedImages == null) { + logger.info("This example only works for PyTorch Engine"); + } else { + logger.info("Using PyTorch Engine. {} images generated.", generatedImages.length); + saveImages(generatedImages); + } + } + + private static void saveImages(Image[] generatedImages) throws IOException { + Path outputPath = Paths.get("build/output/gan/"); + Files.createDirectories(outputPath); + + for (int i = 0; i < generatedImages.length; i++) { + Path imagePath = outputPath.resolve("image" + i + ".jpg"); + generatedImages[i].save(Files.newOutputStream(imagePath), "jpg"); + } + logger.info("Generated images have been saved in: {}", outputPath); + } + + public static Image[] generate() throws IOException, ModelException, TranslateException { + if (!"PyTorch".equals(Engine.getInstance().getEngineName())) { + return null; + } + + String modelPath = "build/models/gan/"; + String modelName = "biggan-deep-256"; + + DownloadUtils.download( + "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/gan/ai/djl/pytorch/biggan-deep/0.0.1/" + + modelName + + ".pt.gz", + modelPath + modelName + ".pt", + new ProgressBar()); + + Criteria criteria = + Criteria.builder() + .setTypes(BigGANInput.class, Image[].class) + .optModelName(modelName) + .optModelPath(Paths.get(modelPath)) + .optTranslator(new BigGANTranslator()) + .optProgress(new ProgressBar()) + .build(); + + BigGANInput input = + BigGANInput.builder() + .setCategory(BigGANCategory.of("tiger cat")) + .optSampleSize(5) + .optTruncation(0.5f) + .build(); + + try (ZooModel model = ModelZoo.loadModel(criteria)) { + try (Predictor generator = model.newPredictor()) { + return generator.predict(input); + } + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/biggan/package-info.java b/examples/src/main/java/ai/djl/examples/inference/biggan/package-info.java new file mode 100644 index 00000000000..ff0fbaee26d --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/biggan/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains an examples using BigGAN from DeepMind (Google). */ +package ai.djl.examples.inference.biggan; diff --git a/examples/src/main/resources/categories.txt b/examples/src/main/resources/categories.txt new file mode 100644 index 00000000000..63eeb203c62 --- /dev/null +++ b/examples/src/main/resources/categories.txt @@ -0,0 +1,1000 @@ +0: tench, Tinca tinca +1: goldfish, Carassius auratus +2: great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +3: tiger shark, Galeocerdo cuvieri +4: hammerhead, hammerhead shark +5: electric ray, crampfish, numbfish, torpedo +6: stingray +7: cock +8: hen +9: ostrich, Struthio camelus +10: brambling, Fringilla montifringilla +11: goldfinch, Carduelis carduelis +12: house finch, linnet, Carpodacus mexicanus +13: junco, snowbird +14: indigo bunting, indigo finch, indigo bird, Passerina cyanea +15: robin, American robin, Turdus migratorius +16: bulbul +17: jay +18: magpie +19: chickadee +20: water ouzel, dipper +21: kite +22: bald eagle, American eagle, Haliaeetus leucocephalus +23: vulture +24: great grey owl, great gray owl, Strix nebulosa +25: European fire salamander, Salamandra salamandra +26: common newt, Triturus vulgaris +27: eft +28: spotted salamander, Ambystoma maculatum +29: axolotl, mud puppy, Ambystoma mexicanum +30: bullfrog, Rana catesbeiana +31: tree frog, tree-frog +32: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +33: loggerhead, loggerhead turtle, Caretta caretta +34: leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +35: mud turtle +36: terrapin +37: box turtle, box tortoise +38: banded gecko +39: common iguana, iguana, Iguana iguana +40: American chameleon, anole, Anolis carolinensis +41: whiptail, whiptail lizard +42: agama +43: frilled lizard, Chlamydosaurus kingi +44: alligator lizard +45: Gila monster, Heloderma suspectum +46: green lizard, Lacerta viridis +47: African chameleon, Chamaeleo chamaeleon +48: Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +49: African crocodile, Nile crocodile, Crocodylus niloticus +50: American alligator, Alligator mississipiensis +51: triceratops +52: thunder snake, worm snake, Carphophis amoenus +53: ringneck snake, ring-necked snake, ring snake +54: hognose snake, puff adder, sand viper +55: green snake, grass snake +56: king snake, kingsnake +57: garter snake, grass snake +58: water snake +59: vine snake +60: night snake, Hypsiglena torquata +61: boa constrictor, Constrictor constrictor +62: rock python, rock snake, Python sebae +63: Indian cobra, Naja naja +64: green mamba +65: sea snake +66: horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +67: diamondback, diamondback rattlesnake, Crotalus adamanteus +68: sidewinder, horned rattlesnake, Crotalus cerastes +69: trilobite +70: harvestman, daddy longlegs, Phalangium opilio +71: scorpion +72: black and gold garden spider, Argiope aurantia +73: barn spider, Araneus cavaticus +74: garden spider, Aranea diademata +75: black widow, Latrodectus mactans +76: tarantula +77: wolf spider, hunting spider +78: tick +79: centipede +80: black grouse +81: ptarmigan +82: ruffed grouse, partridge, Bonasa umbellus +83: prairie chicken, prairie grouse, prairie fowl +84: peacock +85: quail +86: partridge +87: African grey, African gray, Psittacus erithacus +88: macaw +89: sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +90: lorikeet +91: coucal +92: bee eater +93: hornbill +94: hummingbird +95: jacamar +96: toucan +97: drake +98: red-breasted merganser, Mergus serrator +99: goose +100: black swan, Cygnus atratus +101: tusker +102: echidna, spiny anteater, anteater +103: platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +104: wallaby, brush kangaroo +105: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +106: wombat +107: jellyfish +108: sea anemone, anemone +109: brain coral +110: flatworm, platyhelminth +111: nematode, nematode worm, roundworm +112: conch +113: snail +114: slug +115: sea slug, nudibranch +116: chiton, coat-of-mail shell, sea cradle, polyplacophore +117: chambered nautilus, pearly nautilus, nautilus +118: Dungeness crab, Cancer magister +119: rock crab, Cancer irroratus +120: fiddler crab +121: king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +122: American lobster, Northern lobster, Maine lobster, Homarus americanus +123: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +124: crayfish, crawfish, crawdad, crawdaddy +125: hermit crab +126: isopod +127: white stork, Ciconia ciconia +128: black stork, Ciconia nigra +129: spoonbill +130: flamingo +131: little blue heron, Egretta caerulea +132: American egret, great white heron, Egretta albus +133: bittern +134: crane +135: limpkin, Aramus pictus +136: European gallinule, Porphyrio porphyrio +137: American coot, marsh hen, mud hen, water hen, Fulica americana +138: bustard +139: ruddy turnstone, Arenaria interpres +140: red-backed sandpiper, dunlin, Erolia alpina +141: redshank, Tringa totanus +142: dowitcher +143: oystercatcher, oyster catcher +144: pelican +145: king penguin, Aptenodytes patagonica +146: albatross, mollymawk +147: grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +148: killer whale, killer, orca, grampus, sea wolf, Orcinus orca +149: dugong, Dugong dugon +150: sea lion +151: Chihuahua +152: Japanese spaniel +153: Maltese dog, Maltese terrier, Maltese +154: Pekinese, Pekingese, Peke +155: Shih-Tzu +156: Blenheim spaniel +157: papillon +158: toy terrier +159: Rhodesian ridgeback +160: Afghan hound, Afghan +161: basset, basset hound +162: beagle +163: bloodhound, sleuthhound +164: bluetick +165: black-and-tan coonhound +166: Walker hound, Walker foxhound +167: English foxhound +168: redbone +169: borzoi, Russian wolfhound +170: Irish wolfhound +171: Italian greyhound +172: whippet +173: Ibizan hound, Ibizan Podenco +174: Norwegian elkhound, elkhound +175: otterhound, otter hound +176: Saluki, gazelle hound +177: Scottish deerhound, deerhound +178: Weimaraner +179: Staffordshire bullterrier, Staffordshire bull terrier +180: American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +181: Bedlington terrier +182: Border terrier +183: Kerry blue terrier +184: Irish terrier +185: Norfolk terrier +186: Norwich terrier +187: Yorkshire terrier +188: wire-haired fox terrier +189: Lakeland terrier +190: Sealyham terrier, Sealyham +191: Airedale, Airedale terrier +192: cairn, cairn terrier +193: Australian terrier +194: Dandie Dinmont, Dandie Dinmont terrier +195: Boston bull, Boston terrier +196: miniature schnauzer +197: giant schnauzer +198: standard schnauzer +199: Scotch terrier, Scottish terrier, Scottie +200: Tibetan terrier, chrysanthemum dog +201: silky terrier, Sydney silky +202: soft-coated wheaten terrier +203: West Highland white terrier +204: Lhasa, Lhasa apso +205: flat-coated retriever +206: curly-coated retriever +207: golden retriever +208: Labrador retriever +209: Chesapeake Bay retriever +210: German short-haired pointer +211: vizsla, Hungarian pointer +212: English setter +213: Irish setter, red setter +214: Gordon setter +215: Brittany spaniel +216: clumber, clumber spaniel +217: English springer, English springer spaniel +218: Welsh springer spaniel +219: cocker spaniel, English cocker spaniel, cocker +220: Sussex spaniel +221: Irish water spaniel +222: kuvasz +223: schipperke +224: groenendael +225: malinois +226: briard +227: kelpie +228: komondor +229: Old English sheepdog, bobtail +230: Shetland sheepdog, Shetland sheep dog, Shetland +231: collie +232: Border collie +233: Bouvier des Flandres, Bouviers des Flandres +234: Rottweiler +235: German shepherd, German shepherd dog, German police dog, alsatian +236: Doberman, Doberman pinscher +237: miniature pinscher +238: Greater Swiss Mountain dog +239: Bernese mountain dog +240: Appenzeller +241: EntleBucher +242: boxer +243: bull mastiff +244: Tibetan mastiff +245: French bulldog +246: Great Dane +247: Saint Bernard, St Bernard +248: Eskimo dog, husky +249: malamute, malemute, Alaskan malamute +250: Siberian husky +251: dalmatian, coach dog, carriage dog +252: affenpinscher, monkey pinscher, monkey dog +253: basenji +254: pug, pug-dog +255: Leonberg +256: Newfoundland, Newfoundland dog +257: Great Pyrenees +258: Samoyed, Samoyede +259: Pomeranian +260: chow, chow chow +261: keeshond +262: Brabancon griffon +263: Pembroke, Pembroke Welsh corgi +264: Cardigan, Cardigan Welsh corgi +265: toy poodle +266: miniature poodle +267: standard poodle +268: Mexican hairless +269: timber wolf, grey wolf, gray wolf, Canis lupus +270: white wolf, Arctic wolf, Canis lupus tundrarum +271: red wolf, maned wolf, Canis rufus, Canis niger +272: coyote, prairie wolf, brush wolf, Canis latrans +273: dingo, warrigal, warragal, Canis dingo +274: dhole, Cuon alpinus +275: African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +276: hyena, hyaena +277: red fox, Vulpes vulpes +278: kit fox, Vulpes macrotis +279: Arctic fox, white fox, Alopex lagopus +280: grey fox, gray fox, Urocyon cinereoargenteus +281: tabby, tabby cat +282: tiger cat +283: Persian cat +284: Siamese cat, Siamese +285: Egyptian cat +286: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +287: lynx, catamount +288: leopard, Panthera pardus +289: snow leopard, ounce, Panthera uncia +290: jaguar, panther, Panthera onca, Felis onca +291: lion, king of beasts, Panthera leo +292: tiger, Panthera tigris +293: cheetah, chetah, Acinonyx jubatus +294: brown bear, bruin, Ursus arctos +295: American black bear, black bear, Ursus americanus, Euarctos americanus +296: ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +297: sloth bear, Melursus ursinus, Ursus ursinus +298: mongoose +299: meerkat, mierkat +300: tiger beetle +301: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +302: ground beetle, carabid beetle +303: long-horned beetle, longicorn, longicorn beetle +304: leaf beetle, chrysomelid +305: dung beetle +306: rhinoceros beetle +307: weevil +308: fly +309: bee +310: ant, emmet, pismire +311: grasshopper, hopper +312: cricket +313: walking stick, walkingstick, stick insect +314: cockroach, roach +315: mantis, mantid +316: cicada, cicala +317: leafhopper +318: lacewing, lacewing fly +319: "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk" +320: damselfly +321: admiral +322: ringlet, ringlet butterfly +323: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +324: cabbage butterfly +325: sulphur butterfly, sulfur butterfly +326: lycaenid, lycaenid butterfly +327: starfish, sea star +328: sea urchin +329: sea cucumber, holothurian +330: wood rabbit, cottontail, cottontail rabbit +331: hare +332: Angora, Angora rabbit +333: hamster +334: porcupine, hedgehog +335: fox squirrel, eastern fox squirrel, Sciurus niger +336: marmot +337: beaver +338: guinea pig, Cavia cobaya +339: sorrel +340: zebra +341: hog, pig, grunter, squealer, Sus scrofa +342: wild boar, boar, Sus scrofa +343: warthog +344: hippopotamus, hippo, river horse, Hippopotamus amphibius +345: ox +346: water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +347: bison +348: ram, tup +349: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +350: ibex, Capra ibex +351: hartebeest +352: impala, Aepyceros melampus +353: gazelle +354: Arabian camel, dromedary, Camelus dromedarius +355: llama +356: weasel +357: mink +358: polecat, fitch, foulmart, foumart, Mustela putorius +359: black-footed ferret, ferret, Mustela nigripes +360: otter +361: skunk, polecat, wood pussy +362: badger +363: armadillo +364: three-toed sloth, ai, Bradypus tridactylus +365: orangutan, orang, orangutang, Pongo pygmaeus +366: gorilla, Gorilla gorilla +367: chimpanzee, chimp, Pan troglodytes +368: gibbon, Hylobates lar +369: siamang, Hylobates syndactylus, Symphalangus syndactylus +370: guenon, guenon monkey +371: patas, hussar monkey, Erythrocebus patas +372: baboon +373: macaque +374: langur +375: colobus, colobus monkey +376: proboscis monkey, Nasalis larvatus +377: marmoset +378: capuchin, ringtail, Cebus capucinus +379: howler monkey, howler +380: titi, titi monkey +381: spider monkey, Ateles geoffroyi +382: squirrel monkey, Saimiri sciureus +383: Madagascar cat, ring-tailed lemur, Lemur catta +384: indri, indris, Indri indri, Indri brevicaudatus +385: Indian elephant, Elephas maximus +386: African elephant, Loxodonta africana +387: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +388: giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +389: barracouta, snoek +390: eel +391: coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +392: rock beauty, Holocanthus tricolor +393: anemone fish +394: sturgeon +395: gar, garfish, garpike, billfish, Lepisosteus osseus +396: lionfish +397: puffer, pufferfish, blowfish, globefish +398: abacus +399: abaya +400: "academic gown, academic robe, judges robe" +401: accordion, piano accordion, squeeze box +402: acoustic guitar +403: aircraft carrier, carrier, flattop, attack aircraft carrier +404: airliner +405: airship, dirigible +406: altar +407: ambulance +408: amphibian, amphibious vehicle +409: analog clock +410: apiary, bee house +411: apron +412: ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +413: assault rifle, assault gun +414: backpack, back pack, knapsack, packsack, rucksack, haversack +415: bakery, bakeshop, bakehouse +416: balance beam, beam +417: balloon +418: ballpoint, ballpoint pen, ballpen, Biro +419: Band Aid +420: banjo +421: bannister, banister, balustrade, balusters, handrail +422: barbell +423: barber chair +424: barbershop +425: barn +426: barometer +427: barrel, cask +428: barrow, garden cart, lawn cart, wheelbarrow +429: baseball +430: basketball +431: bassinet +432: bassoon +433: bathing cap, swimming cap +434: bath towel +435: bathtub, bathing tub, bath, tub +436: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +437: beacon, lighthouse, beacon light, pharos +438: beaker +439: bearskin, busby, shako +440: beer bottle +441: beer glass +442: bell cote, bell cot +443: bib +444: bicycle-built-for-two, tandem bicycle, tandem +445: bikini, two-piece +446: binder, ring-binder +447: binoculars, field glasses, opera glasses +448: birdhouse +449: boathouse +450: bobsled, bobsleigh, bob +451: bolo tie, bolo, bola tie, bola +452: bonnet, poke bonnet +453: bookcase +454: bookshop, bookstore, bookstall +455: bottlecap +456: bow +457: bow tie, bow-tie, bowtie +458: brass, memorial tablet, plaque +459: brassiere, bra, bandeau +460: breakwater, groin, groyne, mole, bulwark, seawall, jetty +461: breastplate, aegis, egis +462: broom +463: bucket, pail +464: buckle +465: bulletproof vest +466: bullet train, bullet +467: butcher shop, meat market +468: cab, hack, taxi, taxicab +469: caldron, cauldron +470: candle, taper, wax light +471: cannon +472: canoe +473: can opener, tin opener +474: cardigan +475: car mirror +476: carousel, carrousel, merry-go-round, roundabout, whirligig +477: "carpenters kit, tool kit" +478: carton +479: car wheel +480: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +481: cassette +482: cassette player +483: castle +484: catamaran +485: CD player +486: cello, violoncello +487: cellular telephone, cellular phone, cellphone, cell, mobile phone +488: chain +489: chainlink fence +490: chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +491: chain saw, chainsaw +492: chest +493: chiffonier, commode +494: chime, bell, gong +495: china cabinet, china closet +496: Christmas stocking +497: church, church building +498: cinema, movie theater, movie theatre, movie house, picture palace +499: cleaver, meat cleaver, chopper +500: cliff dwelling +501: cloak +502: clog, geta, patten, sabot +503: cocktail shaker +504: coffee mug +505: coffeepot +506: coil, spiral, volute, whorl, helix +507: combination lock +508: computer keyboard, keypad +509: confectionery, confectionary, candy store +510: container ship, containership, container vessel +511: convertible +512: corkscrew, bottle screw +513: cornet, horn, trumpet, trump +514: cowboy boot +515: cowboy hat, ten-gallon hat +516: cradle +517: crane +518: crash helmet +519: crate +520: crib, cot +521: Crock Pot +522: croquet ball +523: crutch +524: cuirass +525: dam, dike, dyke +526: desk +527: desktop computer +528: dial telephone, dial phone +529: diaper, nappy, napkin +530: digital clock +531: digital watch +532: dining table, board +533: dishrag, dishcloth +534: dishwasher, dish washer, dishwashing machine +535: disk brake, disc brake +536: dock, dockage, docking facility +537: dogsled, dog sled, dog sleigh +538: dome +539: doormat, welcome mat +540: drilling platform, offshore rig +541: drum, membranophone, tympan +542: drumstick +543: dumbbell +544: Dutch oven +545: electric fan, blower +546: electric guitar +547: electric locomotive +548: entertainment center +549: envelope +550: espresso maker +551: face powder +552: feather boa, boa +553: file, file cabinet, filing cabinet +554: fireboat +555: fire engine, fire truck +556: fire screen, fireguard +557: flagpole, flagstaff +558: flute, transverse flute +559: folding chair +560: football helmet +561: forklift +562: fountain +563: fountain pen +564: four-poster +565: freight car +566: French horn, horn +567: frying pan, frypan, skillet +568: fur coat +569: garbage truck, dustcart +570: gasmask, respirator, gas helmet +571: gas pump, gasoline pump, petrol pump, island dispenser +572: goblet +573: go-kart +574: golf ball +575: golfcart, golf cart +576: gondola +577: gong, tam-tam +578: gown +579: grand piano, grand +580: greenhouse, nursery, glasshouse +581: grille, radiator grille +582: grocery store, grocery, food market, market +583: guillotine +584: hair slide +585: hair spray +586: half track +587: hammer +588: hamper +589: hand blower, blow dryer, blow drier, hair dryer, hair drier +590: hand-held computer, hand-held microcomputer +591: handkerchief, hankie, hanky, hankey +592: hard disc, hard disk, fixed disk +593: harmonica, mouth organ, harp, mouth harp +594: harp +595: harvester, reaper +596: hatchet +597: holster +598: home theater, home theatre +599: honeycomb +600: hook, claw +601: hoopskirt, crinoline +602: horizontal bar, high bar +603: horse cart, horse-cart +604: hourglass +605: iPod +606: iron, smoothing iron +607: "jack-o-lantern" +608: jean, blue jean, denim +609: jeep, landrover +610: jersey, T-shirt, tee shirt +611: jigsaw puzzle +612: jinrikisha, ricksha, rickshaw +613: joystick +614: kimono +615: knee pad +616: knot +617: lab coat, laboratory coat +618: ladle +619: lampshade, lamp shade +620: laptop, laptop computer +621: lawn mower, mower +622: lens cap, lens cover +623: letter opener, paper knife, paperknife +624: library +625: lifeboat +626: lighter, light, igniter, ignitor +627: limousine, limo +628: liner, ocean liner +629: lipstick, lip rouge +630: Loafer +631: lotion +632: loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +633: "loupe, jewelers loupe" +634: lumbermill, sawmill +635: magnetic compass +636: mailbag, postbag +637: mailbox, letter box +638: maillot +639: maillot, tank suit +640: manhole cover +641: maraca +642: marimba, xylophone +643: mask +644: matchstick +645: maypole +646: maze, labyrinth +647: measuring cup +648: medicine chest, medicine cabinet +649: megalith, megalithic structure +650: microphone, mike +651: microwave, microwave oven +652: military uniform +653: milk can +654: minibus +655: miniskirt, mini +656: minivan +657: missile +658: mitten +659: mixing bowl +660: mobile home, manufactured home +661: Model T +662: modem +663: monastery +664: monitor +665: moped +666: mortar +667: mortarboard +668: mosque +669: mosquito net +670: motor scooter, scooter +671: mountain bike, all-terrain bike, off-roader +672: mountain tent +673: mouse, computer mouse +674: mousetrap +675: moving van +676: muzzle +677: nail +678: neck brace +679: necklace +680: nipple +681: notebook, notebook computer +682: obelisk +683: oboe, hautboy, hautbois +684: ocarina, sweet potato +685: odometer, hodometer, mileometer, milometer +686: oil filter +687: organ, pipe organ +688: oscilloscope, scope, cathode-ray oscilloscope, CRO +689: overskirt +690: oxcart +691: oxygen mask +692: packet +693: paddle, boat paddle +694: paddlewheel, paddle wheel +695: padlock +696: paintbrush +697: "pajama, pyjama, pjs, jammies" +698: palace +699: panpipe, pandean pipe, syrinx +700: paper towel +701: parachute, chute +702: parallel bars, bars +703: park bench +704: parking meter +705: passenger car, coach, carriage +706: patio, terrace +707: pay-phone, pay-station +708: pedestal, plinth, footstall +709: pencil box, pencil case +710: pencil sharpener +711: perfume, essence +712: Petri dish +713: photocopier +714: pick, plectrum, plectron +715: pickelhaube +716: picket fence, paling +717: pickup, pickup truck +718: pier +719: piggy bank, penny bank +720: pill bottle +721: pillow +722: ping-pong ball +723: pinwheel +724: pirate, pirate ship +725: pitcher, ewer +726: "plane, carpenters plane, woodworking plane" +727: planetarium +728: plastic bag +729: plate rack +730: plow, plough +731: "plunger, plumbers helper" +732: Polaroid camera, Polaroid Land camera +733: pole +734: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +735: poncho +736: pool table, billiard table, snooker table +737: pop bottle, soda bottle +738: pot, flowerpot +739: "potters wheel" +740: power drill +741: prayer rug, prayer mat +742: printer +743: prison, prison house +744: projectile, missile +745: projector +746: puck, hockey puck +747: punching bag, punch bag, punching ball, punchball +748: purse +749: quill, quill pen +750: quilt, comforter, comfort, puff +751: racer, race car, racing car +752: racket, racquet +753: radiator +754: radio, wireless +755: radio telescope, radio reflector +756: rain barrel +757: recreational vehicle, RV, R.V. +758: reel +759: reflex camera +760: refrigerator, icebox +761: remote control, remote +762: restaurant, eating house, eating place, eatery +763: revolver, six-gun, six-shooter +764: rifle +765: rocking chair, rocker +766: rotisserie +767: rubber eraser, rubber, pencil eraser +768: rugby ball +769: rule, ruler +770: running shoe +771: safe +772: safety pin +773: saltshaker, salt shaker +774: sandal +775: sarong +776: sax, saxophone +777: scabbard +778: scale, weighing machine +779: school bus +780: schooner +781: scoreboard +782: screen, CRT screen +783: screw +784: screwdriver +785: seat belt, seatbelt +786: sewing machine +787: shield, buckler +788: shoe shop, shoe-shop, shoe store +789: shoji +790: shopping basket +791: shopping cart +792: shovel +793: shower cap +794: shower curtain +795: ski +796: ski mask +797: sleeping bag +798: slide rule, slipstick +799: sliding door +800: slot, one-armed bandit +801: snorkel +802: snowmobile +803: snowplow, snowplough +804: soap dispenser +805: soccer ball +806: sock +807: solar dish, solar collector, solar furnace +808: sombrero +809: soup bowl +810: space bar +811: space heater +812: space shuttle +813: spatula +814: speedboat +815: "spider web, spiders web" +816: spindle +817: sports car, sport car +818: spotlight, spot +819: stage +820: steam locomotive +821: steel arch bridge +822: steel drum +823: stethoscope +824: stole +825: stone wall +826: stopwatch, stop watch +827: stove +828: strainer +829: streetcar, tram, tramcar, trolley, trolley car +830: stretcher +831: studio couch, day bed +832: stupa, tope +833: submarine, pigboat, sub, U-boat +834: suit, suit of clothes +835: sundial +836: sunglass +837: sunglasses, dark glasses, shades +838: sunscreen, sunblock, sun blocker +839: suspension bridge +840: swab, swob, mop +841: sweatshirt +842: swimming trunks, bathing trunks +843: swing +844: switch, electric switch, electrical switch +845: syringe +846: table lamp +847: tank, army tank, armored combat vehicle, armoured combat vehicle +848: tape player +849: teapot +850: teddy, teddy bear +851: television, television system +852: tennis ball +853: thatch, thatched roof +854: theater curtain, theatre curtain +855: thimble +856: thresher, thrasher, threshing machine +857: throne +858: tile roof +859: toaster +860: tobacco shop, tobacconist shop, tobacconist +861: toilet seat +862: torch +863: totem pole +864: tow truck, tow car, wrecker +865: toyshop +866: tractor +867: trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +868: tray +869: trench coat +870: tricycle, trike, velocipede +871: trimaran +872: tripod +873: triumphal arch +874: trolleybus, trolley coach, trackless trolley +875: trombone +876: tub, vat +877: turnstile +878: typewriter keyboard +879: umbrella +880: unicycle, monocycle +881: upright, upright piano +882: vacuum, vacuum cleaner +883: vase +884: vault +885: velvet +886: vending machine +887: vestment +888: viaduct +889: violin, fiddle +890: volleyball +891: waffle iron +892: wall clock +893: wallet, billfold, notecase, pocketbook +894: wardrobe, closet, press +895: warplane, military plane +896: washbasin, handbasin, washbowl, lavabo, wash-hand basin +897: washer, automatic washer, washing machine +898: water bottle +899: water jug +900: water tower +901: whiskey jug +902: whistle +903: wig +904: window screen +905: window shade +906: Windsor tie +907: wine bottle +908: wing +909: wok +910: wooden spoon +911: wool, woolen, woollen +912: worm fence, snake fence, snake-rail fence, Virginia fence +913: wreck +914: yawl +915: yurt +916: web site, website, internet site, site +917: comic book +918: crossword puzzle, crossword +919: street sign +920: traffic light, traffic signal, stoplight +921: book jacket, dust cover, dust jacket, dust wrapper +922: menu +923: plate +924: guacamole +925: consomme +926: hot pot, hotpot +927: trifle +928: ice cream, icecream +929: ice lolly, lolly, lollipop, popsicle +930: French loaf +931: bagel, beigel +932: pretzel +933: cheeseburger +934: hotdog, hot dog, red hot +935: mashed potato +936: head cabbage +937: broccoli +938: cauliflower +939: zucchini, courgette +940: spaghetti squash +941: acorn squash +942: butternut squash +943: cucumber, cuke +944: artichoke, globe artichoke +945: bell pepper +946: cardoon +947: mushroom +948: Granny Smith +949: strawberry +950: orange +951: lemon +952: fig +953: pineapple, ananas +954: banana +955: jackfruit, jak, jack +956: custard apple +957: pomegranate +958: hay +959: carbonara +960: chocolate sauce, chocolate syrup +961: dough +962: meat loaf, meatloaf +963: pizza, pizza pie +964: potpie +965: burrito +966: red wine +967: espresso +968: cup +969: eggnog +970: alp +971: bubble +972: cliff, drop, drop-off +973: coral reef +974: geyser +975: lakeside, lakeshore +976: promontory, headland, head, foreland +977: sandbar, sand bar +978: seashore, coast, seacoast, sea-coast +979: valley, vale +980: volcano +981: ballplayer, baseball player +982: groom, bridegroom +983: scuba diver +984: rapeseed +985: daisy +986: "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum" +987: corn +988: acorn +989: hip, rose hip, rosehip +990: buckeye, horse chestnut, conker +991: coral fungus +992: agaric +993: gyromitra +994: stinkhorn, carrion fungus +995: earthstar +996: hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +997: bolete +998: ear, spike, capitulum +999: toilet tissue, toilet paper, bathroom tissue \ No newline at end of file diff --git a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java index 803a22b0146..873a6b7846d 100644 --- a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java +++ b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java @@ -53,6 +53,21 @@ public void testTokenize() throws IOException { } } + @Test + @SuppressWarnings("AvoidEscapedUnicodeCharacters") + public void testUtf16Tokenize() throws IOException { + if (System.getProperty("os.name").startsWith("Win")) { + throw new SkipException("Skip windows test."); + } + Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) { + String original = "\uD83D\uDC4B\uD83D\uDC4B"; + List tokens = tokenizer.tokenize(original); + List expected = Arrays.asList("▁", "\uD83D\uDC4B\uD83D\uDC4B"); + Assert.assertEquals(tokens, expected); + } + } + @Test public void testEncodeDecode() throws IOException { if (System.getProperty("os.name").startsWith("Win")) { diff --git a/gradle.properties b/gradle.properties index 58639f2cd83..0f791657470 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,21 +13,21 @@ pytorch_version=1.8.1 tensorflow_version=2.4.1 tflite_version=2.4.1 dlr_version=1.6.0 -onnxruntime_version=1.7.0 +onnxruntime_version=1.8.0 paddlepaddle_version=2.0.2 sentencepiece_version=0.1.95 fasttext_version=0.9.2 mkl_dnn_version=0.21.2-1.5.2 -xgboost_version=1.3.1 +xgboost_version=1.4.1 antlr_version=4.7.2 commons_cli_version=1.4 commons_compress_version=1.20 commons_csv_version=1.8 -gson_version=2.8.6 -jna_version=5.3.0 -netty_version=4.1.51.Final +gson_version=2.8.7 +jna_version=5.8.0 +netty_version=4.1.65.Final slf4j_version=1.7.30 log4j_slf4j_version=2.13.3 -testng_version=7.1.0 +testng_version=7.3.0 powermock_version=2.0.7 diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/ModelZooTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/ModelZooTest.java index fd245d065db..b45ba290388 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/ModelZooTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/ModelZooTest.java @@ -23,6 +23,7 @@ import ai.djl.util.Utils; import java.io.IOException; import java.nio.file.Paths; +import java.util.Calendar; import java.util.List; import java.util.ServiceLoader; import org.testng.SkipException; @@ -49,7 +50,10 @@ public void tearDown() { @Test public void testDownloadModels() throws IOException, ModelException { if (!Boolean.getBoolean("nightly") || Boolean.getBoolean("offline")) { - throw new SkipException("Nightly only"); + throw new SkipException("Weekly only"); + } + if (Calendar.SATURDAY != Calendar.getInstance().get(Calendar.DAY_OF_WEEK)) { + throw new SkipException("Weekly only"); } ServiceLoader providers = ServiceLoader.load(ZooProvider.class); diff --git a/integration/src/test/translator/MyTranslator.java b/integration/src/test/translator/MyTranslator.java index 601d4a66ca6..34b745bf5e5 100644 --- a/integration/src/test/translator/MyTranslator.java +++ b/integration/src/test/translator/MyTranslator.java @@ -54,7 +54,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) { } @Override - public void setArguments(Map arguments) { + public void setArguments(Map arguments) { } @Override diff --git a/jupyter/load_pytorch_model.ipynb b/jupyter/load_pytorch_model.ipynb index a21550611b8..6a9e0c7073a 100644 --- a/jupyter/load_pytorch_model.ipynb +++ b/jupyter/load_pytorch_model.ipynb @@ -42,6 +42,7 @@ "metadata": {}, "outputs": [], "source": [ + "import java.nio.file.*;\n", "import java.awt.image.*;\n", "import ai.djl.*;\n", "import ai.djl.inference.*;\n", @@ -120,18 +121,15 @@ "metadata": {}, "outputs": [], "source": [ - "Pipeline pipeline = new Pipeline();\n", - "pipeline.add(new Resize(256))\n", - " .add(new CenterCrop(224, 224))\n", - " .add(new ToTensor())\n", - " .add(new Normalize(\n", - " new float[] {0.485f, 0.456f, 0.406f},\n", - " new float[] {0.229f, 0.224f, 0.225f}));\n", - "\n", "Translator translator = ImageClassificationTranslator.builder()\n", - " .setPipeline(pipeline)\n", - " .optApplySoftmax(true)\n", - " .build();" + " .addTransform(new Resize(256))\n", + " .addTransform(new CenterCrop(224, 224))\n", + " .addTransform(new ToTensor())\n", + " .addTransform(new Normalize(\n", + " new float[] {0.485f, 0.456f, 0.406f},\n", + " new float[] {0.229f, 0.224f, 0.225f}))\n", + " .optApplySoftmax(true)\n", + " .build();" ] }, { @@ -140,9 +138,7 @@ "source": [ "## Step 3: Load your model\n", "\n", - "Next, we will set the model zoo location to the `build/pytorch_models` directory we saved the model to. You can also create your own [`Repository`](https://javadoc.io/static/ai.djl/repository/0.11.0/index.html?ai/djl/repository/Repository.html) to avoid manually managing files.\n", - "\n", - "Next, we add some search criteria to find the resnet18 model and load it." + "Next, we add some search criteria to find the resnet18 model and load it. In this case, we need to tell `Criteria` where to locate the model by calling `.optModelPath()` API." ] }, { @@ -151,14 +147,9 @@ "metadata": {}, "outputs": [], "source": [ - "// Search for models in the build/pytorch_models folder\n", - "System.setProperty(\"ai.djl.repository.zoo.location\", \"build/pytorch_models/resnet18\");\n", - "\n", "Criteria criteria = Criteria.builder()\n", " .setTypes(Image.class, Classifications.class)\n", - " // only search the model in local directory\n", - " // \"ai.djl.localmodelzoo:{name of the model}\"\n", - " .optArtifactId(\"ai.djl.localmodelzoo:resnet18\")\n", + " .optModelPath(Paths.get(\"build/pytorch_models/resnet18\"))\n", " .optTranslator(translator)\n", " .optProgress(new ProgressBar()).build();\n", "\n", @@ -230,7 +221,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "12.0.2+10" + "version": "14.0.2+12" }, "pycharm": { "stem_cell": { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index afe713f453c..bb28345b8a1 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -77,7 +77,7 @@ private Engine getAlternativeEngine() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.7.0"; + return "1.8.0"; } /** {@inheritDoc} */ diff --git a/paddlepaddle/paddlepaddle-native/build.gradle b/paddlepaddle/paddlepaddle-native/build.gradle index 66381e392da..74fb55b5511 100644 --- a/paddlepaddle/paddlepaddle-native/build.gradle +++ b/paddlepaddle/paddlepaddle-native/build.gradle @@ -285,7 +285,3 @@ publishing.repositories { } } } - -if (JavaVersion.current() == JavaVersion.VERSION_1_8) { - tasks.getByName("spotbugsMain").enabled = false -} diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 3e0ede826a5..246b9e57fee 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1408,6 +1408,18 @@ public NDArray norm(int order, int[] axes, boolean keepDims) { return JniUtils.norm(this, order, axes, keepDims); } + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth) { + return JniUtils.oneHot(this, depth, DataType.FLOAT32); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, DataType dataType) { + return JniUtils.oneHot(this, depth, dataType); + } + /** {@inheritDoc} */ @Override public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index c98f4e60d75..64a8965f736 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -720,6 +720,14 @@ public static PtNDArray cumSum(PtNDArray ndArray, long dim) { ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } + public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchNNOneHot( + ndArray.toType(DataType.INT64, false).getHandle(), depth)) + .toType(dataType, false); + } + public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 2145ada3df9..0bb7dbafccf 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -448,6 +448,8 @@ native long torchNNMaxPool( native long torchNNLpPool( long inputHandle, double normType, long[] kernelSize, long[] stride, boolean ceilMode); + native long torchNNOneHot(long inputHandle, int depth); + native boolean torchRequiresGrad(long inputHandle); native String torchGradFnName(long inputHandle); diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc index 643bd20e5f8..66f2e171ba4 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc @@ -99,8 +99,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleWrite( API_BEGIN() auto* module_ptr = reinterpret_cast(module_handle); #if defined(__ANDROID__) - env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); - return; + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); + return; #endif std::ostringstream stream; module_ptr->save(stream); @@ -207,8 +207,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleSave( API_BEGIN() auto* module_ptr = reinterpret_cast(jhandle); #if defined(__ANDROID__) - env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); - return; + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "This kind of mode is not supported on Android"); + return; #endif module_ptr->save(djl::utils::jni::GetStringFromJString(env, jpath)); API_END() diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc index ecf7bb2b98c..59afabcecf5 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc @@ -37,6 +37,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogSoftmax( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNOneHot( + JNIEnv* env, jobject jthis, jlong jhandle, jint jdepth) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* result_ptr = new torch::Tensor(torch::nn::functional::one_hot(*tensor_ptr, jdepth)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNInterpolate( JNIEnv* env, jobject jthis, jlong jhandle, jlongArray jsize, jint jmode, jboolean jalign_corners) { API_BEGIN() diff --git a/serving/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 057720cb186..f663fb3d0b4 100644 --- a/serving/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -170,7 +170,7 @@ public void test() // plugin tests testStaticHtmlRequest(); - channel.close(); + channel.close().sync(); // negative test case that channel will be closed by server testInvalidUri(); @@ -437,7 +437,7 @@ private void testInvalidUri() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -457,7 +457,7 @@ private void testInvalidDescribeModel() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -476,7 +476,7 @@ private void testInvalidPredictionsUri() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -496,7 +496,7 @@ private void testPredictionsModelNotFound() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -515,7 +515,7 @@ private void testInvalidManagementUri() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -534,7 +534,7 @@ private void testInvalidManagementMethod() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -553,7 +553,7 @@ private void testInvalidPredictionsMethod() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -573,7 +573,7 @@ private void testDescribeModelNotFound() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -592,7 +592,7 @@ private void testRegisterModelMissingUrl() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -612,7 +612,7 @@ private void testRegisterModelNotFound() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -638,7 +638,7 @@ private void testRegisterModelConflict() channel.writeAndFlush(req); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -660,7 +660,7 @@ private void testInvalidScaleModel() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -679,7 +679,7 @@ private void testScaleModelNotFound() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -698,7 +698,7 @@ private void testUnregisterModelNotFound() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -730,7 +730,7 @@ private void testServiceUnavailable() throws InterruptedException { channel.writeAndFlush(req).sync(); latch.await(); channel.closeFuture().sync(); - channel.close(); + channel.close().sync(); if (!System.getProperty("os.name").startsWith("Win")) { ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index 71e731f47c9..a200673b8c8 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -242,6 +242,31 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy } } + /** {@inheritDoc} */ + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + if (DataType.STRING.equals(dataType)) { + throw new IllegalArgumentException("String data type is not supported!"); + } + NDArray axes = create(shape.getShape()); + TfOpExecutor opBuilder = + opExecutor("TruncatedNormal").addInput(axes).addParam("dtype", dataType); + Integer seed = getEngine().getSeed(); + if (seed != null) { + // seed1 is graph-level seed + // set it to default graph seed used by tensorflow + // https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/framework/random_seed.py#L31 + opBuilder.addParam("seed", 87654321); + opBuilder.addParam("seed2", seed); + } + try (NDArray array = opBuilder.buildSingletonOrThrow(); + NDArray temp = array.mul(scale)) { + return temp.add(loc); + } finally { + axes.close(); + } + } + /** {@inheritDoc} */ @Override public TfNDManager newSubManager(Device device) { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index d2c48b49133..e676f9aefc5 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -79,7 +79,7 @@ public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) { + "not found in Saved Model Bundle." + "Available keys: " + String.join(" ", keys) - + "Please use .optOptions(\"SignatureDefKey\", \"value\") with Criteria.builder to load the model." + + "Please use .optOption(\"SignatureDefKey\", \"value\") with Criteria.builder to load the model." + "Normally the value is \"default\" for TF1.x models and \"serving_default\" for TF2.x models. " + "Refer to: https://www.tensorflow.org/guide/saved_model" + "Loading the model using next available key."); diff --git a/tensorflow/tensorflow-native/build.gradle b/tensorflow/tensorflow-native/build.gradle index aec7fec84db..97b6450381c 100644 --- a/tensorflow/tensorflow-native/build.gradle +++ b/tensorflow/tensorflow-native/build.gradle @@ -366,7 +366,3 @@ task downloadTensorflowNativeLib() { new File("${BINARY_ROOT}/auto").mkdirs() } } - -if (JavaVersion.current() == JavaVersion.VERSION_1_8) { - tasks.getByName("spotbugsMain").enabled = false -} diff --git a/tflite/tflite-native/build.gradle b/tflite/tflite-native/build.gradle index f890b8f879e..2bd992c0a20 100644 --- a/tflite/tflite-native/build.gradle +++ b/tflite/tflite-native/build.gradle @@ -179,7 +179,3 @@ flavorNames.each { flavor -> checkstyleMain.enabled = false pmdMain.enabled = false - -if (JavaVersion.current() == JavaVersion.VERSION_1_8) { - tasks.getByName("spotbugsMain").enabled = false -}