Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jul 9, 2024
1 parent 009411d commit 1af53f8
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 61 deletions.
18 changes: 9 additions & 9 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def __init__(self):
self.peft_config = None
self.stopping_criteria_list = None
self.adapter_registry = {}
self.adapters = None
self.hf_configs = None
self.input_format_args = None

Expand Down Expand Up @@ -254,14 +253,14 @@ def _dynamic_batch_inference(self, batch: List, errors: Dict,
inputs: Input, outputs: Output,
requests: List):
# Dynamic batching
input_data, input_size = get_input_details(requests, errors, batch)
parameters = requests[0].request_input.server_parameters
input_data, input_size, parameters, adapters = get_input_details(
requests, errors, batch)

if isinstance(self.model, PeftModelForCausalLM):
if self.adapters is None:
if adapters is None:
# Inference with only base model
self.adapters = [""] * len(input_data)
parameters["adapters"] = self.adapters
adapters = [""] * len(input_data)
parameters["adapters"] = adapters
prediction = self.hf_pipeline(input_data, **parameters)
offset = 0
for i, item in enumerate(batch):
Expand Down Expand Up @@ -293,20 +292,21 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
if len(batch) > 1:
raise NotImplementedError(
"Dynamic batch not supported for generic streaming")

parameters = request_input.server_parameters
outputs.add_property("content-type", "application/jsonlines")
if self.hf_configs.enable_streaming.value == StreamingEnum.huggingface.value:
outputs.add_stream_content(
StreamingUtils.use_hf_default_streamer(
self.model, self.tokenizer, request_input.input_text,
self.hf_configs.device, **request_input.server_parameters))
self.hf_configs.device, **parameters))
else:
stream_generator = StreamingUtils.get_stream_generator(
"Accelerate")
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer,
request_input.input_text,
self.hf_configs.device,
**request_input.server_parameters))
self.hf_configs.device, **parameters))
return outputs

def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
Expand Down
60 changes: 32 additions & 28 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
request_id_counter = get_req_id_counter(kwargs)
for i, input_item in enumerate(batch):
try:
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
kwargs.get("configs").rolling_batch)
request_id = request_id_counter.next_id(
) if request_id_counter else i
# TODO: Decide whether it is a text input based on content-type
Expand All @@ -70,7 +72,7 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:

def get_req_id_counter(kwargs):
req_id_counter = None
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
if kwargs.get("is_rolling_batch"):
req_id_counter = kwargs.get("rolling_batch").req_id_counter
return req_id_counter

Expand All @@ -89,24 +91,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api")
tokenizer = kwargs.get("tokenizer")
if is_chat_completions_request(input_map):
_inputs, _param = parse_chat_completions_request(
inputs, param = parse_chat_completions_request(
input_map, kwargs.get("is_rolling_batch"), tokenizer)
elif is_3p_request(invoke_type):
_inputs, _param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
tokenizer, invoke_type)
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
tokenizer, invoke_type)
else:
_inputs = input_map.pop("inputs", input_map)
_param = input_map.pop("parameters", {})

request_input.input_text = _inputs
request_input.parameters = _param
# assign input_ids
if kwargs.get("tokenizer"):
inputs = input_map.pop("inputs", input_map)
param = input_map.pop("parameters", {})

request_input.input_text = inputs
request_input.parameters = param
# assigns input_ids
# TODO: for dynamic batching, or HF pipeline, tokenizer is applied differently.
if kwargs.get("tokenizer") and kwargs.get("is_rolling_batch"):
request_input.input_ids = tokenizer.encode(request_input.input_text)

