From e61c2f01c0309172623d61cf787c59f7c747742b Mon Sep 17 00:00:00 2001 From: joostinyi <63941848+joostinyi@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:12:43 -0800 Subject: [PATCH] bump briton package to 0.3.13.dev3 (#1286) * bump briton package to 0.3.13.dev3 and truss version to 0.9.57 * add default runtime * 0.9.57rc0 * fix existing runtime case * 0.9.57rc1 * use briton==0.3.13.dev4 --- pyproject.toml | 2 +- truss/base/constants.py | 2 +- truss/base/trt_llm_config.py | 19 ++++++++----- truss/tests/conftest.py | 31 ++++++++++++++++++++++ truss/tests/trt_llm/test_trt_llm_config.py | 15 +++++++++++ 5 files changed, 60 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2f5634499..c9a21f506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.56rc3" +version = "0.9.57rc1" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/base/constants.py b/truss/base/constants.py index 0cce6fa13..ae295fa04 100644 --- a/truss/base/constants.py +++ b/truss/base/constants.py @@ -109,7 +109,7 @@ TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft" TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7" TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3" -BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev8"] +BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.13.dev4"] AUDIO_MODEL_TRTLLM_REQUIREMENTS = [ "--extra-index-url https://pypi.nvidia.com", "tensorrt_cu12_bindings==10.2.0.post1", diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index 194adcec7..1c6ce8d0d 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -221,13 +221,18 @@ def migrate_runtime_fields(cls, data: Any) -> Any: f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values." " This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config." ) - data.get("runtime").update( - { - k: v - for k, v in extra_runtime_fields.items() - if k not in data.get("runtime") - } - ) + if data.get("runtime"): + data.get("runtime").update( + { + k: v + for k, v in extra_runtime_fields.items() + if k not in data.get("runtime") + } + ) + else: + data.update( + {"runtime": {k: v for k, v in extra_runtime_fields.items()}} + ) data.update({"build": valid_build_fields}) return data return data diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index e515ccf43..3892966c6 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -791,6 +791,37 @@ def trtllm_config(default_config) -> Dict[str, Any]: @pytest.fixture def deprecated_trtllm_config(default_config) -> Dict[str, Any]: + trtllm_config = default_config + trtllm_config["resources"] = { + "accelerator": Accelerator.L4.value, + "cpu": "1", + "memory": "24Gi", + "use_gpu": True, + } + trtllm_config["trt_llm"] = { + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + # start deprecated fields + "kv_cache_free_gpu_mem_fraction": 0.1, + "enable_chunked_context": True, + "batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value, + "request_default_max_tokens": 10, + "total_token_limit": 50, + # end deprecated fields + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + "gather_all_token_logits": False, + }, + } + return trtllm_config + + +@pytest.fixture +def deprecated_trtllm_config_with_runtime_existing(default_config) -> Dict[str, Any]: trtllm_config = default_config trtllm_config["resources"] = { "accelerator": Accelerator.L4.value, diff --git a/truss/tests/trt_llm/test_trt_llm_config.py b/truss/tests/trt_llm/test_trt_llm_config.py index e6a0578d1..852baefb3 100644 --- a/truss/tests/trt_llm/test_trt_llm_config.py +++ b/truss/tests/trt_llm/test_trt_llm_config.py @@ -15,6 +15,21 @@ def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields( deprecated_trtllm_config, ): trt_llm_config = TRTLLMConfiguration(**deprecated_trtllm_config["trt_llm"]) + assert trt_llm_config.runtime.model_dump() == { + "kv_cache_free_gpu_mem_fraction": 0.1, + "enable_chunked_context": True, + "batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value, + "request_default_max_tokens": 10, + "total_token_limit": 50, + } + + +def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields_existing_runtime( + deprecated_trtllm_config_with_runtime_existing, +): + trt_llm_config = TRTLLMConfiguration( + **deprecated_trtllm_config_with_runtime_existing["trt_llm"] + ) assert trt_llm_config.runtime.model_dump() == { "kv_cache_free_gpu_mem_fraction": 0.1, "enable_chunked_context": True,