Skip to content

Commit

Permalink
Merge pull request #1134 from basetenlabs/bump-version-0.9.33
Browse files Browse the repository at this point in the history
Release 0.9.33
  • Loading branch information
squidarth authored Sep 9, 2024
2 parents 5e34879 + a5033d0 commit 52efc12
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.32"
version = "0.9.33"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
5 changes: 3 additions & 2 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def deploy_draft_chain(
return resp["data"]["deploy_draft_chain"]

def deploy_chain_deployment(
self, chain_id: str, chainlet_data: List[b10_types.ChainletData]
self, chain_id: str, chainlet_data: List[b10_types.ChainletData], promote: bool
):
chainlet_data_strings = [
_chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data
Expand All @@ -245,7 +245,8 @@ def deploy_chain_deployment(
mutation {{
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}]
chainlets: [{chainlets_string}],
promote_after_deploy: {'true' if promote else 'false'},
) {{
chain_id
chain_deployment_id
Expand Down
7 changes: 6 additions & 1 deletion truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,16 @@ def create_chain(
chain_name: str,
chainlets: List[b10_types.ChainletData],
is_draft: bool = False,
promote: bool = True,
) -> ChainDeploymentHandle:
if is_draft:
response = api.deploy_draft_chain(chain_name, chainlets)
elif chain_id:
response = api.deploy_chain_deployment(chain_id, chainlets)
# This is the only case where promote has relevance, since
# if there is no chain already, the first deployment will
# already be production, and only published deployments can
# be promoted.
response = api.deploy_chain_deployment(chain_id, chainlets, promote)
else:
response = api.deploy_chain(chain_name, chainlets)

Expand Down
50 changes: 42 additions & 8 deletions truss/templates/trtllm-briton/src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,16 @@ async def predict(self, model_input):
):
request.pad_id = self._tokenizer.pad_token_id
# Add output schema hash if we're function calling or response_format is provided
schema_hash = (
self._fsm_cache.add_schema(function_calling_schema)
if function_calling_schema is not None
else self._fsm_cache.add_schema_from_input(model_input)
)
schema_hash = None
try:
schema_hash = (
self._fsm_cache.add_schema(function_calling_schema)
if function_calling_schema is not None
else self._fsm_cache.add_schema_from_input(model_input)
)
# If the input schema is invalid, we should return a 400
except NotImplementedError as ex:
raise HTTPException(status_code=400, detail=str(ex))
if schema_hash is not None:
request.output_schema_hash = schema_hash
set_briton_request_fields_from_model_input(model_input, request)
Expand All @@ -240,23 +245,52 @@ async def predict(self, model_input):
resp_iter = self._stub.Infer(request)

async def generate():
eos_token = (
self._tokenizer.eos_token
if hasattr(self._tokenizer, "eos_token")
else None
)
async for response in resp_iter:
yield response.output_text
if eos_token:
yield response.output_text.removesuffix(eos_token)
else:
yield response.output_text

async def build_response():
eos_token = (
self._tokenizer.eos_token
if hasattr(self._tokenizer, "eos_token")
else None
)
full_text = ""
async for delta in resp_iter:
full_text += delta.output_text
return full_text
if eos_token:
return full_text.removesuffix(eos_token)
else:
return full_text

try:
if model_input.get("stream", True):
return generate()
gen = generate()
first_chunk = await gen.__anext__()

async def generate_after_first_chunk():
yield first_chunk
async for chunk in gen:
yield chunk

return generate_after_first_chunk()
else:
return await build_response()
except grpc.RpcError as ex:
if ex.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise HTTPException(status_code=400, detail=ex.details())
# If the error is another GRPC exception like NotImplemented, we should return a 500
else:
raise HTTPException(
status_code=500, detail=f"An error has occurred: {ex}"
)
except Exception as ex:
raise HTTPException(status_code=500, detail=f"An error has occurred: {ex}")

Expand Down

0 comments on commit 52efc12

Please sign in to comment.