Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 0092f8e
Author: Aziz Zayed <azayed01@gmail.com>
Date:   Tue Jun 15 08:22:51 2021 -0700

    Fixed truncated-normal bug

commit a6ded8c
Author: Aziz Zayed <azayed01@gmail.com>
Date:   Mon Jun 14 13:33:30 2021 -0700

    [pytorch] Add BigGAN demo

commit f145614
Merge: a8a1a9b ec8405b
Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com>
Date:   Fri Jun 11 20:45:34 2021 -0700

    Merge branch 'deepjavalibrary:master' into master

commit ec8405b
Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com>
Date:   Fri Jun 11 14:53:59 2021 -0700

    [pytorch] Add oneHot operator (deepjavalibrary#1014)

    [tensoflow] Add truncated normal operation

commit 50600fd
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Fri Jun 11 14:53:43 2021 -0700

    upgrade dependencies version (deepjavalibrary#1012)

    Change-Id: I709938f69f21096bc5cd29a24191f0f282dcbc97

commit 3379fd2
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Fri Jun 11 14:53:29 2021 -0700

    [serving] Fix flaky test (deepjavalibrary#1013)

    Change-Id: I13b89e04516c59a3d28ecafd49f4f808630b22fb

commit 23157fd
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Thu Jun 10 16:31:03 2021 -0700

    Enable spotbugs for java 11+ (deepjavalibrary#1010)

    Change-Id: I74effbf45492a5cf50e09ba8af0223d2b1bcb5a5

commit 4f38708
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Thu Jun 10 16:30:50 2021 -0700

    Fix model zoo test typo (deepjavalibrary#1009)

    Change-Id: I7c0109c6e5fc0ece16288082fd830718f20ad489

commit a8a1a9b
Merge: 77809f4 30b03f4
Author: Aziz Zayed <azayed01@gmail.com>
Date:   Thu Jun 10 15:16:05 2021 -0700

    Merge Truncated-Normal branch

commit 77809f4
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Thu Jun 10 14:07:43 2021 -0700

    Make model zoo test weekly (deepjavalibrary#1004)

    Change-Id: I1c73df17cb077b9ce8905fcc2fc8bbb37b9688d8

commit 0aec8ca
Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com>
Date:   Thu Jun 10 12:46:16 2021 -0700

    [tensoflow] Add truncated normal operation (deepjavalibrary#1005)

commit 30b03f4
Author: Aziz Zayed <azayed01@gmail.com>
Date:   Wed Jun 9 01:40:33 2021 -0700

    [tensoflow] Add truncated normal operation

commit d8e7e1d
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Wed Jun 9 07:55:15 2021 -0700

    Fixes deepjavalibrary#999, hanlde UTF16 surrogate charactors properly. (deepjavalibrary#1003)

    Change-Id: I19e77cf5a8282bea901434041806eb102549ec0f

commit b0fe73a
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Tue Jun 8 17:56:19 2021 -0700

    [pytorch] Update load model jupyter notebook (deepjavalibrary#1002)

    Change-Id: I1889aa93d2002e6ce02c740d2d1d3517bf586760

commit 8286930
Author: Frank Liu <frankfliu2000@gmail.com>
Date:   Tue Jun 8 15:29:27 2021 -0700

    [tensorflow] fix optOption usage document (deepjavalibrary#1001)

    Change-Id: Ie044839cf082d63010a5c26d3f2f8833447919c6

commit a26f5b2
Author: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com>
Date:   Tue Jun 8 15:29:10 2021 -0700

    Updated PyTorch Docs  (deepjavalibrary#1000)

    * Added auto softmax metadata for action_recognition

    * Update PyTorch Docs

commit e6890f9
Author: Lanking <qingla@amazon.com>
Date:   Mon Jun 7 18:25:19 2021 -0700

    upgrade xgboost (deepjavalibrary#993)

commit a0dcf3a
Author: Lanking <qingla@amazon.com>
Date:   Mon Jun 7 18:25:12 2021 -0700

    bump up onnx runtime version (deepjavalibrary#992)
  • Loading branch information
AzizZayed committed Jun 15, 2021
1 parent 2fc24f2 commit 697655b
Show file tree
Hide file tree
Showing 29 changed files with 1,657 additions and 69 deletions.
19 changes: 19 additions & 0 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import ai.djl.util.RandomUtils;
import java.nio.Buffer;
import java.nio.file.Path;
import java.util.UUID;
Expand Down Expand Up @@ -153,6 +154,24 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
int sampleSize = (int) shape.size();
double[] dist = new double[sampleSize];

for (int i = 0; i < sampleSize; i++) {
double sample = RandomUtils.nextGaussian();
while (sample < -2 || sample > 2) {
sample = RandomUtils.nextGaussian();
}

dist[i] = sample;
}

return create(dist).muli(scale).addi(loc).reshape(shape).toType(dataType, false);
}

/** {@inheritDoc} */
@Override
public NDArray randomMultinomial(int n, NDArray pValues) {
Expand Down
51 changes: 51 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4592,6 +4592,57 @@ default NDArray oneHot(int depth) {
return oneHot(depth, 1f, 0f, DataType.FLOAT32);
}

/**
* Returns a one-hot {@code NDArray}.
*
* <ul>
* <li>The locations represented by indices take value 1, while all other locations take value
* 0.
* <li>If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
* appended at the end.
* <li>If {@code NDArray} is a scalar the output shape will be a vector of length depth.
* <li>If {@code NDArray} is a vector of length features, the output shape will be features x
* depth.
* <li>If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
* batch x features x depth.
* </ul>
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new int[] {1, 0, 2, 0});
* jshell&gt; array.oneHot(3);
* ND: (4, 3) cpu() float32
* [[0., 1., 0.],
* [1., 0., 0.],
* [0., 0., 1.],
* [1., 0., 0.],
* ]
* jshell&gt; NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
* jshell&gt; array.oneHot(3);
* ND: (3, 2, 3) cpu() float32
* [[[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 0., 1.],
* [1., 0., 0.],
* ],
* ]
* </pre>
*
* @param depth Depth of the one hot dimension.
* @param dataType dataType of the output.
* @return one-hot encoding of this {@code NDArray}
* @see <a
* href=https://d2l.djl.ai/chapter_linear-networks/softmax-regression.html#classification-problems>Classification-problems</a>
*/
default NDArray oneHot(int depth, DataType dataType) {
return oneHot(depth, 0f, 1f, dataType);
}

/**
* Returns a one-hot {@code NDArray}.
*
Expand Down
59 changes: 59 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,65 @@ default NDArray randomNormal(
return newSubManager(device).randomNormal(loc, scale, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* <p>Samples are distributed according to a normal distribution parametrized by mean = 0 and
* standard deviation = 1.
*
* @param shape the output {@link Shape}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape) {
return truncatedNormal(0f, 1f, shape, DataType.FLOAT32);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape, DataType dataType) {
return truncatedNormal(0.0f, 1.0f, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType);

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @param device the {@link Device} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(
float loc, float scale, Shape shape, DataType dataType, Device device) {
if (device == null || device.equals(getDevice())) {
return truncatedNormal(loc, scale, shape, dataType);
}
return newSubManager(device).truncatedNormal(loc, scale, shape, dataType);
}

/**
* Draw samples from a multinomial distribution.
*
Expand Down
34 changes: 30 additions & 4 deletions api/src/main/native/djl/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@ inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) {
if (jstr == nullptr) {
return std::string();
}
const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE);
std::string str = std::string(c_str);
env->ReleaseStringUTFChars(jstr, c_str);

// TODO: cache reflection to improve performance
const jclass string_class = env->GetObjectClass(jstr);
const jmethodID getbytes_method = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");

const jstring charset = env->NewStringUTF("UTF-8");
const jbyteArray jbytes = (jbyteArray) env->CallObjectMethod(jstr, getbytes_method, charset);
env->DeleteLocalRef(charset);

const jsize length = env->GetArrayLength(jbytes);
jbyte* c_str = env->GetByteArrayElements(jbytes, NULL);
std::string str = std::string(reinterpret_cast<const char *>(c_str), length);

env->ReleaseByteArrayElements(jbytes, c_str, RELEASE_MODE);
env->DeleteLocalRef(jbytes);
return str;
}

Expand Down Expand Up @@ -100,9 +112,23 @@ inline std::vector<std::string> GetVecFromJStringArray(JNIEnv* env, jobjectArray
// String[]
inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector <std::string> &vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr);

// TODO: cache reflection to improve performance
const jclass string_class = env->FindClass("java/lang/String");
const jmethodID ctor = env->GetMethodID(string_class, "<init>", "([BLjava/lang/String;)V");
const jstring charset = env->NewStringUTF("UTF-8");

for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str()));
const char* c_str = vec[i].c_str();
int len = vec[i].length();
auto jbytes = env->NewByteArray(len);
env->SetByteArrayRegion(jbytes, 0, len, reinterpret_cast<const jbyte*>(c_str));
jobject jstr = env->NewObject(string_class, ctor, jbytes, charset);
env->DeleteLocalRef(jbytes);
env->SetObjectArrayElement(array, i, jstr);
}

env->DeleteLocalRef(charset);
return array;
}

Expand Down
2 changes: 1 addition & 1 deletion docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Criteria<Image, DetectedObjects> criteria =
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "mobilenet_v2")
.optEngine("TensorFlow")
.optOption("Tags", new String[] {})
.optOption("Tags", "")
.optOption("SignatureDefKey", "default")
.optProgress(new ProgressBar())
.build();
Expand Down
2 changes: 1 addition & 1 deletion examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies {
}

application {
mainClassName = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection")
mainClassName = System.getProperty("main", "ai.djl.examples.inference.biggan.Generator")
}

run {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.biggan;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class BigGANCategory {
private static final Logger logger = LoggerFactory.getLogger(BigGANCategory.class);

public static final int NUMBER_OF_CATEGORIES = 1000;
private static final Map<String, BigGANCategory> CATEGORIES_BY_NAME =
new ConcurrentHashMap<>(NUMBER_OF_CATEGORIES);
private static String[] categoriesById;

private int id;
private String[] names;

static {
try {
parseCategories();
} catch (IOException e) {
logger.error("Error parsing the ImageNet categories: {}", e);
}
createCategoriesByName();
}

private BigGANCategory(int id, String[] names) {
this.id = id;
this.names = names;
}

public int getId() {
return id;
}

public String[] getNames() {
return names.clone();
}

public static BigGANCategory id(int id) {
String names = categoriesById[id];
int index = names.indexOf(',');
if (index < 0) {
return of(names);
} else {
return of(names.substring(0, index));
}
}

public static BigGANCategory of(String name) {
if (!CATEGORIES_BY_NAME.containsKey(name)) {
throw new IllegalArgumentException(name + " is not a valid category.");
}
return CATEGORIES_BY_NAME.get(name);
}

private static void createCategoriesByName() {
for (int i = 0; i < NUMBER_OF_CATEGORIES; i++) {
String[] categoryNames = categoriesById[i].split(", ");
BigGANCategory category = new BigGANCategory(i, categoryNames);

for (String name : categoryNames) {
CATEGORIES_BY_NAME.put(name, category);
}
}
}

private static void parseCategories() throws IOException {
String filePath = "src/main/resources/categories.txt";

List<String> fileLines = Files.readAllLines(Paths.get(filePath));
List<String> categories = new ArrayList<>(NUMBER_OF_CATEGORIES);
for (String line : fileLines) {
int nameIndex = line.indexOf(':') + 2;
categories.add(line.substring(nameIndex));
}

categoriesById = categories.toArray(new String[] {});
}
}
Loading

0 comments on commit 697655b

Please sign in to comment.