Skip to content

Commit

Permalink
Translate RPC timeout error message. (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Jan 15, 2025
1 parent a5023a8 commit bbba46b
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 6 deletions.
46 changes: 46 additions & 0 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,49 @@ def test_numpy_chain(mode):
response = service.run_remote({})
assert response.status_code == 200
print(response.json())


@pytest.mark.asyncio
async def test_timeout():
with ensure_kill_all():
chain_root = TEST_ROOT / "timeout" / "timeout_chain.py"
with framework.import_target(chain_root, "TimeoutChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = deployment_client.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
time.sleep(1.0) # Wait for models to be ready.

# Async.
response = requests.post(url, json={"use_sync": False})
# print(response.content)

assert response.status_code == 500
error = definitions.RemoteErrorDetail.model_validate(response.json()["error"])
error_str = error.format()
error_regex = r"""
Chainlet-Traceback \(most recent call last\):
File \".*?/timeout_chain\.py\", line \d+, in run_remote
result = await self\._dep.run_remote\(\)
TimeoutError: Timeout calling remote Chainlet `Dependency` \(0.5 seconds limit\)\.
"""
assert re.match(error_regex.strip(), error_str.strip(), re.MULTILINE), error_str

# Sync:
sync_response = requests.post(url, json={"use_sync": True})
assert sync_response.status_code == 500
sync_error = definitions.RemoteErrorDetail.model_validate(
sync_response.json()["error"]
)
sync_error_str = sync_error.format()
sync_error_regex = r"""
Chainlet-Traceback \(most recent call last\):
File \".*?/timeout_chain\.py\", line \d+, in run_remote
result = self\._dep_sync.run_remote\(\)
TimeoutError: Timeout calling remote Chainlet `DependencySync` \(0.5 seconds limit\)\.
"""
assert re.match(
sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE
), sync_error_str
35 changes: 35 additions & 0 deletions truss-chains/tests/timeout/timeout_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import time

import truss_chains as chains


class Dependency(chains.ChainletBase):
async def run_remote(self) -> bool:
await asyncio.sleep(1)
return True


class DependencySync(chains.ChainletBase):
def run_remote(self) -> bool:
time.sleep(1)
return True


@chains.mark_entrypoint # ("My Chain Name")
class TimeoutChain(chains.ChainletBase):
def __init__(
self,
dep=chains.depends(Dependency, timeout_sec=0.5),
dep_sync=chains.depends(DependencySync, timeout_sec=0.5),
):
self._dep = dep
self._dep_sync = dep_sync

async def run_remote(self, use_sync: bool) -> None:
if use_sync:
result = self._dep_sync.run_remote()
print(result)
else:
result = await self._dep.run_remote()
print(result)
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def get_asset_spec(self) -> AssetSpec:
return self.assets.get_spec()


DEFAULT_TIMEOUT_SEC = 600
DEFAULT_TIMEOUT_SEC = 600.0


class RPCOptions(SafeModel):
Expand All @@ -415,7 +415,7 @@ class RPCOptions(SafeModel):
"""

retries: int = 1
timeout_sec: int = DEFAULT_TIMEOUT_SEC
timeout_sec: float = DEFAULT_TIMEOUT_SEC
use_binary: bool = False


Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def depends_context() -> definitions.DeploymentContext:
def depends(
chainlet_cls: Type[framework.ChainletT],
retries: int = 1,
timeout_sec: int = definitions.DEFAULT_TIMEOUT_SEC,
timeout_sec: float = definitions.DEFAULT_TIMEOUT_SEC,
use_binary: bool = False,
) -> framework.ChainletT:
"""Sets a "symbolic marker" to indicate to the framework that a chainlet is a
Expand Down
32 changes: 29 additions & 3 deletions truss-chains/truss_chains/remote_chainlet/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,16 @@ def _rpc() -> bytes:
utils.response_raise_errors(response, self.name)
return response.content

response_bytes = retry(_rpc)
try:
response_bytes = retry(_rpc)
except httpx.ReadTimeout:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).

if output_model:
return self._response_to_pydantic(response_bytes, output_model)
return self._response_to_json(response_bytes)
Expand Down Expand Up @@ -357,7 +366,16 @@ async def _rpc() -> bytes:
await utils.async_response_raise_errors(response, self.name)
return await response.read()

response_bytes: bytes = await retry(_rpc)
try:
response_bytes: bytes = await retry(_rpc)
except asyncio.TimeoutError:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).

if output_model:
return self._response_to_pydantic(response_bytes, output_model)
return self._response_to_json(response_bytes)
Expand All @@ -375,7 +393,15 @@ async def _rpc() -> AsyncIterator[bytes]:
await utils.async_response_raise_errors(response, self.name)
return response.content.iter_any()

return await retry(_rpc)
try:
return await retry(_rpc)
except asyncio.TimeoutError:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).


StubT = TypeVar("StubT", bound=StubBase)
Expand Down

0 comments on commit bbba46b

Please sign in to comment.