Skip to content

Commit

Permalink
fix parse_input signature for backward compatibility (#1733)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Apr 4, 2024
1 parent 0644638 commit 81964c5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 26 deletions.
43 changes: 32 additions & 11 deletions engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,35 @@ def __init__(self):
self.model = None
self.trt_configs = None
self.initialized = False
self.parse_input = parse_input_with_client_batch
self.is_client_side_batch = []

def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)
self._load_model(properties)
self.initialized = True
return

def parse_input(
self, inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list]:
"""
Preprocessing function that extracts information from Input objects.
:param output_formatter: output formatter for the request
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:param tokenizer: the tokenizer used for inference
:return input_data (list[str]): a list of strings, each string being the prompt in a new request
:return input_size (list[int]): a list of ints being the size of each new request
:return parameters (list[dict]): parameters pertaining to each request
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
"""
parsed_input = parse_input_with_client_batch(inputs, tokenizer,
output_formatter)
self.is_client_side_batch = parsed_input.is_client_side_batch
return parsed_input.input_data, parsed_input.input_size, parsed_input.parameters, parsed_input.errors, parsed_input.batch

def inference(self, inputs: Input) -> Output:
"""
Does preprocessing and sends new requests to the rolling batch script for inference
Expand All @@ -102,7 +123,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch, is_client_side_batch = self.parse_input(
input_data, input_size, parameters, errors, batch = self.parse_input(
inputs, None, self.trt_configs.output_formatter)
if len(input_data) == 0:
for i in range(len(batch)):
Expand All @@ -113,7 +134,7 @@ def inference(self, inputs: Input) -> Output:
params = parameters[0]
if params.get("details", False):
return self._stream_inference(inputs, input_data, input_size,
params, batch, is_client_side_batch)
params, batch)

detokenized_python_response = self.model.generate(input_data, **params)
results = [{
Expand All @@ -122,9 +143,9 @@ def inference(self, inputs: Input) -> Output:
offset = 0
for i, item in enumerate(batch):
content_type, accept = _get_accept_and_content_type(item)
batch_item = results[offset:offset +
input_size[i]] if is_client_side_batch[
i] else results[offset]
batch_item = results[offset:offset + input_size[i]] if i < len(
self.is_client_side_batch
) and self.is_client_side_batch[i] else results[offset]
encode(outputs,
batch_item,
accept,
Expand Down Expand Up @@ -160,8 +181,8 @@ def _get_config(self, properties):

# TODO TrtLLM python backend: Change it once T5 bug is fixed.
def _stream_inference(self, inputs: Input, input_data: list[str],
input_size: list[int], parameters: dict, batch: list,
is_client_side_batch: list) -> Output:
input_size: list[int], parameters: dict,
batch: list) -> Output:
outputs = Output()
detokenized_python_response = self.model.generate(
input_data, **parameters)
Expand All @@ -172,9 +193,9 @@ def _stream_inference(self, inputs: Input, input_data: list[str],
for i, item in enumerate(batch):
item = batch[i]
accept, content_type = _get_accept_and_content_type(item)
batch_item = results[offset:offset +
input_size[i]] if is_client_side_batch[
i] else results[offset]
batch_item = results[offset:offset + input_size[i]] if i < len(
self.is_client_side_batch
) and self.is_client_side_batch[i] else results[offset]
encode(outputs,
batch_item,
accept,
Expand Down
41 changes: 26 additions & 15 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,38 @@
from djl_python.inputs import Input
from djl_python.encode_decode import encode, decode
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from dataclasses import dataclass, field


def parse_input_with_client_batch(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list, list]:
@dataclass
class ParsedInput:
input_data: list[str]
input_size: list[int]
parameters: list[dict]
errors: dict
batch: list
is_client_side_batch: list = field(default_factory=lambda: [])


def parse_input_with_client_batch(inputs: Input, tokenizer,
output_formatter) -> ParsedInput:
"""
Preprocessing function that extracts information from Input objects.
:param output_formatter: output formatter for the request
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:param tokenizer: the tokenizer used for inference
:return input_data (list[str]): a list of strings, each string being the prompt in a new request
:return input_size (list[int]): a list of ints being the size of each new request
:return parameters (list[dict]): parameters pertaining to each request
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
:return is_client_size_batch (list): list of boolean value representing whether the input is a client side batch
:return parsed_input: object of data class that contains all parsed input details
"""

input_data = []
input_size = []
parameters = []
errors = {}
batch = inputs.get_batches()
# only for dynamic batch
is_client_size_batch = [False for _ in range(len(batch))]
is_client_side_batch = [False for _ in range(len(batch))]
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
Expand All @@ -48,7 +54,7 @@ def parse_input_with_client_batch(
if not isinstance(_inputs, list):
_inputs = [_inputs]
else:
is_client_size_batch[i] = True
is_client_side_batch[i] = True
input_data.extend(_inputs)
input_size.append(len(_inputs))

Expand All @@ -64,7 +70,12 @@ def parse_input_with_client_batch(
for _ in range(input_size[i]):
parameters.append(_param)

return input_data, input_size, parameters, errors, batch, is_client_size_batch
return ParsedInput(input_data=input_data,
input_size=input_size,
parameters=parameters,
errors=errors,
batch=batch,
is_client_side_batch=is_client_side_batch)


def parse_input(
Expand All @@ -83,6 +94,6 @@ def parse_input(
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
"""
input_data, input_size, parameters, errors, batch, _ = parse_input_with_client_batch(
inputs, tokenizer, output_formatter)
return input_data, input_size, parameters, errors, batch
parsed_input = parse_input_with_client_batch(inputs, tokenizer,
output_formatter)
return parsed_input.input_data, parsed_input.input_size, parsed_input.parameters, parsed_input.errors, parsed_input.batch

0 comments on commit 81964c5

Please sign in to comment.