Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Fix backport of rolling batch non-streaming non-200 error code support #2478

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import json
import logging
import time
from typing import Union, Callable
from typing import Dict, Union, Callable

from typing_extensions import deprecated

from djl_python.request_io import Token, TextGenerationOutput, RequestOutput
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled


def _json_output_formatter(request_output: RequestOutput):
Expand All @@ -26,6 +27,10 @@ def _json_output_formatter(request_output: RequestOutput):

:return: formatted output
"""
if serving_backport_for_non_streaming_http_error_codes_enabled():
return _json_output_formatter_backport_for_non_streaming_http_error_codes(
request_output)

best_sequence = request_output.sequences[
request_output.best_sequence_index]

Expand Down Expand Up @@ -65,6 +70,83 @@ def _json_output_formatter(request_output: RequestOutput):
return json_encoded_str


def _json_output_formatter_backport_for_non_streaming_http_error_codes(
request_output: TextGenerationOutput):
"""
json output formatter that allows non-streaming requests to return non-200 error codes on error.

Backported from djl-serving v0.29.0.

:return: formatted output
"""

def _get_last_token(seq):
if seq._last_token_index:
return seq.tokens[seq._last_token_index]
return None

def _get_generated_text(sequence, request_output):
parameters = request_output.input.parameters
generated_text = request_output.input.input_text if parameters.get(
"return_full_text") else ""
for token in sequence.tokens:
generated_text += token.text
return generated_text

def _get_details_dict(request_output: TextGenerationOutput,
include_tokens: bool = True) -> Dict:
parameters = request_output.input.parameters
best_sequence = request_output.sequences[
request_output.best_sequence_index]
if parameters.get("details", request_output.input.tgi_compat):
final_dict = {
"finish_reason": best_sequence.finish_reason,
"generated_tokens": len(best_sequence.tokens),
"inputs": request_output.input.input_text,
}

if include_tokens:
final_dict["tokens"] = request_output.get_tokens_as_dict()

if parameters.get("decoder_input_details"):
final_dict[
"prefill"] = request_output.get_prompt_tokes_as_dict()
return final_dict
elif best_sequence.finish_reason == "error":
return {"finish_reason": best_sequence.finish_reason}
else:
return {}

best_sequence = request_output.sequences[
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, _ = best_sequence.get_next_token()
if not request_output.finished:
return ""
details = _get_details_dict(request_output, include_tokens=True)
if details.get("finish_reason") == "error":
final_token = _get_last_token(best_sequence)
# In non-streaming, request either succeeds or fails so do not provide the
# partial generation response that may exist
result = {
"generated_text": None,
"error": getattr(final_token, "error_msg", "error"),
"code": 400,
"details": details,
}
return json.dumps(result, ensure_ascii=False)
generated_text = _get_generated_text(best_sequence, request_output)
result = {
"generated_text": generated_text,
}
if details:
result["details"] = details
if request_output.input.tgi_compat:
result = [result]
return json.dumps(result, ensure_ascii=False)


def _jsonlines_output_formatter(request_output: RequestOutput):
"""
jsonlines output formatter
Expand Down
4 changes: 4 additions & 0 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from djl_python.output_formatter import get_output_formatter, _json_output_formatter, sse_response_formatter, \
adapt_legacy_output_formatter
from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled


class Request(object):
Expand Down Expand Up @@ -114,6 +115,9 @@ def set_next_token(self,
self.request_output.set_finish_reason(finish_reason)
self.request_output.prompt_tokens_details = prompt_tokens_details
self.last_token = last_token
if (last_token and
serving_backport_for_non_streaming_http_error_codes_enabled()):
self.request_output.finished = True

def get_next_token(self) -> str:
"""
Expand Down
27 changes: 26 additions & 1 deletion engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from djl_python.properties_manager.properties import Properties
from djl_python.request import Request
from djl_python.request_io import Token
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled

FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"]

Expand Down Expand Up @@ -44,14 +45,31 @@ def stop_on_any_exception(func):
def try_catch_handling(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
except Exception as e:
logging.exception("Rolling batch inference error")
for request in self.active_requests:
token = Token(-1, "", -1, True)
request.set_next_token(token,
last_token=True,
finish_reason="error")
if serving_backport_for_non_streaming_http_error_codes_enabled(
):
request.error_message = str(e)
request.error_code = 424
response = self.postprocess_results()
if (serving_backport_for_non_streaming_http_error_codes_enabled()
and isinstance(response, list)):
# In case postprocess_results implementation doesn't set response "error"
# or "code", set it the same as we did on the request objects above.
# Note: We may want to forward-port this. Only downside is if we want
# `postprocess_results` to be able to "handle" the error, i.e., to
# intentionally not propagate the error_message and error_code above.
# But that's still doable by setting `error` to "" and `code` to 200.
for res in response:
if res.get("error", None) is None:
res["error"] = str(e)
if res.get("code", None) is None:
res["code"] = 424
self.reset()
return response

Expand Down Expand Up @@ -161,6 +179,13 @@ def postprocess_results(self) -> list[dict]:
"last": req.is_last_token(),
"content_type": req.get_content_type()
}
if serving_backport_for_non_streaming_http_error_codes_enabled():
error_message = getattr(req, "error_message", None)
error_code = getattr(req, "error_code", None)
if error_message is not None:
res["error"] = error_message
if error_code is not None:
res["code"] = error_code
req.reset_next_token()
results.append(res)

Expand Down
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
from typing import Union, Callable, Any, List

from djl_python.inputs import Input
Expand Down Expand Up @@ -192,3 +193,7 @@ def apply_profiling(self, *args, **kwargs):
return result

return apply_profiling


def serving_backport_for_non_streaming_http_error_codes_enabled():
return os.getenv("SERVING_BACKPORT_FOR_NON_STREAMING_HTTP_ERROR_CODES")
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,20 @@ void addResponse(byte[] json, Map<String, String> properties) {

if (code != null) {
Map<String, Object> map = new ConcurrentHashMap<>(2);
map.put("code", Integer.parseInt(code));
int httpStatusCode = Integer.parseInt(code);
map.put("code", httpStatusCode);
if (error != null) {
map.put("error", error);
}
if (isBackportForNonStreamingHttpErrorCodes) {
// Update http status code and any error message to the values here, so
// that non-streaming case can return non-200 on errors encountered during
// inference.
output.setCode(httpStatusCode);
if (error != null) {
output.setMessage(error);
}
}
byte[] buffer = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);
data.appendContent(buffer, true);
} else {
Expand Down
Loading