From f6ca67442bda48289e00e78a9bda7b0ba1da88eb Mon Sep 17 00:00:00 2001 From: Somasundaram Date: Wed, 3 Jul 2024 18:10:44 -0700 Subject: [PATCH] [python] Fix new logprobs computation in vllm_utils --- .../djl_python/rolling_batch/rolling_batch_vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index b82f99632..8f1bbfed0 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -116,8 +116,8 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + new_logprobs.append(logprobs[token_id].logprob) for token_id_key, logprob in logprobs.items(): - new_logprobs.append(logprobs[token_id].logprob) top_tokens.append( Token(id=token_id_key, text=logprob.decoded_token, @@ -137,7 +137,7 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): for i, (token_id, token_text, logprob) in enumerate( zip(new_token_ids, output_token_texts, new_logprobs)): token = Token(token_id, token_text, logprob) - is_last_token = i == (len(new_logprobs) - + is_last_token = i == (len(new_token_ids) - 1) and finish_reason is not None request_output.sequences[sequence_index].set_next_token( token, is_last_token)