diff --git a/packages/jupyter-ai/jupyter_ai/actors.py b/packages/jupyter-ai/jupyter_ai/actors.py index 2b900ffb0..ca873923f 100644 --- a/packages/jupyter-ai/jupyter_ai/actors.py +++ b/packages/jupyter-ai/jupyter_ai/actors.py @@ -4,7 +4,7 @@ import time from typing import Union from uuid import uuid4 -from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage, ClearMessage from jupyter_ai_magics.providers import ChatOpenAINewProvider from langchain import ConversationChain, OpenAI from langchain.chains import ConversationalRetrievalChain @@ -43,21 +43,25 @@ class Router(): add a corresponding command in the `COMMANDS` dictionary. """ - def __init__(self, log: Logger): + def __init__(self, log: Logger, reply_queue: Queue): self.log = log + self.reply_queue = reply_queue def route_message(self, message): # assign default actor - actor = ray.get_actor(ACTOR_TYPE.DEFAULT) + default = ray.get_actor(ACTOR_TYPE.DEFAULT) - if(message.body.startswith("/")): + if message.body.startswith("/"): command = message.body.split(' ', 1)[0] if command in COMMANDS.keys(): actor = ray.get_actor(COMMANDS[command].value) - - - actor.process_message.remote(message) + actor.process_message.remote(message) + if command == '/clear': + reply_message = ClearMessage() + self.reply_queue.put(reply_message) + else: + default.process_message.remote(message) class BaseActor(): """Base actor implemented by actors that are called by the `Router`""" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 127f83825..4b11d1279 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -101,7 +101,10 @@ def initialize_settings(self): reply_queue = Queue() self.settings["reply_queue"] = reply_queue - router = Router.options(name="router").remote(log=self.log) + router = Router.options(name="router").remote( + reply_queue=reply_queue, + log=self.log + ) default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( reply_queue=reply_queue, log=self.log diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 5bbc22860..54207e4df 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -16,7 +16,7 @@ from jupyter_server.utils import ensure_async from .task_manager import TaskManager -from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient +from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient from langchain.schema import HumanMessage class APIHandler(BaseAPIHandler): @@ -205,7 +205,7 @@ def open(self): self.log.info(f"Client connected. ID: {client_id}") self.log.debug("Clients are : %s", self.chat_handlers.keys()) - def broadcast_message(self, message: ChatMessage): + def broadcast_message(self, message: Message): """Broadcasts message to all connected clients. Appends message to `self.chat_history`. """ @@ -218,7 +218,9 @@ def broadcast_message(self, message: ChatMessage): if client: client.write_message(message.dict()) - self.chat_history.append(message) + # Only append ChatMessage instances to history, not control messages + if isinstance(message, HumanChatMessage) or isinstance(message, AgentChatMessage): + self.chat_history.append(message) async def on_message(self, message): self.log.debug("Message recieved: %s", message) @@ -230,11 +232,6 @@ async def on_message(self, message): self.log.error(e) return - # message sent to the agent instance - message = HumanMessage( - content=chat_request.prompt, - additional_kwargs=dict(user=asdict(self.current_user)) - ) # message broadcast to chat clients chat_message_id = str(uuid.uuid4()) chat_message = HumanChatMessage( @@ -247,6 +244,12 @@ async def on_message(self, message): # broadcast the message to other clients self.broadcast_message(message=chat_message) + # Clear the message history if given the /clear command + if chat_request.prompt.startswith('/'): + command = chat_request.prompt.split(' ', 1)[0] + if command == '/clear': + self.chat_history.clear() + # process through the router router = ray.get_actor("router") router.route_message.remote(chat_message) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 2800a545b..0f6bf2b32 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -42,6 +42,9 @@ class ConnectionMessage(BaseModel): type: Literal["connection"] = "connection" client_id: str +class ClearMessage(BaseModel): + type: Literal["clear"] = "clear" + # the type of messages being broadcast to clients ChatMessage = Union[ AgentChatMessage, @@ -51,7 +54,8 @@ class ConnectionMessage(BaseModel): Message = Union[ AgentChatMessage, HumanChatMessage, - ConnectionMessage + ConnectionMessage, + ClearMessage ] class ListEnginesEntry(BaseModel): diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 6b4f96c8b..f365acef4 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -45,6 +45,9 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { function handleChatEvents(message: AiService.Message) { if (message.type === 'connection') { return; + } else if (message.type === 'clear') { + setMessages([]); + return; } setMessages(messageGroups => [...messageGroups, message]); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 793a17b27..6ccad4545 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -98,8 +98,12 @@ export namespace AiService { client_id: string; }; + export type ClearMessage = { + type: 'clear' + } + export type ChatMessage = AgentChatMessage | HumanChatMessage; - export type Message = AgentChatMessage | HumanChatMessage | ConnectionMessage; + export type Message = AgentChatMessage | HumanChatMessage | ConnectionMessage | ClearMessage; export type ChatHistory = { messages: ChatMessage[];