diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 9e132b7ce..fbbe608a4 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -13,7 +13,7 @@ import json import logging import time -from typing import Union, Callable +from typing import Union, Callable, Optional, Dict from typing_extensions import deprecated @@ -43,6 +43,9 @@ def get_sequence_details(request_output: RequestOutput, if parameters.get("decoder_input_details"): sequence_details["prefill"] = request_output.get_prompt_tokens_as_dict( ) + if parameters.get("top_n_tokens", 0) > 0: + sequence_details["top_tokens"] = request_output.get_top_tokens_as_dict( + sequence_index) return sequence_details @@ -106,31 +109,45 @@ def _json_output_formatter(request_output: RequestOutput): json_encoded_str = f"[{json_encoded_str}" json_encoded_str = f"{json_encoded_str}{json.dumps(next_token.text, ensure_ascii=False)[1:-1]}" if last_token: - if parameters.get("details", tgi_compat): - final_dict = { - "finish_reason": best_sequence.finish_reason, - "generated_tokens": len(best_sequence.tokens), - "inputs": request_output.input.input_text, - "tokens": request_output.get_tokens_as_dict(), - } - - if parameters.get("decoder_input_details"): - final_dict[ - "prefill"] = request_output.get_prompt_tokens_as_dict() - details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}" - json_encoded_str = f"{json_encoded_str}\", {details_str}}}" - elif best_sequence.finish_reason == "error": - final_dict = {"finish_reason": best_sequence.finish_reason} - details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}" + details_dict = get_details_dict(request_output, include_tokens=True) + if details_dict: + details_str = f"\"details\": {json.dumps(details_dict, ensure_ascii=False)}" json_encoded_str = f"{json_encoded_str}\", {details_str}}}" else: json_encoded_str = f"{json_encoded_str}\"}}" if tgi_compat: json_encoded_str = f"{json_encoded_str}]" - return json_encoded_str +def get_details_dict(request_output: RequestOutput, + include_tokens: bool = True) -> Optional[Dict]: + parameters = request_output.input.parameters + best_sequence = request_output.sequences[ + request_output.best_sequence_index] + if parameters.get("details", request_output.input.tgi_compat): + final_dict = { + "finish_reason": best_sequence.finish_reason, + "generated_tokens": len(best_sequence.tokens), + "inputs": request_output.input.input_text, + } + + if include_tokens: + final_dict["tokens"] = request_output.get_tokens_as_dict() + + if parameters.get("decoder_input_details"): + final_dict["prefill"] = request_output.get_prompt_tokens_as_dict() + if parameters.get("top_n_tokens", 0) > 0: + final_dict["top_tokens"] = request_output.get_top_tokens_as_dict( + request_output.best_sequence_index) + + return final_dict + elif best_sequence.finish_reason == "error": + return {"finish_reason": best_sequence.finish_reason} + else: + return None + + def _jsonlines_output_formatter(request_output: RequestOutput): """ jsonlines output formatter @@ -148,19 +165,9 @@ def _jsonlines_output_formatter(request_output: RequestOutput): if last_token: generated_text = get_generated_text(best_sequence, request_output) final_dict["generated_text"] = generated_text - if parameters.get("details", tgi_compat): - final_dict["details"] = { - "finish_reason": best_sequence.finish_reason, - "generated_tokens": len(best_sequence.tokens), - "inputs": request_output.input.input_text, - } - if parameters.get("decoder_input_details"): - final_dict["details"][ - "prefill"] = request_output.get_prompt_tokens_as_dict() - elif best_sequence.finish_reason == "error": - final_dict["details"] = { - "finish_reason": best_sequence.finish_reason - } + details_dict = get_details_dict(request_output, include_tokens=False) + if details_dict: + final_dict["details"] = details_dict json_encoded_str = json.dumps(final_dict, ensure_ascii=False) + "\n" return json_encoded_str diff --git a/engines/python/setup/djl_python/request_io.py b/engines/python/setup/djl_python/request_io.py index f6377d729..2fd125127 100644 --- a/engines/python/setup/djl_python/request_io.py +++ b/engines/python/setup/djl_python/request_io.py @@ -237,3 +237,21 @@ def get_prompt_tokens_as_dict(self): else: tokens.append(token.as_dict()) return tokens + + def get_top_tokens_as_dict(self, sequence_index=0): + """Returns the top tokens of the given sequence index as a dictionary. + If not given, returns the top tokens of the first sequence index as a dictionary. + + :param sequence_index: index of the sequence to get the top tokens from. + :return: top tokens of the given sequence index as a dictionary. + """ + top_tokens = [] + for top_token in self.sequences[sequence_index].top_tokens: + top_token_list = [] + for token in top_token: + if self.input.tgi_compat: + top_token_list.append(token.as_tgi_dict()) + else: + top_token_list.append(token.as_dict()) + top_tokens.append(top_token_list) + return top_tokens diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index db4978995..94674ac73 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -128,7 +128,14 @@ def translate_lmi_dist_params(self, parameters: dict): parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 - parameters["logprobs"] = parameters.get("logprobs", 1) + if "best_of" in parameters.keys(): + # if n is not explicitly set, we return `best_of` values sequences. + if "n" not in "best_of": + parameters["n"] = parameters["best_of"] + if "top_n_tokens" in parameters.keys(): + parameters["logprobs"] = parameters.pop("top_n_tokens") + else: + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params( parameters, LMI_DIST_GENERATION_PARAMS, diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index b09ad9c4c..c4a5e55f5 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -63,6 +63,9 @@ def update_request_cache_with_output(request_cache: OrderedDict, request_output.best_sequence_index = vllm_request_output.outputs[ 0].index request_cache.pop(request_id) + for i in range(1, len(vllm_request_output.outputs)): + index = vllm_request_output.outputs[i].index + request_output.other_sequences_indices.append(index) return request_cache @@ -105,17 +108,21 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): output_token_texts = [text] * len( new_token_ids) if not output_token_texts else output_token_texts + top_tokens = [] # calculate log probs if completion_output.logprobs: new_logprobs_list = completion_output.logprobs[ prev_len: cur_len] if prev_len < cur_len else completion_output.logprobs - new_logprobs = [ - # NOTE: vLLM 0.4.1 changed logprob type - logprobs[token_id] if isinstance(logprobs[token_id], float) - else logprobs[token_id].logprob - for token_id, logprobs in zip(new_token_ids, new_logprobs_list) - ] + new_logprobs = [] + for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id_key, logprob in logprobs.items(): + new_logprobs.append(logprobs[token_id].logprob) + top_tokens.append( + Token(id=token_id_key, + text=logprob.decoded_token, + log_prob=logprob.logprob)) + else: new_logprobs = [None] * len(new_token_ids) @@ -139,6 +146,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): is_last_token = finish_reason is not None request_output.sequences[sequence_index].set_next_token( token, is_last_token) + top_tokens.append(token) + + request_output.sequences[sequence_index].set_next_top_tokens( + top_tokens) cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( completion_output.text) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index b07c2f7ec..aac050f1e 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -15,12 +15,10 @@ from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.utils import random_uuid -from vllm.lora.request import LoRARequest from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params -from djl_python.request_io import Token from djl_python.rolling_batch.rolling_batch_vllm_utils import ( - update_request_cache_with_output, get_lora_request_params, DTYPE_MAPPER, - FINISH_REASON_MAPPER, get_engine_args_from_config) + update_request_cache_with_output, get_lora_request_params, + get_engine_args_from_config) from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties from typing import List @@ -75,16 +73,12 @@ def translate_vllm_params(self, parameters: dict) -> dict: if "seed" in parameters.keys(): parameters["seed"] = int(parameters["seed"]) - # if parameters' do_sample is not set, we set temperature=0, to do greedy - if "do_sample" not in parameters.keys(): - is_beam_search = "num_beams" in parameters.keys() - is_best_of = "best_of" in parameters.keys( - ) and parameters["best_of"] > 1 - if not (is_beam_search and is_best_of): - # if temperature is zero, vLLM does greedy sampling - parameters['temperature'] = 0 - elif parameters.pop("do_sample"): - parameters["temperature"] = 0 + # If `do_sample` is not provided, force temperature=0.0, i.e. greedy + # else set to user-provided value or default to 1.0 + if not parameters.pop('do_sample', False): + parameters['temperature'] = 0.0 + else: + parameters['temperature'] = parameters.get('temperature', 1.0) if "stop_sequences" in parameters.keys(): parameters["stop"] = parameters.pop("stop_sequences") if "ignore_eos_token" in parameters.keys(): @@ -94,7 +88,16 @@ def translate_vllm_params(self, parameters: dict) -> dict: parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 - parameters["logprobs"] = parameters.get("logprobs", 1) + + # if n is not explicitly set when best_of is set, we return `best_of` values sequences for tgi compatibility. + if "best_of" in parameters.keys(): + if "n" not in "best_of": + parameters["n"] = parameters["best_of"] + + if "top_n_tokens" in parameters.keys(): + parameters["logprobs"] = parameters.pop("top_n_tokens") + else: + parameters["logprobs"] = parameters.get("logprobs", 1) parameters = filter_unused_generation_params(parameters, VLLM_GENERATION_PARAMS, "vllm", diff --git a/engines/python/setup/djl_python/tests/rolling_batch_test_scripts/test_rb_vllm_utils.py b/engines/python/setup/djl_python/tests/rolling_batch_test_scripts/test_rb_vllm_utils.py index be803a526..a5ebbf679 100644 --- a/engines/python/setup/djl_python/tests/rolling_batch_test_scripts/test_rb_vllm_utils.py +++ b/engines/python/setup/djl_python/tests/rolling_batch_test_scripts/test_rb_vllm_utils.py @@ -98,7 +98,11 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new') }], finish_reason=None, stop_reason=None), @@ -114,7 +118,15 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 17372: + MockLogprob(logprob=-13.409988403320312, + rank=10489, + decoded_token=' crown'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new'), }], finish_reason=None, stop_reason=None) @@ -147,12 +159,24 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new'), }, { 302: MockLogprob(logprob=-0.03010374866425991, rank=1, - decoded_token=' of') + decoded_token=' of'), + 235290: + MockLogprob(logprob=-2.2026185989379883, + rank=1, + decoded_token='-'), + 578: + MockLogprob(logprob=-2.2026185989379883, + rank=2, + decoded_token=' and') }], finish_reason=None, stop_reason=None), @@ -168,7 +192,15 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 17372: + MockLogprob(logprob=-13.409988403320312, + rank=10489, + decoded_token=' crown'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new'), }, { 601: MockLogprob(logprob=-1.2847318649291992, @@ -177,7 +209,11 @@ def __init__( 1028: MockLogprob(logprob=-0.909731924533844, rank=1, - decoded_token='ator') + decoded_token='ator'), + 1162: + MockLogprob(logprob=-0.8929234743118286, + rank=2, + decoded_token=' year') }], finish_reason=None, stop_reason=None) @@ -211,17 +247,37 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new'), }, { 302: MockLogprob(logprob=-0.03010374866425991, rank=1, - decoded_token=' of') + decoded_token=' of'), + 235290: + MockLogprob(logprob=-2.2026185989379883, + rank=1, + decoded_token='-'), + 578: + MockLogprob(logprob=-2.2026185989379883, + rank=2, + decoded_token=' and') }, { 272: MockLogprob(logprob=-0.5115904808044434, rank=1, - decoded_token=' the') + decoded_token=' the'), + 169181: + MockLogprob(logprob=-8.463325500488281, + rank=196, + decoded_token=' aviator'), + 194366: + MockLogprob(logprob=-2.463325023651123, + rank=1, + decoded_token=' Realtor') }], finish_reason='length', stop_reason=None), @@ -237,7 +293,15 @@ def __init__( 2032: MockLogprob(logprob=-3.0240092277526855, rank=1, - decoded_token=' big') + decoded_token=' big'), + 17372: + MockLogprob(logprob=-13.409988403320312, + rank=10489, + decoded_token=' crown'), + 888: + MockLogprob(logprob=-4.4099884033203125, + rank=3, + decoded_token=' new'), }, { 601: MockLogprob(logprob=-1.2847318649291992, @@ -246,7 +310,11 @@ def __init__( 1028: MockLogprob(logprob=-0.909731924533844, rank=1, - decoded_token='ator') + decoded_token='ator'), + 1162: + MockLogprob(logprob=-0.8929234743118286, + rank=2, + decoded_token=' year') }, { 442: MockLogprob(logprob=-6.998573303222656, @@ -255,7 +323,15 @@ def __init__( 28725: MockLogprob(logprob=-3.7798233032226562, rank=1, - decoded_token=',') + decoded_token=','), + 1622: + MockLogprob(logprob=-4.463325023651123, + rank=2, + decoded_token=' New'), + 576: + MockLogprob(logprob=-4.463325023651123, + rank=3, + decoded_token=' of') }], finish_reason='length', stop_reason=None) @@ -287,7 +363,8 @@ def test_multiple_sequences(self): "details": True, "decoder_input_details": True, "best_of": 2, - "n": 2 + "n": 2, + "top_n_tokens": 3 } # 1. Creates the request @@ -333,6 +410,56 @@ def test_multiple_sequences(self): log_prob=-6.998573303222656, special_token=None) ], + top_tokens=[[ + Token(id=22968, + text=' consolid', + log_prob=-12.117759704589844, + special_token=None), + Token(id=2032, + text=' big', + log_prob=-3.0240092277526855, + special_token=None), + Token(id=17372, + text=' crown', + log_prob=-13.409988403320312, + special_token=None), + Token(id=888, + text=' new', + log_prob=-4.4099884033203125, + special_token=None) + ], + [ + Token(id=601, + text='ated', + log_prob=-1.2847318649291992, + special_token=None), + Token(id=1028, + text='ator', + log_prob=-0.909731924533844, + special_token=None), + Token(id=1162, + text=' year', + log_prob=-0.8929234743118286, + special_token=None) + ], + [ + Token(id=442, + text=' or', + log_prob=-6.998573303222656, + special_token=None), + Token(id=28725, + text=',', + log_prob=-3.7798233032226562, + special_token=None), + Token(id=1622, + text=' New', + log_prob=-4.463325023651123, + special_token=None), + Token(id=576, + text=' of', + log_prob=-4.463325023651123, + special_token=None) + ]], finish_reason='length', cumulative_log_prob=-20.4010648727417, stop_reason=None), @@ -352,6 +479,48 @@ def test_multiple_sequences(self): log_prob=-0.5115904808044434, special_token=None) ], + top_tokens=[[ + Token(id=4292, + text=' member', + log_prob=-4.2740092277526855, + special_token=None), + Token(id=2032, + text=' big', + log_prob=-3.0240092277526855, + special_token=None), + Token(id=888, + text=' new', + log_prob=-4.4099884033203125, + special_token=None) + ], + [ + Token(id=302, + text=' of', + log_prob=-0.03010374866425991, + special_token=None), + Token(id=235290, + text='-', + log_prob=-2.2026185989379883, + special_token=None), + Token(id=578, + text=' and', + log_prob=-2.2026185989379883, + special_token=None) + ], + [ + Token(id=272, + text=' the', + log_prob=-0.5115904808044434, + special_token=None), + Token(id=169181, + text=' aviator', + log_prob=-8.463325500488281, + special_token=None), + Token(id=194366, + text=' Realtor', + log_prob=-2.463325023651123, + special_token=None) + ]], finish_reason='length', cumulative_log_prob=-4.815703457221389, stop_reason=None, @@ -372,3 +541,12 @@ def test_multiple_sequences(self): self.assertTrue( _compare_tokens(token, actual_sequence.tokens[token_index])) + for top_tokens_index, top_tokens in enumerate(sequence.top_tokens): + self.assertEqual( + len(top_tokens), + len(actual_sequence.top_tokens[top_tokens_index])) + for token_index, token in enumerate(top_tokens): + self.assertTrue( + _compare_tokens( + token, actual_sequence.top_tokens[top_tokens_index] + [token_index]))