Skip to content

Commit

Permalink
[rolling batch] allow client timeouts to cancel engine requests
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Feb 8, 2025
1 parent b71f0f2 commit bd0d7b3
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 5 deletions.
3 changes: 3 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def add_server_maintained_params(request_input: RequestInput,
request_input.server_parameters["output_formatter"] = kwargs.get(
"configs").output_formatter

if input_item.get_property("cancelled"):
request_input.is_cancelled = True

output_formatter = request_input.server_parameters["output_formatter"]
if output_formatter == "json" or output_formatter == "sse":
request_input.tgi_compat = kwargs.get("configs").tgi_compat
Expand Down
11 changes: 11 additions & 0 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import inspect
import json
from typing import Union, Callable, Any, List, Dict, Optional

from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter
Expand Down Expand Up @@ -108,6 +109,8 @@ def get_next_token(self) -> str:
:return: next_token
"""
if self.is_cancelled():
return ""
if self.next_token_str:
return self.next_token_str
if self.legacy_formatter:
Expand Down Expand Up @@ -181,3 +184,11 @@ def get_client_request_id(self) -> str:
:return: the requestId specified in the HTTP request
"""
return self.request_input.client_request_id

def is_cancelled(self) -> bool:
"""
Returns whether the request has been cancelled by the client
:return: true if the request is cancelled
"""
return self.request_input.is_cancelled
1 change: 1 addition & 0 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class RequestInput:
parameters: Dict = field(default_factory=lambda: {})
server_parameters: Dict = field(default_factory=lambda: {})
tgi_compat: bool = False
is_cancelled: bool = False


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,17 @@ def postprocess_results(self) -> List[dict]:
req = self.active_requests[i]
res = {
"data": req.get_next_token(),
"last": req.is_last_token(),
"last": req.is_last_token() or req.is_cancelled(),
"content_type": req.get_content_type(),
"request_id": req.get_client_request_id(),
}
if req.get_error_message():
res["error"] = req.get_error_message()
if req.get_error_code():
res["code"] = req.get_error_code()
if req.is_cancelled():
res["error"] = res.get("error", "request has been cancelled")
res["code"] = res.get("code", 499)
req.reset_next_token()
results.append(res)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter
from vllm.utils import AtomicCounter

from djl_python.request import Request
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
Expand Down Expand Up @@ -154,6 +155,15 @@ def translate_vllm_params(self, parameters: dict) -> dict:
remove_unused_params=True)
return parameters

def cancel_requests(self):
for req in self.active_requests:
if req.is_cancelled():
self.engine.abort_request(req.get_client_request_id())
self.request_cache.pop(req.get_client_request_id(), None)
logging.info(
f"RequestId[{req.get_client_request_id()}] has been cancelled"
)

@stop_on_any_exception
def inference(self, new_requests: List[Request]) -> List:
"""
Expand All @@ -164,9 +174,10 @@ def inference(self, new_requests: List[Request]) -> List:
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
self.add_new_requests(new_requests)
self.cancel_requests()
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
request_id = request.get_client_request_id()
# Chat completions request route
if request.parameters.get("sampling_params") is not None:
prompt_inputs = request.parameters.get("engine_prompt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ public void run() {
String key = prefix + entry.getKey();
batch.addProperty(key, entry.getValue());
}
if (req.isCancelled()) {
String key = prefix + "cancelled";
batch.addProperty(key, "true");
}

batch.add(prefix + "data", req.getRequest());
String seed = req.getSeed();
Expand Down Expand Up @@ -223,12 +227,16 @@ public void run() {
}

public Output addInput(Input input, int timeout) throws TranslateException {
String requestId = input.getProperty("requestId", "");
String requestIdLogPrefix = "RequestId=[" + requestId + "]";
if (input.isCancelled()) {
logger.warn("{} has been cancelled, not processing request", requestIdLogPrefix);
return new Output(499, "request has been cancelled due to client disconnect");
}
try {
lock.lock();
if (list.size() >= maxRollingBatchSize) {
// Input always reflects a single request here
String requestId = input.getProperty("requestId", "");
String requestIdLogPrefix = "RequestId=[" + requestId + "]";
logger.debug(
"{} exceed max_rolling_batch_size: {}",
requestIdLogPrefix,
Expand Down Expand Up @@ -370,5 +378,9 @@ void addResponse(byte[] json, Map<String, String> properties) {
data.appendContent(nextToken.getBytes(StandardCharsets.UTF_8), last);
}
}

boolean isCancelled() {
return input.isCancelled();
}
}
}

0 comments on commit bd0d7b3

Please sign in to comment.