diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 018852a41a6..9bdc4759199 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -40,6 +40,7 @@ public abstract class BaseNDManager implements NDManager { private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class); protected NDManager parent; + protected NDManager alternativeManager; protected String uid; protected String name; protected Device device; @@ -53,6 +54,10 @@ protected BaseNDManager(NDManager parent, Device device) { resources = new ConcurrentHashMap<>(); tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); + Engine engine = getEngine().getAlternativeEngine(); + if (engine != null) { + alternativeManager = engine.newBaseManager(Device.cpu()); + } } /** {@inheritDoc} */ @@ -368,12 +373,8 @@ public void debugDump(int level) { } } - protected NDManager getAlternativeManager() { - Engine engine = getEngine().getAlternativeEngine(); - if (engine != null) { - return engine.newBaseManager(Device.cpu()); - } - return null; + NDManager getAlternativeManager() { + return alternativeManager; } /** diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index b82ece0b51b..1a30edb19cf 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -71,6 +71,10 @@ public void attach(NDManager manager) { detach(); this.manager = manager; manager.attachInternal(getUid(), this); + alternativeManager = ((BaseNDManager) manager).getAlternativeManager(); + if (alternativeManager == null) { + alternativeManager = manager; + } } /** {@inheritDoc} */ diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java index 86eff274d7f..6939273cbb2 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java @@ -29,11 +29,8 @@ public class DlrNDManager extends BaseNDManager { private static final DlrNDManager SYSTEM_MANAGER = new SystemManager(); - private NDManager alternativeManager; - private DlrNDManager(NDManager parent, Device device) { super(parent, device); - alternativeManager = getAlternativeManager(); } static DlrNDManager getSystemManager() { diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index fc984013135..c6443e2b086 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -32,11 +32,8 @@ public class XgbNDManager extends BaseNDManager { private static final XgbNDManager SYSTEM_MANAGER = new SystemManager(); - private NDManager alternativeManager; - private XgbNDManager(NDManager parent, Device device) { super(parent, device); - alternativeManager = getAlternativeManager(); } static XgbNDManager getSystemManager() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index d9b7f31d252..00def7a4a9e 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -34,12 +34,10 @@ public class OrtNDManager extends BaseNDManager { private static final OrtNDManager SYSTEM_MANAGER = new SystemManager(); private OrtEnvironment env; - private NDManager alternativeManager; private OrtNDManager(NDManager parent, Device device, OrtEnvironment env) { super(parent, device); this.env = env; - alternativeManager = getAlternativeManager(); } static OrtNDManager getSystemManager() { diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java index 445eadcf4ed..7e81ba0f698 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java @@ -29,11 +29,8 @@ public class PpNDManager extends BaseNDManager { private static final PpNDManager SYSTEM_MANAGER = new SystemManager(); - private NDManager alternativeManager; - private PpNDManager(NDManager parent, Device device) { super(parent, device); - alternativeManager = getAlternativeManager(); } static PpNDManager getSystemManager() { diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java index 3fb4c5d70db..21d5ea13cad 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java @@ -29,11 +29,8 @@ public class TrtNDManager extends BaseNDManager { private static final TrtNDManager SYSTEM_MANAGER = new SystemManager(); - private NDManager alternativeManager; - private TrtNDManager(NDManager parent, Device device) { super(parent, device); - alternativeManager = getAlternativeManager(); } static TrtNDManager getSystemManager() { diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java index 5ba7228a3ba..4b2274ad7b8 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java @@ -29,11 +29,8 @@ public class TfLiteNDManager extends BaseNDManager { private static final TfLiteNDManager SYSTEM_MANAGER = new SystemManager(); - private NDManager alternativeManager; - private TfLiteNDManager(NDManager parent, Device device) { super(parent, device); - alternativeManager = getAlternativeManager(); } static TfLiteNDManager getSystemManager() {