diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index c3d856a83bc..dfd9caa1fbc 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -180,6 +180,12 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep if (Boolean.parseBoolean(memoryOptimization)) { ortSession.setMemoryPatternOptimization(true); } + + String cpuArena = (String) options.get("cpuArenaAllocator"); + if (Boolean.parseBoolean(cpuArena)) { + ortSession.setCPUArenaAllocator(true); + } + return ortSession; } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index 94fab66b015..6d6c709f03c 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -48,6 +48,7 @@ public void testOrt() throws TranslateException, ModelException, IOException { .optOption("executionMode", "SEQUENTIAL") .optOption("optLevel", "NO_OPT") .optOption("memoryPatternOptimization", "true") + .optOption("cpuArenaAllocator", "true") .build(); IrisFlower virginica = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);