Skip to content

Commit

Permalink
forward port of patch in 0.28.0 that terminates a python worker when … (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Dec 11, 2024
1 parent d7729ce commit 77c154b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand Down Expand Up @@ -63,6 +64,10 @@ public PyPredictor(
@Override
@SuppressWarnings("unchecked")
public List<O> batchPredict(List<I> inputs) throws TranslateException {
if (process.isModelUnrecoverable()) {
throw new EngineException(
"Backend Python process is unrecoverable. Initiating worker termination");
}
if (!process.isReady()) {
// TODO: wait for restart
throw new TranslateException("Backend Python process is stopped.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PyProcess {
private CountDownLatch latch;
private volatile boolean started; // NOPMD
private volatile boolean modelLoaded; // NOPMD
private volatile boolean modelUnrecoverable; // NOPMD
private AtomicInteger restartCount;
private CompletableFuture<Void> restartFuture;
private boolean trtLlmMode;
Expand Down Expand Up @@ -147,6 +148,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
if (!initialLoad) {
logger.info("Restart python process ...");
restartFuture = CompletableFuture.runAsync(this::startPythonProcess);
} else {
modelUnrecoverable = true;
}
if (e instanceof EngineException) {
throw (EngineException) e;
Expand Down Expand Up @@ -263,6 +266,10 @@ boolean isReady() {
return started && modelLoaded;
}

boolean isModelUnrecoverable() {
return modelUnrecoverable;
}

private static String[] getHosts(int clusterSize) {
String leaderAddr = Utils.getenv("DJL_LEADER_ADDR");
String workerAddrFormat = Utils.getenv("DJL_WORKER_ADDR_FORMAT");
Expand Down

0 comments on commit 77c154b

Please sign in to comment.