From c55407a08659dac7bf51d5651b35799cc9c92cfd Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 28 Apr 2022 14:40:06 -0700 Subject: [PATCH 1/2] add tensorRT option --- .../onnxruntime/onnxruntime-engine/README.md | 8 ++++++ .../ai/djl/onnxruntime/engine/OrtModel.java | 27 +++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index 87e690fb674..900f987c912 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -85,3 +85,11 @@ Gradle: } implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.11.0" ``` + +#### Enable TensorRT execution + +ONNXRuntime offers TensorRT execution as the backend. In DJL, user can specify the followings in the Criteria to enable: + +``` +optOption("ortDevice", "TensorRT") +``` diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index dfd9caa1fbc..7f744f3aef8 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -77,10 +77,6 @@ public void load(Path modelPath, String prefix, Map options) try { SessionOptions ortOptions = getSessionOptions(options); - Device device = manager.getDevice(); - if (device.isGpu()) { - ortOptions.addCUDA(manager.getDevice().getDeviceId()); - } OrtSession session = env.createSession(modelFile.toString(), ortOptions); block = new OrtSymbolBlock(session, (OrtNDManager) manager); } catch (OrtException e) { @@ -100,10 +96,6 @@ public void load(InputStream is, Map options) try { byte[] buf = Utils.toByteArray(is); SessionOptions ortOptions = getSessionOptions(options); - Device device = manager.getDevice(); - if (device.isGpu()) { - ortOptions.addCUDA(manager.getDevice().getDeviceId()); - } OrtSession session = env.createSession(buf, ortOptions); block = new OrtSymbolBlock(session, (OrtNDManager) manager); } catch (OrtException e) { @@ -186,6 +178,25 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep ortSession.setCPUArenaAllocator(true); } + Device device = manager.getDevice(); + if (options.containsKey("ortDevice")) { + String ortDevice = (String) options.get("ortDevice"); + switch (ortDevice) { + case "TensorRT": + ortSession.addTensorrt(manager.getDevice().getDeviceId()); + break; + case "ROCM": + ortSession.addROCM(); + break; + case "CoreML": + ortSession.addCoreML(); + break; + default: + throw new UnsupportedOperationException(ortDevice + " not supported by DJL"); + } + } else if (device.isGpu()) { + ortSession.addCUDA(manager.getDevice().getDeviceId()); + } return ortSession; } } From ed623ef5a77e32161370bb40dab62673eacaae79 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 28 Apr 2022 22:43:07 -0700 Subject: [PATCH 2/2] Validate GPU device for TensorRT on OnnxRuntime Change-Id: I5812a0621b3d5019e3cf7a873e11641cd7dd5658 --- .../main/java/ai/djl/onnxruntime/engine/OrtModel.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index 7f744f3aef8..7ac6e0ce85b 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -183,7 +183,10 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep String ortDevice = (String) options.get("ortDevice"); switch (ortDevice) { case "TensorRT": - ortSession.addTensorrt(manager.getDevice().getDeviceId()); + if (!device.isGpu()) { + throw new IllegalArgumentException("TensorRT required GPU device."); + } + ortSession.addTensorrt(device.getDeviceId()); break; case "ROCM": ortSession.addROCM(); @@ -192,10 +195,10 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep ortSession.addCoreML(); break; default: - throw new UnsupportedOperationException(ortDevice + " not supported by DJL"); + throw new IllegalArgumentException("Invalid ortDevice: " + ortDevice); } } else if (device.isGpu()) { - ortSession.addCUDA(manager.getDevice().getDeviceId()); + ortSession.addCUDA(device.getDeviceId()); } return ortSession; }