Skip to content

Commit

Permalink
Merge pull request #1139 from basetenlabs/bump-version-0.9.35
Browse files Browse the repository at this point in the history
Release 0.9.35
  • Loading branch information
bdubayah authored Sep 11, 2024
2 parents 181d0c2 + 7f6f1d2 commit 71bb741
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 215 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.34"
version = "0.9.35"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.8"
TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.9"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = [
"grpcio==1.62.3",
Expand Down
192 changes: 132 additions & 60 deletions truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import asyncio
import logging
from http import HTTPStatus
from typing import Optional

from typing import (
Callable,
Coroutine,
Mapping,
NoReturn,
Optional,
TypeVar,
Union,
overload,
)

import fastapi
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from typing_extensions import ParamSpec

# See https://github.com/basetenlabs/baseten/blob/master/docs/Error-Propagation.md
_TRUSS_SERVER_SERVICE_ID = 4
_BASETEN_UNEXPECTED_ERROR = 500
_BASETEN_DOWNSTREAM_ERROR_CODE = 600
_BASETEN_CLIENT_ERROR_CODE = 700


class ModelMissingError(Exception):
Expand All @@ -12,40 +32,6 @@ def __str__(self):
return self.path


class InferenceError(RuntimeError):
def __init__(self, reason):
self.reason = reason

def __str__(self):
return self.reason


class InvalidInput(ValueError):
"""
Exception class indicating invalid input arguments.
HTTP Servers should return HTTP_400 (Bad Request).
"""

def __init__(self, reason):
self.reason = reason

def __str__(self):
return self.reason


class ModelNotFound(Exception):
"""
Exception class indicating requested model does not exist.
HTTP Servers should return HTTP_404 (Not Found).
"""

def __init__(self, model_name=None):
self.reason = f"Model with name {model_name} does not exist."

def __str__(self):
return self.reason


class ModelNotReady(RuntimeError):
def __init__(self, model_name: str, detail: Optional[str] = None):
self.model_name = model_name
Expand All @@ -57,44 +43,130 @@ def __str__(self):
return self.error_msg


async def exception_handler(_, exc):
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(exc)}
)
class InputParsingError(ValueError):
pass


class UserCodeError(Exception):
pass


async def invalid_input_handler(_, exc):
return JSONResponse(status_code=HTTPStatus.BAD_REQUEST, content={"error": str(exc)})
def _make_baseten_error_headers(error_code: int) -> Mapping[str, str]:
return {
"X-BASETEN-ERROR-SOURCE": f"{_TRUSS_SERVER_SERVICE_ID:02}",
"X-BASETEN-ERROR-CODE": f"{error_code:03}",
}


async def inference_error_handler(_, exc):
def _make_baseten_response(
http_status: int,
info: Union[str, Exception],
baseten_error_code: int,
) -> fastapi.Response:
msg = str(info) if isinstance(info, Exception) else info
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(exc)}
status_code=http_status,
content={"error": msg},
headers=_make_baseten_error_headers(baseten_error_code),
)


async def generic_exception_handler(_, exc):
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"error": f"{type(exc).__name__} : {str(exc)}"},
async def exception_handler(
request: fastapi.Request, exc: Exception
) -> fastapi.Response:
if isinstance(exc, ModelMissingError):
return _make_baseten_response(
HTTPStatus.NOT_FOUND.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
)
if isinstance(exc, ModelNotReady):
return _make_baseten_response(
HTTPStatus.SERVICE_UNAVAILABLE.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
)
if isinstance(exc, InputParsingError):
return _make_baseten_response(
HTTPStatus.BAD_REQUEST.value,
exc,
_BASETEN_CLIENT_ERROR_CODE,
)
if isinstance(exc, UserCodeError):
return _make_baseten_response(
HTTPStatus.INTERNAL_SERVER_ERROR.value,
"Internal Server Error",
_BASETEN_DOWNSTREAM_ERROR_CODE,
)
if isinstance(exc, fastapi.HTTPException):
# This is a pass through, but additionally adds our custom error headers.
return _make_baseten_response(
exc.status_code, exc.detail, _BASETEN_DOWNSTREAM_ERROR_CODE
)
# Fallback case.
return _make_baseten_response(
HTTPStatus.INTERNAL_SERVER_ERROR.value,
f"Unhandled exception: {type(exc).__name__}: {str(exc)}",
_BASETEN_UNEXPECTED_ERROR,
)


async def model_not_found_handler(_, exc):
return JSONResponse(status_code=HTTPStatus.NOT_FOUND, content={"error": str(exc)})
HANDLED_EXCEPTIONS = {
ModelMissingError,
ModelNotReady,
NotImplementedError,
InputParsingError,
UserCodeError,
fastapi.HTTPException,
}


async def model_not_ready_handler(_, exc):
return JSONResponse(
status_code=HTTPStatus.SERVICE_UNAVAILABLE, content={"error": str(exc)}
)
def _intercept_user_exception(exc: Exception, logger: logging.Logger) -> NoReturn:
# Note that logger.exception logs the stacktrace, such that the user can
# debug this error from the logs.
# TODO: consider removing the wrapper function from the stack trace.
if isinstance(exc, HTTPException):
logger.exception("Model raised HTTPException", stacklevel=2)
raise exc
else:
logger.exception("Internal Server Error", stacklevel=2)
raise UserCodeError(str(exc))


async def not_implemented_error_handler(_, exc):
return JSONResponse(
status_code=HTTPStatus.NOT_IMPLEMENTED, content={"error": str(exc)}
)
_P = ParamSpec("_P")
_R = TypeVar("_R")
_R_async = TypeVar("_R_async", bound=Coroutine) # Return type for async functions


