From dba7e4be5799caea1a675a8bdd3aee852734c7cc Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Mon, 28 Oct 2024 16:32:14 -0700 Subject: [PATCH] Implement streaming for `/fix` (#1048) * doc for streaming and added streaming for /fix * removed additional docs * added custom pending message * Update fix.py * Update fix.py * fall back * Update fix.py * Update base.py * Update fix.py * Update fix.py * Update fix.py * add back anthropic models removed in rebase --------- Co-authored-by: David L. Qiu --- .../jupyter_ai/chat_handlers/base.py | 6 +++- .../jupyter_ai/chat_handlers/fix.py | 35 ++++++++----------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 107b5a000..c844650ad 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -521,6 +521,7 @@ async def stream_reply( self, input: Input, human_msg: HumanChatMessage, + pending_msg="Generating response", config: Optional[RunnableConfig] = None, ): """ @@ -538,6 +539,9 @@ async def stream_reply( - `config` (optional): A `RunnableConfig` object that specifies additional configuration when streaming from the runnable. + + - `pending_msg` (optional): Changes the default pending message from + "Generating response". """ assert self.llm_chain assert isinstance(self.llm_chain, Runnable) @@ -551,7 +555,7 @@ async def stream_reply( merged_config: RunnableConfig = merge_runnable_configs(base_config, config) # start with a pending message - with self.pending("Generating response", human_msg) as pending_message: + with self.pending(pending_msg, human_msg) as pending_message: # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 4daf70e03..390b93cf6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -2,7 +2,6 @@ from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from .base import BaseChatHandler, SlashCommandRoutingType @@ -64,6 +63,7 @@ class FixChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prompt_template = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -73,13 +73,11 @@ def create_llm_chain( **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters) - self.llm = llm - # TODO: migrate this class to use a LCEL `Runnable` instead of - # `Chain`, then remove the below ignore comment. - self.llm_chain = LLMChain( # type:ignore[arg-type] - llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True - ) + prompt_template = FIX_PROMPT_TEMPLATE + + runnable = prompt_template | llm # type:ignore + self.llm_chain = runnable async def process_message(self, message: HumanChatMessage): if not (message.selection and message.selection.type == "cell-with-error"): @@ -96,16 +94,13 @@ async def process_message(self, message: HumanChatMessage): extra_instructions = message.prompt[4:].strip() or "None." self.get_llm_chain() - with self.pending("Analyzing error", message): - assert self.llm_chain - # TODO: migrate this class to use a LCEL `Runnable` instead of - # `Chain`, then remove the below ignore comment. - response = await self.llm_chain.apredict( # type:ignore[attr-defined] - extra_instructions=extra_instructions, - stop=["\nHuman:"], - cell_content=selection.source, - error_name=selection.error.name, - error_value=selection.error.value, - traceback="\n".join(selection.error.traceback), - ) - self.reply(response, message) + assert self.llm_chain + + inputs = { + "extra_instructions": extra_instructions, + "cell_content": selection.source, + "traceback": "\n".join(selection.error.traceback), + "error_name": selection.error.name, + "error_value": selection.error.value, + } + await self.stream_reply(inputs, message, pending_msg="Analyzing error")