diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index 46b6c5719..465b5d1ee 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -492,7 +492,7 @@ def create_llm_chain( prompt_template = FIX_PROMPT_TEMPLATE self.prompt_template = prompt_template - runnable = prompt_template | llm # type:ignore + runnable = prompt_template | llm | StrOutputParser() # type:ignore self.llm_chain = runnable ``` diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index e25cfd88c..b1206da9c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -3,6 +3,7 @@ from jupyter_ai_magics.providers import BaseProvider from jupyterlab_chat.models import Message +from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables.history import RunnableWithMessageHistory from ..context_providers import ContextProviderException, find_commands @@ -36,7 +37,7 @@ def create_llm_chain( self.llm = llm self.prompt_template = prompt_template - runnable = prompt_template | llm # type:ignore + runnable = prompt_template | llm | StrOutputParser() # type:ignore if not llm.manages_history: runnable = RunnableWithMessageHistory( runnable=runnable, # type:ignore[arg-type]