diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index 42e0b0884..af6bd9d9c 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -15,12 +15,17 @@ import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.Map; import java.util.Properties; +import java.util.Set; /** A utility class to auto configure LMI model properties. */ public final class LmiConfigRecommender { + private static final Logger logger = LoggerFactory.getLogger(LmiConfigRecommender.class); private static final Map MODEL_TO_ROLLING_BATCH = Map.ofEntries( Map.entry("falcon", "lmi-dist"), @@ -46,6 +51,9 @@ public final class LmiConfigRecommender { Map.entry("qwen2", "vllm"), Map.entry("stablelm", "vllm")); + private static final Set OPTIMIZED_TASK_ARCHITECTURES = + Set.of("ForCausalLM", "LMHeadModel", "ForConditionalGeneration"); + private LmiConfigRecommender() {} static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) { @@ -64,6 +72,9 @@ private static void setRollingBatch( String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); if (!"auto".equals(rollingBatch)) { return; + } else if (!isTextGenerationModel(modelConfig)) { + // Non text-generation use-cases are not compatible with rolling batch + rollingBatch = "disable"; } else if (isVLLMEnabled(features) && isLmiDistEnabled(features)) { rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto"); } else if (LmiUtils.isTrtLLM(lmiProperties)) { @@ -103,4 +114,21 @@ private static boolean isVLLMEnabled(String features) { private static boolean isLmiDistEnabled(String features) { return features != null && features.contains("lmi-dist"); } + + private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) { + for (String arch : modelConfig.getArchitectures()) { + boolean isTextGenerationModel = + OPTIMIZED_TASK_ARCHITECTURES.stream().anyMatch(arch::endsWith); + if (isTextGenerationModel) { + return true; + } + } + logger.warn( + "The model task architecture {} is not supported for optimized inference. LMI will" + + " attempt to load the model using HuggingFace Accelerate. Optimized inference" + + " performance is only available for the following task architectures: {}", + modelConfig.getArchitectures(), + OPTIMIZED_TASK_ARCHITECTURES); + return false; + } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index c8f92b221..a7e9b07d3 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -36,7 +36,11 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Properties; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Stream; @@ -270,14 +274,39 @@ static final class HuggingFaceModelConfig { @SerializedName("model_type") private String modelType; + @SerializedName("architectures") + private List configArchitectures; + + @SerializedName("auto_map") + private Map autoMap; + @SerializedName("_diffusers_version") private String diffusersVersion; + private Set allArchitectures; + public String getModelType() { if (modelType == null) { return diffusersVersion == null ? null : "stable-diffusion"; } return modelType; } + + public Set getArchitectures() { + if (allArchitectures == null) { + determineAllArchitectures(); + } + return allArchitectures; + } + + private void determineAllArchitectures() { + allArchitectures = new HashSet<>(); + if (configArchitectures != null) { + allArchitectures.addAll(configArchitectures); + } + if (autoMap != null) { + allArchitectures.addAll(autoMap.keySet()); + } + } } } diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index 527028367..8b0defa97 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -260,6 +260,8 @@ public void testInferLMIEngine() throws IOException, ModelException { put("tiiuae/falcon-7b", "lmi-dist"); put("mistralai/Mistral-7B-v0.1", "vllm"); put("src/test/resources/local-hf-model", "vllm"); + put("HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable"); + put("THUDM/chatglm3-6b", "vllm"); } }; Path modelStore = Paths.get("build/models"); diff --git a/wlm/src/test/resources/local-hf-model/config.json b/wlm/src/test/resources/local-hf-model/config.json index 575704404..d0c525b7f 100644 --- a/wlm/src/test/resources/local-hf-model/config.json +++ b/wlm/src/test/resources/local-hf-model/config.json @@ -1,3 +1,4 @@ { - "model_type": "gpt2" + "model_type": "gpt2", + "architectures": ["GPT2LMHeadModel"] }