Skip to content

Commit

Permalink
bump briton package to 0.3.13.dev3 (#1286)
Browse files Browse the repository at this point in the history
* bump briton package to 0.3.13.dev3 and truss version to 0.9.57

* add default runtime

* 0.9.57rc0

* fix existing runtime case

* 0.9.57rc1

* use briton==0.3.13.dev4
  • Loading branch information
joostinyi authored Dec 16, 2024
1 parent a5bde43 commit e61c2f0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.56rc3"
version = "0.9.57rc1"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev8"]
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.13.dev4"]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
"tensorrt_cu12_bindings==10.2.0.post1",
Expand Down
19 changes: 12 additions & 7 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,18 @@ def migrate_runtime_fields(cls, data: Any) -> Any:
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
)
data.get("runtime").update(
{
k: v
for k, v in extra_runtime_fields.items()
if k not in data.get("runtime")
}
)
if data.get("runtime"):
data.get("runtime").update(
{
k: v
for k, v in extra_runtime_fields.items()
if k not in data.get("runtime")
}
)
else:
data.update(
{"runtime": {k: v for k, v in extra_runtime_fields.items()}}
)
data.update({"build": valid_build_fields})
return data
return data
Expand Down
31 changes: 31 additions & 0 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,37 @@ def trtllm_config(default_config) -> Dict[str, Any]:

@pytest.fixture
def deprecated_trtllm_config(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
}
trtllm_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
# start deprecated fields
"kv_cache_free_gpu_mem_fraction": 0.1,
"enable_chunked_context": True,
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
"request_default_max_tokens": 10,
"total_token_limit": 50,
# end deprecated fields
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"gather_all_token_logits": False,
},
}
return trtllm_config


@pytest.fixture
def deprecated_trtllm_config_with_runtime_existing(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
Expand Down
15 changes: 15 additions & 0 deletions truss/tests/trt_llm/test_trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields(
deprecated_trtllm_config,
):
trt_llm_config = TRTLLMConfiguration(**deprecated_trtllm_config["trt_llm"])
assert trt_llm_config.runtime.model_dump() == {
"kv_cache_free_gpu_mem_fraction": 0.1,
"enable_chunked_context": True,
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
"request_default_max_tokens": 10,
"total_token_limit": 50,
}


def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields_existing_runtime(
deprecated_trtllm_config_with_runtime_existing,
):
trt_llm_config = TRTLLMConfiguration(
**deprecated_trtllm_config_with_runtime_existing["trt_llm"]
)
assert trt_llm_config.runtime.model_dump() == {
"kv_cache_free_gpu_mem_fraction": 0.1,
"enable_chunked_context": True,
Expand Down

0 comments on commit e61c2f0

Please sign in to comment.