# TODO: Instead of modifying user parameters, maintain this in server_parameters.
# Added here for backward compatibility
# re-organize the parameters
if is_rolling_batch_enabled(kwargs.get("configs").rolling_batch):
if kwargs.get("is_rolling_batch"):
if "stream" in input_map:
request_input.parameters["stream"] = input_map.pop("stream")
if "cached_prompt" in input_map:
Expand All @@ -124,45 +129,44 @@ def add_server_maintained_params(request_input: TextInput, input_item: Input,
if input_item.contains_key("seed"):
request_input.server_parameters["seed"] = input_item.get_as_string(
key="seed")

# setting the output formatter
if not "output_formatter" in request_input.server_parameters:
request_input.server_parameters["output_formatter"] = kwargs.get(
"configs").output_formatter

request_input.output_formatter = request_input.server_parameters.get(
"output_formatter")

if request_input.output_formatter == "json" or request_input.output_formatter == "sse":
output_formatter = request_input.server_parameters["output_formatter"]
if output_formatter == "json" or output_formatter == "sse":
request_input.tgi_compat = kwargs.get("configs").tgi_compat

# duplicating parameters for client side batching
if isinstance(request_input.input_text, list):
parameters = []
for _ in range(len(request_input.input_text)):
parameters.append(request_input.server_parameters.copy())
request_input.server_parameters = parameters


def parse_adapters(request_input: TextInput, input_item: Input,
input_map: Dict, **kwargs):
adapter_registry = kwargs.get("adapter_registry")
# if adapter registry exists and not empty, then we assume, peft is supported for the incoming
if adapter_registry:
input_len = len(request_input.input_text) if isinstance(
request_input.input_text, list) else 1
adapters_per_item = _fetch_adapters_from_input(input_map, input_item)
if adapters_per_item:
_validate_adapters(adapters_per_item,
kwargs.get("adapter_registry"))
else:
# inference with just base model.
adapters_per_item = [""] * len(request_input.input_text)
adapters_per_item = [""] * input_len

if len(request_input.input_text) != len(adapters_per_item):
if input_len != len(adapters_per_item):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
# lookup the adapter registry to get the adapter details of the registered adapter.
request_input.adapters = [
adapters_data = [
kwargs.get("adapter_registry").get(adapter, None)
for adapter in adapter_registry
for adapter in adapters_per_item
]
if len(adapters_data) == 1:
adapters_data = adapters_data[0]

request_input.adapters = adapters_data


def _fetch_adapters_from_input(input_map: dict, input_item: Input):
Expand Down
18 changes: 16 additions & 2 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,31 @@ def __init__(self, request_input: TextInput = None):
:param id: request id
"""

#TODO: Remove some of these redundant attributes and
# use request_input and request_output wherever necessary.
self.id = request_input.request_id
self.request_input = request_input
self.parameters = self.request_input.server_parameters
self.input_text = request_input.input_text
self.last_token = False
self.adapter = request_input.adapters

# server parameters may not be set, if custom input formatter is used.
if not self.request_input.server_parameters:
self.request_input.server_parameters = self.request_input.parameters.copy(
)
self.parameters = self.request_input.server_parameters

# output formatter
stream = self.request_input.parameters.get("stream", False)
request_input.output_formatter = self.parameters.pop(
"output_formatter", request_input.output_formatter)
# stream parameter is only used for determining the output.
stream = self.parameters.pop("stream", False)
# details is only used in output formatter for rolling batch
self.parameters.pop("details", False)
self.output_formatter, self.content_type = get_output_formatter(
request_input.output_formatter, stream, request_input.tgi_compat)
request_input.output_formatter = self.output_formatter
self.legacy_formatter = self._is_output_formatter_legacy()

self.request_output = TextGenerationOutput(request_id=self.id,
Expand Down
6 changes: 3 additions & 3 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class RequestInput:
Attributes:
request_id: The request ID.
output_formatter: Output formatter of the request
parameters: parameters in the request payload, will be used in the output formatter
parameters: parameters in the request payload
server_parameters: parameters that are modified by the built-in handlers to support backend engines.
"""
request_id: int
output_formatter: Union[Callable, str] = None
Expand All @@ -147,11 +148,10 @@ class TextInput(RequestInput):
adapters: adapter used for the request.
tokenizer: tokenizer used for the request.
"""
input_text: str = None
input_text: Union[str, List[str]] = None
input_ids: List[int] = field(default_factory=lambda: [])
adapters: Optional[Any] = None
tokenizer: Optional[Any] = None
found_adapters: bool = False

def prompt_tokens_length(self) -> int:
return len(self.input_ids)
Expand Down
10 changes: 5 additions & 5 deletions engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TRTLLMService(object):
"""

def __init__(self):
self.input_format_args = None
self.initialized = False
self.trt_configs = None
self.rolling_batch = None
Expand All @@ -40,6 +41,7 @@ def initialize(self, properties: dict):
self.rolling_batch = TRTLLMRollingBatch(
self.trt_configs.model_id_or_path, properties, self.trt_configs)
self.tokenizer = self.rolling_batch.get_tokenizer()
self.input_format_args = self.get_input_format_args()
self.initialized = True
return

Expand All @@ -54,16 +56,14 @@ def inference(self, inputs: Input) -> Output:
"""
Does preprocessing and sends new requests to the rolling batch script for inference
:param inputs (Input): a batch of inputs, each corresponding to a new request
:param inputs: (Input) a batch of inputs, each corresponding to a new request
:return outputs (Output): a batch of outputs that contain status code, output text, and other information
"""
outputs = Output()
kwargs = self.__dict__
kwargs[
"configs"] = self.trt_configs # TODO: Rename it to configs, so it would uniform in all handlers

parsed_input = parse_input_with_formatter(inputs, **kwargs)
parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
if len(parsed_input.requests) == 0:
for i in range(len(parsed_input.batch)):
err = parsed_input.errors.get(i)
Expand Down
6 changes: 2 additions & 4 deletions engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,8 @@ def inference(self, inputs: Input) -> Output:
outputs.add(err, key="data", batch_index=i)
return outputs

input_data, input_size = get_input_details(parsed_input.requests,
parsed_input.errors,
parsed_input.batch)
params = parsed_input.requests[0].request_input.server_parameters
input_data, input_size, params, _ = get_input_details(
parsed_input.requests, parsed_input.errors, parsed_input.batch)

if "output_formatter" in params:
# output formatter is not supported for TensorRT-LLM python backend.
Expand Down
29 changes: 29 additions & 0 deletions engines/python/setup/djl_python/tests/test_input_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

from djl_python.input_parser import parse_input_with_formatter
from djl_python.test_model import create_concurrent_batch_request


class InputParserTest(unittest.TestCase):

def test_input_parameters(self):
inputs = [{
"inputs": "The winner of oscar this year is",
"parameters": {
"max_new_tokens": 50
},
"stream": False
}, {
"inputs": "A little redhood is",
"parameters": {
"min_new_tokens": 51,
"max_new_tokens": 256,
},
"stream": True
}]

serving_properties = {"rolling_batch": "disable"}

inputs = create_concurrent_batch_request(
inputs, serving_properties=serving_properties)
parsed_input = parse_input_with_formatter(inputs)
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_sse_tgi_compat_fmt(self):
input_text="This is a wonderful day",
parameters={
"max_new_tokens": 256,
"stream": True
"stream": True,
},
tgi_compat=True))
req.set_next_token(Token(244, "He", -0.334532))
Expand Down
10 changes: 4 additions & 6 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def partition(self, properties: dict):
self.initialized = True

