diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 85a494b1c..15c95f363 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -78,7 +78,6 @@ public void load(Path modelPath, String prefix, Map options) throws I "Python engine does not support dynamic blocks"); } String entryPoint = null; - boolean isTrtLlmBackend = false; if (options != null) { logger.debug("options in serving.properties for model: {}", modelName); for (Map.Entry entry : options.entrySet()) { @@ -121,9 +120,6 @@ public void load(Path modelPath, String prefix, Map options) throws I case "entryPoint": entryPoint = value; break; - case "rolling_batch": - isTrtLlmBackend = "trtllm".equals(value); - break; case "parallel_loading": parallelLoading = Boolean.parseBoolean(value); break; @@ -158,6 +154,7 @@ public void load(Path modelPath, String prefix, Map options) throws I entryPoint = Utils.getenv("DJL_ENTRY_POINT"); if (entryPoint == null) { Path modelFile = findModelFile(prefix); + String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); // find default entryPoint String engineName = manager.getEngine().getEngineName(); if (modelFile != null) { @@ -167,7 +164,7 @@ public void load(Path modelPath, String prefix, Map options) throws I } else if ("nc".equals(manager.getDevice().getDeviceType()) && pyEnv.getTensorParallelDegree() > 0) { entryPoint = "djl_python.transformers_neuronx"; - } else if (isTrtLlmBackend) { + } else if ("trtllm".equals(features)) { entryPoint = "djl_python.tensorrt_llm"; } else if (pyEnv.getInitParameters().containsKey("model_id")) { entryPoint = "djl_python.huggingface"; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index af6bd9d9c..3f468718a 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -56,20 +56,27 @@ public final class LmiConfigRecommender { private LmiConfigRecommender() {} - static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) { - setRollingBatch(lmiProperties, modelConfig); - setEngine(lmiProperties); + static void configure( + ModelInfo modelInfo, + Properties lmiProperties, + LmiUtils.HuggingFaceModelConfig modelConfig) { + String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); + setDynamicBatch(lmiProperties, modelConfig, modelInfo, features); + setRollingBatch(lmiProperties, modelConfig, features); + setEngine(lmiProperties, modelConfig, features); setTensorParallelDegree(lmiProperties); } private static void setRollingBatch( - Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) { + Properties lmiProperties, + LmiUtils.HuggingFaceModelConfig modelConfig, + String features) { // If dynamic batch is enabled, we don't enable rolling batch. if (Integer.parseInt(lmiProperties.getProperty("batch_size", "1")) > 1) { + lmiProperties.setProperty("option.rolling_batch", "disable"); return; } String rollingBatch = lmiProperties.getProperty("option.rolling_batch", "auto"); - String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); if (!"auto".equals(rollingBatch)) { return; } else if (!isTextGenerationModel(modelConfig)) { @@ -77,13 +84,16 @@ private static void setRollingBatch( rollingBatch = "disable"; } else if (isVLLMEnabled(features) && isLmiDistEnabled(features)) { rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto"); - } else if (LmiUtils.isTrtLLM(lmiProperties)) { + } else if (LmiUtils.isTrtLLMRollingBatch(lmiProperties)) { rollingBatch = "trtllm"; } lmiProperties.setProperty("option.rolling_batch", rollingBatch); } - private static void setEngine(Properties lmiProperties) { + private static void setEngine( + Properties lmiProperties, + LmiUtils.HuggingFaceModelConfig modelConfig, + String features) { if (lmiProperties.containsKey("engine")) { return; } @@ -93,6 +103,11 @@ private static void setEngine(Properties lmiProperties) { engine = "MPI"; lmiProperties.setProperty("option.mpi_mode", "true"); } + // TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching. + if (isT5TrtLLM(modelConfig, features)) { + engine = "MPI"; + lmiProperties.setProperty("option.mpi_mode", "true"); + } lmiProperties.setProperty("engine", engine); } @@ -107,6 +122,26 @@ private static void setTensorParallelDegree(Properties lmiProperties) { lmiProperties.setProperty("option.tensor_parallel_degree", tpDegree); } + private static void setDynamicBatch( + Properties lmiProperties, + LmiUtils.HuggingFaceModelConfig modelConfig, + ModelInfo modelInfo, + String features) { + // TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching. + if (isT5TrtLLM(modelConfig, features)) { + + // To do runtime compilation for TensorRT-LLM T5 model. + lmiProperties.setProperty("trtllm_python_backend", String.valueOf(true)); + lmiProperties.setProperty("option.rolling_batch", "disable"); + + // We set batch_size only when customer did not provide it. + if (Integer.parseInt(lmiProperties.getProperty("batch_size", "0")) == 0) { + modelInfo.batchSize = 32; + lmiProperties.setProperty("batch_size", String.valueOf(32)); + } + } + } + private static boolean isVLLMEnabled(String features) { return features != null && features.contains("vllm"); } @@ -115,6 +150,15 @@ private static boolean isLmiDistEnabled(String features) { return features != null && features.contains("lmi-dist"); } + private static boolean isTrtLLMEnabled(String features) { + return features != null && features.contains("trtllm"); + } + + private static boolean isT5TrtLLM( + LmiUtils.HuggingFaceModelConfig modelConfig, String features) { + return isTrtLLMEnabled(features) && "t5".equals(modelConfig.getModelType()); + } + private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) { for (String arch : modelConfig.getArchitectures()) { boolean isTextGenerationModel = diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index a7e9b07d3..074a8d8ef 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -55,13 +55,13 @@ static String inferLmiEngine(ModelInfo modelInfo) throws ModelException { Properties prop = modelInfo.getProperties(); HuggingFaceModelConfig modelConfig = getHuggingFaceModelConfig(modelInfo); if (modelConfig == null) { - String engineName = isTrtLLM(prop) ? "MPI" : "Python"; + String engineName = isTrtLLMRollingBatch(prop) ? "MPI" : "Python"; logger.info("No config.json found, use {} engine.", engineName); return engineName; } - LmiConfigRecommender.configure(prop, modelConfig); + LmiConfigRecommender.configure(modelInfo, prop, modelConfig); logger.info( - "Detected engine: {}, rolling_batch: {}, tensor_paralell_degre {}, for modelType:" + "Detected engine: {}, rolling_batch: {}, tensor_parallel_degree {}, for modelType:" + " {}", prop.getProperty("engine"), prop.getProperty("option.rolling_batch"), @@ -70,7 +70,7 @@ static String inferLmiEngine(ModelInfo modelInfo) throws ModelException { return prop.getProperty("engine"); } - static boolean isTrtLLM(Properties properties) { + static boolean isTrtLLMRollingBatch(Properties properties) { String rollingBatch = properties.getProperty("option.rolling_batch"); if ("trtllm".equals(rollingBatch)) { return true; @@ -84,11 +84,12 @@ static boolean isTrtLLM(Properties properties) { } static boolean needConvert(ModelInfo info) { - return isTrtLLM(info.getProperties()); + Properties properties = info.getProperties(); + return isTrtLLMRollingBatch(info.getProperties()) + || properties.containsKey("trtllm_python_backend"); } static void convertTrtLLM(ModelInfo info) throws IOException { - info.prop.put("option.rolling_batch", "trtllm"); Path trtRepo; String modelId = null; if (info.downloadDir != null) { @@ -100,18 +101,30 @@ static void convertTrtLLM(ModelInfo info) throws IOException { trtRepo = Paths.get(modelId); } } - if (!isValidTrtLlmModelRepo(trtRepo)) { - if (modelId == null) { - modelId = trtRepo.toString(); - } - String tpDegree = info.prop.getProperty("option.tensor_parallel_degree"); - if (tpDegree == null) { - tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max"); - } - if ("max".equals(tpDegree)) { - tpDegree = String.valueOf(CudaUtils.getGpuCount()); - } + + if (modelId == null) { + modelId = trtRepo.toString(); + } + String tpDegree = info.prop.getProperty("option.tensor_parallel_degree"); + if (tpDegree == null) { + tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max"); + } + if ("max".equals(tpDegree)) { + tpDegree = String.valueOf(CudaUtils.getGpuCount()); + } + + // TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching. + if (info.prop.containsKey("trtllm_python_backend")) { + // Inflight batching support is not available for certain models like t5. + // Python backend models have different model repo format compared to C++ backend. + // And whether it is valid or not is checked in tensorrt_llm_toolkit. So it is not + // necessary to check here. info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree); + } else { + info.prop.put("option.rolling_batch", "trtllm"); + if (!isValidTrtLlmModelRepo(trtRepo)) { + info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree); + } } }