Skip to content

Commit

Permalink
Merge pull request #1205 from basetenlabs/bump-version-0.9.45
Browse files Browse the repository at this point in the history
Release 0.9.45
  • Loading branch information
zhyncs authored Oct 29, 2024
2 parents 5a4626e + 9332603 commit a298392
Show file tree
Hide file tree
Showing 44 changed files with 891 additions and 1,062 deletions.
2 changes: 0 additions & 2 deletions .gitattributes

This file was deleted.

3 changes: 2 additions & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ jobs:
- name: Enforce acknowledgment in PR description
if: env.chains_docs_update_needed == 'true'
env:
DESCRIPTION: ${{ github.event.pull_request.body }}
run: |
DESCRIPTION="${{ github.event.pull_request.body }}"
if [[ "$DESCRIPTION" != *"UPDATE_DOCS=done"* && "$DESCRIPTION" != *"UPDATE_DOCS=not_needed"* ]]; then
echo "::error file=truss-chains/examples/::Chains examples were modified and ack not found in PR description. Verify whether docs need to be update (https://github.com/basetenlabs/docs.baseten.co/tree/main/chains) and add an ack tag `UPDATE_DOCS={done|not_needed}` to the PR description."
exit 1
Expand Down
1 change: 0 additions & 1 deletion .python-version

This file was deleted.

