diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index d43cf6c9ba..fbf25a3f06 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -19,7 +19,7 @@ from lmi_dist.arg_utils import VllmEngineArgs from lmi_dist.init_engine import engine_from_args from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor -from vllm import SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import AtomicCounter from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params @@ -140,6 +140,7 @@ def translate_lmi_dist_params(self, parameters: dict): :return: The same parameters dict, but with lmi-dist style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) # If `do_sample` is not provided, force temperature=0.0, i.e. greedy # else set to user-provided value or default to 1.0 diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb6..ee7be75b7c 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict, def update_multiple_sequences(cache, request_output, vllm_request_output): for completion_output in vllm_request_output.outputs: - sequence_index = completion_output.index - if f"sequence_index_{sequence_index}" not in cache: - cache[f"sequence_index_{sequence_index}"] = { - "curr_length": 0, - "num_generated_tokens": 0 - } if sequence_index not in request_output.sequences: request_output.sequences[sequence_index] = Sequence() - # set token of the sequence - # previous length of token ids generated - prev_len = cache[f"sequence_index_{sequence_index}"][ - 'num_generated_tokens'] - # curr length of the token ids generated so far - cur_len = len(completion_output.token_ids) - cache[f"sequence_index_{sequence_index}"][ - "num_generated_tokens"] = cur_len - # get the newly generated token_ids - new_token_ids = completion_output.token_ids[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.token_ids + new_token_ids = completion_output.token_ids # get the newly generated token texts for speculative decoding output_token_texts = [] if hasattr(completion_output, "output_token_texts"): - output_token_texts = completion_output.output_token_texts[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.output_token_texts + output_token_texts = completion_output.output_token_texts top_tokens = [] token_texts = [] # calculate log probs and token_texts if completion_output.logprobs: - new_logprobs_list = completion_output.logprobs[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] - for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id, logprobs in zip(new_token_ids, + completion_output.logprobs): new_logprobs.append(logprobs[token_id].logprob) decoded_token = logprobs[token_id].decoded_token if logprobs[ token_id].decoded_token else "" @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): Token(id=token_id_key, text=logprob.decoded_token, log_prob=logprob.logprob)) - elif new_token_ids: # TODO: Test and remove this. logprobs is always set 1. This case should never happen. new_logprobs = [None] * len(new_token_ids) - curr_length = cache[f"sequence_index_{sequence_index}"][ - "curr_length"] - token_texts.append(completion_output.text[curr_length:]) + token_texts.append(completion_output.text) if not output_token_texts: if len(token_texts) != len(new_token_ids): @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): request_output.sequences[sequence_index].set_next_top_tokens( top_tokens) - cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( - completion_output.text) - def get_speculative_decoding_metrics_record( completion_output: CompletionOutput, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811e..71f80258cd 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -13,6 +13,7 @@ from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -78,6 +79,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) if "seed" in parameters.keys(): parameters["seed"] = int(parameters["seed"]) diff --git a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py index 41486fb20c..24959627f3 100644 --- a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py +++ b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py @@ -1,6 +1,5 @@ import sys import unittest -import uuid from dataclasses import dataclass from typing import List, Optional, Dict, Union from collections import OrderedDict @@ -12,7 +11,7 @@ import djl_python from djl_python.output_formatter import _json_output_formatter from djl_python.request import Request -from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput +from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token '''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1''' @@ -148,7 +147,7 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of', + text=' of', token_ids=[4292, 302], cumulative_logprob=-4.3041129764169455, logprobs=[{ @@ -181,7 +180,7 @@ def __init__( finish_reason=None, stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated', + text='ated', token_ids=[22968, 601], cumulative_logprob=-13.402491569519043, logprobs=[{ @@ -235,7 +234,7 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of the', + text=' the', token_ids=[4292, 302, 272], cumulative_logprob=-4.815703457221389, @@ -282,7 +281,7 @@ def __init__( finish_reason='length', stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated or', + text=' or', token_ids=[22968, 601, 442], cumulative_logprob=-20.4010648727417, logprobs=[{