From 7472bb681061e7f606603f697213a2a041fecf1b Mon Sep 17 00:00:00 2001 From: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:33:52 -0800 Subject: [PATCH] Chains Streaming, fixes BT-10339 (#1261) --- .github/workflows/pr.yml | 35 -- .../examples/streaming/streaming_chain.py | 107 +++++ truss-chains/tests/chains_e2e_test.py | 53 ++- truss-chains/tests/test_framework.py | 54 ++- truss-chains/tests/test_streaming.py | 203 +++++++++ truss-chains/truss_chains/code_gen.py | 132 ++++-- truss-chains/truss_chains/definitions.py | 25 +- truss-chains/truss_chains/framework.py | 96 ++++- truss-chains/truss_chains/model_skeleton.py | 5 +- truss-chains/truss_chains/remote.py | 8 +- truss-chains/truss_chains/streaming.py | 395 ++++++++++++++++++ truss-chains/truss_chains/stub.py | 68 ++- truss-chains/truss_chains/utils.py | 68 +-- truss/templates/server/common/schema.py | 6 +- truss/templates/server/model_wrapper.py | 23 +- truss/templates/server/truss_server.py | 6 +- truss/templates/shared/serialization.py | 23 +- 17 files changed, 1128 insertions(+), 179 deletions(-) create mode 100644 truss-chains/examples/streaming/streaming_chain.py create mode 100644 truss-chains/tests/test_streaming.py create mode 100644 truss-chains/truss_chains/streaming.py diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f1955c325..963da7544 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -51,38 +51,3 @@ jobs: with: use-verbose-mode: "yes" folder-path: "docs" - - enforce-chains-example-docs-sync: - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v4 - with: - lfs: true - fetch-depth: 2 - - - name: Fetch main branch - run: git fetch origin main - - - name: Check if chains examples were modified - id: check_files - run: | - if git diff --name-only origin/main | grep -q '^truss-chains/examples/.*'; then - echo "chains_docs_update_needed=true" >> $GITHUB_ENV - echo "Chains examples were modified." - else - echo "chains_docs_update_needed=false" >> $GITHUB_ENV - echo "Chains examples were not modified." - echo "::notice file=truss-chains/examples/::Chains examples not modified." - fi - - - name: Enforce acknowledgment in PR description - if: env.chains_docs_update_needed == 'true' - env: - DESCRIPTION: ${{ github.event.pull_request.body }} - run: | - if [[ "$DESCRIPTION" != *"UPDATE_DOCS=done"* && "$DESCRIPTION" != *"UPDATE_DOCS=not_needed"* ]]; then - echo "::error file=truss-chains/examples/::Chains examples were modified and ack not found in PR description. Verify whether docs need to be update (https://github.com/basetenlabs/docs.baseten.co/tree/main/chains) and add an ack tag `UPDATE_DOCS={done|not_needed}` to the PR description." - exit 1 - else - echo "::notice file=truss-chains/examples/::Chains examples modified and ack found int PR description." - fi diff --git a/truss-chains/examples/streaming/streaming_chain.py b/truss-chains/examples/streaming/streaming_chain.py new file mode 100644 index 000000000..4b1b2488d --- /dev/null +++ b/truss-chains/examples/streaming/streaming_chain.py @@ -0,0 +1,107 @@ +import asyncio +import time +from typing import AsyncIterator + +import pydantic + +import truss_chains as chains +from truss_chains import streaming + + +class Header(pydantic.BaseModel): + time: float + msg: str + + +class MyDataChunk(pydantic.BaseModel): + words: list[str] + + +class Footer(pydantic.BaseModel): + time: float + duration_sec: float + msg: str + + +class ConsumerOutput(pydantic.BaseModel): + header: Header + chunks: list[MyDataChunk] + footer: Footer + strings: str + + +STREAM_TYPES = streaming.stream_types( + MyDataChunk, header_type=Header, footer_type=Footer +) + + +class Generator(chains.ChainletBase): + """Example that streams fully structured pydantic items with header and footer.""" + + async def run_remote(self) -> AsyncIterator[bytes]: + print("Entering Generator") + streamer = streaming.stream_writer(STREAM_TYPES) + header = Header(time=time.time(), msg="Start.") + yield streamer.yield_header(header) + for i in range(1, 5): + data = MyDataChunk( + words=[chr(x + 70) * x for x in range(1, i + 1)], + ) + print("Yield") + yield streamer.yield_item(data) + await asyncio.sleep(0.05) + + end_time = time.time() + footer = Footer(time=end_time, duration_sec=end_time - header.time, msg="Done.") + yield streamer.yield_footer(footer) + print("Exiting Generator") + + +class StringGenerator(chains.ChainletBase): + """Minimal streaming example with strings (e.g. for raw LLM output).""" + + async def run_remote(self) -> AsyncIterator[str]: + # Note: the "chunk" boundaries are lost, when streaming raw strings. You must + # add spaces and linebreaks to the items yourself.. + yield "First " + yield "second " + yield "last." + + +class Consumer(chains.ChainletBase): + """Consume that reads the raw streams and parses them.""" + + def __init__( + self, + generator=chains.depends(Generator), + string_generator=chains.depends(StringGenerator), + ): + self._generator = generator + self._string_generator = string_generator + + async def run_remote(self) -> ConsumerOutput: + print("Entering Consumer") + reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote()) + print("Consuming...") + header = await reader.read_header() + chunks = [] + async for data in reader.read_items(): + print(f"Read: {data}") + chunks.append(data) + + footer = await reader.read_footer() + strings = [] + async for part in self._string_generator.run_remote(): + strings.append(part) + + print("Exiting Consumer") + return ConsumerOutput( + header=header, chunks=chunks, footer=footer, strings="".join(strings) + ) + + +if __name__ == "__main__": + with chains.run_local(): + chain = Consumer() + result = asyncio.run(chain.run_remote()) + print(result) diff --git a/truss-chains/tests/chains_e2e_test.py b/truss-chains/tests/chains_e2e_test.py index a64adc6f1..29d7ca894 100644 --- a/truss-chains/tests/chains_e2e_test.py +++ b/truss-chains/tests/chains_e2e_test.py @@ -13,8 +13,8 @@ @pytest.mark.integration def test_chain(): with ensure_kill_all(): - root = Path(__file__).parent.resolve() - chain_root = root / "itest_chain" / "itest_chain.py" + tests_root = Path(__file__).parent.resolve() + chain_root = tests_root / "itest_chain" / "itest_chain.py" with framework.import_target(chain_root, "ItestChain") as entrypoint: options = definitions.PushOptionsLocalDocker( chain_name="integration-test", use_local_chains_src=True @@ -81,8 +81,8 @@ def test_chain(): @pytest.mark.asyncio async def test_chain_local(): - root = Path(__file__).parent.resolve() - chain_root = root / "itest_chain" / "itest_chain.py" + tests_root = Path(__file__).parent.resolve() + chain_root = tests_root / "itest_chain" / "itest_chain.py" with framework.import_target(chain_root, "ItestChain") as entrypoint: with public_api.run_local(): with pytest.raises(ValueError): @@ -119,3 +119,48 @@ async def test_chain_local(): match="Chainlets cannot be naively instantiated", ): await entrypoint().run_remote(length=20, num_partitions=5) + + +@pytest.mark.integration +def test_streaming_chain(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "streaming" / "streaming_chain.py" + with framework.import_target(chain_root, "Consumer") as entrypoint: + service = remote.push( + entrypoint, + options=definitions.PushOptionsLocalDocker( + chain_name="stream", + only_generate_trusses=False, + use_local_chains_src=True, + ), + ) + assert service is not None + response = service.run_remote({}) + assert response.status_code == 200 + print(response.json()) + result = response.json() + print(result) + assert result["header"]["msg"] == "Start." + assert result["chunks"][0]["words"] == ["G"] + assert result["chunks"][1]["words"] == ["G", "HH"] + assert result["chunks"][2]["words"] == ["G", "HH", "III"] + assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"] + assert result["footer"]["duration_sec"] > 0 + assert result["strings"] == "First second last." + + +@pytest.mark.asyncio +async def test_streaming_chain_local(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "streaming" / "streaming_chain.py" + with framework.import_target(chain_root, "Consumer") as entrypoint: + with public_api.run_local(): + result = await entrypoint().run_remote() + print(result) + assert result.header.msg == "Start." + assert result.chunks[0].words == ["G"] + assert result.chunks[1].words == ["G", "HH"] + assert result.chunks[2].words == ["G", "HH", "III"] + assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"] + assert result.footer.duration_sec > 0 + assert result.strings == "First second last." diff --git a/truss-chains/tests/test_framework.py b/truss-chains/tests/test_framework.py index 5f33a3c00..c29324606 100644 --- a/truss-chains/tests/test_framework.py +++ b/truss-chains/tests/test_framework.py @@ -2,7 +2,7 @@ import contextlib import logging import re -from typing import List +from typing import AsyncIterator, Iterator, List import pydantic import pytest @@ -505,3 +505,55 @@ def run_remote(argument: object): ... with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): with public_api.run_local(): MultiIssue() + + +def test_raises_iterator_no_yield(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorNoYield\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"If the endpoint returns an iterator \(streaming\), it must have `yield` statements" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorNoYield(chains.ChainletBase): + async def run_remote(self) -> AsyncIterator[str]: + return "123" # type: ignore[return-value] + + +def test_raises_yield_no_iterator(): + match = ( + rf"{TEST_FILE}:\d+ \(YieldNoIterator\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"If the endpoint is streaming \(has `yield` statements\), the return type must be an iterator" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class YieldNoIterator(chains.ChainletBase): + async def run_remote(self) -> str: # type: ignore[misc] + yield "123" + + +def test_raises_iterator_sync(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorSync\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Streaming endpoints \(containing `yield` statements\) are only supported for async endpoints" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorSync(chains.ChainletBase): + def run_remote(self) -> Iterator[str]: + yield "123" + + +def test_raises_iterator_no_arg(): + match = ( + rf"{TEST_FILE}:\d+ \(IteratorNoArg\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Iterators must be annotated with type \(one of \['str', 'bytes'\]\)" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class IteratorNoArg(chains.ChainletBase): + async def run_remote(self) -> AsyncIterator: + yield "123" diff --git a/truss-chains/tests/test_streaming.py b/truss-chains/tests/test_streaming.py new file mode 100644 index 000000000..88dd5421a --- /dev/null +++ b/truss-chains/tests/test_streaming.py @@ -0,0 +1,203 @@ +import asyncio +from typing import AsyncIterator + +import pydantic +import pytest + +from truss_chains import streaming + + +class Header(pydantic.BaseModel): + time: float + msg: str + + +class MyDataChunk(pydantic.BaseModel): + words: list[str] + + +class Footer(pydantic.BaseModel): + time: float + duration_sec: float + msg: str + + +async def to_bytes_iterator(data_stream) -> AsyncIterator[bytes]: + for data in data_stream: + yield data + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_streaming_with_header_and_footer(): + types = streaming.stream_types( + item_type=MyDataChunk, header_type=Header, footer_type=Footer + ) + + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [ + MyDataChunk(words=["hello", "world"]), + MyDataChunk(words=["foo", "bar"]), + MyDataChunk(words=["baz"]), + ] + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + + data_stream = [] + data_stream.append(writer.yield_header(header)) + for item in items: + data_stream.append(writer.yield_item(item)) + data_stream.append(writer.yield_footer(footer)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + # Assert that serialization roundtrip works. + read_header = await reader.read_header() + assert read_header == header + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert read_items == items + read_footer = await reader.read_footer() + assert read_footer == footer + + +@pytest.mark.asyncio +async def test_streaming_with_items_only(): + types = streaming.stream_types(item_type=MyDataChunk) + writer = streaming.stream_writer(types) + + items = [ + MyDataChunk(words=["hello", "world"]), + MyDataChunk(words=["foo", "bar"]), + MyDataChunk(words=["baz"]), + ] + + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + + assert read_items == items + + +@pytest.mark.asyncio +async def test_reading_header_when_none_sent(): + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) + writer = streaming.stream_writer(types) + items = [MyDataChunk(words=["hello", "world"])] + + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + with pytest.raises(ValueError, match="Stream does not contain header."): + await reader.read_header() + + +@pytest.mark.asyncio +async def test_reading_items_with_wrong_model(): + types_writer = streaming.stream_types(item_type=MyDataChunk) + types_reader = streaming.stream_types(item_type=Header) # Wrong item type + writer = streaming.stream_writer(types_writer) + items = [MyDataChunk(words=["hello", "world"])] + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types_reader, to_bytes_iterator(data_stream)) + + with pytest.raises(pydantic.ValidationError): + async for item in reader.read_items(): + pass + + +@pytest.mark.asyncio +async def test_streaming_with_wrong_order(): + types = streaming.stream_types( + item_type=MyDataChunk, + header_type=Header, + footer_type=Footer, + ) + + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [MyDataChunk(words=["hello", "world"])] + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + with pytest.raises( + ValueError, match="Cannot yield header after other data has been sent." + ): + data_stream.append(writer.yield_header(header)) + data_stream.append(writer.yield_footer(footer)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + # Try to read header, should fail because the first data is an item + with pytest.raises(ValueError, match="Stream does not contain header."): + await reader.read_header() + + +@pytest.mark.asyncio +async def test_reading_items_without_consuming_header(): + types = streaming.stream_types(item_type=MyDataChunk, header_type=Header) + writer = streaming.stream_writer(types) + header = Header(time=123.456, msg="Start of stream") + items = [MyDataChunk(words=["hello", "world"])] + + data_stream = [] + data_stream.append(writer.yield_header(header)) + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + # Try to read items without consuming header + with pytest.raises( + ValueError, + match="Called `read_items`, but there the stream contains header data", + ): + async for item in reader.read_items(): + pass + + +@pytest.mark.asyncio +async def test_reading_footer_when_none_sent(): + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) + items = [MyDataChunk(words=["hello", "world"])] + data_stream = [] + for item in items: + data_stream.append(writer.yield_item(item)) + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert read_items == items + + # Try to read footer, expect an error + with pytest.raises(ValueError, match="Stream does not contain footer."): + await reader.read_footer() + + +@pytest.mark.asyncio +async def test_reading_footer_with_no_items(): + types = streaming.stream_types(item_type=MyDataChunk, footer_type=Footer) + writer = streaming.stream_writer(types) + footer = Footer(time=789.012, duration_sec=665.556, msg="End of stream") + data_stream = [writer.yield_footer(footer)] + + reader = streaming.stream_reader(types, to_bytes_iterator(data_stream)) + read_items = [] + async for item in reader.read_items(): + read_items.append(item) + assert len(read_items) == 0 + + read_footer = await reader.read_footer() + assert read_footer == footer diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 832e7c524..6ec2e98ca 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -93,7 +93,7 @@ def _update_src(new_source: _Source, src_parts: list[str], imports: set[str]) -> imports.update(new_source.imports) -def _gen_import_and_ref(raw_type: Any) -> _Source: +def _gen_pydantic_import_and_ref(raw_type: Any) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" if raw_type.__module__ == "__main__": # TODO: assuming that main is copied into package dir and can be imported. @@ -122,7 +122,7 @@ def _gen_import_and_ref(raw_type: Any) -> _Source: def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" if type_descr.is_pydantic: - return _gen_import_and_ref(type_descr.raw) + return _gen_pydantic_import_and_ref(type_descr.raw) elif isinstance(type_descr.raw, type): if not type_descr.raw.__module__ == "builtins": @@ -134,11 +134,21 @@ def _gen_type_import_and_ref(type_descr: definitions.TypeDescriptor) -> _Source: return _Source(src=str(type_descr.raw)) +def _gen_streaming_type_import_and_ref( + stream_type: definitions.StreamingTypeDescriptor, +) -> _Source: + """Unlike other `_gen`-helpers, this does not define a type, it creates a symbol.""" + mod = stream_type.origin_type.__module__ + arg = stream_type.arg_type.__name__ + type_src = f"{mod}.{stream_type.origin_type.__name__}[{arg}]" + return _Source(src=type_src, imports={f"import {mod}"}) + + def _gen_chainlet_import_and_ref( chainlet_descriptor: definitions.ChainletAPIDescriptor, ) -> _Source: """Returns e.g. ("from sub_package import module", "module.OutputType").""" - return _gen_import_and_ref(chainlet_descriptor.chainlet_cls) + return _gen_pydantic_import_and_ref(chainlet_descriptor.chainlet_cls) # I/O used by Stubs and Truss models ################################################### @@ -206,28 +216,30 @@ async def run_remote( ) -> tuple[shared_chainlet.SplitTextOutput, int]: ``` """ - if endpoint.is_generator: - raise NotImplementedError("Generator.") - imports = set() - args = [] + args = ["self"] for arg in endpoint.input_args: arg_ref = _gen_type_import_and_ref(arg.type) imports.update(arg_ref.imports) args.append(f"{arg.name}: {arg_ref.src}") - outputs: list[str] = [] - for output_type in endpoint.output_types: - _update_src(_gen_type_import_and_ref(output_type), outputs, imports) - - if len(outputs) == 1: - output = outputs[0] + if endpoint.is_streaming: + streaming_src = _gen_streaming_type_import_and_ref(endpoint.streaming_type) + imports.update(streaming_src.imports) + output = streaming_src.src else: - output = f"tuple[{', '.join(outputs)}]" + outputs: list[str] = [] + for output_type in endpoint.output_types: + _update_src(_gen_type_import_and_ref(output_type), outputs, imports) + + if len(outputs) == 1: + output = outputs[0] + else: + output = f"tuple[{', '.join(outputs)}]" def_str = "async def" if endpoint.is_async else "def" return _Source( - src=f"{def_str} {endpoint.name}(self, {','.join(args)}) -> {output}:", + src=f"{def_str} {endpoint.name}({','.join(args)}) -> {output}:", imports=imports, ) @@ -244,23 +256,42 @@ def _stub_endpoint_body_src( return SplitTextOutput.model_validate(json_result).output ``` """ - if endpoint.is_generator: - raise NotImplementedError("Generator") - imports: set[str] = set() args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args] - inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()" + if args: + inputs = ( + f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()" + ) + else: + inputs = "{}" + parts = [] # Invoke remote. - if endpoint.is_async: - remote_call = f"await self._remote.predict_async({inputs})" + if not endpoint.is_streaming: + if endpoint.is_async: + remote_call = f"await self._remote.predict_async({inputs})" + else: + remote_call = f"self._remote.predict_sync({inputs})" + + parts = [f"json_result = {remote_call}"] + # Unpack response and parse as pydantic models if needed. + output_model_name = _get_output_model_name(chainlet_name) + parts.append(f"return {output_model_name}.model_validate(json_result).root") else: - remote_call = f"self._remote.predict_sync({inputs})" + if endpoint.is_async: + parts.append( + f"async for data in await self._remote.predict_async_stream({inputs}):", + ) + if endpoint.streaming_type.is_string: + parts.append(_indent("yield data.decode()")) + else: + parts.append(_indent("yield data")) + else: + raise NotImplementedError( + "`Streaming endpoints (containing `yield` statements) are only " + "supported for async endpoints." + ) - parts = [f"json_result = {remote_call}"] - # Unpack response and parse as pydantic models if needed. - output_model_name = _get_output_model_name(chainlet_name) - parts.append(f"return {output_model_name}.model_validate(json_result).root") return _Source(src="\n".join(parts), imports=imports) @@ -290,8 +321,9 @@ async def run_remote( src_parts: list[str] = [] input_src = _gen_truss_input_pydantic(chainlet) _update_src(input_src, src_parts, imports) - output_src = _gen_truss_output_pydantic(chainlet) - _update_src(output_src, src_parts, imports) + if not chainlet.endpoint.is_streaming: + output_src = _gen_truss_output_pydantic(chainlet) + _update_src(output_src, src_parts, imports) signature = _stub_endpoint_signature_src(chainlet.endpoint) imports.update(signature.imports) body = _stub_endpoint_body_src(chainlet.endpoint, chainlet.name) @@ -396,42 +428,51 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source: """Generates AST for the `predict` method of the truss model.""" - if chainlet_descriptor.endpoint.is_generator: - raise NotImplementedError("Generator.") - imports: set[str] = {"from truss_chains import utils"} parts: list[str] = [] def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def" input_model_name = _get_input_model_name(chainlet_descriptor.name) - output_model_name = _get_output_model_name(chainlet_descriptor.name) + if chainlet_descriptor.endpoint.is_streaming: + streaming_src = _gen_streaming_type_import_and_ref( + chainlet_descriptor.endpoint.streaming_type + ) + imports.update(streaming_src.imports) + output_type_name = streaming_src.src + else: + output_type_name = _get_output_model_name(chainlet_descriptor.name) imports.add("import starlette.requests") imports.add("from truss_chains import stub") parts.append( f"{def_str} predict(self, inputs: {input_model_name}, " - f"request: starlette.requests.Request) -> {output_model_name}:" + f"request: starlette.requests.Request) -> {output_type_name}:" ) # Add error handling context manager: parts.append( _indent( f"with stub.trace_parent(request), utils.exception_to_http_error(" - f'include_stack=True, chainlet_name="{chainlet_descriptor.name}"):' + f'chainlet_name="{chainlet_descriptor.name}"):' ) ) # Invoke Chainlet. - maybe_await = "await " if chainlet_descriptor.endpoint.is_async else "" + if ( + chainlet_descriptor.endpoint.is_async + and not chainlet_descriptor.endpoint.is_streaming + ): + maybe_await = "await " + else: + maybe_await = "" run_remote = chainlet_descriptor.endpoint.name - # `exclude_unset` is important to handle arguments where `run_remote` has a default - # correctly. In that case the pydantic model has an optional field and defaults to - # `None`. But there might also be situations where the user explicitly passes a - # value of `None`. So the condition whether to pass that argument or not is - # whether it was *set* in the model. It is considered unset, if the incoming JSON - # (from which the model was parsed/initialized) does not have that key. + # See docs of `pydantic_set_field_dict` for why this is needed. args = "**utils.pydantic_set_field_dict(inputs)" parts.append( _indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2) ) - result_pydantic = f"{output_model_name}(result)" - parts.append(_indent(f"return {result_pydantic}")) + if chainlet_descriptor.endpoint.is_streaming: + # Streaming returns raw iterator, no pydantic model. + parts.append(_indent("return result")) + else: + result_pydantic = f"{output_type_name}(result)" + parts.append(_indent(f"return {result_pydantic}")) return _Source(src="\n".join(parts), imports=imports) @@ -496,8 +537,9 @@ def _gen_truss_chainlet_file( input_src = _gen_truss_input_pydantic(chainlet_descriptor) _update_src(input_src, src_parts, imports) - output_src = _gen_truss_output_pydantic(chainlet_descriptor) - _update_src(output_src, src_parts, imports) + if not chainlet_descriptor.endpoint.is_streaming: + output_src = _gen_truss_output_pydantic(chainlet_descriptor) + _update_src(output_src, src_parts, imports) model_src = _gen_truss_chainlet_model(chainlet_descriptor) _update_src(model_src, src_parts, imports) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index efe3c1095..0510f9f4c 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -524,6 +524,19 @@ def is_pydantic(self) -> bool: ) +class StreamingTypeDescriptor(TypeDescriptor): + origin_type: type + arg_type: type + + @property + def is_string(self) -> bool: + return self.arg_type is str + + @property + def is_pydantic(self) -> bool: + return False + + class InputArg(SafeModelNonSerializable): name: str type: TypeDescriptor @@ -535,7 +548,17 @@ class EndpointAPIDescriptor(SafeModelNonSerializable): input_args: list[InputArg] output_types: list[TypeDescriptor] is_async: bool - is_generator: bool + is_streaming: bool + + @property + def streaming_type(self) -> StreamingTypeDescriptor: + if ( + not self.is_streaming + or len(self.output_types) != 1 + or not isinstance(self.output_types[0], StreamingTypeDescriptor) + ): + raise ValueError(f"{self} is not a streaming endpoint.") + return cast(StreamingTypeDescriptor, self.output_types[0]) class DependencyDescriptor(SafeModelNonSerializable): diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 771ee73d0..db2f822fa 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -36,11 +36,13 @@ _SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None} _SIMPLE_CONTAINERS = {list, dict} +_STREAM_TYPES = {bytes, str} _DOCS_URL_CHAINING = ( "https://docs.baseten.co/chains/concepts#depends-call-other-chainlets" ) _DOCS_URL_LOCAL = "https://docs.baseten.co/chains/guide#local-development" +_DOCS_URL_STREAMING = "https://docs.baseten.co/chains/guide#streaming" _ENTRYPOINT_ATTR_NAME = "_chains_entrypoint" @@ -48,6 +50,7 @@ _P = ParamSpec("_P") _R = TypeVar("_R") + # Error Collector ###################################################################### @@ -296,6 +299,38 @@ def _validate_io_type( _collect_error(error_msg, _ErrorKind.IO_TYPE_ERROR, location) +def _validate_streaming_output_type( + annotation: Any, location: _ErrorLocation +) -> definitions.StreamingTypeDescriptor: + origin = get_origin(annotation) + assert origin in (collections.abc.AsyncIterator, collections.abc.Iterator) + args = get_args(annotation) + if len(args) < 1: + _collect_error( + f"Iterators must be annotated with type (one of {list(x.__name__ for x in _STREAM_TYPES)}).", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + return definitions.StreamingTypeDescriptor( + raw=annotation, origin_type=origin, arg_type=bytes + ) + + assert len(args) == 1, "Iterator type annotations cannot have more than 1 arg." + arg = args[0] + if arg not in _STREAM_TYPES: + msg = ( + "Streaming endpoints (containing `yield` statements) can only yield string " + "or byte items. For streaming structured pydantic data, use `stream_writer`" + "and `stream_reader` helpers.\n" + f"See streaming docs: {_DOCS_URL_STREAMING}" + ) + _collect_error(msg, _ErrorKind.IO_TYPE_ERROR, location) + + return definitions.StreamingTypeDescriptor( + raw=annotation, origin_type=origin, arg_type=arg + ) + + def _validate_endpoint_params( params: list[inspect.Parameter], location: _ErrorLocation ) -> list[definitions.InputArg]: @@ -336,8 +371,9 @@ def _validate_endpoint_params( def _validate_endpoint_output_types( - annotation: Any, signature, location: _ErrorLocation + annotation: Any, signature, location: _ErrorLocation, is_streaming: bool ) -> list[definitions.TypeDescriptor]: + has_streaming_type = False if annotation == inspect.Parameter.empty: _collect_error( "Return values of endpoints must be type annotated. Got:\n" @@ -346,14 +382,36 @@ def _validate_endpoint_output_types( location, ) return [] - if get_origin(annotation) is tuple: + origin = get_origin(annotation) + if origin is tuple: output_types = [] for i, arg in enumerate(get_args(annotation)): _validate_io_type(arg, f"return_type[{i}]", location) output_types.append(definitions.TypeDescriptor(raw=arg)) + + elif origin in (collections.abc.AsyncIterator, collections.abc.Iterator): + output_types = [_validate_streaming_output_type(annotation, location)] + has_streaming_type = True + if not is_streaming: + _collect_error( + "If the endpoint returns an iterator (streaming), it must have `yield` " + "statements.", + _ErrorKind.IO_TYPE_ERROR, + location, + ) else: _validate_io_type(annotation, "return_type", location) output_types = [definitions.TypeDescriptor(raw=annotation)] + + if is_streaming and not has_streaming_type: + _collect_error( + "If the endpoint is streaming (has `yield` statements), the return type " + "must be an iterator (e.g. `AsyncIterator[bytes]`). Got:\n" + f"\t{location.method_name}{signature} -> {annotation}", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + return output_types @@ -384,7 +442,7 @@ def _validate_and_describe_endpoint( # Return a "neutral dummy" if validation fails, this allows to safely # continue checking for more errors. return definitions.EndpointAPIDescriptor( - input_args=[], output_types=[], is_async=False, is_generator=False + input_args=[], output_types=[], is_async=False, is_streaming=False ) # This is the unbound method. @@ -402,26 +460,38 @@ def _validate_and_describe_endpoint( # Return a "neutral dummy" if validation fails, this allows to safely # continue checking for more errors. return definitions.EndpointAPIDescriptor( - input_args=[], output_types=[], is_async=False, is_generator=False + input_args=[], output_types=[], is_async=False, is_streaming=False ) signature = inspect.signature(endpoint_method) input_args = _validate_endpoint_params( list(signature.parameters.values()), location ) - output_types = _validate_endpoint_output_types( - signature.return_annotation, signature, location - ) - if inspect.isasyncgenfunction(endpoint_method): is_async = True - is_generator = True + is_streaming = True elif inspect.iscoroutinefunction(endpoint_method): is_async = True - is_generator = False + is_streaming = False else: is_async = False - is_generator = inspect.isgeneratorfunction(endpoint_method) + is_streaming = inspect.isgeneratorfunction(endpoint_method) + + output_types = _validate_endpoint_output_types( + signature.return_annotation, + signature, + location, + is_streaming, + ) + + if is_streaming: + if not is_async: + _collect_error( + "`Streaming endpoints (containing `yield` statements) are only " + "supported for async endpoints.", + _ErrorKind.IO_TYPE_ERROR, + location, + ) if not is_async: warnings.warn( @@ -446,7 +516,7 @@ def _validate_and_describe_endpoint( input_args=input_args, output_types=output_types, is_async=is_async, - is_generator=is_generator, + is_streaming=is_streaming, ) @@ -995,7 +1065,7 @@ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None: assert chainlet_cls._init_is_patched # Dependency chainlets are instantiated here, using their __init__ # that is patched for local. - logging.warning(f"Making first {dep.name}.") + logging.info(f"Making first {dep.name}.") instance = chainlet_cls() # type: ignore # Here init args are patched. cls_to_instance[chainlet_cls] = instance kwargs_mod[arg_name] = instance diff --git a/truss-chains/truss_chains/model_skeleton.py b/truss-chains/truss_chains/model_skeleton.py index 6f637e8d9..4aa053178 100644 --- a/truss-chains/truss_chains/model_skeleton.py +++ b/truss-chains/truss_chains/model_skeleton.py @@ -16,9 +16,8 @@ def __init__( config: dict, data_dir: pathlib.Path, secrets: secrets_resolver.Secrets, - environment: Optional[ - dict - ] = None, # TODO: Remove the default value once all truss versions are synced up. + # TODO: Remove the default value once all truss versions are synced up. + environment: Optional[dict] = None, ) -> None: truss_metadata: definitions.TrussMetadata = ( definitions.TrussMetadata.model_validate( diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 91304c4ba..2b5f73863 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -44,8 +44,7 @@ class DockerTrussService(b10_service.TrussService): """This service is for Chainlets (not for Chains).""" def __init__(self, port: int, is_draft: bool, **kwargs): - # http://localhost:{port} seems to only work *sometimes* with docker. - remote_url = f"http://host.docker.internal:{port}" + remote_url = f"http://localhost:{port}" self._port = port super().__init__(remote_url, is_draft, **kwargs) @@ -411,8 +410,11 @@ def push( is_draft=True, port=port, ) + docker_internal_url = service.predict_url.replace( + "localhost", "host.docker.internal" + ) chainlet_to_predict_url[chainlet_artifact.display_name] = { - "predict_url": service.predict_url, + "predict_url": docker_internal_url, } chainlet_to_service[chainlet_artifact.name] = service diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py new file mode 100644 index 000000000..9d9a1cae8 --- /dev/null +++ b/truss-chains/truss_chains/streaming.py @@ -0,0 +1,395 @@ +import asyncio +import dataclasses +import enum +import struct +import sys +from collections.abc import AsyncIterator +from typing import Generic, Optional, Protocol, Type, TypeVar, Union, overload + +import pydantic + +_TAG_SIZE = 5 # uint8 + uint32. +_JSONType = Union[ + str, int, float, bool, None, list["_JSONType"], dict[str, "_JSONType"] +] +_T = TypeVar("_T") + +if sys.version_info < (3, 10): + + async def anext(iterable: AsyncIterator[_T]) -> _T: + return await iterable.__anext__() + + +# Note on the (verbose) typing in this module: we want exact typing of the reader and +# writer helpers, while also allowing flexibility to users to leave out header/footer +# if not needed. +# Putting both a constraint on the header/footer types to be pydantic +# models, but also letting them be optional is not well-supported by typing tools, +# (missing feature is using type variables a constraints on other type variables). +# +# A functional, yet verbose workaround that gives correct variadic type inference, +# is using intermediate type variables `HeaderT` <-> `HeaderTT` and in conjunction with +# mapping out all usage combinations with overloads (the overloads essentially allow +# "conditional" binding of type vars). These overloads also allow to use granular +# reader/writer sub-classes conditionally, that have the read/write methods only for the +# data types configured, and implemented DRY with mixin classes. +ItemT = TypeVar("ItemT", bound=pydantic.BaseModel) +HeaderT = TypeVar("HeaderT", bound=pydantic.BaseModel) +FooterT = TypeVar("FooterT", bound=pydantic.BaseModel) + +# Since header/footer could also be `None`, we need an extra type variable that +# can assume either `Type[HeaderT]` or `None` - `Type[None]` causes issues. +HeaderTT = TypeVar("HeaderTT") +FooterTT = TypeVar("FooterTT") + + +@dataclasses.dataclass +class StreamTypes(Generic[ItemT, HeaderTT, FooterTT]): + item_type: Type[ItemT] + header_type: HeaderTT # Is either `Type[HeaderT]` or `None`. + footer_type: FooterTT # Is either `Type[FooterT]` or `None`. + + +@overload +def stream_types( + item_type: Type[ItemT], + *, + header_type: Type[HeaderT], + footer_type: Type[FooterT], +) -> StreamTypes[ItemT, HeaderT, FooterT]: ... + + +@overload +def stream_types( + item_type: Type[ItemT], + *, + header_type: Type[HeaderT], +) -> StreamTypes[ItemT, HeaderT, None]: ... + + +@overload +def stream_types( + item_type: Type[ItemT], + *, + footer_type: Type[FooterT], +) -> StreamTypes[ItemT, None, FooterT]: ... + + +@overload +def stream_types(item_type: Type[ItemT]) -> StreamTypes[ItemT, None, None]: ... + + +def stream_types( + item_type: Type[ItemT], + *, + header_type: Optional[Type[HeaderT]] = None, + footer_type: Optional[Type[FooterT]] = None, +) -> StreamTypes: + """Creates a bundle of item type and potentially header/footer types, + each as pydantic model.""" + # This indirection for creating `StreamTypes` is needed to get generic typing. + return StreamTypes(item_type, header_type, footer_type) + + +# Reading ############################################################################## + + +class _Delimiter(enum.IntEnum): + NOT_SET = enum.auto() + HEADER = enum.auto() + ITEM = enum.auto() + FOOTER = enum.auto() + END = enum.auto() + + +class _Streamer(Generic[ItemT, HeaderTT, FooterTT]): + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + + def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: + self._stream_types = types + + +# Reading ############################################################################## + + +class _ByteReader: + """Helper to provide `readexactly` API for an async bytes iterator.""" + + def __init__(self, source: AsyncIterator[bytes]) -> None: + self._source = source + self._buffer = bytearray() + + async def readexactly(self, num_bytes: int) -> bytes: + while len(self._buffer) < num_bytes: + try: + chunk = await anext(self._source) + except StopAsyncIteration: + break + self._buffer.extend(chunk) + + if len(self._buffer) < num_bytes: + if len(self._buffer) == 0: + raise EOFError() + raise asyncio.IncompleteReadError(self._buffer, num_bytes) + + result = bytes(self._buffer[:num_bytes]) + del self._buffer[:num_bytes] + return result + + +class _StreamReaderProtocol(Protocol[ItemT, HeaderTT, FooterTT]): + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + _footer_data: Optional[bytes] + + async def _read(self) -> tuple[_Delimiter, bytes]: ... + + +class _StreamReader(_Streamer[ItemT, HeaderTT, FooterTT]): + _stream: _ByteReader + _footer_data: Optional[bytes] + + def __init__( + self, + types: StreamTypes[ItemT, HeaderTT, FooterTT], + stream: AsyncIterator[bytes], + ) -> None: + super().__init__(types) + self._stream = _ByteReader(stream) + self._footer_data = None + + @staticmethod + def _unpack_tag(tag: bytes) -> tuple[_Delimiter, int]: + enum_value, length = struct.unpack(">BI", tag) + return _Delimiter(enum_value), length + + async def _read(self) -> tuple[_Delimiter, bytes]: + try: + tag = await self._stream.readexactly(_TAG_SIZE) + # It's ok to read nothing (end of stream), but unexpected to read partial. + except asyncio.IncompleteReadError: + raise + except EOFError: + return _Delimiter.END, b"" + + delimiter, length = self._unpack_tag(tag) + if not length: + return delimiter, b"" + data_bytes = await self._stream.readexactly(length) + print(f"Read Delimiter: {delimiter}") + return delimiter, data_bytes + + async def read_items(self) -> AsyncIterator[ItemT]: + delimiter, data_bytes = await self._read() + if delimiter == _Delimiter.HEADER: + raise ValueError( + "Called `read_items`, but there the stream contains header data, which " + "is not consumed. Call `read_header` first or remove sending a header." + ) + if delimiter in (_Delimiter.FOOTER, _Delimiter.END): # In case of 0 items. + self._footer_data = data_bytes + return + + assert delimiter == _Delimiter.ITEM + while True: + yield self._stream_types.item_type.model_validate_json(data_bytes) + # We don't know if the next data is another item, footer or the end. + delimiter, data_bytes = await self._read() + if delimiter == _Delimiter.END: + return + if delimiter == _Delimiter.FOOTER: + self._footer_data = data_bytes + return + + +class _HeaderReadMixin(_Streamer[ItemT, HeaderT, FooterTT]): + async def read_header( + self: _StreamReaderProtocol[ItemT, HeaderT, FooterTT], + ) -> HeaderT: + delimiter, data_bytes = await self._read() + if delimiter != _Delimiter.HEADER: + raise ValueError("Stream does not contain header.") + return self._stream_types.header_type.model_validate_json(data_bytes) + + +class _FooterReadMixin(_Streamer[ItemT, HeaderTT, FooterT]): + _footer_data: Optional[bytes] + + async def read_footer( + self: _StreamReaderProtocol[ItemT, HeaderTT, FooterT], + ) -> FooterT: + if self._footer_data is None: + delimiter, data_bytes = await self._read() + if delimiter != _Delimiter.FOOTER: + raise ValueError("Stream does not contain footer.") + self._footer_data = data_bytes + + footer = self._stream_types.footer_type.model_validate_json(self._footer_data) + self._footer_data = None + return footer + + +class StreamReaderWithHeader( + _StreamReader[ItemT, HeaderT, FooterTT], _HeaderReadMixin[ItemT, HeaderT, FooterTT] +): ... + + +class StreamReaderWithFooter( + _StreamReader[ItemT, HeaderTT, FooterT], _FooterReadMixin[ItemT, HeaderTT, FooterT] +): ... + + +class StreamReaderFull( + _StreamReader[ItemT, HeaderT, FooterT], + _HeaderReadMixin[ItemT, HeaderT, FooterT], + _FooterReadMixin[ItemT, HeaderT, FooterT], +): ... + + +@overload +def stream_reader( + types: StreamTypes[ItemT, None, None], + stream: AsyncIterator[bytes], +) -> _StreamReader[ItemT, None, None]: ... + + +@overload +def stream_reader( + types: StreamTypes[ItemT, HeaderT, None], + stream: AsyncIterator[bytes], +) -> StreamReaderWithHeader[ItemT, HeaderT, None]: ... + + +@overload +def stream_reader( + types: StreamTypes[ItemT, None, FooterT], + stream: AsyncIterator[bytes], +) -> StreamReaderWithFooter[ItemT, None, FooterT]: ... + + +@overload +def stream_reader( + types: StreamTypes[ItemT, HeaderT, FooterT], + stream: AsyncIterator[bytes], +) -> StreamReaderFull[ItemT, HeaderT, FooterT]: ... + + +def stream_reader( + types: StreamTypes[ItemT, HeaderTT, FooterTT], + stream: AsyncIterator[bytes], +) -> _StreamReader: + if types.header_type is None and types.footer_type is None: + return _StreamReader(types, stream) + if types.header_type is None: + return StreamReaderWithFooter(types, stream) + if types.footer_type is None: + return StreamReaderWithHeader(types, stream) + + return StreamReaderFull(types, stream) + + +# Writing ############################################################################## + + +class _StreamWriterProtocol(Protocol[ItemT, HeaderTT, FooterTT]): + _stream_types: StreamTypes[ItemT, HeaderTT, FooterTT] + _last_sent: _Delimiter + + def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: ... + + +class _StreamWriter(_Streamer[ItemT, HeaderTT, FooterTT]): + def __init__(self, types: StreamTypes[ItemT, HeaderTT, FooterTT]) -> None: + super().__init__(types) + self._last_sent = _Delimiter.NOT_SET + self._stream_types = types + + @staticmethod + def _pack_tag(delimiter: _Delimiter, length: int) -> bytes: + return struct.pack(">BI", delimiter.value, length) + + def _serialize(self, obj: pydantic.BaseModel, delimiter: _Delimiter) -> bytes: + data_bytes = obj.model_dump_json().encode() + data = bytearray(self._pack_tag(delimiter, len(data_bytes))) + data.extend(data_bytes) + # Starlette cannot handle byte array, but view works.. + return memoryview(data) + + def yield_item(self, item: ItemT) -> bytes: + if self._last_sent in (_Delimiter.FOOTER, _Delimiter.END): + raise ValueError("Cannot yield item after sending footer / closing stream.") + self._last_sent = _Delimiter.ITEM + return self._serialize(item, _Delimiter.ITEM) + + +class _HeaderWriteMixin(_Streamer[ItemT, HeaderT, FooterTT]): + def yield_header( + self: _StreamWriterProtocol[ItemT, HeaderT, FooterTT], header: HeaderT + ) -> bytes: + if self._last_sent != _Delimiter.NOT_SET: + raise ValueError("Cannot yield header after other data has been sent.") + self._last_sent = _Delimiter.HEADER + return self._serialize(header, _Delimiter.HEADER) + + +class _FooterWriteMixin(_Streamer[ItemT, HeaderTT, FooterT]): + def yield_footer( + self: _StreamWriterProtocol[ItemT, HeaderTT, FooterT], footer: FooterT + ) -> bytes: + if self._last_sent == _Delimiter.END: + raise ValueError("Cannot yield footer after closing stream.") + self._last_sent = _Delimiter.FOOTER + return self._serialize(footer, _Delimiter.FOOTER) + + +class StreamWriterWithHeader( + _StreamWriter[ItemT, HeaderT, FooterTT], _HeaderWriteMixin[ItemT, HeaderT, FooterTT] +): ... + + +class StreamWriterWithFooter( + _StreamWriter[ItemT, HeaderTT, FooterT], _FooterWriteMixin[ItemT, HeaderTT, FooterT] +): ... + + +class StreamWriterFull( + _StreamWriter[ItemT, HeaderT, FooterT], + _HeaderWriteMixin[ItemT, HeaderT, FooterT], + _FooterWriteMixin[ItemT, HeaderT, FooterT], +): ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, None, None], +) -> _StreamWriter[ItemT, None, None]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, HeaderT, None], +) -> StreamWriterWithHeader[ItemT, HeaderT, None]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, None, FooterT], +) -> StreamWriterWithFooter[ItemT, None, FooterT]: ... + + +@overload +def stream_writer( + types: StreamTypes[ItemT, HeaderT, FooterT], +) -> StreamWriterFull[ItemT, HeaderT, FooterT]: ... + + +def stream_writer( + types: StreamTypes[ItemT, HeaderTT, FooterTT], +) -> _StreamWriter: + if types.header_type is None and types.footer_type is None: + return _StreamWriter(types) + if types.header_type is None: + return StreamWriterWithFooter(types) + if types.footer_type is None: + return StreamWriterWithHeader(types) + + return StreamWriterFull(types) diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 6e0927a30..5de4f66de 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -6,7 +6,17 @@ import ssl import threading import time -from typing import Any, ClassVar, Iterator, Mapping, Optional, Type, TypeVar, final +from typing import ( + Any, + AsyncIterator, + ClassVar, + Iterator, + Mapping, + Optional, + Type, + TypeVar, + final, +) import aiohttp import httpx @@ -127,6 +137,9 @@ async def _client_async(self) -> aiohttp.ClientSession: return self._cached_async_client[0] def predict_sync(self, json_payload): + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } retrying = tenacity.Retrying( stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), retry=tenacity.retry_if_exception_type(Exception), @@ -139,14 +152,14 @@ def predict_sync(self, json_payload): try: with self._sync_num_requests as num_requests: self._maybe_warn_for_overload(num_requests) - resp = self._client_sync().post( + response = self._client_sync().post( self._service_descriptor.predict_url, json=json_payload, - headers={ - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - }, + headers=headers, ) - return utils.handle_response(resp, self.name) + utils.response_raise_errors(response, self.name) + return response.json() + # As a special case we invalidate the client in case of certificate # errors. This has happened in the past and is a defensive measure. except ssl.SSLError: @@ -154,6 +167,39 @@ def predict_sync(self, json_payload): raise async def predict_async(self, json_payload): + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } + retrying = tenacity.AsyncRetrying( + stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), + retry=tenacity.retry_if_exception_type(Exception), + reraise=True, + ) + async for attempt in retrying: + with attempt: + if (num := attempt.retry_state.attempt_number) > 1: + logging.info(f"Retrying `{self.name}`, " f"attempt {num}") + try: + client = await self._client_async() + async with self._async_num_requests as num_requests: + self._maybe_warn_for_overload(num_requests) + async with client.post( + self._service_descriptor.predict_url, + json=json_payload, + headers=headers, + ) as response: + await utils.async_response_raise_errors(response, self.name) + return await response.json() + # As a special case we invalidate the client in case of certificate + # errors. This has happened in the past and is a defensive measure. + except ssl.SSLError: + self._cached_async_client = None + raise + + async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: # type: ignore[return] # Handled by retries. + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } retrying = tenacity.AsyncRetrying( stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), retry=tenacity.retry_if_exception_type(Exception), @@ -167,14 +213,14 @@ async def predict_async(self, json_payload): client = await self._client_async() async with self._async_num_requests as num_requests: self._maybe_warn_for_overload(num_requests) - resp = await client.post( + response = await client.post( self._service_descriptor.predict_url, json=json_payload, - headers={ - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - }, + headers=headers, ) - return await utils.handle_async_response(resp, self.name) + await utils.async_response_raise_errors(response, self.name) + return response.content.iter_any() + # As a special case we invalidate the client in case of certificate # errors. This has happened in the past and is a defensive measure. except ssl.SSLError: diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index f29853542..28a485451 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -186,28 +186,27 @@ def populate_chainlet_service_predict_urls( # Error Propagation Utils. ############################################################# +# TODO: move request related code into `stub.py`. -def _handle_exception( - exception: Exception, include_stack: bool, chainlet_name: str -) -> NoReturn: +def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: """Raises `fastapi.HTTPException` with `RemoteErrorDetail` as detail.""" if hasattr(exception, "__module__"): exception_module_name = exception.__module__ else: exception_module_name = None - if include_stack: - error_stack = traceback.extract_tb(exception.__traceback__) - # Exclude the error handling functions from the stack trace. - exclude_frames = {exception_to_http_error.__name__, handle_response.__name__} - final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] - stack = list( - [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] - ) - else: - stack = [] - + error_stack = traceback.extract_tb(exception.__traceback__) + # Exclude the error handling functions from the stack trace. + exclude_frames = { + exception_to_http_error.__name__, + response_raise_errors.__name__, + async_response_raise_errors.__name__, + } + final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] + stack = list( + [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] + ) error = definitions.RemoteErrorDetail( remote_name=chainlet_name, exception_cls_name=exception.__class__.__name__, @@ -221,11 +220,12 @@ def _handle_exception( @contextlib.contextmanager -def exception_to_http_error(include_stack: bool, chainlet_name: str) -> Iterator[None]: +def exception_to_http_error(chainlet_name: str) -> Iterator[None]: + # TODO: move chainlet name from here to caller side. try: yield except Exception as e: - _handle_exception(e, include_stack, chainlet_name) + _handle_exception(e, chainlet_name) def _resolve_exception_class( @@ -279,8 +279,8 @@ def _handle_response_error(response_json: dict, remote_name: str): raise exception_cls(msg) -def handle_response(response: httpx.Response, remote_name: str) -> Any: - """For successful requests returns JSON, otherwise raises error. +def response_raise_errors(response: httpx.Response, remote_name: str) -> None: + """In case of error, raise it. If the response error contains `RemoteErrorDetail`, it tries to re-raise the same exception that was raised remotely and falls back to @@ -334,17 +334,11 @@ def handle_response(response: httpx.Response, remote_name: str) -> Any: ) from e _handle_response_error(response_json=response_json, remote_name=remote_name) - return response.json() - -async def handle_async_response( +async def async_response_raise_errors( response: aiohttp.ClientResponse, remote_name: str -) -> Any: - """For successful requests returns JSON, otherwise raises error. - - See `handle_response` for more details on the specifics of the error-handling - here. - """ +) -> None: + """Async version of `async_response_raise_errors`.""" if response.status >= 400: try: response_json = await response.json() @@ -353,10 +347,10 @@ async def handle_async_response( "Could not get JSON from error response. Status: " f"`{response.status}`." ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) - return await response.json() + +######################################################################################## class InjectedError(Exception): @@ -417,7 +411,21 @@ def issubclass_safe(x: Any, cls: type) -> bool: def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]: - """Like `BaseModel.model_dump(exclude_unset=True), but only top-level.""" + """Like `BaseModel.model_dump(exclude_unset=True), but only top-level. + + This is used to get kwargs for invoking a function, while dropping fields for which + there is no value explicitly set in the pydantic model. A field is considered unset + if the key was not present in the incoming JSON request (from which the model was + parsed/initialized) and the pydantic model has a default value, such as `None`. + + By dropping these unset fields, the default values from the function definition + will be used instead. This behavior ensures correct handling of arguments where + the function has a default, such as in the case of `run_remote`. If the model has + an optional field defaulting to `None`, this approach differentiates between + the user explicitly passing a value of `None` and the field being unset in the + request. + + """ return {name: getattr(obj, name) for name in obj.__fields_set__} diff --git a/truss/templates/server/common/schema.py b/truss/templates/server/common/schema.py index 89e7060f4..0201af824 100644 --- a/truss/templates/server/common/schema.py +++ b/truss/templates/server/common/schema.py @@ -2,8 +2,10 @@ from typing import ( Any, AsyncGenerator, + AsyncIterator, Awaitable, Generator, + Iterator, List, Optional, Type, @@ -83,7 +85,7 @@ def _annotation_is_pydantic_model(annotation: Any) -> bool: def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: """ - Therea are 4 possible cases for output_annotation: + There are 4 possible cases for output_annotation: 1. Data object -- represented by a Pydantic BaseModel 2. Streaming -- represented by a Generator or AsyncGenerator 3. Async -- represented by an Awaitable @@ -117,7 +119,7 @@ def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: def _is_generator_type(annotation: Any) -> bool: base_type = get_origin(annotation) return isinstance(base_type, type) and issubclass( - base_type, (Generator, AsyncGenerator) + base_type, (Generator, AsyncGenerator, Iterator, AsyncIterator) ) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 82bab57d4..ab28713d2 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -27,6 +27,7 @@ ) import opentelemetry.sdk.trace as sdk_trace +import pydantic import starlette.requests import starlette.responses from anyio import Semaphore, to_thread @@ -56,6 +57,15 @@ TRT_LLM_EXTENSION_NAME = "trt_llm" POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30 +InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel] +OutputType = Union[ + serialization.JSONType, + serialization.MsgPackType, + Generator[bytes, None, None], + AsyncGenerator[bytes, None], + "starlette.responses.Response", +] + @asynccontextmanager async def deferred_semaphore_and_span( @@ -520,7 +530,7 @@ async def poll_for_environment_updates(self) -> None: async def preprocess( self, - inputs: serialization.InputType, + inputs: InputType, request: starlette.requests.Request, ) -> Any: descriptor = self.model_descriptor.preprocess @@ -538,7 +548,7 @@ async def predict( self, inputs: Any, request: starlette.requests.Request, - ) -> Union[serialization.OutputType, Any]: + ) -> Union[OutputType, Any]: # The result can be a serializable data structure, byte-generator, a request, # or, if `postprocessing` is used, anything. In the last case postprocessing # must convert the result to something serializable. @@ -555,9 +565,9 @@ async def predict( async def postprocess( self, - result: Union[serialization.InputType, Any], + result: Union[InputType, Any], request: starlette.requests.Request, - ) -> serialization.OutputType: + ) -> OutputType: # The postprocess function can handle outputs of `predict`, but not # generators and responses - in that case predict must return directly # and postprocess is skipped. @@ -642,9 +652,9 @@ async def _buffered_response_generator() -> AsyncGenerator[bytes, None]: async def __call__( self, - inputs: Optional[serialization.InputType], + inputs: Optional[InputType], request: starlette.requests.Request, - ) -> serialization.OutputType: + ) -> OutputType: """ Returns result from: preprocess -> predictor -> postprocess. """ @@ -726,6 +736,7 @@ async def __call__( ), tracing.detach_context(): postprocess_result = await self.postprocess(predict_result, request) + final_result: OutputType if isinstance(postprocess_result, BaseModel): # If we return a pydantic object, convert it back to a dict with tracing.section_as_event(span_post, "dump-pydantic"): diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 42b2293ae..37ab4c223 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -16,7 +16,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.routing import APIRoute as FastAPIRoute -from model_wrapper import ModelWrapper +from model_wrapper import InputType, ModelWrapper from opentelemetry import propagate as otel_propagate from opentelemetry import trace from opentelemetry.sdk import trace as sdk_trace @@ -104,7 +104,7 @@ async def _parse_body( body_raw: bytes, truss_schema: Optional[TrussSchema], span: trace.Span, - ) -> serialization.InputType: + ) -> InputType: if self.is_binary(request): with tracing.section_as_event(span, "binary-deserialize"): inputs = serialization.truss_msgpack_deserialize(body_raw) @@ -157,7 +157,7 @@ async def predict( with self._tracer.start_as_current_span( "predict-endpoint", context=trace_ctx ) as span: - inputs: Optional[serialization.InputType] + inputs: Optional[InputType] if model.model_descriptor.skip_input_parsing: inputs = None else: diff --git a/truss/templates/shared/serialization.py b/truss/templates/shared/serialization.py index a1281d4d4..21b099892 100644 --- a/truss/templates/shared/serialization.py +++ b/truss/templates/shared/serialization.py @@ -2,22 +2,9 @@ import uuid from datetime import date, datetime, time, timedelta from decimal import Decimal -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Callable, - Dict, - Generator, - List, - Optional, - Union, -) - -import pydantic +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union if TYPE_CHECKING: - import starlette.responses from numpy.typing import NDArray @@ -38,14 +25,6 @@ List["MsgPackType"], Dict[str, "MsgPackType"], ] -InputType = Union[JSONType, MsgPackType, pydantic.BaseModel] -OutputType = Union[ - JSONType, - MsgPackType, - Generator[bytes, None, None], - AsyncGenerator[bytes, None], - "starlette.responses.Response", -] # mostly cribbed from django.core.serializer.DjangoJSONEncoder