Skip to content

Commit

Permalink
Now everything works
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Dec 2, 2024
1 parent ba06987 commit 3702de6
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 499 deletions.
683 changes: 324 additions & 359 deletions poetry.lock

Large diffs are not rendered by default.

62 changes: 39 additions & 23 deletions truss-chains/examples/numpy_and_binary/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,69 @@ async def run_remote(self, data: DataModel) -> DataModel:

class AsyncChainletNoInput(chains.ChainletBase):
async def run_remote(self) -> DataModel:
return DataModel(msg="From async no input", np_array=np.full((3, 2, 1), 3))
data = DataModel(msg="From async no input", np_array=np.full((2, 2), 3))
print(data)
return data


class AsyncChainletNoOutput(chains.ChainletBase):
async def run_remote(self, data: DataModel) -> None:
print(data)


class Host(chains.ChainletBase):
"""Consume that reads the raw streams and parses them."""
class HostJSON(chains.ChainletBase):
"""Calls various chainlets in JSON mode."""

def __init__(
self,
sync_chainlet=chains.depends(SynChainlet, use_binary=True),
async_chainlet=chains.depends(AsyncChainlet, use_binary=True),
async_chainlet_no_output=chains.depends(AsyncChainletNoOutput, use_binary=True),
async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=True),
sync_chainlet=chains.depends(SynChainlet, use_binary=False),
async_chainlet=chains.depends(AsyncChainlet, use_binary=False),
async_chainlet_no_output=chains.depends(
AsyncChainletNoOutput, use_binary=False
),
async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=False),
):
self._sync_chainlet = sync_chainlet
self._async_chainlet = async_chainlet
self._async_chainlet_no_output = async_chainlet_no_output
self._async_chainlet_no_input = async_chainlet_no_input

async def run_remote(self) -> tuple[DataModel, DataModel]:
async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]:
a = np.ones((3, 2, 1))
data = DataModel(msg="From Host", np_array=a)
sync_result = self._sync_chainlet.run_remote(data)
print(sync_result)
async_result = await self._async_chainlet.run_remote(data)
print(async_result)
await self._async_chainlet_no_output.run_remote(data)
# async_no_input = await self._async_chainlet_no_input.run_remote()
# print(async_no_input)
return sync_result, async_result
# return data, async_result
async_no_input = await self._async_chainlet_no_input.run_remote()
print(async_no_input)
return sync_result, async_result, async_no_input


if __name__ == "__main__":
from truss.templates.shared import serialization
class HostBinary(chains.ChainletBase):
"""Calls various chainlets in binary mode."""

a = np.ones(
(3, 2, 1),
)
out = serialization.truss_msgpack_serialize(a)
print(out)
def __init__(
self,
sync_chainlet=chains.depends(SynChainlet, use_binary=True),
async_chainlet=chains.depends(AsyncChainlet, use_binary=True),
async_chainlet_no_output=chains.depends(AsyncChainletNoOutput, use_binary=True),
async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=True),
):
self._sync_chainlet = sync_chainlet
self._async_chainlet = async_chainlet
self._async_chainlet_no_output = async_chainlet_no_output
self._async_chainlet_no_input = async_chainlet_no_input

