Skip to content

Commit

Permalink
migrate streaming logic to BaseChatHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Oct 18, 2024
1 parent b3a5f8c commit 7107bfe
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 117 deletions.
135 changes: 130 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import time
import traceback
from asyncio import Event
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
ClassVar,
Dict,
Expand All @@ -24,16 +25,23 @@
from jupyter_ai.history import WrappedBoundedChatHistory
from jupyter_ai.models import (
AgentChatMessage,
AgentStreamMessage,
AgentStreamChunkMessage,
ChatMessage,
ClosePendingMessage,
HumanChatMessage,
Message,
PendingMessage,
)
from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain

from langchain.pydantic_v1 import BaseModel
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig, merge_configs as merge_runnable_configs
from langchain_core.runnables.utils import Input

if TYPE_CHECKING:
from jupyter_ai.context_providers import BaseCommandContextProvider
Expand Down Expand Up @@ -129,7 +137,7 @@ class BaseChatHandler:
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

message_interrupted: Dict[str, Event]
message_interrupted: Dict[str, asyncio.Event]
"""Dictionary mapping an agent message identifier to an asyncio Event
which indicates if the message generation/streaming was interrupted."""

Expand All @@ -147,7 +155,7 @@ def __init__(
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
message_interrupted: Dict[str, Event],
message_interrupted: Dict[str, asyncio.Event],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -173,7 +181,7 @@ def __init__(

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None
self.llm_chain: Optional[Runnable] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand Down Expand Up @@ -471,3 +479,120 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
)

self.broadcast_message(help_message)

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

self.broadcast_message(stream_msg)
return stream_id

def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Dict[str, Any] = {},
) -> None:
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)
self.broadcast_message(stream_chunk_msg)

async def stream_reply(self, input: Input, human_msg: HumanChatMessage, config: Optional[RunnableConfig] = None):
"""
Streams a reply to a human message by invoking
`self.llm_chain.astream()`. A LangChain `Runnable` instance must be
bound to `self.llm_chain` before invoking this method.
Arguments
---------
- `input`: The input to your runnable. The type of `input` depends on
the runnable in `self.llm_chain`, but is usually a dictionary whose keys
refer to input variables in your prompt template.
- `human_msg`: The `HumanChatMessage` being replied to.
- `config` (optional): A `RunnableConfig` object that specifies
additional configuration when streaming from the runnable.
"""
assert self.llm_chain

received_first_chunk = False
metadata_handler = MetadataCallbackHandler()
base_config: RunnableConfig = {
"configurable": {"last_human_msg": human_msg},
"callbacks": [metadata_handler]
}
merged_config: RunnableConfig = merge_runnable_configs(base_config, config)

# start with a pending message
with self.pending("Generating response", human_msg) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
chunk_generator = self.llm_chain.astream(
input,
config=merged_config
)
stream_interrupted = False
async for chunk in chunk_generator:
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=human_msg)
received_first_chunk = True
self.message_interrupted[stream_id] = asyncio.Event()

if self.message_interrupted[stream_id].is_set():
try:
# notify the model provider that streaming was interrupted
# (this is essential to allow the model to stop generating)
await chunk_generator.athrow(GenerationInterrupted())
except GenerationInterrupted:
# do not let the exception bubble up in case if
# the provider did not handle it
pass
stream_interrupted = True
break

if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
stream_tombstone = (
"\n\n(AI response stopped by user)" if stream_interrupted else ""
)
self._send_stream_chunk(
stream_id,
stream_tombstone,
complete=True,
metadata=metadata_handler.jai_metadata,
)
del self.message_interrupted[stream_id]


class GenerationInterrupted(asyncio.CancelledError):
"""Exception raised when streaming is cancelled by the user"""

115 changes: 3 additions & 112 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
import asyncio
import time
from typing import Any, Dict, Type
from uuid import uuid4
from typing import Dict, Type

from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
HumanChatMessage,
)
from jupyter_ai_magics.providers import BaseProvider
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
from .base import BaseChatHandler, SlashCommandRoutingType


class GenerationInterrupted(asyncio.CancelledError):
"""Exception raised when streaming is cancelled by the user"""


class DefaultChatHandler(BaseChatHandler):
id = "default"
name = "Default"
Expand Down Expand Up @@ -65,55 +55,8 @@ def create_llm_chain(
)
self.llm_chain = runnable

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_msg)
break

return stream_id

def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Dict[str, Any] = {},
):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_chunk_msg)
break

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False
assert self.llm_chain

inputs = {"input": message.body}
Expand All @@ -126,61 +69,9 @@ async def process_message(self, message: HumanChatMessage):
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

await self.stream_reply(inputs, message)

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
metadata_handler = MetadataCallbackHandler()
chunk_generator = self.llm_chain.astream(
inputs,
config={
"configurable": {"last_human_msg": message},
"callbacks": [metadata_handler],
},
)
stream_interrupted = False
async for chunk in chunk_generator:
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True
self.message_interrupted[stream_id] = asyncio.Event()

if self.message_interrupted[stream_id].is_set():
try:
# notify the model provider that streaming was interrupted
# (this is essential to allow the model to stop generating)
await chunk_generator.athrow(GenerationInterrupted())
except GenerationInterrupted:
# do not let the exception bubble up in case if
# the provider did not handle it
pass
stream_interrupted = True
break

if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
stream_tombstone = (
"\n\n(AI response stopped by user)" if stream_interrupted else ""
)
self._send_stream_chunk(
stream_id,
stream_tombstone,
complete=True,
metadata=metadata_handler.jai_metadata,
)
del self.message_interrupted[stream_id]

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
Expand Down

0 comments on commit 7107bfe

Please sign in to comment.