diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index 5a4849637..4ceffab1e 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -202,6 +202,9 @@ def add_server_maintained_params(request_input: RequestInput, request_input.server_parameters["output_formatter"] = kwargs.get( "configs").output_formatter + if input_item.get_property("cancelled"): + request_input.is_cancelled = True + output_formatter = request_input.server_parameters["output_formatter"] if output_formatter == "json" or output_formatter == "sse": request_input.tgi_compat = kwargs.get("configs").tgi_compat diff --git a/engines/python/setup/djl_python/request.py b/engines/python/setup/djl_python/request.py index 30396d24c..97fa470ce 100644 --- a/engines/python/setup/djl_python/request.py +++ b/engines/python/setup/djl_python/request.py @@ -11,6 +11,7 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. import inspect +import json from typing import Union, Callable, Any, List, Dict, Optional from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter @@ -108,6 +109,8 @@ def get_next_token(self) -> str: :return: next_token """ + if self.is_cancelled(): + return "" if self.next_token_str: return self.next_token_str if self.legacy_formatter: @@ -181,3 +184,11 @@ def get_client_request_id(self) -> str: :return: the requestId specified in the HTTP request """ return self.request_input.client_request_id + + def is_cancelled(self) -> bool: + """ + Returns whether the request has been cancelled by the client + + :return: true if the request is cancelled + """ + return self.request_input.is_cancelled diff --git a/engines/python/setup/djl_python/request_io.py b/engines/python/setup/djl_python/request_io.py index 7f66fa9a2..1e17ba1c8 100644 --- a/engines/python/setup/djl_python/request_io.py +++ b/engines/python/setup/djl_python/request_io.py @@ -149,6 +149,7 @@ class RequestInput: parameters: Dict = field(default_factory=lambda: {}) server_parameters: Dict = field(default_factory=lambda: {}) tgi_compat: bool = False + is_cancelled: bool = False @dataclass diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index ee90ad15d..fb904a600 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -171,7 +171,7 @@ def postprocess_results(self) -> List[dict]: req = self.active_requests[i] res = { "data": req.get_next_token(), - "last": req.is_last_token(), + "last": req.is_last_token() or req.is_cancelled(), "content_type": req.get_content_type(), "request_id": req.get_client_request_id(), } @@ -179,6 +179,9 @@ def postprocess_results(self) -> List[dict]: res["error"] = req.get_error_message() if req.get_error_code(): res["code"] = req.get_error_code() + if req.is_cancelled(): + res["error"] = res.get("error", "request has been cancelled") + res["code"] = res.get("code", 499) req.reset_next_token() results.append(res) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 76e5c8e0c..ab1722ceb 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -10,11 +10,12 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import logging from collections import OrderedDict from vllm import LLMEngine, SamplingParams from vllm.sampling_params import RequestOutputKind -from vllm.utils import random_uuid, AtomicCounter +from vllm.utils import AtomicCounter from djl_python.request import Request from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params @@ -154,6 +155,15 @@ def translate_vllm_params(self, parameters: dict) -> dict: remove_unused_params=True) return parameters + def cancel_requests(self): + for req in self.active_requests: + if req.is_cancelled(): + self.engine.abort_request(req.get_client_request_id()) + self.request_cache.pop(req.get_client_request_id(), None) + logging.info( + f"RequestId[{req.get_client_request_id()}] has been cancelled" + ) + @stop_on_any_exception def inference(self, new_requests: List[Request]) -> List: """ @@ -164,9 +174,10 @@ def inference(self, new_requests: List[Request]) -> List: :return results: List of dictionaries, one for each request, that contain output tokens and other data. """ self.add_new_requests(new_requests) + self.cancel_requests() # step 0: register new requests to engine for request in new_requests: - request_id = random_uuid() + request_id = request.get_client_request_id() # Chat completions request route if request.parameters.get("sampling_params") is not None: prompt_inputs = request.parameters.get("engine_prompt") diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index a7330813a..44bb09016 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -130,6 +130,10 @@ public void run() { String key = prefix + entry.getKey(); batch.addProperty(key, entry.getValue()); } + if (req.isCancelled()) { + String key = prefix + "cancelled"; + batch.addProperty(key, "true"); + } batch.add(prefix + "data", req.getRequest()); String seed = req.getSeed(); @@ -223,12 +227,16 @@ public void run() { } public Output addInput(Input input, int timeout) throws TranslateException { + String requestId = input.getProperty("requestId", ""); + String requestIdLogPrefix = "RequestId=[" + requestId + "]"; + if (input.isCancelled()) { + logger.warn("{} has been cancelled, not processing request", requestIdLogPrefix); + return new Output(499, "request has been cancelled due to client disconnect"); + } try { lock.lock(); if (list.size() >= maxRollingBatchSize) { // Input always reflects a single request here - String requestId = input.getProperty("requestId", ""); - String requestIdLogPrefix = "RequestId=[" + requestId + "]"; logger.debug( "{} exceed max_rolling_batch_size: {}", requestIdLogPrefix, @@ -370,5 +378,9 @@ void addResponse(byte[] json, Map properties) { data.appendContent(nextToken.getBytes(StandardCharsets.UTF_8), last); } } + + boolean isCancelled() { + return input.isCancelled(); + } } }