From 00fc9518a7c0b8ce191f75dc1f5e86a92fe194bf Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 6 Jan 2025 10:30:36 -0800 Subject: [PATCH] Prep for Smoke Tests --- truss-chains/truss_chains/__init__.py | 2 + .../deployment/deployment_client.py | 31 +++++++-- .../truss_chains/remote_chainlet/stub.py | 6 +- .../truss_chains/remote_chainlet/utils.py | 66 ++++++++----------- truss/remote/baseten/remote.py | 7 +- truss/remote/remote_factory.py | 3 +- truss/remote/truss_remote.py | 2 +- 7 files changed, 66 insertions(+), 51 deletions(-) diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index b5c8bc070..252218e41 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -28,6 +28,7 @@ DeployedServiceDescriptor, DeploymentContext, DockerImage, + GenericRemoteException, RemoteConfig, RemoteErrorDetail, RPCOptions, @@ -55,6 +56,7 @@ "DeploymentContext", "DockerImage", "RPCOptions", + "GenericRemoteException", "RemoteConfig", "RemoteErrorDetail", "DeployedServiceDescriptor", diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index c7db3c093..13299fc06 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -129,12 +129,10 @@ class ChainService(abc.ABC): """ _name: str - _entrypoint_service: b10_service.TrussService _entrypoint_fake_json_data: Any - def __init__(self, name: str, entrypoint_service: b10_service.TrussService): + def __init__(self, name: str): self._name = name - self._entrypoint_service = entrypoint_service self._entrypoint_fake_json_data = None @property @@ -151,12 +149,12 @@ def status_page_url(self) -> str: def run_remote_url(self) -> str: """URL to invoke the entrypoint.""" + @abc.abstractmethod def run_remote(self, json: Dict) -> Any: """Invokes the entrypoint with JSON data. Returns: The JSON response.""" - return self._entrypoint_service.predict(json) @abc.abstractmethod def get_info(self) -> list[b10_types.DeployedChainlet]: @@ -183,7 +181,10 @@ def entrypoint_fake_json_data(self, fake_data: Any) -> None: class BasetenChainService(ChainService): + # TODO: entrypoint service is for truss model - make chains-specific. + # E.g. chain/chainlet will not have model URLs anymore. _chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic + _entrypoint_service: b10_service.BasetenService _remote: b10_remote.BasetenRemote def __init__( @@ -193,8 +194,9 @@ def __init__( chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic, remote: b10_remote.BasetenRemote, ) -> None: - super().__init__(name, entrypoint_service) + super().__init__(name) self._chain_deployment_handle = chain_deployment_handle + self._entrypoint_service = entrypoint_service self._remote = remote @property @@ -208,6 +210,13 @@ def run_remote_url(self) -> str: self._chain_deployment_handle.is_draft, ) + def run_remote(self, json: Dict) -> Any: + """Invokes the entrypoint with JSON data. + + Returns: + The JSON response.""" + return self._entrypoint_service.predict(json) + @property def status_page_url(self) -> str: """Link to status page on Baseten.""" @@ -231,14 +240,24 @@ def get_info(self) -> list[b10_types.DeployedChainlet]: class DockerChainService(ChainService): + _entrypoint_service: DockerTrussService + def __init__(self, name: str, entrypoint_service: DockerTrussService) -> None: - super().__init__(name, entrypoint_service) + super().__init__(name) + self._entrypoint_service = entrypoint_service @property def run_remote_url(self) -> str: """URL to invoke the entrypoint.""" return self._entrypoint_service.predict_url + def run_remote(self, json: Dict) -> Any: + """Invokes the entrypoint with JSON data. + + Returns: + The JSON response.""" + return self._entrypoint_service.predict(json) + @property def status_page_url(self) -> str: """Not Implemented..""" diff --git a/truss-chains/truss_chains/remote_chainlet/stub.py b/truss-chains/truss_chains/remote_chainlet/stub.py index 81bd940a0..33fe5af4b 100644 --- a/truss-chains/truss_chains/remote_chainlet/stub.py +++ b/truss-chains/truss_chains/remote_chainlet/stub.py @@ -267,7 +267,8 @@ def _make_request_params( if isinstance(inputs, pydantic.BaseModel): if self._service_descriptor.options.use_binary: data_dict = inputs.model_dump(mode="python") - kwargs["data"] = serialization.truss_msgpack_serialize(data_dict) + data_key = "content" if for_httpx else "data" + kwargs[data_key] = serialization.truss_msgpack_serialize(data_dict) headers["Content-Type"] = "application/octet-stream" else: data_key = "content" if for_httpx else "data" @@ -275,7 +276,8 @@ def _make_request_params( headers["Content-Type"] = "application/json" else: # inputs is JSON dict. if self._service_descriptor.options.use_binary: - kwargs["data"] = serialization.truss_msgpack_serialize(inputs) + data_key = "content" if for_httpx else "data" + kwargs[data_key] = serialization.truss_msgpack_serialize(inputs) headers["Content-Type"] = "application/octet-stream" else: kwargs["json"] = inputs diff --git a/truss-chains/truss_chains/remote_chainlet/utils.py b/truss-chains/truss_chains/remote_chainlet/utils.py index 8b90f006a..05159e6d5 100644 --- a/truss-chains/truss_chains/remote_chainlet/utils.py +++ b/truss-chains/truss_chains/remote_chainlet/utils.py @@ -210,20 +210,24 @@ def _resolve_exception_class( return exception_cls -def _handle_response_error(response_json: dict, remote_name: str): +def _handle_response_error(response_json: dict, remote_name: str, status: int): try: error_json = response_json["error"] except KeyError as e: logging.error(f"response_json: {response_json}") raise ValueError( - "Could not get `error` field from JSON from chainlet error response" + "Could not get `error` field from JSON from chainlet " + f"error response. HTTP status: {status}." ) from e try: error = definitions.RemoteErrorDetail.model_validate(error_json) except pydantic.ValidationError as e: if isinstance(error_json, str): - msg = f"Remote error occurred in `{remote_name}`: '{error_json}'" + msg = ( + f"Remote error occurred in `{remote_name}` " + f"(HTTP status {status}): '{error_json}'" + ) raise definitions.GenericRemoteException(msg) from None raise ValueError( "Could not parse chainlet error. Error details are expected to be either a " @@ -238,7 +242,7 @@ def _handle_response_error(response_json: dict, remote_name: str): error_format = "\n".join(lines + [last_line]) msg = ( f"(showing chained remote errors, root error at the bottom)\n" - f"├─ Error in dependency Chainlet `{remote_name}`:\n" + f"├─ Error in dependency Chainlet `{remote_name}` (HTTP status {status}):\n" f"{error_format}" ) raise exception_cls(msg) @@ -255,38 +259,24 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None: Chainlet that raised an exception. E.g. the message might look like this: ``` - RemoteChainletError in "Chain" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 112, in predict - result = await self._chainlet.run( - File "/app/model/Chainlet.py", line 79, in run - value += self._text_to_num.run(part) - File "/packages/remote_stubs.py", line 21, in run - json_result = self.predict_sync(json_args) - File "/packages/truss_chains/stub.py", line 37, in predict_sync - return utils.handle_response( - ValueError: (showing remote errors, root message at the bottom) - --> Preceding Remote Cause: - RemoteChainletError in "TextToNum" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 113, in predict - result = self._chainlet.run(data=payload["data"]) - File "/app/model/Chainlet.py", line 54, in run - generated_text = self._replicator.run(data) - File "/packages/remote_stubs.py", line 7, in run - json_result = self.predict_sync(json_args) - File "/packages/truss_chains/stub.py", line 37, in predict_sync - return utils.handle_response( - ValueError: (showing remote errors, root message at the bottom) - --> Preceding Remote Cause: - RemoteChainletError in "TextReplicator" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 112, in predict - result = self._chainlet.run(data=payload["data"]) - File "/app/model/Chainlet.py", line 36, in run - raise ValueError(f"This input is too long: {len(data)}.") - ValueError: This input is too long: 100. - + Chainlet-Traceback (most recent call last): + File "/packages/itest_chain.py", line 132, in run_remote + value = self._accumulate_parts(text_parts.parts) + File "/packages/itest_chain.py", line 144, in _accumulate_parts + value += self._text_to_num.run_remote(part) + ValueError: (showing chained remote errors, root error at the bottom) + ├─ Error in dependency Chainlet `TextToNum`: + │ Chainlet-Traceback (most recent call last): + │ File "/packages/itest_chain.py", line 87, in run_remote + │ generated_text = self._replicator.run_remote(data) + │ ValueError: (showing chained remote errors, root error at the bottom) + │ ├─ Error in dependency Chainlet `TextReplicator`: + │ │ Chainlet-Traceback (most recent call last): + │ │ File "/packages/itest_chain.py", line 52, in run_remote + │ │ validate_data(data) + │ │ File "/packages/itest_chain.py", line 36, in validate_data + │ │ raise ValueError(f"This input is too long: {len(data)}.") + ╰ ╰ ValueError: This input is too long: 100. ``` """ if response.is_error: @@ -297,7 +287,7 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None: "Could not get JSON from error response. Status: " f"`{response.status_code}`." ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) + _handle_response_error(response_json, remote_name, response.status_code) async def async_response_raise_errors( @@ -312,4 +302,4 @@ async def async_response_raise_errors( "Could not get JSON from error response. Status: " f"`{response.status}`." ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) + _handle_response_error(response_json, remote_name, response.status) diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index e32edb2ca..ac9796362 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -68,8 +68,8 @@ class Config: class BasetenRemote(TrussRemote): - def __init__(self, remote_url: str, api_key: str, **kwargs): - super().__init__(remote_url, **kwargs) + def __init__(self, remote_url: str, api_key: str): + super().__init__(remote_url) self._auth_service = AuthService(api_key=api_key) self._api = BasetenApi(remote_url, self._auth_service) @@ -310,7 +310,8 @@ def push_chain_atomic( model_id = chain_deployment_handle.entrypoint_model_id model_version_id = chain_deployment_handle.entrypoint_model_version_id - + # TODO: entrypoint service is for truss model - make chains-specific. + # E.g. chain/chainlet will not have model URLs anymore. entrypoint_service = BasetenService( model_id=model_id, model_version_id=model_version_id, diff --git a/truss/remote/remote_factory.py b/truss/remote/remote_factory.py index 426080d24..6b82b84b4 100644 --- a/truss/remote/remote_factory.py +++ b/truss/remote/remote_factory.py @@ -1,4 +1,5 @@ import inspect +import os try: from configparser import DEFAULTSECT, ConfigParser # type: ignore @@ -16,7 +17,7 @@ from truss.remote.baseten import BasetenRemote from truss.remote.truss_remote import RemoteConfig, TrussRemote -USER_TRUSSRC_PATH = Path("~/.trussrc").expanduser() +USER_TRUSSRC_PATH = Path(os.environ.get("USER_TRUSSRC_PATH", "~/.trussrc")).expanduser() def load_config() -> ConfigParser: diff --git a/truss/remote/truss_remote.py b/truss/remote/truss_remote.py index 727c6e322..351df697e 100644 --- a/truss/remote/truss_remote.py +++ b/truss/remote/truss_remote.py @@ -199,7 +199,7 @@ class TrussRemote(ABC): """ - def __init__(self, remote_url: str, **kwargs) -> None: + def __init__(self, remote_url: str) -> None: self._remote_url = remote_url @property