Skip to content

Commit

Permalink
[lmi][lcnc] fallback to accelerate backend when non text-generation m… (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Mar 26, 2024
1 parent 191b084 commit 761664e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
28 changes: 28 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.Properties;
import java.util.Set;

/** A utility class to auto configure LMI model properties. */
public final class LmiConfigRecommender {

private static final Logger logger = LoggerFactory.getLogger(LmiConfigRecommender.class);
private static final Map<String, String> MODEL_TO_ROLLING_BATCH =
Map.ofEntries(
Map.entry("falcon", "lmi-dist"),
Expand All @@ -46,6 +51,9 @@ public final class LmiConfigRecommender {
Map.entry("qwen2", "vllm"),
Map.entry("stablelm", "vllm"));

private static final Set<String> OPTIMIZED_TASK_ARCHITECTURES =
Set.of("ForCausalLM", "LMHeadModel", "ForConditionalGeneration");

private LmiConfigRecommender() {}

static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) {
Expand All @@ -64,6 +72,9 @@ private static void setRollingBatch(
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
if (!"auto".equals(rollingBatch)) {
return;
} else if (!isTextGenerationModel(modelConfig)) {
// Non text-generation use-cases are not compatible with rolling batch
rollingBatch = "disable";
} else if (isVLLMEnabled(features) && isLmiDistEnabled(features)) {
rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto");
} else if (LmiUtils.isTrtLLM(lmiProperties)) {
Expand Down Expand Up @@ -103,4 +114,21 @@ private static boolean isVLLMEnabled(String features) {
private static boolean isLmiDistEnabled(String features) {
return features != null && features.contains("lmi-dist");
}

private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) {
for (String arch : modelConfig.getArchitectures()) {
boolean isTextGenerationModel =
OPTIMIZED_TASK_ARCHITECTURES.stream().anyMatch(arch::endsWith);
if (isTextGenerationModel) {
return true;
}
}
logger.warn(
"The model task architecture {} is not supported for optimized inference. LMI will"
+ " attempt to load the model using HuggingFace Accelerate. Optimized inference"
+ " performance is only available for the following task architectures: {}",
modelConfig.getArchitectures(),
OPTIMIZED_TASK_ARCHITECTURES);
return false;
}
}
29 changes: 29 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;

Expand Down Expand Up @@ -270,14 +274,39 @@ static final class HuggingFaceModelConfig {
@SerializedName("model_type")
private String modelType;

@SerializedName("architectures")
private List<String> configArchitectures;

@SerializedName("auto_map")
private Map<String, String> autoMap;

@SerializedName("_diffusers_version")
private String diffusersVersion;

private Set<String> allArchitectures;

public String getModelType() {
if (modelType == null) {
return diffusersVersion == null ? null : "stable-diffusion";
}
return modelType;
}

public Set<String> getArchitectures() {
if (allArchitectures == null) {
determineAllArchitectures();
}
return allArchitectures;
}

private void determineAllArchitectures() {
allArchitectures = new HashSet<>();
if (configArchitectures != null) {
allArchitectures.addAll(configArchitectures);
}
if (autoMap != null) {
allArchitectures.addAll(autoMap.keySet());
}
}
}
}
2 changes: 2 additions & 0 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ public void testInferLMIEngine() throws IOException, ModelException {
put("tiiuae/falcon-7b", "lmi-dist");
put("mistralai/Mistral-7B-v0.1", "vllm");
put("src/test/resources/local-hf-model", "vllm");
put("HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable");
put("THUDM/chatglm3-6b", "vllm");
}
};
Path modelStore = Paths.get("build/models");
Expand Down
3 changes: 2 additions & 1 deletion wlm/src/test/resources/local-hf-model/config.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"model_type": "gpt2"
"model_type": "gpt2",
"architectures": ["GPT2LMHeadModel"]
}

0 comments on commit 761664e

Please sign in to comment.