diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index 36b0c956b..c520c8eec 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -141,20 +141,22 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input, if configs is not None: is_bedrock = configs.bedrock_compat if is_chat_completions_request(input_map): - if type(kwargs.get("rolling_batch")).__name__ == "TRTLLMRollingBatch": - inputs, param = parse_chat_completions_request( + if type(kwargs.get("rolling_batch")).__name__ in [ + "LmiDistRollingBatch", "VLLMRollingBatch" + ]: + inputs, param = parse_chat_completions_request_vllm( input_map, kwargs.get("is_rolling_batch"), + kwargs.get("rolling_batch"), tokenizer, image_token=image_token, configs=configs, is_mistral_tokenizer=is_mistral_tokenizer, ) else: - inputs, param = parse_chat_completions_request_vllm( + inputs, param = parse_chat_completions_request( input_map, kwargs.get("is_rolling_batch"), - kwargs.get("rolling_batch"), tokenizer, image_token=image_token, configs=configs,