Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.x] Upgrade to LangChain v0.3 and Pydantic v2 #1199

Merged
merged 12 commits into from
Jan 15, 2025
Prev Previous commit
Next Next commit
fix history impl for Pydantic v2, fixes chat
  • Loading branch information
dlqqq committed Jan 14, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit b246404ea2de6afca400af3f44eed3e944d4ccaa
33 changes: 23 additions & 10 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import time
from typing import List, Optional, Sequence, Set, Union
from typing import List, Optional, Sequence, Set

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from pydantic import BaseModel, PrivateAttr

from .models import HumanChatMessage

HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id"


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
class BoundedChatHistory(BaseChatMessageHistory):
"""
An in-memory implementation of `BaseChatMessageHistory` that stores up to
`k` exchanges between a user and an LLM.
@@ -19,10 +18,16 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
messages and 2 AI messages. If `k` is set to `None` all messages are kept.
"""

k: Union[int, None]
clear_time: float = 0.0
cleared_msgs: Set[str] = set()
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)
def __init__(
self,
k: Optional[int] = None,
clear_time: float = 0.0,
cleared_msgs: Set[str] = set(),
):
self.k = k
self.clear_time = clear_time
self.cleared_msgs = cleared_msgs
self._all_messages = []

@property
def messages(self) -> List[BaseMessage]: # type:ignore[override]
@@ -67,7 +72,7 @@ async def aclear(self) -> None:
self.clear()


class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel):
class WrappedBoundedChatHistory(BaseChatMessageHistory):
"""
Wrapper around `BoundedChatHistory` that only appends an `AgentChatMessage`
if the `HumanChatMessage` it is replying to was not cleared. If a chat
@@ -88,8 +93,16 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel):
Reference: https://python.langchain.com/v0.1/docs/expression_language/how_to/message_history/
"""

history: BoundedChatHistory
last_human_msg: HumanChatMessage
def __init__(
self,
history: BoundedChatHistory,
last_human_msg: HumanChatMessage,
*args,
**kwargs,
):
self.history = history
self.last_human_msg = last_human_msg
super().__init__(*args, **kwargs)

@property
def messages(self) -> List[BaseMessage]: # type:ignore[override]