print(serialization.truss_msgpack_deserialize(out))
# data = b'\x81\xa4data\x82\xa3msg\xa9From Host\xa8np_array\x85\xc4\x02nd\xc3\xc4\x04type\xa3<f8\xc4\x04kind\xc4\x00\xc4\x05shape\x93\x03\x02\x01\xc4\x04data\xc4\x18\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?'
#
# serialization.truss_msgpack_deserialize(data)
async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]:
a = np.ones((3, 2, 1))
data = DataModel(msg="From Host", np_array=a)
sync_result = self._sync_chainlet.run_remote(data)
print(sync_result)
async_result = await self._async_chainlet.run_remote(data)
print(async_result)
await self._async_chainlet_no_output.run_remote(data)
async_no_input = await self._async_chainlet_no_input.run_remote()
print(async_no_input)
return sync_result, async_result, async_no_input
2 changes: 1 addition & 1 deletion truss-chains/examples/rag/rag_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def run_remote(self, new_bio: str, bios: list[str]) -> str:
f"People from database: {bios_info}"
)
resp = await self.predict_async(
inputs={
{
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"max_new_tokens": 32,
Expand Down
11 changes: 8 additions & 3 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,20 @@ async def test_streaming_chain_local():


@pytest.mark.integration
def test_numpy_chain():
@pytest.mark.parametrize("mode", ["json", "binary"])
def test_numpy_chain(mode):
if mode == "json":
target = "HostBinary"
else:
target = "HostBinary"
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, "Host") as entrypoint:
with framework.import_target(chain_root, target) as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="integration-test-numpy",
chain_name=f"integration-test-numpy-{mode}",
only_generate_trusses=False,
use_local_chains_src=True,
),
Expand Down
10 changes: 4 additions & 6 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,8 @@ def _stub_endpoint_body_src(
E.g.:
```
json_result = await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump())
return SplitTextOutput.model_validate(json_result).output
return await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
```
"""
imports: set[str] = set()
Expand Down Expand Up @@ -309,9 +308,8 @@ class SplitText(stub.StubBase):
async def run_remote(
self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
) -> tuple[shared_chainlet.SplitTextOutput, int]:
json_result = await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump())
return SplitTextOutput.model_validate(json_result).root
return await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
```
"""
imports = {"from truss_chains import stub"}
Expand Down
16 changes: 14 additions & 2 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,22 @@ def get_asset_spec(self) -> AssetSpec:


class RPCOptions(SafeModel):
"""Options to customize RPCs to dependency chainlets."""
"""Options to customize RPCs to dependency chainlets.
Args:
retries: The number of times to retry the remote chainlet in case of failures
(e.g. due to transient network issues). For streaming, retries are only made
if the request fails before streaming any results back. Failures mid-stream
not retried.
timeout_sec: Timeout for the HTTP request to this chainlet.
use_binary: whether to send data data in binary format. This can give a parsing
speedup and message size reduction (~25%) for numpy arrays. Use
``NumpyArrayField`` as a field type on pydantic models for integration and set
this option to ``True``. For simple text data, there is no significant benefit.
"""

timeout_sec: int = DEFAULT_TIMEOUT_SEC
retries: int = 1
timeout_sec: int = DEFAULT_TIMEOUT_SEC
use_binary: bool = False


Expand Down
8 changes: 7 additions & 1 deletion truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ def depends(
Args:
chainlet_cls: The chainlet class of the dependency.
retries: The number of times to retry the remote chainlet in case of failures
(e.g. due to transient network issues).
(e.g. due to transient network issues). For streaming, retries are only made
if the request fails before streaming any results back. Failures mid-stream
not retried.
timeout_sec: Timeout for the HTTP request to this chainlet.
use_binary: whether to send data data in binary format. This can give a parsing
speedup and message size reduction (~25%) for numpy arrays. Use
``NumpyArrayField`` as a field type on pydantic models for integration and set
this option to ``True``. For simple text data, there is no significant benefit.
Returns:
A "symbolic marker" to be used as a default argument in a chainlet's
Expand Down
99 changes: 31 additions & 68 deletions truss-chains/truss_chains/pydantic_numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import base64
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar

import pydantic
from pydantic.json_schema import JsonSchemaValue
Expand All @@ -10,6 +10,36 @@


class NumpyArrayField:
"""Wrapper class to support numpy arrays as fields on pydantic models and provide
JSON or binary serialization implementations.
The JSON serialization exposes (data, shape, dtype), where the data is base-64
encoded which leads to ~33% overhead.
Usage example:
```
import numpy as np
class MyModel(pydantic.BaseModel):
my_array: NumpyArrayField
m = MyModel(my_array=np.arange(4).reshape((2, 2)))
m.my_array_field.array += 10 # Work with the numpy array.
print(m)
# my_array=NumpyArrayField(
# shape=(2, 2),
# dtype=int64,
# data=[[10 11] [12 13]])
m_json = m.model_dump(mode="json")
print(m_json)
# {'my_array': {
# 'data_b64': 'CgAAAAAAAAALAAAAAAAAAAwAAAAAAAAADQAAAAAAAAA=',
# 'shape': [2, 2],
# 'dtype': 'int64'}}
```
"""

data_key: ClassVar[str] = "data_b64"
shape_key: ClassVar[str] = "shape"
dtype_key: ClassVar[str] = "dtype"
Expand Down Expand Up @@ -98,70 +128,3 @@ def __get_pydantic_json_schema__(
}
)
return json_schema


if __name__ == "__main__":
import numpy as np
from truss.templates.shared import serialization

def model_dump_binary(data) -> bytes:
import msgpack
import msgpack_numpy as m

m.patch()
return msgpack.packb(data, use_bin_type=True)

_T = TypeVar("_T", bound=pydantic.BaseModel)

def deserialize_model(data: bytes):
import msgpack
import msgpack_numpy as m

m.patch() # Patch msgpack to support numpy arrays.
return msgpack.unpackb(data, raw=False)

class Header(pydantic.BaseModel):
time: float = 123.45
msg: str = "test"

class MyModel(pydantic.BaseModel):
array: NumpyArrayField
header: Header = Header()

class NestedModel(pydantic.BaseModel):
array_wrapper: MyModel

# print(NestedModel.model_json_schema())

array = np.random.random(size=(2, 10000))
nested_model = NestedModel(array_wrapper=MyModel(array=NumpyArrayField(array)))

def python_roundtrip():
data_dict = nested_model.model_dump(mode="python")
data_bytes = model_dump_binary(data_dict)
deserialize_model(data_bytes)
restored_dict = deserialize_model(data_bytes)
restored = NestedModel.model_validate(restored_dict)
return restored, len(data_bytes)

def python_roundtrip_truss():
data_dict = nested_model.model_dump(mode="python")
data_bytes = model_dump_binary(data_dict)
deserialize_model(data_bytes)
restored_dict = serialization.truss_msgpack_deserialize(data_bytes)
restored = NestedModel.model_validate(restored_dict)
return restored, len(data_bytes)

def json_roundtrip():
data_str = nested_model.model_dump_json()
restored = NestedModel.model_validate_json(data_str)
return restored, len(data_str)

# 34.9 µs ± 191 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
python_restored, python_len = python_roundtrip()
# 554 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
json_restored, json_len = json_roundtrip()
# 554 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
python_restored_truss, python_len_truss = python_roundtrip_truss()

print(python_len, json_len, python_len / json_len)
17 changes: 8 additions & 9 deletions truss-chains/truss_chains/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


_RetryPolicyT = TypeVar("_RetryPolicyT", tenacity.AsyncRetrying, tenacity.Retrying)
_InputT = TypeVar("_InputT", pydantic.BaseModel, None, Any)
_InputT = TypeVar("_InputT", pydantic.BaseModel, Any)
_OutputT = TypeVar("_OutputT", bound=pydantic.BaseModel)


Expand Down Expand Up @@ -456,10 +456,13 @@ def _make_request_params(
data_key = "content" if for_httpx else "data"
kwargs[data_key] = inputs.model_dump_json()
headers["Content-Type"] = "application/json"

elif inputs is not None:
kwargs["json"] = inputs
headers["Content-Type"] = "application/json"
else: # inputs is JSON dict.
if self._service_descriptor.options.use_binary:
kwargs["data"] = serialization.truss_msgpack_serialize(inputs)
headers["Content-Type"] = "application/octet-stream"
else:
kwargs["json"] = inputs
headers["Content-Type"] = "application/json"

kwargs["headers"] = headers
return kwargs
Expand All @@ -475,10 +478,6 @@ def _response_to_pydantic(
def _response_to_json(self, response: bytes) -> Any:
if self._service_descriptor.options.use_binary:
return serialization.truss_msgpack_deserialize(response)
# import msgpack
# import msgpack_numpy as m
# m.patch()
# return msgpack.unpackb(response, raw=False)
return json.loads(response)

@overload
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Generator[bytes, None, None],
AsyncGenerator[bytes, None],
"starlette.responses.Response",
pydantic.BaseModel,
]


Expand Down
2 changes: 1 addition & 1 deletion truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fastapi==0.114.1
joblib==1.2.0
loguru==0.7.2
msgpack-numpy==0.4.8
msgpack==1.1.0
msgpack==1.1.0 # Numpy/msgpack versions are finniky (1.0.2 breaks), double check when changing.
numpy>=1.23.5
opentelemetry-api>=1.25.0
opentelemetry-sdk>=1.25.0
Expand Down
Loading

0 comments on commit 3702de6

Please sign in to comment.