diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 3a27458d6..331e5d0b8 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -125,6 +125,7 @@ def translate_lmi_dist_params(self, parameters: dict): parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params( parameters, LMI_DIST_GENERATION_PARAMS, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index b8d602544..4cd17fb1d 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -85,6 +85,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params(parameters, VLLM_GENERATION_PARAMS, "vllm",