diff --git a/engines/python/setup/djl_python/tensorrt_llm.py b/engines/python/setup/djl_python/tensorrt_llm.py index 20fc11112..f1c9ec01e 100644 --- a/engines/python/setup/djl_python/tensorrt_llm.py +++ b/engines/python/setup/djl_python/tensorrt_llm.py @@ -11,22 +11,13 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import os -import logging -import tensorrt_llm_toolkit -from tensorrt_llm_toolkit.utils import utils as toolkit_utils - -from transformers import AutoConfig - -from djl_python.encode_decode import encode, decode from djl_python.inputs import Input from djl_python.outputs import Output from djl_python.rolling_batch.rolling_batch import get_content_type_from_output_formatter from djl_python.rolling_batch.trtllm_rolling_batch import TRTLLMRollingBatch from djl_python.properties_manager.trt_properties import TensorRtLlmProperties -from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request - -from djl_python.properties_manager.properties import is_rolling_batch_enabled +from djl_python.tensorrt_llm_python import TRTLLMPythonService +from djl_python.utils import parse_input class TRTLLMService(object): @@ -36,109 +27,20 @@ class TRTLLMService(object): calls TensorRT-LLM in the back-end. """ - PYTHON_BACKEND_SUPPORTED_MODELS = {'t5'} - def __init__(self): self.initialized = False self.trt_configs = None self.rolling_batch = None - self.model = None - self.is_rolling_batch_enabled = True def initialize(self, properties: dict): self.trt_configs = TensorRtLlmProperties(**properties) - self.is_rolling_batch_enabled = is_rolling_batch_enabled( - self.trt_configs.rolling_batch) - self._load_model(properties) + + self.rolling_batch = TRTLLMRollingBatch( + self.trt_configs.model_id_or_path, properties, **properties) self.initialized = True return - def parse_input( - self, inputs: Input - ) -> tuple[list[str], list[int], list[dict], dict, list]: - """ - Preprocessing function that extracts information from Input objects. - - :param inputs (Input): a batch of inputs, each corresponding to a new request - - :return input_data (list[str]): a list of strings, each string being the prompt in a new request - :return input_size (list[int]): a list of ints being the size of each new request - :return parameters (list[dict]): parameters pertaining to each request - :return errors (dict): a dictionary mapping int indices to corresponding error strings if any - :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request) - """ - input_data = [] - input_size = [] - parameters = [] - errors = {} - batch = inputs.get_batches() - for i, item in enumerate(batch): - try: - content_type = item.get_property("Content-Type") - input_map = decode(item, content_type) - except Exception as e: # pylint: disable=broad-except - logging.warning(f"Parse input failed: {i}") - input_size.append(0) - errors[i] = str(e) - continue - - if is_chat_completions_request(input_map): - _inputs, _param = parse_chat_completions_request( - input_map, True, self.rolling_batch.get_tokenizer()) - else: - _inputs = input_map.pop("inputs", input_map) - _param = input_map.pop("parameters", {}) - _param["stream"] = input_map.pop("stream", False) - if not isinstance(_inputs, list): - _inputs = [_inputs] - input_data.extend(_inputs) - input_size.append(len(_inputs)) - - if "cached_prompt" in input_map: - _param["cached_prompt"] = input_map.pop("cached_prompt") - if "seed" not in _param: - # set server provided seed if seed is not part of request - if item.contains_key("seed"): - _param["seed"] = item.get_as_string(key="seed") - if not "output_formatter" in _param: - _param["output_formatter"] = self.trt_configs.output_formatter - - for _ in range(input_size[i]): - parameters.append(_param) - - return input_data, input_size, parameters, errors, batch - - def _get_config(self, properties): - model_path = self.trt_configs.model_id_or_path - if not os.path.isfile(os.path.join(model_path, 'config.json')): - model_path = toolkit_utils.get_python_backend_engine_path( - model_path, properties) - if not os.path.isfile(os.path.join(model_path, 'config.json')): - raise ValueError( - f"Could not find config.json in {self.trt_configs.model_id_or_path} or" - f"{model_path} for TensorRT python backend") - - return AutoConfig.from_pretrained( - model_path, trust_remote_code=self.trt_configs.trust_remote_code) - - def _load_model(self, properties): - if self.is_rolling_batch_enabled: - self.rolling_batch = TRTLLMRollingBatch( - self.trt_configs.model_id_or_path, properties, **properties) - else: - model_config = self._get_config(properties) - if model_config.model_type in self.PYTHON_BACKEND_SUPPORTED_MODELS: - self.model = tensorrt_llm_toolkit.init_inference( - self.trt_configs.model_id_or_path, - **properties, - use_python_backend=True) - else: - raise ValueError( - f"You cannot disable rolling batch if its not any of these models" - f" {self.PYTHON_BACKEND_SUPPORTED_MODELS}. Please enable it with auto or trtllm " - f"values to option.rolling_batch") - - def rolling_batch_inference(self, inputs: Input) -> Output: + def inference(self, inputs: Input) -> Output: """ Does preprocessing and sends new requests to the rolling batch script for inference @@ -148,8 +50,9 @@ def rolling_batch_inference(self, inputs: Input) -> Output: """ outputs = Output() - input_data, input_size, parameters, errors, batch = self.parse_input( - inputs) + input_data, input_size, parameters, errors, batch = parse_input( + inputs, self.rolling_batch.get_tokenizer(), + self.trt_configs.output_formatter) if len(input_data) == 0: for i in range(len(batch)): err = errors.get(i) @@ -184,46 +87,6 @@ def rolling_batch_inference(self, inputs: Input) -> Output: return outputs - # TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching. - def inference(self, inputs: Input) -> Output: - """ - Does preprocessing and sends new requests to the rolling batch script for inference - - :param inputs (Input): a batch of inputs, each corresponding to a new request - - :return outputs (Output): a batch of outputs that contain status code, output text, and other information - """ - outputs = Output() - - input_data, input_size, parameters, errors, batch = self.parse_input( - inputs) - if len(input_data) == 0: - for i in range(len(batch)): - err = errors.get(i) - outputs.add(err, key="data", batch_index=i) - return outputs - - params = parameters[0] - result = self.model.generate(input_data, **params) - result = [{"generated_text": s} for s in result.batch_generation()] - idx = 0 - for i, item in enumerate(batch): - content_type = item.get_property("Content-Type") - accept = item.get_property("Accept") - if not accept: - content_type = content_type if content_type else "application/json" - accept = content_type if content_type.startswith( - "tensor/") else "application/json" - elif "*/*" in accept: - accept = "application/json" - - encode(outputs, - result[idx:idx + input_size[i]], - accept, - key=inputs.get_content().key_at(i)) - idx += input_size[i] - return outputs - _service = TRTLLMService() @@ -232,11 +95,15 @@ def handle(inputs: Input) -> Output: """ Handler function for the default TensorRT-LLM handler. - :param inputs (Input): a batch of inputs, each corresponding to a new request + :param inputs: (Input) a batch of inputs, each corresponding to a new request :return outputs (Output): a batch of outputs that contain status code, output text, and other information. """ + global _service if not _service.initialized: + properties = inputs.get_properties() + if properties.get("rolling_batch", "disable") == "disable": + _service = TRTLLMPythonService() # stateful model _service.initialize(inputs.get_properties()) @@ -244,7 +111,4 @@ def handle(inputs: Input) -> Output: # initialization request return None - if _service.is_rolling_batch_enabled: - return _service.rolling_batch_inference(inputs) - else: - return _service.inference(inputs) + return _service.inference(inputs) diff --git a/engines/python/setup/djl_python/tensorrt_llm_python.py b/engines/python/setup/djl_python/tensorrt_llm_python.py new file mode 100644 index 000000000..1365cba1e --- /dev/null +++ b/engines/python/setup/djl_python/tensorrt_llm_python.py @@ -0,0 +1,180 @@ +import os +import torch + +import tensorrt_llm_toolkit +from tensorrt_llm_toolkit.utils import utils as toolkit_utils + +from transformers import AutoConfig + +from djl_python.properties_manager.trt_properties import TensorRtLlmProperties +from djl_python.encode_decode import encode +from djl_python.inputs import Input +from djl_python.outputs import Output +from djl_python.utils import parse_input + + +def _get_value_based_on_tensor(value, index=None): + if isinstance(value, torch.Tensor): + if index: + return value.cpu().numpy()[index] + else: + return value.cpu().item() + else: + return value + + +def _get_generation_result_from_python_backend(generations, inputs_size): + batch_size = sum(inputs_size) + tokens_results = [[] for _ in range(batch_size) + ] # list[list], [batch_size, generated_tokens_len] + prediction_results = [{} for _ in range(batch_size) + ] # list[dict], [batch_size] + cum_log_probs = [0.0 + for _ in range(batch_size)] # list[dict], [batch_size] + for generation in generations: # each token of whole batch + for i in range(len(generation)): # loop through each batch + # generation will be none, when it is already finished for that input + if not generation[i]: + continue + # generated_text will not be none, only during the last token. + if generation[i].generated_text: + result = { + "generated_text": generation[i].generated_text, + 'details': { + # TODO: add finish reason + "tokens": tokens_results[i] + } + } + prediction_results[i] = result + else: + curr_cum_log_prob = _get_value_based_on_tensor( + generation[i].cum_logprob) + log_prob = curr_cum_log_prob - cum_log_probs[i] + token_result = { + 'id': + _get_value_based_on_tensor(generation[i].token_id, + index=0), + 'text': + generation[i].token_text, + 'log_prob': + log_prob if i < len(tokens_results) else curr_cum_log_prob, + } + cum_log_probs[i] = curr_cum_log_prob + tokens_results[i].append(token_result) + return prediction_results + + +def _get_accept_and_content_type(batch_item) -> tuple[str, str]: + content_type = batch_item.get_property("Content-Type") + accept = batch_item.get_property("Accept") + if not accept: + content_type = content_type if content_type else "application/json" + accept = content_type if content_type.startswith( + "tensor/") else "application/json" + elif "*/*" in accept: + accept = "application/json" + return content_type, accept + + +class TRTLLMPythonService: + + PYTHON_BACKEND_SUPPORTED_MODELS = {'t5'} + + def __init__(self): + self.model = None + self.trt_configs = None + self.initialized = False + + def initialize(self, properties: dict): + self.trt_configs = TensorRtLlmProperties(**properties) + self._load_model(properties) + self.initialized = True + return + + def inference(self, inputs: Input) -> Output: + """ + Does preprocessing and sends new requests to the rolling batch script for inference + + :param inputs: (Input) a batch of inputs, each corresponding to a new request + + :return outputs (Output): a batch of outputs that contain status code, output text, and other information + """ + outputs = Output() + + input_data, input_size, parameters, errors, batch = parse_input( + inputs, None, self.trt_configs.output_formatter) + if len(input_data) == 0: + for i in range(len(batch)): + err = errors.get(i) + outputs.add(err, key="data", batch_index=i) + return outputs + + params = parameters[0] + if params.get("details", False): + return self._stream_inference(inputs, input_data, input_size, + params, batch) + + detokenized_python_response = self.model.generate(input_data, **params) + results = [{ + "generated_text": s + } for s in detokenized_python_response.batch_generation()] + offset = 0 + for i, item in enumerate(batch): + content_type, accept = _get_accept_and_content_type(item) + batch_item = results[offset] if input_size[i] == 1 else results[ + offset:offset + input_size[i]] + encode(outputs, + batch_item, + accept, + key=inputs.get_content().key_at(i)) + offset += input_size[i] + return outputs + + def _load_model(self, properties): + model_config = self._get_config(properties) + if model_config.model_type in self.PYTHON_BACKEND_SUPPORTED_MODELS: + self.model = tensorrt_llm_toolkit.init_inference( + self.trt_configs.model_id_or_path, + **properties, + use_python_backend=True) + else: + raise ValueError( + f"You cannot disable rolling batch if its not any of these models" + f" {self.PYTHON_BACKEND_SUPPORTED_MODELS}. Please enable it with auto or trtllm " + f"values to option.rolling_batch") + + def _get_config(self, properties): + model_path = self.trt_configs.model_id_or_path + if not os.path.isfile(os.path.join(model_path, 'config.json')): + model_path = toolkit_utils.get_python_backend_engine_path( + model_path, properties) + if not os.path.isfile(os.path.join(model_path, 'config.json')): + raise ValueError( + f"Could not find config.json in {self.trt_configs.model_id_or_path} or" + f"{model_path} for TensorRT python backend") + + return AutoConfig.from_pretrained( + model_path, trust_remote_code=self.trt_configs.trust_remote_code) + + # TODO TrtLLM python backend: Change it once T5 bug is fixed. + def _stream_inference(self, inputs: Input, input_data: list[str], + input_size: list[int], parameters: dict, + batch: list) -> Output: + outputs = Output() + detokenized_python_response = self.model.generate( + input_data, **parameters) + generations = detokenized_python_response.stream_batch_generation() + results = _get_generation_result_from_python_backend( + generations, input_size) + offset = 0 + for i, item in enumerate(batch): + item = batch[i] + accept, content_type = _get_accept_and_content_type(item) + batch_item = results[offset] if input_size[i] == 1 else results[ + offset:offset + input_size[i]] + encode(outputs, + batch_item, + accept, + key=inputs.get_content().key_at(i)) + offset += input_size[i] + return outputs diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py new file mode 100644 index 000000000..075915f9c --- /dev/null +++ b/engines/python/setup/djl_python/utils.py @@ -0,0 +1,61 @@ +import logging +from djl_python.inputs import Input +from djl_python.encode_decode import encode, decode +from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request + + +def parse_input( + inputs: Input, tokenizer, output_formatter +) -> tuple[list[str], list[int], list[dict], dict, list]: + """ + Preprocessing function that extracts information from Input objects. + + :param inputs :(Input) a batch of inputs, each corresponding to a new request + :param tokenizer: the tokenizer used for inference + + :return input_data (list[str]): a list of strings, each string being the prompt in a new request + :return input_size (list[int]): a list of ints being the size of each new request + :return parameters (list[dict]): parameters pertaining to each request + :return errors (dict): a dictionary mapping int indices to corresponding error strings if any + :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request) + """ + input_data = [] + input_size = [] + parameters = [] + errors = {} + batch = inputs.get_batches() + for i, item in enumerate(batch): + try: + content_type = item.get_property("Content-Type") + input_map = decode(item, content_type) + except Exception as e: # pylint: disable=broad-except + logging.warning(f"Parse input failed: {i}") + input_size.append(0) + errors[i] = str(e) + continue + + if is_chat_completions_request(input_map): + _inputs, _param = parse_chat_completions_request( + input_map, True, tokenizer) + else: + _inputs = input_map.pop("inputs", input_map) + _param = input_map.pop("parameters", {}) + _param["stream"] = input_map.pop("stream", False) + if not isinstance(_inputs, list): + _inputs = [_inputs] + input_data.extend(_inputs) + input_size.append(len(_inputs)) + + if "cached_prompt" in input_map: + _param["cached_prompt"] = input_map.pop("cached_prompt") + if "seed" not in _param: + # set server provided seed if seed is not part of request + if item.contains_key("seed"): + _param["seed"] = item.get_as_string(key="seed") + if not "output_formatter" in _param: + _param["output_formatter"] = output_formatter + + for _ in range(input_size[i]): + parameters.append(_param) + + return input_data, input_size, parameters, errors, batch