diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index 87e690fb674..900f987c912 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -85,3 +85,11 @@ Gradle: } implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.11.0" ``` + +#### Enable TensorRT execution + +ONNXRuntime offers TensorRT execution as the backend. In DJL, user can specify the followings in the Criteria to enable: + +``` +optOption("ortDevice", "TensorRT") +``` diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index dfd9caa1fbc..7ac6e0ce85b 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -77,10 +77,6 @@ public void load(Path modelPath, String prefix, Map options) try { SessionOptions ortOptions = getSessionOptions(options); - Device device = manager.getDevice(); - if (device.isGpu()) { - ortOptions.addCUDA(manager.getDevice().getDeviceId()); - } OrtSession session = env.createSession(modelFile.toString(), ortOptions); block = new OrtSymbolBlock(session, (OrtNDManager) manager); } catch (OrtException e) { @@ -100,10 +96,6 @@ public void load(InputStream is, Map options) try { byte[] buf = Utils.toByteArray(is); SessionOptions ortOptions = getSessionOptions(options); - Device device = manager.getDevice(); - if (device.isGpu()) { - ortOptions.addCUDA(manager.getDevice().getDeviceId()); - } OrtSession session = env.createSession(buf, ortOptions); block = new OrtSymbolBlock(session, (OrtNDManager) manager); } catch (OrtException e) { @@ -186,6 +178,28 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep ortSession.setCPUArenaAllocator(true); } + Device device = manager.getDevice(); + if (options.containsKey("ortDevice")) { + String ortDevice = (String) options.get("ortDevice"); + switch (ortDevice) { + case "TensorRT": + if (!device.isGpu()) { + throw new IllegalArgumentException("TensorRT required GPU device."); + } + ortSession.addTensorrt(device.getDeviceId()); + break; + case "ROCM": + ortSession.addROCM(); + break; + case "CoreML": + ortSession.addCoreML(); + break; + default: + throw new IllegalArgumentException("Invalid ortDevice: " + ortDevice); + } + } else if (device.isGpu()) { + ortSession.addCUDA(device.getDeviceId()); + } return ortSession; } }