Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vllm, lmi-dist] add support for top_n_tokens #2051

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import json
import logging
import time
from typing import Union, Callable
from typing import Union, Callable, Optional, Dict

from typing_extensions import deprecated

Expand Down Expand Up @@ -43,6 +43,9 @@ def get_sequence_details(request_output: RequestOutput,
if parameters.get("decoder_input_details"):
sequence_details["prefill"] = request_output.get_prompt_tokens_as_dict(
)
if parameters.get("top_n_tokens", 0) > 0:
sequence_details["top_tokens"] = request_output.get_top_tokens_as_dict(
sequence_index)
return sequence_details


Expand Down Expand Up @@ -106,31 +109,45 @@ def _json_output_formatter(request_output: RequestOutput):
json_encoded_str = f"[{json_encoded_str}"
json_encoded_str = f"{json_encoded_str}{json.dumps(next_token.text, ensure_ascii=False)[1:-1]}"
if last_token:
if parameters.get("details", tgi_compat):
final_dict = {
"finish_reason": best_sequence.finish_reason,
"generated_tokens": len(best_sequence.tokens),
"inputs": request_output.input.input_text,
"tokens": request_output.get_tokens_as_dict(),
}

if parameters.get("decoder_input_details"):
final_dict[
"prefill"] = request_output.get_prompt_tokens_as_dict()
details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}"
json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
elif best_sequence.finish_reason == "error":
final_dict = {"finish_reason": best_sequence.finish_reason}
details_str = f"\"details\": {json.dumps(final_dict, ensure_ascii=False)}"
details_dict = get_details_dict(request_output, include_tokens=True)
if details_dict:
details_str = f"\"details\": {json.dumps(details_dict, ensure_ascii=False)}"
json_encoded_str = f"{json_encoded_str}\", {details_str}}}"
else:
json_encoded_str = f"{json_encoded_str}\"}}"
if tgi_compat:
json_encoded_str = f"{json_encoded_str}]"

return json_encoded_str


def get_details_dict(request_output: RequestOutput,
include_tokens: bool = True) -> Optional[Dict]:
parameters = request_output.input.parameters
best_sequence = request_output.sequences[
request_output.best_sequence_index]
if parameters.get("details", request_output.input.tgi_compat):
final_dict = {
"finish_reason": best_sequence.finish_reason,
"generated_tokens": len(best_sequence.tokens),
"inputs": request_output.input.input_text,
}

if include_tokens:
final_dict["tokens"] = request_output.get_tokens_as_dict()

if parameters.get("decoder_input_details"):
final_dict["prefill"] = request_output.get_prompt_tokens_as_dict()
if parameters.get("top_n_tokens", 0) > 0:
final_dict["top_tokens"] = request_output.get_top_tokens_as_dict(
request_output.best_sequence_index)

return final_dict
elif best_sequence.finish_reason == "error":
return {"finish_reason": best_sequence.finish_reason}
else:
return None


def _jsonlines_output_formatter(request_output: RequestOutput):
"""
jsonlines output formatter
Expand All @@ -148,19 +165,9 @@ def _jsonlines_output_formatter(request_output: RequestOutput):
if last_token:
generated_text = get_generated_text(best_sequence, request_output)
final_dict["generated_text"] = generated_text
if parameters.get("details", tgi_compat):
final_dict["details"] = {
"finish_reason": best_sequence.finish_reason,
"generated_tokens": len(best_sequence.tokens),
"inputs": request_output.input.input_text,
}
if parameters.get("decoder_input_details"):
final_dict["details"][
"prefill"] = request_output.get_prompt_tokens_as_dict()
elif best_sequence.finish_reason == "error":
final_dict["details"] = {
"finish_reason": best_sequence.finish_reason
}
details_dict = get_details_dict(request_output, include_tokens=False)
if details_dict:
final_dict["details"] = details_dict
json_encoded_str = json.dumps(final_dict, ensure_ascii=False) + "\n"
return json_encoded_str

Expand Down
18 changes: 18 additions & 0 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,21 @@ def get_prompt_tokens_as_dict(self):
else:
tokens.append(token.as_dict())
return tokens

def get_top_tokens_as_dict(self, sequence_index=0):
"""Returns the top tokens of the given sequence index as a dictionary.
If not given, returns the top tokens of the first sequence index as a dictionary.

