From dd2b31c8c9251d6814f51b74bde3d2d58134bac1 Mon Sep 17 00:00:00 2001 From: rcano-baseten Date: Wed, 11 Sep 2024 12:41:35 -0400 Subject: [PATCH 1/4] handle different types of envvars (#1138) * handle different types * pre-commmit gang --- truss/tests/test_config.py | 17 +++++++++++++++++ truss/truss_config.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index 5f16c5f73..69cf80f45 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -396,6 +396,23 @@ def test_from_yaml_python_version(): assert result.python_version == "py39" +def test_from_yaml_environment_variables(): + data = { + "description": "this is a test", + "environment_variables": {"foo": "bar", "bool": True, "int": 0}, + } + with tempfile.NamedTemporaryFile(mode="w", delete=False) as yaml_file: + yaml_path = Path(yaml_file.name) + yaml.safe_dump(data, yaml_file) + + result = TrussConfig.from_yaml(yaml_path) + assert result.environment_variables == { + "foo": "bar", + "bool": "true", + "int": "0", + } + + def test_secret_to_path_mapping_correct_type(default_config): data = { "description": "this is a test", diff --git a/truss/truss_config.py b/truss/truss_config.py index 778d454ec..93de10276 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -543,7 +543,7 @@ def from_dict(d): requirements_file=d.get("requirements_file", None), requirements=d.get("requirements", []), system_packages=d.get("system_packages", []), - environment_variables=d.get("environment_variables", {}), + environment_variables=_handle_env_vars(d.get("environment_variables", {})), resources=Resources.from_dict(d.get("resources", {})), runtime=Runtime.from_dict(d.get("runtime", {})), build=Build.from_dict(d.get("build", {})), @@ -660,6 +660,16 @@ def validate(self): self._validate_quant_format_and_accelerator_for_trt_llm_builder() +def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]: + new_env_vars = {} + for k, v in env_vars.items(): + if isinstance(v, bool): + new_env_vars[k] = str(v).lower() + else: + new_env_vars[k] = str(v) + return new_env_vars + + DATACLASS_TO_REQ_KEYS_MAP = { Resources: {"accelerator", "cpu", "memory", "use_gpu"}, Runtime: {"predict_concurrency"}, From 5f0eb4830c76dc1a3c425c0cf2b085236846f877 Mon Sep 17 00:00:00 2001 From: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:42:34 -0700 Subject: [PATCH 2/4] Unified error handling in truss server and adding baseten error headers to response (#1132) --- truss/templates/server/common/errors.py | 192 ++++++++++++------ truss/templates/server/common/truss_server.py | 68 ++++--- truss/templates/server/model_wrapper.py | 75 ++----- truss/tests/test_model_inference.py | 16 ++ truss/tests/test_model_schema.py | 30 ++- 5 files changed, 224 insertions(+), 157 deletions(-) diff --git a/truss/templates/server/common/errors.py b/truss/templates/server/common/errors.py index 2218e2327..fea19fe12 100644 --- a/truss/templates/server/common/errors.py +++ b/truss/templates/server/common/errors.py @@ -1,7 +1,27 @@ +import asyncio +import logging from http import HTTPStatus -from typing import Optional - +from typing import ( + Callable, + Coroutine, + Mapping, + NoReturn, + Optional, + TypeVar, + Union, + overload, +) + +import fastapi +from fastapi import HTTPException from fastapi.responses import JSONResponse +from typing_extensions import ParamSpec + +# See https://github.com/basetenlabs/baseten/blob/master/docs/Error-Propagation.md +_TRUSS_SERVER_SERVICE_ID = 4 +_BASETEN_UNEXPECTED_ERROR = 500 +_BASETEN_DOWNSTREAM_ERROR_CODE = 600 +_BASETEN_CLIENT_ERROR_CODE = 700 class ModelMissingError(Exception): @@ -12,40 +32,6 @@ def __str__(self): return self.path -class InferenceError(RuntimeError): - def __init__(self, reason): - self.reason = reason - - def __str__(self): - return self.reason - - -class InvalidInput(ValueError): - """ - Exception class indicating invalid input arguments. - HTTP Servers should return HTTP_400 (Bad Request). - """ - - def __init__(self, reason): - self.reason = reason - - def __str__(self): - return self.reason - - -class ModelNotFound(Exception): - """ - Exception class indicating requested model does not exist. - HTTP Servers should return HTTP_404 (Not Found). - """ - - def __init__(self, model_name=None): - self.reason = f"Model with name {model_name} does not exist." - - def __str__(self): - return self.reason - - class ModelNotReady(RuntimeError): def __init__(self, model_name: str, detail: Optional[str] = None): self.model_name = model_name @@ -57,44 +43,130 @@ def __str__(self): return self.error_msg -async def exception_handler(_, exc): - return JSONResponse( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(exc)} - ) +class InputParsingError(ValueError): + pass + + +class UserCodeError(Exception): + pass -async def invalid_input_handler(_, exc): - return JSONResponse(status_code=HTTPStatus.BAD_REQUEST, content={"error": str(exc)}) +def _make_baseten_error_headers(error_code: int) -> Mapping[str, str]: + return { + "X-BASETEN-ERROR-SOURCE": f"{_TRUSS_SERVER_SERVICE_ID:02}", + "X-BASETEN-ERROR-CODE": f"{error_code:03}", + } -async def inference_error_handler(_, exc): +def _make_baseten_response( + http_status: int, + info: Union[str, Exception], + baseten_error_code: int, +) -> fastapi.Response: + msg = str(info) if isinstance(info, Exception) else info return JSONResponse( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"error": str(exc)} + status_code=http_status, + content={"error": msg}, + headers=_make_baseten_error_headers(baseten_error_code), ) -async def generic_exception_handler(_, exc): - return JSONResponse( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - content={"error": f"{type(exc).__name__} : {str(exc)}"}, +async def exception_handler( + request: fastapi.Request, exc: Exception +) -> fastapi.Response: + if isinstance(exc, ModelMissingError): + return _make_baseten_response( + HTTPStatus.NOT_FOUND.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE + ) + if isinstance(exc, ModelNotReady): + return _make_baseten_response( + HTTPStatus.SERVICE_UNAVAILABLE.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE + ) + if isinstance(exc, InputParsingError): + return _make_baseten_response( + HTTPStatus.BAD_REQUEST.value, + exc, + _BASETEN_CLIENT_ERROR_CODE, + ) + if isinstance(exc, UserCodeError): + return _make_baseten_response( + HTTPStatus.INTERNAL_SERVER_ERROR.value, + "Internal Server Error", + _BASETEN_DOWNSTREAM_ERROR_CODE, + ) + if isinstance(exc, fastapi.HTTPException): + # This is a pass through, but additionally adds our custom error headers. + return _make_baseten_response( + exc.status_code, exc.detail, _BASETEN_DOWNSTREAM_ERROR_CODE + ) + # Fallback case. + return _make_baseten_response( + HTTPStatus.INTERNAL_SERVER_ERROR.value, + f"Unhandled exception: {type(exc).__name__}: {str(exc)}", + _BASETEN_UNEXPECTED_ERROR, ) -async def model_not_found_handler(_, exc): - return JSONResponse(status_code=HTTPStatus.NOT_FOUND, content={"error": str(exc)}) +HANDLED_EXCEPTIONS = { + ModelMissingError, + ModelNotReady, + NotImplementedError, + InputParsingError, + UserCodeError, + fastapi.HTTPException, +} -async def model_not_ready_handler(_, exc): - return JSONResponse( - status_code=HTTPStatus.SERVICE_UNAVAILABLE, content={"error": str(exc)} - ) +def _intercept_user_exception(exc: Exception, logger: logging.Logger) -> NoReturn: + # Note that logger.exception logs the stacktrace, such that the user can + # debug this error from the logs. + # TODO: consider removing the wrapper function from the stack trace. + if isinstance(exc, HTTPException): + logger.exception("Model raised HTTPException", stacklevel=2) + raise exc + else: + logger.exception("Internal Server Error", stacklevel=2) + raise UserCodeError(str(exc)) -async def not_implemented_error_handler(_, exc): - return JSONResponse( - status_code=HTTPStatus.NOT_IMPLEMENTED, content={"error": str(exc)} - ) +_P = ParamSpec("_P") +_R = TypeVar("_R") +_R_async = TypeVar("_R_async", bound=Coroutine) # Return type for async functions + + +@overload +def intercept_exceptions( + func: Callable[_P, _R], logger: logging.Logger +) -> Callable[_P, _R]: ... + + +@overload +def intercept_exceptions( + func: Callable[_P, _R_async], logger: logging.Logger +) -> Callable[_P, _R_async]: ... + + +def intercept_exceptions( + func: Callable[_P, _R], logger: logging.Logger +) -> Callable[_P, _R]: + """Converts all exceptions to 500-`HTTPException` and logs them. + If exception is already `HTTPException`, re-raises exception as is. + """ + if asyncio.iscoroutinefunction(func): + + async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R: + try: + return await func(*args, **kwargs) + except Exception as e: + _intercept_user_exception(e, logger) + + return inner_async # type: ignore[return-value] + else: + def inner_sync(*args: _P.args, **kwargs: _P.kwargs) -> _R: + try: + return func(*args, **kwargs) + except Exception as e: + _intercept_user_exception(e, logger) -async def http_exception_handler(_, exc): - return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) + return inner_sync diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 7db57a99c..b567ea500 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -11,10 +11,9 @@ from pathlib import Path from typing import AsyncGenerator, Dict, List, Optional, Union -import common.errors as errors -import shared.util as utils +import pydantic import uvicorn -from common import tracing +from common import errors, tracing from common.termination_handler_middleware import TerminationHandlerMiddleware from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse @@ -22,6 +21,7 @@ from model_wrapper import ModelWrapper from opentelemetry import propagate as otel_propagate from opentelemetry.sdk import trace as sdk_trace +from shared import util from shared.logging import setup_logging from shared.secrets_resolver import SecretsResolver from shared.serialization import ( @@ -41,6 +41,8 @@ DEFAULT_NUM_SERVER_PROCESSES = 1 WORKER_TERMINATION_TIMEOUT_SECS = 120.0 WORKER_TERMINATION_CHECK_INTERVAL_SECS = 0.5 +INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser() +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" async def parse_body(request: Request) -> bytes: @@ -55,12 +57,6 @@ async def parse_body(request: Request) -> bytes: raise HTTPException(status_code=499, detail=error_message) from exc -FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s [%(funcName)s():%(lineno)s] %(message)s" -DATE_FORMAT = "%Y-%m-%d %H:%M:%S" -INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser() -logging.basicConfig(level=logging.INFO, format=FORMAT, datefmt=DATE_FORMAT) - - class UvicornCustomServer(multiprocessing.Process): def __init__( self, config: uvicorn.Config, sockets: Optional[List[socket.socket]] = None @@ -145,20 +141,39 @@ async def predict( if self.is_binary(request): with tracing.section_as_event(span, "binary-deserialize"): body = truss_msgpack_deserialize(body_raw) + if model.truss_schema: + try: + with tracing.section_as_event(span, "parse-pydantic"): + body = model.truss_schema.input_type.parse_obj(body) + except pydantic.ValidationError as e: + raise errors.InputParsingError( + f"Request Validation Error, {str(e)}" + ) from e else: - try: - with tracing.section_as_event(span, "json-deserialize"): - body = json.loads(body_raw) - except json.JSONDecodeError as e: - error_message = f"Invalid JSON payload: {str(e)}" - logging.error(error_message) - raise HTTPException(status_code=400, detail=error_message) - - # calls ModelWrapper.__call__, which runs validate, preprocess, predict, and postprocess + if model.truss_schema: + if model.truss_schema: + try: + with tracing.section_as_event(span, "parse-pydantic"): + body = model.truss_schema.input_type.parse_raw(body_raw) + except pydantic.ValidationError as e: + raise errors.InputParsingError( + f"Request Validation Error, {str(e)}" + ) from e + else: + try: + with tracing.section_as_event(span, "json-deserialize"): + body = json.loads(body_raw) + except json.JSONDecodeError as e: + raise errors.InputParsingError( + f"Invalid JSON payload: {str(e)}" + ) from e + + # Calls ModelWrapper.__call__, which runs validate, preprocess, predict, + # and postprocess. with tracing.section_as_event(span, "model-call"): response: Union[Dict, Generator] = await model( body, - headers=utils.transform_keys( + headers=util.transform_keys( request.headers, lambda key: key.lower() ), ) @@ -192,7 +207,6 @@ async def schema(self, model_name: str) -> Dict: if model.truss_schema is None: # If there is not a TrussSchema, we return a 404. - if model.ready: raise HTTPException(status_code=404, detail="No schema found") else: @@ -290,19 +304,17 @@ def create_application(self): ), ], exception_handlers={ - errors.InferenceError: errors.inference_error_handler, - errors.ModelNotFound: errors.model_not_found_handler, - errors.ModelNotReady: errors.model_not_ready_handler, - NotImplementedError: errors.not_implemented_error_handler, - HTTPException: errors.http_exception_handler, - Exception: errors.generic_exception_handler, + exc: errors.exception_handler for exc in errors.HANDLED_EXCEPTIONS }, ) + # Above `exception_handlers` only triggers on exact exception classes. + # This here is a fallback to add our custom headers in all other cases. + app.add_exception_handler(Exception, errors.exception_handler) def exit_self(): # Note that this kills the current process, the worker process, not # the main truss_server process. - utils.kill_child_processes(os.getpid()) + util.kill_child_processes(os.getpid()) sys.exit() termination_handler_middleware = TerminationHandlerMiddleware( @@ -395,7 +407,7 @@ def stop_servers(): ) for _ in range(termination_check_attempts): time.sleep(WORKER_TERMINATION_CHECK_INTERVAL_SECS) - if utils.all_processes_dead(servers): + if util.all_processes_dead(servers): return for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 32bb9f6df..4b7bab46c 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -15,28 +15,22 @@ Any, AsyncGenerator, Callable, - Coroutine, Dict, Mapping, - NoReturn, Optional, - TypeVar, Union, ) import opentelemetry.sdk.trace as sdk_trace -import pydantic from anyio import Semaphore, to_thread -from common import tracing +from common import errors, tracing from common.patches import apply_patches from common.retry import retry from common.schema import TrussSchema -from fastapi import HTTPException from opentelemetry import trace from pydantic import BaseModel from shared.lazy_data_resolver import LazyDataResolver from shared.secrets_resolver import SecretsResolver -from typing_extensions import ParamSpec MODEL_BASENAME = "model" @@ -231,10 +225,13 @@ async def preprocess( return payload if inspect.iscoroutinefunction(self._model.preprocess): - return await _intercept_exceptions_async(self._model.preprocess)(payload) + return await errors.intercept_exceptions( + self._model.preprocess, self._logger + )(payload) else: return await to_thread.run_sync( - _intercept_exceptions_sync(self._model.preprocess), payload + errors.intercept_exceptions(self._model.preprocess, self._logger), + payload, ) async def predict( @@ -255,10 +252,12 @@ async def predict( return self._model.predict(payload) if inspect.iscoroutinefunction(self._model.predict): - return await _intercept_exceptions_async(self._model.predict)(payload) + return await errors.intercept_exceptions(self._model.predict, self._logger)( + payload + ) return await to_thread.run_sync( - _intercept_exceptions_sync(self._model.predict), payload + errors.intercept_exceptions(self._model.predict, self._logger), payload ) async def postprocess( @@ -280,10 +279,12 @@ async def postprocess( return self._model.postprocess(response) if inspect.iscoroutinefunction(self._model.postprocess): - return await _intercept_exceptions_async(self._model.postprocess)(response) + return await errors.intercept_exceptions( + self._model.postprocess, self._logger + )(response) return await to_thread.run_sync( - _intercept_exceptions_sync(self._model.postprocess), response + errors.intercept_exceptions(self._model.postprocess, self._logger), response ) async def write_response_to_queue( @@ -304,7 +305,7 @@ async def write_response_to_queue( async def _streaming_post_process(self, response: Any, span: trace.Span) -> Any: if hasattr(self._model, "postprocess"): - logging.warning( + self._logger.warning( "Predict returned a streaming response, while a postprocess is defined." "Note that in this case, the postprocess will run within the predict lock." ) @@ -388,15 +389,6 @@ async def __call__( String: in case of non-streamed generator (the string is the JSON result). """ with self._tracer.start_as_current_span("call-pre") as span_pre: - if self.truss_schema is not None: - try: - with tracing.section_as_event(span_pre, "parse-pydantic"): - body = self.truss_schema.input_type(**body) - except pydantic.ValidationError as e: - self._logger.info("Request Validation Error") - raise HTTPException( - status_code=400, detail=f"Request Validation Error, {str(e)}" - ) from e with tracing.section_as_event( span_pre, "preprocess" ), tracing.detach_context(): @@ -500,43 +492,6 @@ def _elapsed_ms(since_micro_seconds: float) -> int: return int((time.perf_counter() - since_micro_seconds) * 1000) -def _handle_exception(exception: Exception) -> NoReturn: - # Note that logger.exception logs the stacktrace, such that the user can - # debug this error from the logs. - if isinstance(exception, HTTPException): - logging.exception("Model raised HTTPException") - raise exception - else: - logging.exception("Internal Server Error") - raise HTTPException(status_code=500, detail="Internal Server Error") - - -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -def _intercept_exceptions_sync(func: Callable[_P, _R]) -> Callable[_P, _R]: - def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: - try: - return func(*args, **kwargs) - except Exception as e: - _handle_exception(e) - - return inner - - -def _intercept_exceptions_async( - func: Callable[_P, Coroutine[Any, Any, _R]], -) -> Callable[_P, Coroutine[Any, Any, _R]]: - async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: - try: - return await func(*args, **kwargs) - except Exception as e: - _handle_exception(e) - - return inner - - def _init_extensions(config, data_dir, secrets, lazy_data_resolver): extensions = {} extensions_path = Path(__file__).parent / EXTENSIONS_DIR_NAME diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index c3283627d..e64c67ca2 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -454,6 +454,8 @@ def predict(self, request): assert_logs_contain_error(container.logs(), missing_secret_error_message) assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: # Case where the secret is not mounted @@ -473,6 +475,8 @@ def predict(self, request): assert_logs_contain_error(container.logs(), missing_secret_error_message) assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" @pytest.mark.integration @@ -662,6 +666,8 @@ def predict(self, request): assert_logs_contain_error(container.logs(), "ValueError: error") assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" model_preprocess_error = """ class Model: @@ -690,6 +696,8 @@ def predict(self, request): assert_logs_contain_error(container.logs(), "ValueError: error") assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" model_postprocess_error = """ class Model: @@ -717,6 +725,8 @@ def postprocess(self, response): assert "error" in response.json() assert_logs_contain_error(container.logs(), "ValueError: error") assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" model_async = """ class Model: @@ -743,6 +753,8 @@ async def predict(self, request): assert_logs_contain_error(container.logs(), "ValueError: error") assert "Internal Server Error" in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" @pytest.mark.integration @@ -773,6 +785,8 @@ def predict(self, request): response = requests.post(full_url, json={}) assert response.status_code == 500 assert "error" in response.json() + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" assert_logs_contain_error( container.logs(), @@ -781,6 +795,8 @@ def predict(self, request): ) assert "My custom message." in response.json()["error"] + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "600" @pytest.mark.integration diff --git a/truss/tests/test_model_schema.py b/truss/tests/test_model_schema.py index 360fd5680..9457e8fa3 100644 --- a/truss/tests/test_model_schema.py +++ b/truss/tests/test_model_schema.py @@ -5,6 +5,7 @@ import pytest import requests +from truss.templates.shared import serialization from truss.tests.helpers import create_truss from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all from truss.truss_handle import TrussHandle @@ -35,6 +36,8 @@ def test_truss_with_no_annotations(): assert schema_response.status_code == 404 assert schema_response.json()["error"] == "No schema found" + assert schema_response.headers["x-baseten-error-source"] == "04" + assert schema_response.headers["x-baseten-error-code"] == "600" @pytest.mark.integration @@ -58,8 +61,9 @@ def predict(self, request: str) -> list[str]: schema_response = requests.get(SCHEMA_URL) assert schema_response.status_code == 404 - assert schema_response.json()["error"] == "No schema found" + assert schema_response.headers["x-baseten-error-source"] == "04" + assert schema_response.headers["x-baseten-error-code"] == "600" @pytest.mark.integration @@ -102,36 +106,44 @@ def test_truss_with_annotated_inputs_outputs(): with ensure_kill_all(): _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) - - response = requests.post(INFERENCE_URL, json={"prompt": "value"}) + # Valid JSON input. + json_input = {"prompt": "value"} + response = requests.post(INFERENCE_URL, json=json_input) assert response.json() == { "generated_text": "value", } - # An invalid input + # Valid binary input. + byte_input = serialization.truss_msgpack_serialize(json_input) + print(byte_input) + response = requests.post( + INFERENCE_URL, + data=byte_input, + headers={"Content-Type": "application/octet-stream"}, + ) + assert response.content == b"\x81\xaegenerated_text\xa5value" + # An invalid input response = requests.post(INFERENCE_URL, json={"bad_key": "value"}) - assert response.status_code == 400 assert "error" in response.json() - assert ( "Request Validation Error, 1 validation error for ModelInput" "\nprompt\n Field required [type=missing, input_value={'bad_key': 'value'}, input_type=dict]\n" in response.json()["error"] ) + assert response.headers["x-baseten-error-source"] == "04" + assert response.headers["x-baseten-error-code"] == "700" + # Schema response. schema_response = requests.get(SCHEMA_URL) - schema = schema_response.json() - assert schema["input_schema"] == { "properties": {"prompt": {"title": "Prompt", "type": "string"}}, "required": ["prompt"], "title": "ModelInput", "type": "object", } - assert schema["output_schema"] == { "properties": { "generated_text": {"title": "Generated Text", "type": "string"} From 8544185ab4296611affdc9cf6c67702d6ef40a4d Mon Sep 17 00:00:00 2001 From: Bryce Dubayah Date: Wed, 11 Sep 2024 14:50:46 -0700 Subject: [PATCH 3/4] BT-12026 BT-12027 BT-12029 Support parallel function calls, optional calls, and mistral (#1133) * Update briton server image and template * improvements for handling concurrency * bump version --- pyproject.toml | 2 +- truss/constants.py | 2 +- .../trtllm-briton/packages/briton_pb2.py | 40 ++-- truss/templates/trtllm-briton/src/engine.py | 209 +++++++++++++++--- 4 files changed, 196 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 187b3f6ad..f07866187 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.34" +version = "0.9.35rc1" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/constants.py b/truss/constants.py index 4e6cd6f57..e3c64e7f6 100644 --- a/truss/constants.py +++ b/truss/constants.py @@ -103,7 +103,7 @@ REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_" -TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.8" +TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.9" TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3" BASE_TRTLLM_REQUIREMENTS = [ "grpcio==1.62.3", diff --git a/truss/templates/trtllm-briton/packages/briton_pb2.py b/truss/templates/trtllm-briton/packages/briton_pb2.py index a248e06fb..086db0dae 100644 --- a/truss/templates/trtllm-briton/packages/briton_pb2.py +++ b/truss/templates/trtllm-briton/packages/briton_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0c\x62riton.proto\x12\x06\x62riton"r\n\x06Tensor\x12#\n\x05shape\x18\x01 \x01(\x0b\x32\x14.briton.Tensor.Shape\x12\x1f\n\x05\x64type\x18\x02 \x01(\x0e\x32\x10.briton.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x1a\x14\n\x05Shape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x03"\xb4\x06\n\x10InferenceRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x12\n\ninput_text\x18\x02 \x01(\t\x12\x11\n\tinput_ids\x18\x03 \x03(\x05\x12\x1f\n\x12request_output_len\x18\x05 \x01(\rH\x00\x88\x01\x01\x12\x13\n\x06\x65nd_id\x18\x06 \x01(\rH\x01\x88\x01\x01\x12\x13\n\x06pad_id\x18\x07 \x01(\rH\x02\x88\x01\x01\x12\x17\n\nbeam_width\x18\n \x01(\rH\x03\x88\x01\x01\x12\x18\n\x0btemperature\x18\x0b \x01(\x02H\x04\x88\x01\x01\x12\x1a\n\rruntime_top_k\x18\x0c \x01(\rH\x05\x88\x01\x01\x12\x1a\n\rruntime_top_p\x18\r \x01(\x02H\x06\x88\x01\x01\x12\x18\n\x0blen_penalty\x18\x0e \x01(\x02H\x07\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x0f \x01(\x02H\x08\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x10 \x01(\x02H\t\x88\x01\x01\x12\x11\n\tbad_words\x18\x11 \x03(\t\x12\x12\n\nstop_words\x18\x12 \x03(\t\x12\x19\n\x0clora_task_id\x18\x13 \x01(\x04H\n\x88\x01\x01\x12)\n\x0clora_weights\x18\x14 \x01(\x0b\x32\x0e.briton.TensorH\x0b\x88\x01\x01\x12(\n\x0blora_config\x18\x15 \x01(\x0b\x32\x0e.briton.TensorH\x0c\x88\x01\x01\x12\x18\n\x0brandom_seed\x18\x16 \x01(\x03H\r\x88\x01\x01\x12\x1f\n\x12output_schema_hash\x18\x17 \x01(\tH\x0e\x88\x01\x01\x42\x15\n\x13_request_output_lenB\t\n\x07_end_idB\t\n\x07_pad_idB\r\n\x0b_beam_widthB\x0e\n\x0c_temperatureB\x10\n\x0e_runtime_top_kB\x10\n\x0e_runtime_top_pB\x0e\n\x0c_len_penaltyB\x15\n\x13_repetition_penaltyB\x13\n\x11_presence_penaltyB\x0f\n\r_lora_task_idB\x0f\n\r_lora_weightsB\x0e\n\x0c_lora_configB\x0e\n\x0c_random_seedB\x15\n\x13_output_schema_hash"R\n\x13InferenceAnswerPart\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12\x12\n\noutput_ids\x18\x03 \x03(\x05"\xa6\x08\n\x0c\x42ritonConfig\x12\x13\n\x0b\x65ngine_path\x18\x01 \x01(\t\x12\x14\n\x0chf_tokenizer\x18\x02 \x01(\t\x12N\n\x16\x62\x61tch_scheduler_policy\x18\x05 \x01(\x0e\x32).briton.BritonConfig.BatchSchedulerPolicyH\x00\x88\x01\x01\x12\x1f\n\x12\x65nable_trt_overlap\x18\x06 \x01(\x08H\x01\x88\x01\x01\x12)\n\x1cmax_tokens_in_paged_kv_cache\x18\n \x01(\x04H\x02\x88\x01\x01\x12+\n\x1ekv_cache_free_gpu_mem_fraction\x18\x0b \x01(\x02H\x03\x88\x01\x01\x12!\n\x14medusa_decoding_mode\x18\x0c \x01(\x08H\x04\x88\x01\x01\x12#\n\x16\x65nable_chunked_context\x18\r \x01(\x08H\x05\x88\x01\x01\x12"\n\x15\x65nable_kv_cache_reuse\x18\x0e \x01(\x08H\x06\x88\x01\x01\x12\'\n\x1akv_cache_host_memory_bytes\x18\x0f \x01(\x04H\x07\x88\x01\x01\x12(\n\x1blora_cache_max_adapter_size\x18\x10 \x01(\x04H\x08\x88\x01\x01\x12,\n\x1flora_cache_optimal_adapter_size\x18\x11 \x01(\x04H\t\x88\x01\x01\x12+\n\x1elora_cache_gpu_memory_fraction\x18\x12 \x01(\x02H\n\x88\x01\x01\x12)\n\x1clora_cache_host_memory_bytes\x18\x13 \x01(\x04H\x0b\x88\x01\x01\x12\x1a\n\rfsm_cache_dir\x18\x14 \x01(\tH\x0c\x88\x01\x01"D\n\x14\x42\x61tchSchedulerPolicy\x12\x13\n\x0fMAX_UTILIZATION\x10\x00\x12\x17\n\x13GUARANTEED_NO_EVICT\x10\x01\x42\x19\n\x17_batch_scheduler_policyB\x15\n\x13_enable_trt_overlapB\x1f\n\x1d_max_tokens_in_paged_kv_cacheB!\n\x1f_kv_cache_free_gpu_mem_fractionB\x17\n\x15_medusa_decoding_modeB\x19\n\x17_enable_chunked_contextB\x18\n\x16_enable_kv_cache_reuseB\x1d\n\x1b_kv_cache_host_memory_bytesB\x1e\n\x1c_lora_cache_max_adapter_sizeB"\n _lora_cache_optimal_adapter_sizeB!\n\x1f_lora_cache_gpu_memory_fractionB\x1f\n\x1d_lora_cache_host_memory_bytesB\x10\n\x0e_fsm_cache_dir"\x98\x01\n\x10TokenToNextState\x12K\n\x13token_to_next_state\x18\x01 \x03(\x0b\x32..briton.TokenToNextState.TokenToNextStateEntry\x1a\x37\n\x15TokenToNextStateEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01"\xfb\x01\n\x0eStatesToTokens\x12\x44\n\x10states_to_tokens\x18\x01 \x03(\x0b\x32*.briton.StatesToTokens.StatesToTokensEntry\x12\x17\n\nvocab_size\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x19\n\x0c\x65os_token_id\x18\x03 \x01(\x05H\x01\x88\x01\x01\x1aO\n\x13StatesToTokensEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.briton.TokenToNextState:\x02\x38\x01\x42\r\n\x0b_vocab_sizeB\x0f\n\r_eos_token_id*\xa8\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0b\n\x07\x44T_INT4\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_UINT8\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0b\x12\x0e\n\nDT_FLOAT32\x10\x0c\x12\n\n\x06\x44T_FP8\x10\r\x12\x0b\n\x07\x44T_BOOL\x10\x14\x32L\n\x06\x42riton\x12\x42\n\x05Infer\x12\x18.briton.InferenceRequest\x1a\x1b.briton.InferenceAnswerPart"\x00\x30\x01\x62\x06proto3' + b'\n\x0c\x62riton.proto\x12\x06\x62riton"r\n\x06Tensor\x12#\n\x05shape\x18\x01 \x01(\x0b\x32\x14.briton.Tensor.Shape\x12\x1f\n\x05\x64type\x18\x02 \x01(\x0e\x32\x10.briton.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x1a\x14\n\x05Shape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x03"\x82\x07\n\x10InferenceRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x12\n\ninput_text\x18\x02 \x01(\t\x12\x11\n\tinput_ids\x18\x03 \x03(\x05\x12\x1f\n\x12request_output_len\x18\x05 \x01(\rH\x00\x88\x01\x01\x12\x13\n\x06\x65nd_id\x18\x06 \x01(\rH\x01\x88\x01\x01\x12\x13\n\x06pad_id\x18\x07 \x01(\rH\x02\x88\x01\x01\x12\x17\n\nbeam_width\x18\n \x01(\rH\x03\x88\x01\x01\x12\x18\n\x0btemperature\x18\x0b \x01(\x02H\x04\x88\x01\x01\x12\x1a\n\rruntime_top_k\x18\x0c \x01(\rH\x05\x88\x01\x01\x12\x1a\n\rruntime_top_p\x18\r \x01(\x02H\x06\x88\x01\x01\x12\x18\n\x0blen_penalty\x18\x0e \x01(\x02H\x07\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x0f \x01(\x02H\x08\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x10 \x01(\x02H\t\x88\x01\x01\x12\x11\n\tbad_words\x18\x11 \x03(\t\x12\x12\n\nstop_words\x18\x12 \x03(\t\x12\x19\n\x0clora_task_id\x18\x13 \x01(\x04H\n\x88\x01\x01\x12)\n\x0clora_weights\x18\x14 \x01(\x0b\x32\x0e.briton.TensorH\x0b\x88\x01\x01\x12(\n\x0blora_config\x18\x15 \x01(\x0b\x32\x0e.briton.TensorH\x0c\x88\x01\x01\x12\x18\n\x0brandom_seed\x18\x16 \x01(\x03H\r\x88\x01\x01\x12\x1f\n\x12output_schema_hash\x18\x17 \x01(\tH\x0e\x88\x01\x01\x12\x15\n\x08tools_id\x18\x18 \x01(\rH\x0f\x88\x01\x01\x12\x18\n\x0b\x66orce_tools\x18\x19 \x01(\x08H\x10\x88\x01\x01\x42\x15\n\x13_request_output_lenB\t\n\x07_end_idB\t\n\x07_pad_idB\r\n\x0b_beam_widthB\x0e\n\x0c_temperatureB\x10\n\x0e_runtime_top_kB\x10\n\x0e_runtime_top_pB\x0e\n\x0c_len_penaltyB\x15\n\x13_repetition_penaltyB\x13\n\x11_presence_penaltyB\x0f\n\r_lora_task_idB\x0f\n\r_lora_weightsB\x0e\n\x0c_lora_configB\x0e\n\x0c_random_seedB\x15\n\x13_output_schema_hashB\x0b\n\t_tools_idB\x0e\n\x0c_force_tools"R\n\x13InferenceAnswerPart\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12\x12\n\noutput_ids\x18\x03 \x03(\x05"\xa6\x08\n\x0c\x42ritonConfig\x12\x13\n\x0b\x65ngine_path\x18\x01 \x01(\t\x12\x14\n\x0chf_tokenizer\x18\x02 \x01(\t\x12N\n\x16\x62\x61tch_scheduler_policy\x18\x05 \x01(\x0e\x32).briton.BritonConfig.BatchSchedulerPolicyH\x00\x88\x01\x01\x12\x1f\n\x12\x65nable_trt_overlap\x18\x06 \x01(\x08H\x01\x88\x01\x01\x12)\n\x1cmax_tokens_in_paged_kv_cache\x18\n \x01(\x04H\x02\x88\x01\x01\x12+\n\x1ekv_cache_free_gpu_mem_fraction\x18\x0b \x01(\x02H\x03\x88\x01\x01\x12!\n\x14medusa_decoding_mode\x18\x0c \x01(\x08H\x04\x88\x01\x01\x12#\n\x16\x65nable_chunked_context\x18\r \x01(\x08H\x05\x88\x01\x01\x12"\n\x15\x65nable_kv_cache_reuse\x18\x0e \x01(\x08H\x06\x88\x01\x01\x12\'\n\x1akv_cache_host_memory_bytes\x18\x0f \x01(\x04H\x07\x88\x01\x01\x12(\n\x1blora_cache_max_adapter_size\x18\x10 \x01(\x04H\x08\x88\x01\x01\x12,\n\x1flora_cache_optimal_adapter_size\x18\x11 \x01(\x04H\t\x88\x01\x01\x12+\n\x1elora_cache_gpu_memory_fraction\x18\x12 \x01(\x02H\n\x88\x01\x01\x12)\n\x1clora_cache_host_memory_bytes\x18\x13 \x01(\x04H\x0b\x88\x01\x01\x12\x1a\n\rfsm_cache_dir\x18\x14 \x01(\tH\x0c\x88\x01\x01"D\n\x14\x42\x61tchSchedulerPolicy\x12\x13\n\x0fMAX_UTILIZATION\x10\x00\x12\x17\n\x13GUARANTEED_NO_EVICT\x10\x01\x42\x19\n\x17_batch_scheduler_policyB\x15\n\x13_enable_trt_overlapB\x1f\n\x1d_max_tokens_in_paged_kv_cacheB!\n\x1f_kv_cache_free_gpu_mem_fractionB\x17\n\x15_medusa_decoding_modeB\x19\n\x17_enable_chunked_contextB\x18\n\x16_enable_kv_cache_reuseB\x1d\n\x1b_kv_cache_host_memory_bytesB\x1e\n\x1c_lora_cache_max_adapter_sizeB"\n _lora_cache_optimal_adapter_sizeB!\n\x1f_lora_cache_gpu_memory_fractionB\x1f\n\x1d_lora_cache_host_memory_bytesB\x10\n\x0e_fsm_cache_dir"\x98\x01\n\x10TokenToNextState\x12K\n\x13token_to_next_state\x18\x01 \x03(\x0b\x32..briton.TokenToNextState.TokenToNextStateEntry\x1a\x37\n\x15TokenToNextStateEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01"\xfb\x01\n\x0eStatesToTokens\x12\x44\n\x10states_to_tokens\x18\x01 \x03(\x0b\x32*.briton.StatesToTokens.StatesToTokensEntry\x12\x17\n\nvocab_size\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x19\n\x0c\x65os_token_id\x18\x03 \x01(\x05H\x01\x88\x01\x01\x1aO\n\x13StatesToTokensEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.briton.TokenToNextState:\x02\x38\x01\x42\r\n\x0b_vocab_sizeB\x0f\n\r_eos_token_id*\xa8\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0b\n\x07\x44T_INT4\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_UINT8\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0b\x12\x0e\n\nDT_FLOAT32\x10\x0c\x12\n\n\x06\x44T_FP8\x10\r\x12\x0b\n\x07\x44T_BOOL\x10\x14\x32L\n\x06\x42riton\x12\x42\n\x05Infer\x12\x18.briton.InferenceRequest\x1a\x1b.briton.InferenceAnswerPart"\x00\x30\x01\x62\x06proto3' ) _globals = globals() @@ -30,28 +30,28 @@ _globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_options = b"8\001" _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._options = None _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_options = b"8\001" - _globals["_DATATYPE"]._serialized_start = 2522 - _globals["_DATATYPE"]._serialized_end = 2690 + _globals["_DATATYPE"]._serialized_start = 2600 + _globals["_DATATYPE"]._serialized_end = 2768 _globals["_TENSOR"]._serialized_start = 24 _globals["_TENSOR"]._serialized_end = 138 _globals["_TENSOR_SHAPE"]._serialized_start = 118 _globals["_TENSOR_SHAPE"]._serialized_end = 138 _globals["_INFERENCEREQUEST"]._serialized_start = 141 - _globals["_INFERENCEREQUEST"]._serialized_end = 961 - _globals["_INFERENCEANSWERPART"]._serialized_start = 963 - _globals["_INFERENCEANSWERPART"]._serialized_end = 1045 - _globals["_BRITONCONFIG"]._serialized_start = 1048 - _globals["_BRITONCONFIG"]._serialized_end = 2110 - _globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_start = 1661 - _globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_end = 1729 - _globals["_TOKENTONEXTSTATE"]._serialized_start = 2113 - _globals["_TOKENTONEXTSTATE"]._serialized_end = 2265 - _globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_start = 2210 - _globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_end = 2265 - _globals["_STATESTOTOKENS"]._serialized_start = 2268 - _globals["_STATESTOTOKENS"]._serialized_end = 2519 - _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_start = 2408 - _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_end = 2487 - _globals["_BRITON"]._serialized_start = 2692 - _globals["_BRITON"]._serialized_end = 2768 + _globals["_INFERENCEREQUEST"]._serialized_end = 1039 + _globals["_INFERENCEANSWERPART"]._serialized_start = 1041 + _globals["_INFERENCEANSWERPART"]._serialized_end = 1123 + _globals["_BRITONCONFIG"]._serialized_start = 1126 + _globals["_BRITONCONFIG"]._serialized_end = 2188 + _globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_start = 1739 + _globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_end = 1807 + _globals["_TOKENTONEXTSTATE"]._serialized_start = 2191 + _globals["_TOKENTONEXTSTATE"]._serialized_end = 2343 + _globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_start = 2288 + _globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_end = 2343 + _globals["_STATESTOTOKENS"]._serialized_start = 2346 + _globals["_STATESTOTOKENS"]._serialized_end = 2597 + _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_start = 2486 + _globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_end = 2565 + _globals["_BRITON"]._serialized_start = 2770 + _globals["_BRITON"]._serialized_end = 2846 # @@protoc_insertion_point(module_scope) diff --git a/truss/templates/trtllm-briton/src/engine.py b/truss/templates/trtllm-briton/src/engine.py index 03e73f309..01b8dfde1 100644 --- a/truss/templates/trtllm-briton/src/engine.py +++ b/truss/templates/trtllm-briton/src/engine.py @@ -1,5 +1,9 @@ +import asyncio +import concurrent.futures +import fcntl import hashlib import json +import multiprocessing import os import signal import socket @@ -38,6 +42,16 @@ # Use a directory that can be picked up by baseten-fs FSM_CACHE_DIR = "/cache/model/fsm_cache" +TOOL_CALL_IDS = { + "llama": 128010, + "mistral": 5, +} + +TOOL_CALL_TOKENS = { + "llama": "<|python_tag|>", + "mistral": "[TOOL_CALLS]", +} + def is_port_available(port, host="localhost"): try: @@ -83,6 +97,7 @@ def __init__(self, **kwargs): truss_trtllm_build_config = TrussTRTLLMBuildConfiguration( **trtllm_config.get("build") ) + self._base_model = truss_trtllm_build_config.base_model self._tp_count = truss_trtllm_build_config.tensor_parallel_count self._tokenizer_repository = ( truss_trtllm_build_config.checkpoint_repository.repo @@ -103,6 +118,11 @@ def __init__(self, **kwargs): self._max_input_len = truss_trtllm_build_config.max_input_len self._max_beam_width = truss_trtllm_build_config.max_beam_width + # TODO(@bdubayah): configure this based on CPU. But os.cpu_count() returns the + # number of CPUs for the entire node, not just the container. + self._max_fsm_workers = 10 + print(f"Using {self._max_fsm_workers} workers for FSM schema generation") + def load(self): if self._loaded: return @@ -111,7 +131,9 @@ def load(self): self._tokenizer_repository, token=self._hf_token ) - self._fsm_cache = FsmCache(Path(FSM_CACHE_DIR), self._tokenizer) + self._fsm_cache = FsmCache( + Path(FSM_CACHE_DIR), self._tokenizer, self._max_fsm_workers + ) # Start engine config_str = f""" @@ -189,14 +211,84 @@ async def predict(self, model_input): channel = grpc.aio.insecure_channel(f"localhost:{BRITON_PORT}") self._stub = briton_pb2_grpc.BritonStub(channel) + # TODO(@bdubayah): refactor into smaller functions function_calling_schema = None - tools = model_input.get("tools", None) + tools = model_input.get("tools") + tool_choice = model_input.get("tool_choice") + force_tools = None + if tool_choice is not None: + if not ( + tool_choice in ["none", "required", "auto"] + or isinstance(tool_choice, dict) + ): + raise HTTPException( + status_code=400, + detail="tool_choice must be 'none', 'required', 'auto', or an object of the form {'type': 'function', 'function': {'name': 'function_name'}}.", + ) + if tool_choice == "none": + tools = None + tool_choice = None + elif tool_choice == "required": + if tools is None: + raise HTTPException( + status_code=400, + detail="tool_choice is 'required' but no tools provided.", + ) + force_tools = True if tools is not None: + if model_input.get("response_format") is not None: + raise HTTPException( + status_code=400, + detail="response_format is not allowed when tools are provided, unless tool_choice is 'none'.", + ) + tool_schemas = { + tool["function"]["name"]: create_tool_schema(tool) for tool in tools + } + if isinstance(tool_choice, dict): + if tool_choice.get("type") != "function": + raise HTTPException( + status_code=400, detail="tool_choice['type'] must be function." + ) + if tool_choice.get("function") is None: + raise HTTPException( + status_code=400, detail="tool_choice['function'] required." + ) + if not isinstance(tool_choice["function"], dict): + raise HTTPException( + status_code=400, + detail="tool_choice['function'] must be an object.", + ) + if tool_choice["function"].get("name") is None: + raise HTTPException( + status_code=400, + detail="tool_choice['function']['name'] required.", + ) + if tool_choice["function"]["name"] not in tool_schemas: + raise HTTPException( + status_code=400, + detail=f"Tool choice function {tool_choice['function']['name']} not in tools.", + ) + tool_schemas = { + tool_choice["function"]["name"]: tool_schemas[ + tool_choice["function"]["name"] + ] + } + force_tools = True + elif tool_choice is None or tool_choice == "auto": + force_tools = False function_calling_schema = { - "anyOf": [create_tool_schema(tool) for tool in tools], + "type": "array", + "items": { + "anyOf": list(tool_schemas.values()), + }, } prompt = model_input.get("prompt", None) + if prompt is not None and tools is not None: + raise HTTPException( + status_code=400, + detail="tools can only be provided in chat mode. Please set messages instead of prompt, remove tools, or set tool_choice to 'none'.", + ) if prompt is None and "messages" in model_input: messages = model_input.pop("messages") prompt = self._tokenizer.apply_chat_template( @@ -227,15 +319,18 @@ async def predict(self, model_input): schema_hash = None try: schema_hash = ( - self._fsm_cache.add_schema(function_calling_schema) + await self._fsm_cache.add_schema(function_calling_schema) if function_calling_schema is not None - else self._fsm_cache.add_schema_from_input(model_input) + else await self._fsm_cache.add_schema_from_input(model_input) ) # If the input schema is invalid, we should return a 400 except NotImplementedError as ex: raise HTTPException(status_code=400, detail=str(ex)) if schema_hash is not None: request.output_schema_hash = schema_hash + if force_tools is not None: + request.tools_id = TOOL_CALL_IDS[self._base_model] + request.force_tools = force_tools set_briton_request_fields_from_model_input(model_input, request) for words in ["bad_words", "stop_words"]: if words in model_input: @@ -250,11 +345,14 @@ async def generate(): if hasattr(self._tokenizer, "eos_token") else None ) + tool_call_token = TOOL_CALL_TOKENS.get(self._base_model) async for response in resp_iter: + output_text = response.output_text + if tool_call_token: + output_text = output_text.removeprefix(tool_call_token) if eos_token: - yield response.output_text.removesuffix(eos_token) - else: - yield response.output_text + output_text = output_text.removesuffix(eos_token) + yield output_text async def build_response(): eos_token = ( @@ -262,13 +360,15 @@ async def build_response(): if hasattr(self._tokenizer, "eos_token") else None ) + tool_call_token = TOOL_CALL_TOKENS.get(self._base_model) full_text = "" async for delta in resp_iter: full_text += delta.output_text + if tool_call_token: + full_text = full_text.removeprefix(tool_call_token) if eos_token: - return full_text.removesuffix(eos_token) - else: - return full_text + full_text = full_text.removesuffix(eos_token) + return full_text try: if model_input.get("stream", True): @@ -306,47 +406,86 @@ def create_tool_schema(tool_json: Dict[str, Any]) -> Dict[str, Any]: } +outlines_tokenizer = None + + +def worker(vocab_size: int, end_id: int, schema: Dict[str, Any], output_path: Path): + logits_processor = JSONLogitsProcessor(schema, outlines_tokenizer) + guide = logits_processor.fsm + states_to_tokens = {} + for state, token_to_next_state in guide.states_to_token_maps.items(): + states_to_tokens[state] = briton_pb2.TokenToNextState( # type: ignore[attr-defined] + token_to_next_state=token_to_next_state + ) + states_to_tokens_pb = briton_pb2.StatesToTokens( # type: ignore[attr-defined] + states_to_tokens=states_to_tokens, + vocab_size=vocab_size, + eos_token_id=end_id, + ) + if not output_path.exists(): + try: + fd = os.open(output_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + with os.fdopen(fd, "wb") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.write(states_to_tokens_pb.SerializeToString()) + fcntl.flock(f, fcntl.LOCK_UN) + except FileExistsError: + pass + + +def dummy_task(): + pass + + class FsmCache: - def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer): + def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer, max_workers: int): self._cache_dir = cache_dir if not self._cache_dir.exists(): self._cache_dir.mkdir(parents=True, exist_ok=True) self._cache = set(f.name for f in self._cache_dir.iterdir() if f.is_file()) + self._lock = threading.Lock() self._tokenizer = tokenizer - def add_schema(self, schema: Dict[str, Any]) -> str: + # Concurrent FSM generation initialization + # Make sure we fork because (1) it's faster and (2) it seems that spawning + # ends up being sequential + multiprocessing.set_start_method("fork", force=True) + global outlines_tokenizer + outlines_tokenizer = TransformerTokenizer(tokenizer) + # This is very important. The first time JSONLogitsProcessor is called, some library-wide + # initializations are done in memory (that take 5s). By doing it before we fork, we avoid paying + # that cost for each forked process. + _ = JSONLogitsProcessor({"properties": {}}, outlines_tokenizer) + self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) + # We must create all processes BEFORE the GRPC python client is started to avoid errors + # forking from the process GRPC is running in + for _ in range(max_workers): + self._executor.submit(dummy_task) + + async def add_schema(self, schema: Dict[str, Any]) -> str: schema_str = json.dumps(schema) schema_hash = hashlib.sha256(schema_str.encode()).hexdigest() if schema_hash not in self._cache: - fsm = self._create_fsm(schema) - (self._cache_dir / schema_hash).write_bytes(fsm.SerializeToString()) - self._cache.add(schema_hash) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + self._executor, + worker, + len(self._tokenizer.vocab), + self._tokenizer.eos_token_id, + schema, + self._cache_dir / schema_hash, + ) + with self._lock: + self._cache.add(schema_hash) return schema_hash - def add_schema_from_input(self, model_input: Dict[str, Any]) -> Optional[str]: + async def add_schema_from_input(self, model_input: Dict[str, Any]) -> Optional[str]: schema_hash = None schema = self._extract_schema(model_input) if schema is not None: - schema_hash = self.add_schema(schema) + schema_hash = await self.add_schema(schema) return schema_hash - def _create_fsm(self, schema: Dict[str, Any]) -> briton_pb2.StatesToTokens: # type: ignore[name-defined] - outlines_tokenizer = TransformerTokenizer(self._tokenizer) - logits_processor = JSONLogitsProcessor(schema, outlines_tokenizer) - guide = logits_processor.fsm - - states_to_tokens = {} - for state, token_to_next_state in guide.states_to_token_maps.items(): - states_to_tokens[state] = briton_pb2.TokenToNextState( # type: ignore[attr-defined] - token_to_next_state=token_to_next_state - ) - states_to_tokens_pb = briton_pb2.StatesToTokens( # type: ignore[attr-defined] - states_to_tokens=states_to_tokens, - vocab_size=len(self._tokenizer.vocab), - eos_token_id=self._tokenizer.eos_token_id, - ) - return states_to_tokens_pb - @staticmethod def _extract_schema(model_input: Dict[str, Any]) -> Optional[Dict[str, Any]]: if "response_format" not in model_input: From 7f6f1d24637ca49ad1a724377b9ec2df44a897f2 Mon Sep 17 00:00:00 2001 From: basetenbot <96544894+basetenbot@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:52:10 +0000 Subject: [PATCH 4/4] Bump version to 0.9.35 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f07866187..f39032347 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.35rc1" +version = "0.9.35" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md"