From 65f6fb735ab925a92d801d1eec6b6b70ac7950b8 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 23 May 2024 17:15:54 -0700 Subject: [PATCH] [vLLM/LMI-Dist]change the default logprobs to 1 (#1965) --- .../setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py | 1 + .../python/setup/djl_python/rolling_batch/vllm_rolling_batch.py | 1 + 2 files changed, 2 insertions(+) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 3a27458d6..331e5d0b8 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -125,6 +125,7 @@ def translate_lmi_dist_params(self, parameters: dict): parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params( parameters, LMI_DIST_GENERATION_PARAMS, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index b8d602544..4cd17fb1d 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -85,6 +85,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params(parameters, VLLM_GENERATION_PARAMS, "vllm",