Skip to content

Commit

Permalink
Prep for Smoke Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Jan 6, 2025
1 parent 24af116 commit 00fc951
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 51 deletions.
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DeployedServiceDescriptor,
DeploymentContext,
DockerImage,
GenericRemoteException,
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
Expand Down Expand Up @@ -55,6 +56,7 @@
"DeploymentContext",
"DockerImage",
"RPCOptions",
"GenericRemoteException",
"RemoteConfig",
"RemoteErrorDetail",
"DeployedServiceDescriptor",
Expand Down
31 changes: 25 additions & 6 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,10 @@ class ChainService(abc.ABC):
"""

_name: str
_entrypoint_service: b10_service.TrussService
_entrypoint_fake_json_data: Any

def __init__(self, name: str, entrypoint_service: b10_service.TrussService):
def __init__(self, name: str):
self._name = name
self._entrypoint_service = entrypoint_service
self._entrypoint_fake_json_data = None

@property
Expand All @@ -151,12 +149,12 @@ def status_page_url(self) -> str:
def run_remote_url(self) -> str:
"""URL to invoke the entrypoint."""

@abc.abstractmethod
def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.
Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@abc.abstractmethod
def get_info(self) -> list[b10_types.DeployedChainlet]:
Expand All @@ -183,7 +181,10 @@ def entrypoint_fake_json_data(self, fake_data: Any) -> None:


class BasetenChainService(ChainService):
# TODO: entrypoint service is for truss model - make chains-specific.
# E.g. chain/chainlet will not have model URLs anymore.
_chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic
_entrypoint_service: b10_service.BasetenService
_remote: b10_remote.BasetenRemote

def __init__(
Expand All @@ -193,8 +194,9 @@ def __init__(
chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic,
remote: b10_remote.BasetenRemote,
) -> None:
super().__init__(name, entrypoint_service)
super().__init__(name)
self._chain_deployment_handle = chain_deployment_handle
self._entrypoint_service = entrypoint_service
self._remote = remote

@property
Expand All @@ -208,6 +210,13 @@ def run_remote_url(self) -> str:
self._chain_deployment_handle.is_draft,
)

def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.
Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@property
def status_page_url(self) -> str:
"""Link to status page on Baseten."""
Expand All @@ -231,14 +240,24 @@ def get_info(self) -> list[b10_types.DeployedChainlet]:


class DockerChainService(ChainService):
_entrypoint_service: DockerTrussService

def __init__(self, name: str, entrypoint_service: DockerTrussService) -> None:
super().__init__(name, entrypoint_service)
super().__init__(name)
self._entrypoint_service = entrypoint_service

@property
def run_remote_url(self) -> str:
"""URL to invoke the entrypoint."""
return self._entrypoint_service.predict_url

def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.
Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@property
def status_page_url(self) -> str:
"""Not Implemented.."""
Expand Down
6 changes: 4 additions & 2 deletions truss-chains/truss_chains/remote_chainlet/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,17 @@ def _make_request_params(
if isinstance(inputs, pydantic.BaseModel):
if self._service_descriptor.options.use_binary:
data_dict = inputs.model_dump(mode="python")
kwargs["data"] = serialization.truss_msgpack_serialize(data_dict)
data_key = "content" if for_httpx else "data"
kwargs[data_key] = serialization.truss_msgpack_serialize(data_dict)
headers["Content-Type"] = "application/octet-stream"
else:
data_key = "content" if for_httpx else "data"
kwargs[data_key] = inputs.model_dump_json()
headers["Content-Type"] = "application/json"
else: # inputs is JSON dict.
if self._service_descriptor.options.use_binary:
kwargs["data"] = serialization.truss_msgpack_serialize(inputs)
data_key = "content" if for_httpx else "data"
kwargs[data_key] = serialization.truss_msgpack_serialize(inputs)
headers["Content-Type"] = "application/octet-stream"
else:
kwargs["json"] = inputs
Expand Down
66 changes: 28 additions & 38 deletions truss-chains/truss_chains/remote_chainlet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,24 @@ def _resolve_exception_class(
return exception_cls


def _handle_response_error(response_json: dict, remote_name: str):
def _handle_response_error(response_json: dict, remote_name: str, status: int):
try:
error_json = response_json["error"]
except KeyError as e:
logging.error(f"response_json: {response_json}")
raise ValueError(
"Could not get `error` field from JSON from chainlet error response"
"Could not get `error` field from JSON from chainlet "
f"error response. HTTP status: {status}."
) from e

try:
error = definitions.RemoteErrorDetail.model_validate(error_json)
except pydantic.ValidationError as e:
if isinstance(error_json, str):
msg = f"Remote error occurred in `{remote_name}`: '{error_json}'"
msg = (
f"Remote error occurred in `{remote_name}` "
f"(HTTP status {status}): '{error_json}'"
)
raise definitions.GenericRemoteException(msg) from None
raise ValueError(
"Could not parse chainlet error. Error details are expected to be either a "
Expand All @@ -238,7 +242,7 @@ def _handle_response_error(response_json: dict, remote_name: str):
error_format = "\n".join(lines + [last_line])
msg = (
f"(showing chained remote errors, root error at the bottom)\n"
f"├─ Error in dependency Chainlet `{remote_name}`:\n"
f"├─ Error in dependency Chainlet `{remote_name}` (HTTP status {status}):\n"
f"{error_format}"
)
raise exception_cls(msg)
Expand All @@ -255,38 +259,24 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
Chainlet that raised an exception. E.g. the message might look like this:
```
RemoteChainletError in "Chain"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 112, in predict
result = await self._chainlet.run(
File "/app/model/Chainlet.py", line 79, in run
value += self._text_to_num.run(part)
File "/packages/remote_stubs.py", line 21, in run
json_result = self.predict_sync(json_args)
File "/packages/truss_chains/stub.py", line 37, in predict_sync
return utils.handle_response(
ValueError: (showing remote errors, root message at the bottom)
--> Preceding Remote Cause:
RemoteChainletError in "TextToNum"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 113, in predict
result = self._chainlet.run(data=payload["data"])
File "/app/model/Chainlet.py", line 54, in run
generated_text = self._replicator.run(data)
File "/packages/remote_stubs.py", line 7, in run
json_result = self.predict_sync(json_args)
File "/packages/truss_chains/stub.py", line 37, in predict_sync
return utils.handle_response(
ValueError: (showing remote errors, root message at the bottom)
--> Preceding Remote Cause:
RemoteChainletError in "TextReplicator"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 112, in predict
result = self._chainlet.run(data=payload["data"])
File "/app/model/Chainlet.py", line 36, in run
raise ValueError(f"This input is too long: {len(data)}.")
ValueError: This input is too long: 100.
Chainlet-Traceback (most recent call last):
File "/packages/itest_chain.py", line 132, in run_remote
value = self._accumulate_parts(text_parts.parts)
File "/packages/itest_chain.py", line 144, in _accumulate_parts
value += self._text_to_num.run_remote(part)
ValueError: (showing chained remote errors, root error at the bottom)
├─ Error in dependency Chainlet `TextToNum`:
│ Chainlet-Traceback (most recent call last):
│ File "/packages/itest_chain.py", line 87, in run_remote
│ generated_text = self._replicator.run_remote(data)
│ ValueError: (showing chained remote errors, root error at the bottom)
│ ├─ Error in dependency Chainlet `TextReplicator`:
│ │ Chainlet-Traceback (most recent call last):
│ │ File "/packages/itest_chain.py", line 52, in run_remote
│ │ validate_data(data)
│ │ File "/packages/itest_chain.py", line 36, in validate_data
│ │ raise ValueError(f"This input is too long: {len(data)}.")
╰ ╰ ValueError: This input is too long: 100.
```
"""
if response.is_error:
Expand All @@ -297,7 +287,7 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
"Could not get JSON from error response. Status: "
f"`{response.status_code}`."
) from e
_handle_response_error(response_json=response_json, remote_name=remote_name)
_handle_response_error(response_json, remote_name, response.status_code)


