Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v3-dev] Dedicate one LangChain history object per chat #1151

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def get_llm_chat_memory(
last_human_msg: HumanChatMessage,
**kwargs,
) -> "BaseChatMessageHistory":
if self.ychat:
return self.llm_chat_memory

return WrappedBoundedChatHistory(
history=self.llm_chat_memory,
last_human_msg=last_human_msg,
Expand Down
16 changes: 13 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
RootChatHandler,
SlashCommandsInfoHandler,
)
from .history import BoundedChatHistory
from .history import BoundedChatHistory, YChatHistory

from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
__version__ as jupyter_collaboration_version,
Expand Down Expand Up @@ -418,9 +418,13 @@ def initialize_settings(self):
# list of chat messages to broadcast to new clients
# this is only used to render the UI, and is not the conversational
# memory object used by the LM chain.
#
# TODO: remove this in v3. this list is only used by the REST API to get
# history in v2 chat.
self.settings["chat_history"] = []

# conversational memory object used by LM chain
# TODO: remove this in v3. this is the history implementation that
# provides memory to the chat model in v2.
self.settings["llm_chat_memory"] = BoundedChatHistory(
k=self.default_max_chat_history
)
Expand Down Expand Up @@ -515,13 +519,19 @@ def _init_chat_handlers(
eps = entry_points()
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
chat_handlers: Dict[str, BaseChatHandler] = {}

if ychat:
llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history)
else:
llm_chat_memory = self.settings["llm_chat_memory"]

chat_handler_kwargs = {
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"model_parameters": self.settings["model_parameters"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"chat_history": self.settings["chat_history"],
"llm_chat_memory": self.settings["llm_chat_memory"],
"llm_chat_memory": llm_chat_memory,
"root_dir": self.serverapp.root_dir,
"dask_client_future": self.settings["dask_client_future"],
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
Expand Down
48 changes: 47 additions & 1 deletion packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,61 @@
import time
from typing import List, Optional, Sequence, Set, Union

from jupyterlab_chat.ychat import YChat
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

from .constants import BOT
from .models import HumanChatMessage

HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id"


class YChatHistory(BaseChatMessageHistory):
"""
An implementation of `BaseChatMessageHistory` that returns the preceding `k`
exchanges (`k * 2` messages) from the given YChat model.

If `k` is set to `None`, then this class returns all preceding messages.
"""

def __init__(self, ychat: YChat, k: Optional[int]):
self.ychat = ychat
self.k = k

@property
def messages(self) -> List[BaseMessage]: # type:ignore[override]
"""
Returns the last `2 * k` messages preceding the latest message. If
`k` is set to `None`, return all preceding messages.
"""
# TODO: consider bounding history based on message size (e.g. total
# char/token count) instead of message count.
all_messages = self.ychat.get_messages()

# gather last k * 2 messages and return
# we exclude the last message since that is the HumanChatMessage just
# submitted by a user.
messages: List[BaseMessage] = []
start_idx = 0 if self.k is None else -2 * self.k - 1
for message in all_messages[start_idx:-1]:
if message["sender"] == BOT["username"]:
messages.append(AIMessage(content=message["body"]))
else:
messages.append(HumanMessage(content=message["body"]))

return messages

def add_message(self, message: BaseMessage) -> None:
# do nothing when other LangChain objects call this method, since
# message history is maintained by the `YChat` shared document.
return

def clear(self):
raise NotImplementedError()


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
An in-memory implementation of `BaseChatMessageHistory` that stores up to
Expand Down
Loading