From 36d02123b383d26d6538e846e113e993f205e4bc Mon Sep 17 00:00:00 2001 From: Sindhu Somasundaram <56774226+sindhuvahinis@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:24:15 -0700 Subject: [PATCH] [python] parse input only when new requests are received (#2155) --- .../python/setup/djl_python/huggingface.py | 2 +- .../python/setup/djl_python/input_parser.py | 21 ++++++++++++++++++- .../rolling_batch/lmi_dist_rolling_batch.py | 7 ++++--- .../rolling_batch/neuron_rolling_batch.py | 6 +++--- .../djl_python/rolling_batch/rolling_batch.py | 15 ++++--------- .../rolling_batch/scheduler_rolling_batch.py | 6 +++--- .../rolling_batch/trtllm_rolling_batch.py | 6 +++--- .../rolling_batch/vllm_rolling_batch.py | 6 +++--- .../python/setup/djl_python/tensorrt_llm.py | 3 ++- .../setup/djl_python/tensorrt_llm_python.py | 3 ++- .../tests/rolling_batch/fake_rolling_batch.py | 8 +++---- 11 files changed, 49 insertions(+), 34 deletions(-) diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index dc3cf6413..e6ed93359 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -226,7 +226,7 @@ def inference(self, inputs: Input) -> Output: **self.input_format_args) requests = parsed_input.requests errors = parsed_input.errors - if len(requests) == 0: + if errors and len(parsed_input.batch) == len(errors): for i in range(len(parsed_input.batch)): err = errors.get(i) if is_rolling_batch_enabled(self.hf_configs.rolling_batch): diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index 230af5f72..b5b73eb79 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -30,6 +30,23 @@ class ParsedInput: batch: List = field(default_factory=lambda: []) +def get_batch_start_id(batch, **kwargs): + if kwargs.get("is_rolling_batch"): + # for rolling batch, we only need to parse the new requests, as the active requests kept in cache. + rolling_batch = kwargs.get("rolling_batch") + active_requests_len = len(rolling_batch.active_requests) + batch_size = len(batch) + if batch_size > active_requests_len: + # if batch_size > active_requests_len, then new requests are received + return active_requests_len + else: + # no new requests are received, so sending batch_size, nothing will be parsed. + return batch_size + else: + # for non-rolling batch, python process only receives new requests. + return 0 + + def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput: """ Preprocessing function that extracts information from Input objects. @@ -44,7 +61,9 @@ def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput: kwargs["is_rolling_batch"] = is_rolling_batch_enabled( kwargs.get("configs").rolling_batch) request_id_counter = get_req_id_counter(kwargs) - for i, input_item in enumerate(batch): + start_batch_id = get_batch_start_id(batch, **kwargs) + for i in range(start_batch_id, len(batch)): + input_item = batch[i] try: request_id = request_id_counter.next_id( ) if request_id_counter else i diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 47ea09e6e..644c550ff 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -143,15 +143,15 @@ def translate_lmi_dist_params(self, parameters: dict): return parameters @stop_on_any_exception - def inference(self, requests: List[Request]) -> List: + def inference(self, new_requests: List[Request]) -> List: """ Adds new requests and gets output tokens from the backend. - :param requests: List of requests + :param new_requests: List of requests :return results: List of dictionaries, one for each request, that contain output tokens and other data. """ - new_requests = self.get_new_requests(requests) + self.add_new_requests(new_requests) # step 0: register new requests to engine for request in new_requests: request_id = str(request.id) @@ -159,6 +159,7 @@ def inference(self, requests: List[Request]) -> List: request_params = RequestParams(**params) lora_request_params = get_lora_request_params( request, self.lora_ids) + # Constructing Request in lmi-dist library lmi_dist_request = Request( id=request_id, prompt=request.input_text, diff --git a/engines/python/setup/djl_python/rolling_batch/neuron_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/neuron_rolling_batch.py index deb4ef00d..96d12798f 100644 --- a/engines/python/setup/djl_python/rolling_batch/neuron_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/neuron_rolling_batch.py @@ -98,17 +98,17 @@ def append_speculated_generations(self, generation, request, req_ids): speculated_generation = generation.speculated_generations.dequeue() @stop_on_any_exception - def inference(self, requests: List[Request]) -> list: + def inference(self, new_requests: List[Request]) -> list: """ Loads new requests and gets output tokens from all currently active requests from the Neuron backend. - :param requests: List[Request] List of requests + :param new_requests: List[Request] List of requests :return: generated batch decoded tokens - list of dictionaries, one for each request, that contain output tokens and other data. """ - new_requests = self.get_new_requests(requests) + self.add_new_requests(new_requests) if len(new_requests) > 0: generations = self.scheduler.prefill(new_requests) else: diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 4146516e5..4dfb93712 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -93,30 +93,23 @@ def get_tokenizer(self): raise RuntimeError("get_tokenizer function not supported") @abstractmethod - def inference(self, requests: List[Request]) -> List: + def inference(self, new_requests: List[Request]) -> List: """ Performs prefill and decode operations for the batch. - :param requests: List[Request] List of requests + :param new_requests: List[Request] List of requests :return: generated batch decoded tokens """ pass - def get_new_requests(self, requests: List[Request]) -> List[Request]: + def add_new_requests(self, requests: List[Request]): """ Adds requests to the batch when there is availability :param requests: List[Request] List of requests - - :return: list of current active requests (including those that have just been added) """ - total_req_len = len(self.active_requests) - batch_size = len(requests) - if batch_size > total_req_len: - for i in range(total_req_len, batch_size): - self.active_requests.append(requests[i]) - return self.active_requests[total_req_len:] + self.active_requests.extend(requests) @abstractmethod def preprocess_requests(self, requests: List[Request]): diff --git a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py index 543bd1b33..189cf5635 100644 --- a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py @@ -69,14 +69,14 @@ def __init__(self, model_id_or_path: str, properties: dict, self._init_scheduler() @stop_on_any_exception - def inference(self, requests: List) -> List: + def inference(self, new_requests: List) -> List: """ Performs prefill and decode operations for the batch. - :param requests: List[Request] List of requests + :param new_requests: List[Request] List of requests :return: generated batch decoded tokens """ - new_requests = self.get_new_requests(requests) + self.add_new_requests(new_requests) preprocessed_new_requests = self.preprocess_requests(new_requests) self._prefill_and_decode(preprocessed_new_requests) diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index 4af3168d1..02e0f185b 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -87,12 +87,12 @@ def translate_triton_params(self, parameters: dict) -> dict: return parameters @stop_on_any_exception - def inference(self, requests: List[Request]) -> List: + def inference(self, new_requests: List[Request]) -> List: """ Loads new requests into the batch when there is availability, and gets output tokens from the backend asynchronously. - :param requests: List[Request] List of requests + :param new_requests: List[Request] List of requests :param input_data: List of input prompts. :param parameters: List of settings pertaining to each request. :param adapters: List of adapters inputs for each request in a batch @@ -100,7 +100,7 @@ def inference(self, requests: List[Request]) -> List: :return results: List of dictionaries, one for each request, that contain output tokens and other data. """ # add pending requests to active requests list - new_requests = self.get_new_requests(requests) + self.add_new_requests(new_requests) # step 0: register new active requests for request in new_requests: param = self.translate_triton_params(request.parameters) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 691367d92..4c7976991 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -107,15 +107,15 @@ def translate_vllm_params(self, parameters: dict) -> dict: return parameters @stop_on_any_exception - def inference(self, requests: List[Request]) -> List: + def inference(self, new_requests: List[Request]) -> List: """ Adds new requests and gets output tokens from the backend. - :param requests: List[Request] List of requests + :param new_requests: List[Request] List of requests :return results: List of dictionaries, one for each request, that contain output tokens and other data. """ - new_requests = self.get_new_requests(requests) + self.add_new_requests(new_requests) # step 0: register new requests to engine for request in new_requests: request_id = random_uuid() diff --git a/engines/python/setup/djl_python/tensorrt_llm.py b/engines/python/setup/djl_python/tensorrt_llm.py index bafbca296..f5e7d614d 100644 --- a/engines/python/setup/djl_python/tensorrt_llm.py +++ b/engines/python/setup/djl_python/tensorrt_llm.py @@ -64,7 +64,8 @@ def inference(self, inputs: Input) -> Output: parsed_input = parse_input_with_formatter(inputs, **self.input_format_args) - if len(parsed_input.requests) == 0: + if parsed_input.errors and len(parsed_input.requests) == len( + parsed_input.errors): for i in range(len(parsed_input.batch)): err = parsed_input.errors.get(i) err = {"data": "", "last": True, "code": 424, "error": err} diff --git a/engines/python/setup/djl_python/tensorrt_llm_python.py b/engines/python/setup/djl_python/tensorrt_llm_python.py index e269f8b6d..2ef11219f 100644 --- a/engines/python/setup/djl_python/tensorrt_llm_python.py +++ b/engines/python/setup/djl_python/tensorrt_llm_python.py @@ -115,7 +115,8 @@ def inference(self, inputs: Input) -> Output: parsed_input = parse_input_with_formatter(inputs, **self.input_format_args) - if len(parsed_input.requests) == 0: + if parsed_input.errors and len(parsed_input.requests) == len( + parsed_input.errors): for i in range(len(parsed_input.batch)): err = parsed_input.errors.get(i) outputs.add(err, key="data", batch_index=i) diff --git a/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py index b43edc4bb..235ebd70a 100644 --- a/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/rolling_batch/fake_rolling_batch.py @@ -65,8 +65,8 @@ def reset(self): @profile_objects @stop_on_any_exception - def inference(self, requests: List[Request]) -> List: - new_requests = self.get_new_requests(requests) + def inference(self, new_requests: List[Request]) -> List: + self.add_new_requests(new_requests) for new_request in new_requests: max_len = new_request.parameters[ @@ -118,10 +118,10 @@ def reset(self): @profile_objects @stop_on_any_exception - def inference(self, requests: List[Request]): + def inference(self, new_requests: List[Request]): if self.dead_counter.get_id() < self.dead_trigger: self.dead_counter.next_id() - return super().inference(requests) + return super().inference(new_requests) else: raise RuntimeError("Death trigger triggered...")