Skip to content

Commit

Permalink
Chains Streaming, fixes BT-10339 (#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Dec 2, 2024
1 parent bd60992 commit 7472bb6
Show file tree
Hide file tree
Showing 17 changed files with 1,128 additions and 179 deletions.
35 changes: 0 additions & 35 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,3 @@ jobs:
with:
use-verbose-mode: "yes"
folder-path: "docs"

enforce-chains-example-docs-sync:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v4
with:
lfs: true
fetch-depth: 2

- name: Fetch main branch
run: git fetch origin main

- name: Check if chains examples were modified
id: check_files
run: |
if git diff --name-only origin/main | grep -q '^truss-chains/examples/.*'; then
echo "chains_docs_update_needed=true" >> $GITHUB_ENV
echo "Chains examples were modified."
else
echo "chains_docs_update_needed=false" >> $GITHUB_ENV
echo "Chains examples were not modified."
echo "::notice file=truss-chains/examples/::Chains examples not modified."
fi
- name: Enforce acknowledgment in PR description
if: env.chains_docs_update_needed == 'true'
env:
DESCRIPTION: ${{ github.event.pull_request.body }}
run: |
if [[ "$DESCRIPTION" != *"UPDATE_DOCS=done"* && "$DESCRIPTION" != *"UPDATE_DOCS=not_needed"* ]]; then
echo "::error file=truss-chains/examples/::Chains examples were modified and ack not found in PR description. Verify whether docs need to be update (https://github.com/basetenlabs/docs.baseten.co/tree/main/chains) and add an ack tag `UPDATE_DOCS={done|not_needed}` to the PR description."
exit 1
else
echo "::notice file=truss-chains/examples/::Chains examples modified and ack found int PR description."
fi
107 changes: 107 additions & 0 deletions truss-chains/examples/streaming/streaming_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import asyncio
import time
from typing import AsyncIterator

import pydantic

import truss_chains as chains
from truss_chains import streaming


class Header(pydantic.BaseModel):
time: float
msg: str


class MyDataChunk(pydantic.BaseModel):
words: list[str]


class Footer(pydantic.BaseModel):
time: float
duration_sec: float
msg: str


class ConsumerOutput(pydantic.BaseModel):
header: Header
chunks: list[MyDataChunk]
footer: Footer
strings: str


STREAM_TYPES = streaming.stream_types(
MyDataChunk, header_type=Header, footer_type=Footer
)


class Generator(chains.ChainletBase):
"""Example that streams fully structured pydantic items with header and footer."""

async def run_remote(self) -> AsyncIterator[bytes]:
print("Entering Generator")
streamer = streaming.stream_writer(STREAM_TYPES)
header = Header(time=time.time(), msg="Start.")
yield streamer.yield_header(header)
for i in range(1, 5):
data = MyDataChunk(
words=[chr(x + 70) * x for x in range(1, i + 1)],
)
print("Yield")
yield streamer.yield_item(data)
await asyncio.sleep(0.05)

end_time = time.time()
footer = Footer(time=end_time, duration_sec=end_time - header.time, msg="Done.")
yield streamer.yield_footer(footer)
print("Exiting Generator")


class StringGenerator(chains.ChainletBase):
"""Minimal streaming example with strings (e.g. for raw LLM output)."""

async def run_remote(self) -> AsyncIterator[str]:
# Note: the "chunk" boundaries are lost, when streaming raw strings. You must
# add spaces and linebreaks to the items yourself..
yield "First "
yield "second "
yield "last."


class Consumer(chains.ChainletBase):
"""Consume that reads the raw streams and parses them."""

def __init__(
self,
generator=chains.depends(Generator),
string_generator=chains.depends(StringGenerator),
):
self._generator = generator
self._string_generator = string_generator

async def run_remote(self) -> ConsumerOutput:
print("Entering Consumer")
reader = streaming.stream_reader(STREAM_TYPES, self._generator.run_remote())
print("Consuming...")
header = await reader.read_header()
chunks = []
async for data in reader.read_items():
print(f"Read: {data}")
chunks.append(data)

footer = await reader.read_footer()
strings = []
async for part in self._string_generator.run_remote():
strings.append(part)

print("Exiting Consumer")
return ConsumerOutput(
header=header, chunks=chunks, footer=footer, strings="".join(strings)
)


if __name__ == "__main__":
with chains.run_local():
chain = Consumer()
result = asyncio.run(chain.run_remote())
print(result)
53 changes: 49 additions & 4 deletions truss-chains/tests/chains_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
@pytest.mark.integration
def test_chain():
with ensure_kill_all():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
tests_root = Path(__file__).parent.resolve()
chain_root = tests_root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
Expand Down Expand Up @@ -81,8 +81,8 @@ def test_chain():

@pytest.mark.asyncio
async def test_chain_local():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
tests_root = Path(__file__).parent.resolve()
chain_root = tests_root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with public_api.run_local():
with pytest.raises(ValueError):
Expand Down Expand Up @@ -119,3 +119,48 @@ async def test_chain_local():
match="Chainlets cannot be naively instantiated",
):
await entrypoint().run_remote(length=20, num_partitions=5)


@pytest.mark.integration
def test_streaming_chain():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="stream",
only_generate_trusses=False,
use_local_chains_src=True,
),
)
assert service is not None
response = service.run_remote({})
assert response.status_code == 200
print(response.json())
result = response.json()
print(result)
assert result["header"]["msg"] == "Start."
assert result["chunks"][0]["words"] == ["G"]
assert result["chunks"][1]["words"] == ["G", "HH"]
assert result["chunks"][2]["words"] == ["G", "HH", "III"]
assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
assert result["footer"]["duration_sec"] > 0
assert result["strings"] == "First second last."


@pytest.mark.asyncio
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote()
print(result)
assert result.header.msg == "Start."
assert result.chunks[0].words == ["G"]
assert result.chunks[1].words == ["G", "HH"]
assert result.chunks[2].words == ["G", "HH", "III"]
assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"]
assert result.footer.duration_sec > 0
assert result.strings == "First second last."
54 changes: 53 additions & 1 deletion truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib
import logging
import re
from typing import List
from typing import AsyncIterator, Iterator, List

import pydantic
import pytest
Expand Down Expand Up @@ -505,3 +505,55 @@ def run_remote(argument: object): ...
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
with public_api.run_local():
MultiIssue()


def test_raises_iterator_no_yield():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoYield\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint returns an iterator \(streaming\), it must have `yield` statements"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoYield(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator[str]:
return "123" # type: ignore[return-value]


def test_raises_yield_no_iterator():
match = (
rf"{TEST_FILE}:\d+ \(YieldNoIterator\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"If the endpoint is streaming \(has `yield` statements\), the return type must be an iterator"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class YieldNoIterator(chains.ChainletBase):
async def run_remote(self) -> str: # type: ignore[misc]
yield "123"


def test_raises_iterator_sync():
match = (
rf"{TEST_FILE}:\d+ \(IteratorSync\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Streaming endpoints \(containing `yield` statements\) are only supported for async endpoints"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorSync(chains.ChainletBase):
def run_remote(self) -> Iterator[str]:
yield "123"


def test_raises_iterator_no_arg():
match = (
rf"{TEST_FILE}:\d+ \(IteratorNoArg\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Iterators must be annotated with type \(one of \['str', 'bytes'\]\)"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class IteratorNoArg(chains.ChainletBase):
async def run_remote(self) -> AsyncIterator:
yield "123"
Loading

0 comments on commit 7472bb6

Please sign in to comment.