Skip to content

Commit

Permalink
fix(streaming): accumulate citations (#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and stainless-app[bot] committed Jan 27, 2025
1 parent 14bf8fe commit 872c614
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 64 deletions.
95 changes: 64 additions & 31 deletions src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx
from pydantic import BaseModel

from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type
from ._beta_types import (
BetaTextEvent,
BetaCitationEvent,
BetaInputJsonEvent,
BetaMessageStopEvent,
BetaMessageStreamEvent,
Expand Down Expand Up @@ -314,24 +316,40 @@ def build_events(
events_to_fire.append(event)

content_block = message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
events_to_fire.append(
build(
BetaTextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
if event.delta.type == "text_delta":
if content_block.type == "text":
events_to_fire.append(
build(
BetaTextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
events_to_fire.append(
build(
BetaInputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
elif event.delta.type == "input_json_delta":
if content_block.type == "tool_use":
events_to_fire.append(
build(
BetaInputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
)
elif event.delta.type == "citations_delta":
if content_block.type == "text":
events_to_fire.append(
build(
BetaCitationEvent,
type="citation",
citation=event.delta.citation,
snapshot=content_block.citations or [],
)
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = message_snapshot.content[event.index]

Expand All @@ -354,6 +372,9 @@ def accumulate_event(
event: BetaRawMessageStreamEvent,
current_snapshot: BetaMessage | None,
) -> BetaMessage:
if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(f"Unexpected event runtime type - {event}")

if current_snapshot is None:
if event.type == "message_start":
return BetaMessage.construct(**cast(Any, event.message.to_dict()))
Expand All @@ -370,21 +391,33 @@ def accumulate_event(
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
elif event.delta.type == "input_json_delta":
if content.type == "tool_use":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.delta.type == "citations_delta":
if content.type == "text":
if not content.citations:
content.citations = [event.delta.citation]
else:
content.citations.append(event.delta.citation)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/lib/streaming/_beta_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import Literal, Annotated
from typing_extensions import List, Literal, Annotated

from ..._models import BaseModel
from ...types.beta import (
Expand All @@ -13,6 +13,7 @@
BetaRawContentBlockStartEvent,
)
from ..._utils._transform import PropertyInfo
from ...types.beta.beta_citations_delta import Citation


class BetaTextEvent(BaseModel):
Expand All @@ -25,6 +26,16 @@ class BetaTextEvent(BaseModel):
"""The entire accumulated text"""


class BetaCitationEvent(BaseModel):
type: Literal["citation"]

citation: Citation
"""The new citation"""

snapshot: List[Citation]
"""All of the accumulated citations"""


class BetaInputJsonEvent(BaseModel):
type: Literal["input_json"]

Expand Down Expand Up @@ -57,6 +68,7 @@ class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent):
BetaMessageStreamEvent = Annotated[
Union[
BetaTextEvent,
BetaCitationEvent,
BetaInputJsonEvent,
BetaRawMessageStartEvent,
BetaRawMessageDeltaEvent,
Expand Down
91 changes: 60 additions & 31 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ._types import (
TextEvent,
CitationEvent,
InputJsonEvent,
MessageStopEvent,
MessageStreamEvent,
Expand Down Expand Up @@ -315,24 +316,40 @@ def build_events(
events_to_fire.append(event)

content_block = message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
events_to_fire.append(
build(
TextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
if event.delta.type == "text_delta":
if content_block.type == "text":
events_to_fire.append(
build(
TextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
events_to_fire.append(
build(
InputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
elif event.delta.type == "input_json_delta":
if content_block.type == "tool_use":
events_to_fire.append(
build(
InputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
)
elif event.delta.type == "citations_delta":
if content_block.type == "text":
events_to_fire.append(
build(
CitationEvent,
type="citation",
citation=event.delta.citation,
snapshot=content_block.citations or [],
)
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = message_snapshot.content[event.index]

Expand Down Expand Up @@ -374,21 +391,33 @@ def accumulate_event(
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
elif event.delta.type == "input_json_delta":
if content.type == "tool_use":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.delta.type == "citations_delta":
if content.type == "text":
if not content.citations:
content.citations = [event.delta.citation]
else:
content.citations.append(event.delta.citation)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/lib/streaming/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import Literal, Annotated
from typing_extensions import List, Literal, Annotated

from ...types import (
Message,
Expand All @@ -13,6 +13,7 @@
)
from ..._models import BaseModel
from ..._utils._transform import PropertyInfo
from ...types.citations_delta import Citation


class TextEvent(BaseModel):
Expand All @@ -25,6 +26,16 @@ class TextEvent(BaseModel):
"""The entire accumulated text"""


class CitationEvent(BaseModel):
type: Literal["citation"]

citation: Citation
"""The new citation"""

snapshot: List[Citation]
"""All of the accumulated citations"""


class InputJsonEvent(BaseModel):
type: Literal["input_json"]

Expand Down Expand Up @@ -57,6 +68,7 @@ class ContentBlockStopEvent(RawContentBlockStopEvent):
MessageStreamEvent = Annotated[
Union[
TextEvent,
CitationEvent,
InputJsonEvent,
RawMessageStartEvent,
RawMessageDeltaEvent,
Expand Down

0 comments on commit 872c614

Please sign in to comment.