Skip to content

Commit

Permalink
Merge branch 'deepjavalibrary:master' into take_dev3
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng authored May 16, 2022
2 parents d6aab6b + e9a14c8 commit ef526ad
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public abstract class ObjectDetectionTranslator extends BaseImageTranslator<Dete
protected List<String> classes;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;

/**
* Creates the {@link ObjectDetectionTranslator} from the given builder.
Expand All @@ -42,6 +43,7 @@ protected ObjectDetectionTranslator(ObjectDetectionBuilder<?> builder) {
this.synsetLoader = builder.synsetLoader;
this.imageWidth = builder.imageWidth;
this.imageHeight = builder.imageHeight;
this.applyRatio = builder.applyRatio;
}

/** {@inheritDoc} */
Expand All @@ -60,6 +62,7 @@ public abstract static class ObjectDetectionBuilder<T extends ObjectDetectionBui
protected float threshold = 0.2f;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;

/**
* Sets the threshold for prediction accuracy.
Expand Down Expand Up @@ -87,6 +90,23 @@ public T optRescaleSize(double imageWidth, double imageHeight) {
return self();
}

/**
* Determine Whether to divide output object width/height on the inference result. Default
* false.
*
* <p>DetectedObject value should always bring a ratio based on the width/height instead of
* actual width/height. Most of the model will produce ratio as the inference output. This
* function is aimed to cover those who produce the pixel value. Make this to true to divide
* the width/height in postprocessing in order to get ratio in detectedObjects.
*
* @param value whether to apply ratio
* @return this builder
*/
public T optApplyRatio(boolean value) {
this.applyRatio = value;
return self();
}

/**
* Get resized image width.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,17 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
double y = imageHeight > 0 ? box[1] / imageHeight : box[1];
double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x;
double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y;

Rectangle rect = new Rectangle(x, y, w, h);
Rectangle rect;
if (applyRatio) {
rect =
new Rectangle(
x / imageWidth,
y / imageHeight,
w / imageWidth,
h / imageHeight);
} else {
rect = new Rectangle(x, y, w, h);
}
retNames.add(className);
retProbs.add(probability);
retBB.add(rect);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
}
retClasses.add(classes.get(classIndices[i]));
retProbs.add(probs[i]);
Rectangle rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]);
Rectangle rect;
if (applyRatio) {
rect =
new Rectangle(
boxX[i] / imageWidth,
boxY[i] / imageHeight,
boxWidth[i] / imageWidth,
boxHeight[i] / imageHeight);
} else {
rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]);
}
retBB.add(rect);
}
return new DetectedObjects(retClasses, retProbs, retBB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,17 @@ protected DetectedObjects nms(List<IntermediateResult> list) {
Rectangle rec = detections[0].getLocation();
retClasses.add(detections[0].id);
retProbs.add(detections[0].confidence);
retBB.add(new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
if (applyRatio) {
retBB.add(
new Rectangle(
rec.getX() / imageWidth,
rec.getY() / imageHeight,
rec.getWidth() / imageWidth,
rec.getHeight() / imageHeight));
} else {
retBB.add(
new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
}
pq.clear();
for (int j = 1; j < detections.length; j++) {
IntermediateResult detection = detections[j];
Expand Down
12 changes: 11 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,18 @@ public void detach() {
* @return the byte array
*/
public byte[] encode() {
return encode(false);
}

/**
* Encodes the NDList to byte array.
*
* @param numpy encode in npz format if true
* @return the byte array
*/
public byte[] encode(boolean numpy) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
encode(baos);
encode(baos, numpy);
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError("NDList is not writable", e);
Expand Down
15 changes: 12 additions & 3 deletions api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException {
NDManager manager = ctx.getNDManager();
try {
ctx.setAttachment("properties", input.getProperties());
return input.getDataAsNDList(manager);
} catch (IllegalArgumentException e) {
throw new TranslateException("Input is not a NDList data type", e);
Expand All @@ -259,11 +260,19 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public Output processOutput(TranslatorContext ctx, NDList list) {
Map<String, String> prop = (Map<String, String>) ctx.getAttachment("properties");
String contentType = prop.get("Content-Type");

Output output = new Output();
// TODO: find a way to pass NDList out
output.add(list.getAsBytes());
output.addProperty("Content-Type", "tensor/ndlist");
if ("tensor/npz".equalsIgnoreCase(contentType)) {
output.add(list.encode(true));
output.addProperty("Content-Type", "tensor/npz");
} else {
output.add(list.encode(false));
output.addProperty("Content-Type", "tensor/ndlist");
}
return output;
}
}
Expand Down
27 changes: 27 additions & 0 deletions api/src/test/java/ai/djl/translate/BatchifierTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;

import org.testng.Assert;
import org.testng.annotations.Test;

public class BatchifierTest {

@Test
public void testBatchifier() {
Batchifier batchifier = Batchifier.fromString("stack");
Assert.assertEquals(batchifier, Batchifier.STACK);

Assert.assertThrows(() -> Batchifier.fromString("invalid"));
}
}
84 changes: 84 additions & 0 deletions api/src/test/java/ai/djl/translate/ServingTranslatorTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.util.Utils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

public class ServingTranslatorTest {

@AfterClass
public void tierDown() {
Utils.deleteQuietly(Paths.get("build/model"));
}

@Test
public void testNumpy() throws IOException, TranslateException, ModelException {
Path path = Paths.get("build/model");
Files.createDirectories(path);
Input input = new Input();

try (NDManager manager = NDManager.newBaseManager()) {
Block block = Blocks.identityBlock();
block.initialize(manager, DataType.FLOAT32, new Shape(1));
Model model = Model.newInstance("identity");
model.setBlock(block);
model.save(path, null);
model.close();
NDList list = new NDList();
list.add(manager.create(10f));
input.add(list.encode(true));
input.add("Content-Type", "tensor/npz");
}

Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(path)
.optModelName("identity")
.optBlock(Blocks.identityBlock())
.build();

try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Output output = predictor.predict(input);
try (NDManager manager = NDManager.newBaseManager()) {
NDList list = output.getDataAsNDList(manager);
Assert.assertEquals(list.size(), 1);
Assert.assertEquals(list.get(0).toFloatArray()[0], 10f);
}
Input invalid = new Input();
invalid.add("String");
Assert.assertThrows(TranslateException.class, () -> predictor.predict(invalid));
}
}
}
15 changes: 15 additions & 0 deletions api/src/test/java/ai/djl/translate/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains tests for {@link ai.djl.translate}. */
package ai.djl.translate;
Loading

0 comments on commit ef526ad

Please sign in to comment.