@overload
def intercept_exceptions(
func: Callable[_P, _R], logger: logging.Logger
) -> Callable[_P, _R]: ...


@overload
def intercept_exceptions(
func: Callable[_P, _R_async], logger: logging.Logger
) -> Callable[_P, _R_async]: ...


def intercept_exceptions(
func: Callable[_P, _R], logger: logging.Logger
) -> Callable[_P, _R]:
"""Converts all exceptions to 500-`HTTPException` and logs them.
If exception is already `HTTPException`, re-raises exception as is.
"""
if asyncio.iscoroutinefunction(func):

async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return await func(*args, **kwargs)
except Exception as e:
_intercept_user_exception(e, logger)

return inner_async # type: ignore[return-value]
else:

def inner_sync(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return func(*args, **kwargs)
except Exception as e:
_intercept_user_exception(e, logger)

async def http_exception_handler(_, exc):
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
return inner_sync
68 changes: 40 additions & 28 deletions truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from pathlib import Path
from typing import AsyncGenerator, Dict, List, Optional, Union

import common.errors as errors
import shared.util as utils
import pydantic
import uvicorn
from common import tracing
from common import errors, tracing
from common.termination_handler_middleware import TerminationHandlerMiddleware
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.routing import APIRoute as FastAPIRoute
from model_wrapper import ModelWrapper
from opentelemetry import propagate as otel_propagate
from opentelemetry.sdk import trace as sdk_trace
from shared import util
from shared.logging import setup_logging
from shared.secrets_resolver import SecretsResolver
from shared.serialization import (
Expand All @@ -41,6 +41,8 @@
DEFAULT_NUM_SERVER_PROCESSES = 1
WORKER_TERMINATION_TIMEOUT_SECS = 120.0
WORKER_TERMINATION_CHECK_INTERVAL_SECS = 0.5
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"


async def parse_body(request: Request) -> bytes:
Expand All @@ -55,12 +57,6 @@ async def parse_body(request: Request) -> bytes:
raise HTTPException(status_code=499, detail=error_message) from exc


FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s [%(funcName)s():%(lineno)s] %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser()
logging.basicConfig(level=logging.INFO, format=FORMAT, datefmt=DATE_FORMAT)


class UvicornCustomServer(multiprocessing.Process):
def __init__(
self, config: uvicorn.Config, sockets: Optional[List[socket.socket]] = None
Expand Down Expand Up @@ -145,20 +141,39 @@ async def predict(
if self.is_binary(request):
with tracing.section_as_event(span, "binary-deserialize"):
body = truss_msgpack_deserialize(body_raw)
if model.truss_schema:
try:
with tracing.section_as_event(span, "parse-pydantic"):
body = model.truss_schema.input_type.parse_obj(body)
except pydantic.ValidationError as e:
raise errors.InputParsingError(
f"Request Validation Error, {str(e)}"
) from e
else:
try:
with tracing.section_as_event(span, "json-deserialize"):
body = json.loads(body_raw)
except json.JSONDecodeError as e:
error_message = f"Invalid JSON payload: {str(e)}"
logging.error(error_message)
raise HTTPException(status_code=400, detail=error_message)

# calls ModelWrapper.__call__, which runs validate, preprocess, predict, and postprocess
if model.truss_schema:
if model.truss_schema:
try:
with tracing.section_as_event(span, "parse-pydantic"):
body = model.truss_schema.input_type.parse_raw(body_raw)
except pydantic.ValidationError as e:
raise errors.InputParsingError(
f"Request Validation Error, {str(e)}"
) from e
else:
try:
with tracing.section_as_event(span, "json-deserialize"):
body = json.loads(body_raw)
except json.JSONDecodeError as e:
raise errors.InputParsingError(
f"Invalid JSON payload: {str(e)}"
) from e

# Calls ModelWrapper.__call__, which runs validate, preprocess, predict,
# and postprocess.
with tracing.section_as_event(span, "model-call"):
response: Union[Dict, Generator] = await model(
body,
headers=utils.transform_keys(
headers=util.transform_keys(
request.headers, lambda key: key.lower()
),
)
Expand Down Expand Up @@ -192,7 +207,6 @@ async def schema(self, model_name: str) -> Dict:

if model.truss_schema is None:
# If there is not a TrussSchema, we return a 404.

if model.ready:
raise HTTPException(status_code=404, detail="No schema found")
else:
Expand Down Expand Up @@ -290,19 +304,17 @@ def create_application(self):
),
],
exception_handlers={
errors.InferenceError: errors.inference_error_handler,
errors.ModelNotFound: errors.model_not_found_handler,
errors.ModelNotReady: errors.model_not_ready_handler,
NotImplementedError: errors.not_implemented_error_handler,
HTTPException: errors.http_exception_handler,
Exception: errors.generic_exception_handler,
exc: errors.exception_handler for exc in errors.HANDLED_EXCEPTIONS
},
)
# Above `exception_handlers` only triggers on exact exception classes.
# This here is a fallback to add our custom headers in all other cases.
app.add_exception_handler(Exception, errors.exception_handler)

def exit_self():
# Note that this kills the current process, the worker process, not
# the main truss_server process.
utils.kill_child_processes(os.getpid())
util.kill_child_processes(os.getpid())
sys.exit()

termination_handler_middleware = TerminationHandlerMiddleware(
Expand Down Expand Up @@ -395,7 +407,7 @@ def stop_servers():
)
for _ in range(termination_check_attempts):
time.sleep(WORKER_TERMINATION_CHECK_INTERVAL_SECS)
if utils.all_processes_dead(servers):
if util.all_processes_dead(servers):
return

for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
Expand Down
Loading

0 comments on commit 71bb741

Please sign in to comment.