diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index ae0a44b65..3e23e8756 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -87,6 +87,8 @@ private static void setRollingBatch( } else if (!isTextGenerationModel(modelConfig)) { // Non text-generation use-cases are not compatible with rolling batch rollingBatch = "disable"; + } else if (isTnxEnabled(features)) { + rollingBatch = "tnx"; } else if (isLmiDistEnabled(features) && "lmi-dist".equals(MODEL_TO_ROLLING_BATCH.get(modelType))) { rollingBatch = "lmi-dist"; @@ -175,6 +177,10 @@ private static boolean isTrtLlmEnabled(String features) { return features != null && features.contains("trtllm"); } + private static boolean isTnxEnabled(String features) { + return features != null && features.contains("tnx"); + } + private static boolean isT5TrtLlm( LmiUtils.HuggingFaceModelConfig modelConfig, String features) { return isTrtLlmEnabled(features) && "t5".equals(modelConfig.getModelType());