:param sequence_index: index of the sequence to get the top tokens from.
:return: top tokens of the given sequence index as a dictionary.
"""
top_tokens = []
for top_token in self.sequences[sequence_index].top_tokens:
top_token_list = []
for token in top_token:
if self.input.tgi_compat:
top_token_list.append(token.as_tgi_dict())
else:
top_token_list.append(token.as_dict())
top_tokens.append(top_token_list)
return top_tokens
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,14 @@ def translate_lmi_dist_params(self, parameters: dict):
parameters["use_beam_search"] = True
if parameters.pop("decoder_input_details", False):
parameters["prompt_logprobs"] = 1
parameters["logprobs"] = parameters.get("logprobs", 1)
if "best_of" in parameters.keys():
# if n is not explicitly set, we return `best_of` values sequences.
if "n" not in "best_of":
parameters["n"] = parameters["best_of"]
if "top_n_tokens" in parameters.keys():
parameters["logprobs"] = parameters.pop("top_n_tokens")
else:
parameters["logprobs"] = parameters.get("logprobs", 1)
parameters = filter_unused_generation_params(
parameters,
LMI_DIST_GENERATION_PARAMS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def update_request_cache_with_output(request_cache: OrderedDict,
request_output.best_sequence_index = vllm_request_output.outputs[
0].index
request_cache.pop(request_id)
for i in range(1, len(vllm_request_output.outputs)):
index = vllm_request_output.outputs[i].index
request_output.other_sequences_indices.append(index)

return request_cache

Expand Down Expand Up @@ -105,17 +108,21 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
output_token_texts = [text] * len(
new_token_ids) if not output_token_texts else output_token_texts

top_tokens = []
# calculate log probs
if completion_output.logprobs:
new_logprobs_list = completion_output.logprobs[
prev_len:
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = [
# NOTE: vLLM 0.4.1 changed logprob type
logprobs[token_id] if isinstance(logprobs[token_id], float)
else logprobs[token_id].logprob
for token_id, logprobs in zip(new_token_ids, new_logprobs_list)
]
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
for token_id_key, logprob in logprobs.items():
new_logprobs.append(logprobs[token_id].logprob)
top_tokens.append(
Token(id=token_id_key,
text=logprob.decoded_token,
log_prob=logprob.logprob))

else:
new_logprobs = [None] * len(new_token_ids)

Expand All @@ -139,6 +146,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
is_last_token = finish_reason is not None
request_output.sequences[sequence_index].set_next_token(
token, is_last_token)
top_tokens.append(token)

request_output.sequences[sequence_index].set_next_top_tokens(
top_tokens)

cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
completion_output.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.utils import random_uuid
from vllm.lora.request import LoRARequest
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
from djl_python.request_io import Token
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
update_request_cache_with_output, get_lora_request_params, DTYPE_MAPPER,
FINISH_REASON_MAPPER, get_engine_args_from_config)
update_request_cache_with_output, get_lora_request_params,
get_engine_args_from_config)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List

Expand Down Expand Up @@ -75,16 +73,12 @@ def translate_vllm_params(self, parameters: dict) -> dict:
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])

# if parameters' do_sample is not set, we set temperature=0, to do greedy
if "do_sample" not in parameters.keys():
is_beam_search = "num_beams" in parameters.keys()
is_best_of = "best_of" in parameters.keys(
) and parameters["best_of"] > 1
if not (is_beam_search and is_best_of):
# if temperature is zero, vLLM does greedy sampling
parameters['temperature'] = 0
elif parameters.pop("do_sample"):
parameters["temperature"] = 0
# If `do_sample` is not provided, force temperature=0.0, i.e. greedy
# else set to user-provided value or default to 1.0
if not parameters.pop('do_sample', False):
parameters['temperature'] = 0.0
else:
parameters['temperature'] = parameters.get('temperature', 1.0)
if "stop_sequences" in parameters.keys():
parameters["stop"] = parameters.pop("stop_sequences")
if "ignore_eos_token" in parameters.keys():
Expand All @@ -94,7 +88,16 @@ def translate_vllm_params(self, parameters: dict) -> dict:
parameters["use_beam_search"] = True
if parameters.pop("decoder_input_details", False):
parameters["prompt_logprobs"] = 1
parameters["logprobs"] = parameters.get("logprobs", 1)

# if n is not explicitly set when best_of is set, we return `best_of` values sequences for tgi compatibility.
if "best_of" in parameters.keys():
if "n" not in "best_of":
parameters["n"] = parameters["best_of"]

if "top_n_tokens" in parameters.keys():
parameters["logprobs"] = parameters.pop("top_n_tokens")
else:
parameters["logprobs"] = parameters.get("logprobs", 1)
parameters = filter_unused_generation_params(parameters,
VLLM_GENERATION_PARAMS,
"vllm",
Expand Down
Loading
Loading