def inference(self, inputs: Input) -> Output:
parsed_input = parse_input_with_formatter(inputs, **self.__dict__)
parsed_input = parse_input_with_formatter(inputs,
**self.input_format_args)
errors = parsed_input.errors
requests = parsed_input.requests
outputs = Output()
Expand All @@ -229,11 +230,8 @@ def inference(self, inputs: Input) -> Output:
self.rolling_batch)

batch = parsed_input.batch
input_data, input_size = get_input_details(requests, errors, batch)
parameters = parsed_input.requests[0].request_input.server_parameters
# Remove rolling batch default parameters
parameters.pop("output_formatter", None)
parameters.pop("stream", None)
input_data, input_size, parameters, _ = get_input_details(
requests, errors, batch)
model_kwargs = {}

prompt_size = len(input_data)
Expand Down
6 changes: 3 additions & 3 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get_input_details(requests, errors, batch):
input_size = []
adapters = []
idx = 0
request_input = requests[0].request_input
parameters = request_input.server_parameters
parameters = requests[0].request_input.server_parameters

for i in range(len(batch)):
if i in errors:
input_size.append(0)
Expand All @@ -134,4 +134,4 @@ def get_input_details(requests, errors, batch):

idx += 1
adapters = adapters if adapters else None
return input_data, input_size, adapters
return input_data, input_size, parameters, adapters

0 comments on commit 1af53f8

Please sign in to comment.