Skip to content

Commit

Permalink
[serving] Refactor rolling batch detection logic (#1781)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 16, 2024
1 parent f535d45 commit fd7479f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
1 change: 1 addition & 0 deletions serving/docker/pytorch-inf2.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ ENV PYTORCH_PRECXX11=true
ENV PYTORCH_VERSION=2.1.2
ENV JAVA_OPTS="-Xmx1g -Xms1g -Xss2m -XX:+ExitOnOutOfMemoryError"
ENV NEURON_CC_FLAGS="--logfile /tmp/compile.log --temp-dir=/tmp"
ENV SERVING_FEATURES=vllm,lmi-dist,tnx

ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]
CMD ["serve"]
Expand Down
16 changes: 13 additions & 3 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.serving.wlm;

import ai.djl.util.Ec2Utils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

Expand Down Expand Up @@ -80,15 +81,24 @@ private static void setRollingBatch(
return;
}
String rollingBatch = lmiProperties.getProperty("option.rolling_batch", "auto");
String modelType = modelConfig.getModelType();
if (!"auto".equals(rollingBatch)) {
return;
} else if (!isTextGenerationModel(modelConfig)) {
// Non text-generation use-cases are not compatible with rolling batch
rollingBatch = "disable";
} else if (isVllmEnabled(features) && isLmiDistEnabled(features)) {
rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto");
} else if (LmiUtils.isTrtLlmRollingBatch(lmiProperties)) {
} else if (isLmiDistEnabled(features)
&& "lmi-dist".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) {
rollingBatch = "lmi-dist";
} else if (isVllmEnabled(features)
&& "vllm".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) {
rollingBatch = "vllm";
} else if (isTrtLlmEnabled(features)) {
rollingBatch = "trtllm";
} else if (Ec2Utils.isSageMaker()) {
rollingBatch = "scheduler";
} else {
rollingBatch = "disable";
}
lmiProperties.setProperty("option.rolling_batch", rollingBatch);
}
Expand Down
29 changes: 13 additions & 16 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving.wlm;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;

import ai.djl.Device;
import ai.djl.ModelException;
Expand All @@ -39,7 +40,6 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

Expand Down Expand Up @@ -249,21 +249,18 @@ public void testInitModel() throws IOException, ModelException {
}

@Test
public void testInferLMIEngine() throws IOException, ModelException {
public void testInferLmiEngine() throws IOException, ModelException {
// vllm/lmi-dist features enabled
System.setProperty("SERVING_FEATURES", "vllm,lmi-dist");
Map<String, String> modelToRollingBatch =
new HashMap<>() {
{
put("TheBloke/Llama-2-7B-fp16", "lmi-dist");
put("openai-community/gpt2", "lmi-dist");
put("tiiuae/falcon-7b", "lmi-dist");
put("mistralai/Mistral-7B-v0.1", "lmi-dist");
put("src/test/resources/local-hf-model", "lmi-dist");
put("HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable");
put("THUDM/chatglm3-6b", "lmi-dist");
}
};
Map.of(
"TheBloke/Llama-2-7B-fp16", "lmi-dist",
"openai-community/gpt2", "lmi-dist",
"tiiuae/falcon-7b", "lmi-dist",
"mistralai/Mistral-7B-v0.1", "lmi-dist",
"src/test/resources/local-hf-model", "lmi-dist",
"HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable",
"THUDM/chatglm3-6b", "lmi-dist");
Path modelStore = Paths.get("build/models");
Path modelDir = modelStore.resolve("lmi_test_model");
Path prop = modelDir.resolve("serving.properties");
Expand All @@ -288,8 +285,8 @@ public void testInferLMIEngine() throws IOException, ModelException {
}
ModelInfo<Input, Output> model = new ModelInfo<>("build/models/lmi_test_model");
model.initialize();
assertEquals(model.getProperties().getProperty("option.rolling_batch"), "auto");
assertEquals(model.getProperties().getProperty("option.mpi_mode"), null);
assertEquals(model.getProperties().getProperty("option.rolling_batch"), "disable");
assertNull(model.getProperties().getProperty("option.mpi_mode"));

// invalid hf model case
try (BufferedWriter writer = Files.newBufferedWriter(prop)) {
Expand All @@ -298,7 +295,7 @@ public void testInferLMIEngine() throws IOException, ModelException {
model = new ModelInfo<>("build/models/lmi_test_model");
model.initialize();
assertEquals(model.getEngineName(), "Python");
assertEquals(model.getProperties().getProperty("option.rolling_batch"), null);
assertNull(model.getProperties().getProperty("option.rolling_batch"));

// TODO: no good way to test trtllm now since it requires converting the model
}
Expand Down

0 comments on commit fd7479f

Please sign in to comment.