From 77c154b545ec44ff34eba4cc62a5497aaa41dfc6 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Wed, 11 Dec 2024 14:44:11 -0700 Subject: [PATCH] =?UTF-8?q?forward=20port=20of=20patch=20in=200.28.0=20tha?= =?UTF-8?q?t=20terminates=20a=20python=20worker=20when=20=E2=80=A6=20(#263?= =?UTF-8?q?1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/java/ai/djl/python/engine/PyPredictor.java | 5 +++++ .../src/main/java/ai/djl/python/engine/PyProcess.java | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index dd72112ef..431e74c78 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -14,6 +14,7 @@ import ai.djl.Device; import ai.djl.Model; +import ai.djl.engine.EngineException; import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; @@ -63,6 +64,10 @@ public PyPredictor( @Override @SuppressWarnings("unchecked") public List batchPredict(List inputs) throws TranslateException { + if (process.isModelUnrecoverable()) { + throw new EngineException( + "Backend Python process is unrecoverable. Initiating worker termination"); + } if (!process.isReady()) { // TODO: wait for restart throw new TranslateException("Backend Python process is stopped."); diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index 05b5368a9..b43b88f16 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -51,6 +51,7 @@ class PyProcess { private CountDownLatch latch; private volatile boolean started; // NOPMD private volatile boolean modelLoaded; // NOPMD + private volatile boolean modelUnrecoverable; // NOPMD private AtomicInteger restartCount; private CompletableFuture restartFuture; private boolean trtLlmMode; @@ -147,6 +148,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) { if (!initialLoad) { logger.info("Restart python process ..."); restartFuture = CompletableFuture.runAsync(this::startPythonProcess); + } else { + modelUnrecoverable = true; } if (e instanceof EngineException) { throw (EngineException) e; @@ -263,6 +266,10 @@ boolean isReady() { return started && modelLoaded; } + boolean isModelUnrecoverable() { + return modelUnrecoverable; + } + private static String[] getHosts(int clusterSize) { String leaderAddr = Utils.getenv("DJL_LEADER_ADDR"); String workerAddrFormat = Utils.getenv("DJL_WORKER_ADDR_FORMAT");