495 changes: 275 additions & 220 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.44"
version = "0.9.45"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -27,6 +27,7 @@ packages = [
"Baseten" = "https://baseten.co"

[tool.poetry.dependencies]
aiofiles = "^24.1.0"
blake3 = "^0.3.3"
boto3 = "^1.34.85"
fastapi = ">=0.109.1"
Expand Down Expand Up @@ -96,6 +97,7 @@ pytest = "7.2.0"
pytest-cov = "^3.0.0"
types-PyYAML = "^6.0.12.12"
types-setuptools = "^69.0.0.0"
types-aiofiles = "^24.1.0.20240626"

[tool.poetry.scripts]
truss = 'truss.cli:truss_cli'
Expand Down
12 changes: 9 additions & 3 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pydantic
from truss import truss_config
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote import baseten as baseten_remote
from truss.remote import remote_cli, remote_factory

Expand Down Expand Up @@ -609,22 +610,27 @@ class PushOptions(SafeModelNonSerializable):
class PushOptionsBaseten(PushOptions):
remote_provider: baseten_remote.BasetenRemote
publish: bool
promote: bool
environment: Optional[str]

@classmethod
def create(
cls,
chain_name: str,
publish: bool,
promote: bool,
promote: Optional[bool],
only_generate_trusses: bool,
user_env: Mapping[str, str],
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> "PushOptionsBaseten":
if not remote:
remote = remote_cli.inquire_remote_name(
remote_factory.RemoteFactory.get_available_config_names()
)
if promote and not environment:
environment = PRODUCTION_ENVIRONMENT_NAME
if environment:
publish = True
remote_provider = cast(
baseten_remote.BasetenRemote,
remote_factory.RemoteFactory.create(remote=remote),
Expand All @@ -633,9 +639,9 @@ def create(
remote_provider=remote_provider,
chain_name=chain_name,
publish=publish,
promote=promote,
only_generate_trusses=only_generate_trusses,
user_env=user_env,
environment=environment,
)


Expand Down
20 changes: 20 additions & 0 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pprint
import sys
import types
import warnings
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -268,6 +269,25 @@ def _validate_and_describe_endpoint(
is_async = False
is_generator = inspect.isgeneratorfunction(endpoint_method)

if not is_async:
warnings.warn(
"`run_remote` must be an async (coroutine) function in future releases. "
"Replace `def run_remote(...` with `async def run_remote(...`. "
"Local testing and execution can be done with "
"`asyncio.run(my_chainlet.run_remote(...))`.\n"
"Note on concurrency: previously sync functions were run in threads by the "
"Truss server.\bn"
"For some frameworks this was **unsafe** (e.g. in torch the CUDA context "
"is not thread-safe).\n"
"Additionally, python threads hold the GIL and therefore might not give "
"actual throughput gains.\n"
"To achieve safe and performant concurrency, use framework-specific async "
"APIs (e.g. AsyncLLMEngine for vLLM) or generic async batching like such "
"as https://github.com/hussein-awala/async-batcher.",
DeprecationWarning,
stacklevel=1,
)

return definitions.EndpointAPIDescriptor(
input_args=input_args,
output_types=output_types,
Expand Down
3 changes: 3 additions & 0 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def push(
user_env: Optional[Mapping[str, str]] = None,
only_generate_trusses: bool = False,
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> chains_remote.BasetenChainService:
"""
Deploys a chain remotely (with all dependent chainlets).
Expand All @@ -144,6 +145,7 @@ def push(
``/tmp/.chains_generated``.
remote: name of a remote config in `.trussrc`. If not provided, it will be
inquired.
environment: The name of an environment to promote deployment into.
Returns:
A chain service handle to the deployed chain.
Expand All @@ -156,6 +158,7 @@ def push(
user_env=user_env or {},
only_generate_trusses=only_generate_trusses,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint, options)
assert isinstance(service, chains_remote.BasetenChainService) # Per options above.
Expand Down
8 changes: 2 additions & 6 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ def _push_to_baseten(
model_name = truss_handle.spec.config.model_name
assert model_name is not None
assert bool(_MODEL_NAME_RE.match(model_name))
if options.promote and not options.publish:
logging.info("`promote=True` overrides `publish` to `True`.")
logging.info(
f"Pushing chainlet `{model_name}` as a truss model on Baseten "
f"(publish={options.publish}, promote={options.promote})."
f"Pushing chainlet `{model_name}` as a truss model on Baseten (publish={options.publish})"
)
# Models must be trusted to use the API KEY secret.
service = options.remote_provider.push(
truss_handle,
model_name=model_name,
trusted=True,
publish=options.publish,
promote=options.promote,
origin=b10_types.ModelOrigin.CHAINS,
)
return cast(b10_service.BasetenService, service)
Expand Down Expand Up @@ -327,7 +323,7 @@ def _create_baseten_chain(
chain_name=baseten_options.chain_name,
chainlets=chainlet_data,
publish=baseten_options.publish,
promote=baseten_options.promote,
environment=baseten_options.environment,
)
return BasetenChainService(
baseten_options.chain_name,
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def override_chainlet_to_service_metadata(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
):
# Override predict_urls in chainlet_to_service ServiceDescriptors if dynamic_chainlet_config exists
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value(
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
if dynamic_chainlet_config_str:
Expand Down
32 changes: 29 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,13 @@ def chains():
"""Subcommands for truss chains"""


def _make_chains_curl_snippet(run_remote_url: str) -> str:
def _make_chains_curl_snippet(run_remote_url: str, environment: Optional[str]) -> str:
if environment:
idx = run_remote_url.find("deployment")
if idx != -1:
run_remote_url = (
run_remote_url[:idx] + f"environments/{environment}/run_remote"
)
return (
f"curl -X POST '{run_remote_url}' \\\n"
' -H "Authorization: Api-Key $BASETEN_API_KEY" \\\n'
Expand Down Expand Up @@ -505,6 +511,15 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
default=False,
help="Replace production chainlets with newly deployed chainlets.",
)
@click.option(
"--environment",
type=str,
required=False,
help=(
"Deploy the chain as a published deployment to the specified environment."
"If specified, --publish is implied and the supplied value of --promote will be ignored."
),
)
@click.option(
"--wait/--no-wait",
type=bool,
Expand Down Expand Up @@ -557,6 +572,7 @@ def push_chain(
dryrun: bool,
user_env: Optional[str],
remote: Optional[str],
environment: Optional[str],
) -> None:
"""
Deploys a chain remotely.
Expand Down Expand Up @@ -597,6 +613,10 @@ def push_chain(
else:
user_env_parsed = {}

if promote and environment:
promote_warning = "`promote` flag and `environment` flag were both specified. Ignoring the value of `promote`"
console.print(promote_warning, style="yellow")

with framework.import_target(source, entrypoint) as entrypoint_cls:
chain_name = name or entrypoint_cls.__name__
options = chains_def.PushOptionsBaseten.create(
Expand All @@ -606,6 +626,7 @@ def push_chain(
only_generate_trusses=dryrun,
user_env=user_env_parsed,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint_cls, options)

Expand All @@ -614,7 +635,9 @@ def push_chain(
return

assert isinstance(service, chains_remote.BasetenChainService)
curl_snippet = _make_chains_curl_snippet(service.run_remote_url)
curl_snippet = _make_chains_curl_snippet(
service.run_remote_url, options.environment
)

table, statuses = _create_chains_table(service)
status_check_wait_sec = 2
Expand Down Expand Up @@ -647,7 +670,10 @@ def push_chain(
for log in intercepted_logs:
console.print(f"\t{log}")
if success:
console.print("Deployment succeeded.", style="bold green")
deploy_success_text = "Deployment succeeded."
if environment:
deploy_success_text = f"Your chain has been deployed into the {options.environment} environment."
console.print(deploy_success_text, style="bold green")
console.print(f"You can run the chain with:\n{curl_snippet}")
if watch: # Note that this command will print a startup message.
chains_remote.watch(
Expand Down
11 changes: 5 additions & 6 deletions truss/config/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ class TrussTRTLLMModel(str, Enum):
MISTRAL = "mistral"
DEEPSEEK = "deepseek"
WHISPER = "whisper"
QWEN = "qwen"


class TrussTRTLLMQuantizationType(str, Enum):
NO_QUANT = "no_quant"
WEIGHTS_ONLY_INT8 = "weights_int8"
WEIGHTS_KV_INT8 = "weights_kv_int8"
WEIGHTS_ONLY_INT4 = "weights_int4"
WEIGHTS_KV_INT4 = "weights_kv_int4"
WEIGHTS_INT4_KV_INT8 = "weights_int4_kv_int8"
SMOOTH_QUANT = "smooth_quant"
FP8 = "fp8"
FP8_KV = "fp8_kv"
Expand Down Expand Up @@ -58,10 +59,9 @@ class CheckpointRepository(BaseModel):

class TrussTRTLLMBuildConfiguration(BaseModel):
base_model: TrussTRTLLMModel
max_input_len: int
max_output_len: int
max_batch_size: int
max_num_tokens: Optional[int] = None
max_seq_len: int
max_batch_size: Optional[int] = 256
max_num_tokens: Optional[int] = 8192
max_beam_width: int = 1
max_prompt_embedding_table_size: int = 0
checkpoint_repository: CheckpointRepository
Expand All @@ -75,7 +75,6 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
plugin_configuration: TrussTRTLLMPluginConfiguration = (
TrussTRTLLMPluginConfiguration()
)
use_fused_mlp: bool = False
kv_cache_free_gpu_mem_fraction: float = 0.9
num_builder_gpus: Optional[int] = None
enable_chunked_context: bool = False
Expand Down
12 changes: 2 additions & 10 deletions truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,9 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.11"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0_v0.0.14"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = [
"grpcio==1.62.3",
"grpcio-tools==1.62.3",
"transformers==4.44.2",
"truss==0.9.42rc010",
"outlines==0.0.46",
"torch==2.4.0",
"sentencepiece==0.2.0",
]
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.2"]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
"tensorrt_cu12_bindings==10.2.0.post1",
Expand Down
11 changes: 0 additions & 11 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
DEFAULT_BUNDLED_PACKAGES_DIR,
)

tensor_parallel_count = (
config.trt_llm.build.tensor_parallel_count # type: ignore[union-attr]
if config.trt_llm.build is not None
else config.trt_llm.serve.tensor_parallel_count # type: ignore[union-attr]
)

if tensor_parallel_count != config.resources.accelerator.count:
raise ValueError(
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)

config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY

if not is_audio_model:
Expand Down
12 changes: 12 additions & 0 deletions truss/local/local_config_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def bptr_data_resolution_dir_path():
bptr_data_dir.mkdir(exist_ok=True, parents=True)
return bptr_data_dir

@staticmethod
def dynamic_config_path():
dynamic_config_dir = LocalConfigHandler.TRUSS_CONFIG_DIR / "b10_dynamic_config"
dynamic_config_dir.mkdir(exist_ok=True, parents=True)
return dynamic_config_dir

@staticmethod
def set_dynamic_config(key: str, value: str):
key_path = LocalConfigHandler.dynamic_config_path() / key
with key_path.open("w") as key_file:
key_file.write(value)

@staticmethod
def _signatures_dir_path():
return LocalConfigHandler.TRUSS_CONFIG_DIR / "signatures"
Expand Down
21 changes: 12 additions & 9 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,25 @@ def deploy_draft_chain(
return resp["data"]["deploy_draft_chain"]

def deploy_chain_deployment(
self, chain_id: str, chainlet_data: List[b10_types.ChainletData], promote: bool
self,
chain_id: str,
chainlet_data: List[b10_types.ChainletData],
environment: Optional[str] = None,
):
chainlet_data_strings = [
_chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data
]
chainlets_string = ", ".join(chainlet_data_strings)
query_string = f"""
mutation {{
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
promote_after_deploy: {'true' if promote else 'false'},
) {{
chain_id
chain_deployment_id
}}
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
{f'environment_name: "{environment}"' if environment else ""}
) {{
chain_id
chain_deployment_id
}}
}}
"""
resp = self._post_graphql_query(query_string)
Expand Down
Loading

0 comments on commit a298392

Please sign in to comment.