diff --git a/engines/python/setup/djl_python/chat_completions/chat_utils.py b/engines/python/setup/djl_python/chat_completions/chat_utils.py index 4509306d5..287361a66 100644 --- a/engines/python/setup/djl_python/chat_completions/chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/chat_utils.py @@ -31,7 +31,10 @@ def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool, f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, " f"please ensure that your tokenizer supports chat templates.") chat_params = ChatProperties(**input_map) - param = chat_params.model_dump(by_alias=True, exclude_none=True) + exclude = {"messages"} + param = chat_params.model_dump(by_alias=True, + exclude_none=True, + exclude=exclude) messages = chat_params.messages images = [] tokenizer_inputs = [] diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 644c550ff..abb97ab35 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -23,7 +23,8 @@ from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params from djl_python.rolling_batch.rolling_batch_vllm_utils import ( get_speculative_decoding_metrics_record, update_request_cache_with_output, - supports_speculative_decoding, get_lora_request_params, DTYPE_MAPPER) + supports_speculative_decoding, get_lora_request_params, DTYPE_MAPPER, + get_prompt_inputs) from djl_python.telemetry import telemetry_manager from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties @@ -155,6 +156,7 @@ def inference(self, new_requests: List[Request]) -> List: # step 0: register new requests to engine for request in new_requests: request_id = str(request.id) + llm_input = get_prompt_inputs(request) params = self.translate_lmi_dist_params(request.parameters) request_params = RequestParams(**params) lora_request_params = get_lora_request_params( @@ -162,7 +164,8 @@ def inference(self, new_requests: List[Request]) -> List: # Constructing Request in lmi-dist library lmi_dist_request = Request( id=request_id, - prompt=request.input_text, + prompt=llm_input.get("prompt"), + multi_modal_input=llm_input.get("multi_modal_data"), params=request_params, lora_request=lora_request_params["lora_request"] if lora_request_params else None)