async def async_response_raise_errors(
Expand All @@ -312,4 +302,4 @@ async def async_response_raise_errors(
"Could not get JSON from error response. Status: "
f"`{response.status}`."
) from e
_handle_response_error(response_json=response_json, remote_name=remote_name)
_handle_response_error(response_json, remote_name, response.status)
7 changes: 4 additions & 3 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class Config:


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
super().__init__(remote_url, **kwargs)
def __init__(self, remote_url: str, api_key: str):
super().__init__(remote_url)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(remote_url, self._auth_service)

Expand Down Expand Up @@ -310,7 +310,8 @@ def push_chain_atomic(

model_id = chain_deployment_handle.entrypoint_model_id
model_version_id = chain_deployment_handle.entrypoint_model_version_id

# TODO: entrypoint service is for truss model - make chains-specific.
# E.g. chain/chainlet will not have model URLs anymore.
entrypoint_service = BasetenService(
model_id=model_id,
model_version_id=model_version_id,
Expand Down
3 changes: 2 additions & 1 deletion truss/remote/remote_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os

try:
from configparser import DEFAULTSECT, ConfigParser # type: ignore
Expand All @@ -16,7 +17,7 @@
from truss.remote.baseten import BasetenRemote
from truss.remote.truss_remote import RemoteConfig, TrussRemote

USER_TRUSSRC_PATH = Path("~/.trussrc").expanduser()
USER_TRUSSRC_PATH = Path(os.environ.get("USER_TRUSSRC_PATH", "~/.trussrc")).expanduser()


def load_config() -> ConfigParser:
Expand Down
2 changes: 1 addition & 1 deletion truss/remote/truss_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class TrussRemote(ABC):
"""

def __init__(self, remote_url: str, **kwargs) -> None:
def __init__(self, remote_url: str) -> None:
self._remote_url = remote_url

@property
Expand Down

0 comments on commit 00fc951

Please sign in to comment.