diff --git a/.gitignore b/.gitignore index 6fc57c1d2..f302760ca 100644 --- a/.gitignore +++ b/.gitignore @@ -126,3 +126,5 @@ playground/ # reserve path for a dev script dev.sh + +.vscode \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 38d7c8153..da9c0a773 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,15 +1,30 @@ +import queue from jupyter_server.extension.application import ExtensionApp -from .handlers import PromptAPIHandler, TaskAPIHandler +from langchain import ConversationChain +from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler, ChatAPIHandler from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine +from .providers import ChatOpenAIProvider +import os + +from langchain.memory import ConversationBufferMemory +from langchain.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate +) class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ ("api/ai/prompt", PromptAPIHandler), + (r"api/ai/chat/?", ChatAPIHandler), (r"api/ai/tasks/?", TaskAPIHandler), - (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler) + (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler), + (r"api/ai/chats/?", ChatHandler), + (r"api/ai/chats/history?", ChatHistoryHandler), ] @property @@ -18,6 +33,7 @@ def ai_engines(self): self.settings["ai_engines"] = {} return self.settings["ai_engines"] + def initialize_settings(self): # EP := entry point @@ -69,5 +85,30 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") + ## load OpenAI chat provider + if ChatOpenAIProvider.auth_strategy.name in os.environ: + self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo") + # Create a conversation memory + memory = ConversationBufferMemory(return_messages=True) + prompt_template = ChatPromptTemplate.from_messages([ + SystemMessagePromptTemplate.from_template("The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know."), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}") + ]) + chain = ConversationChain( + llm=self.settings["openai_chat"], + prompt=prompt_template, + verbose=True, + memory=memory + ) + self.settings["chat_provider"] = chain + self.log.info(f"Registered {self.name} server extension") + + # Add a message queue to the settings to be used by the chat handler + self.settings["chat_message_queue"] = queue.Queue() + + # Store chat clients in a dictionary + self.settings["chat_clients"] = {} + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 877bead6b..2a728261d 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -1,15 +1,21 @@ +from dataclasses import asdict import json +from typing import Optional import tornado from tornado.web import HTTPError from pydantic import ValidationError -from jupyter_server.base.handlers import APIHandler +from tornado import web, websocket + +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler from jupyter_server.utils import ensure_async + from .task_manager import TaskManager -from .models import PromptRequest +from .models import ChatHistory, PromptRequest, ChatRequest +from langchain.schema import _message_to_dict, HumanMessage, AIMessage -class PromptAPIHandler(APIHandler): +class APIHandler(BaseAPIHandler): @property def engines(self): return self.settings["ai_engines"] @@ -26,6 +32,11 @@ def task_manager(self): self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks) return self.settings["task_manager"] + @property + def openai_chat(self): + return self.settings["openai_chat"] + +class PromptAPIHandler(APIHandler): @tornado.web.authenticated async def post(self): try: @@ -49,23 +60,27 @@ async def post(self): "insertion_mode": task.insertion_mode })) -class TaskAPIHandler(APIHandler): - @property - def engines(self): - return self.settings["ai_engines"] - - @property - def default_tasks(self): - return self.settings["ai_default_tasks"] +class ChatAPIHandler(APIHandler): + @tornado.web.authenticated + async def post(self): + try: + request = ChatRequest(**self.get_json_body()) + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + + if not self.openai_chat: + raise HTTPError(500, "No chat models available.") + + result = await ensure_async(self.openai_chat.agenerate([request.prompt])) + output = result.generations[0][0].text + self.openai_chat.append_exchange(request.prompt, output) - @property - def task_manager(self): - # we have to create the TaskManager lazily, since no event loop is - # running in ServerApp.initialize_settings(). - if "task_manager" not in self.settings: - self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks) - return self.settings["task_manager"] - + self.finish(json.dumps({ + "output": output, + })) + +class TaskAPIHandler(APIHandler): @tornado.web.authenticated async def get(self, id=None): if id is None: @@ -78,3 +93,148 @@ async def get(self, id=None): raise HTTPError(404, f"Task not found with ID: {id}") self.finish(json.dumps(describe_task_response.dict())) + + +class ChatHistoryHandler(BaseAPIHandler): + """Handler to return message history""" + + _chat_provider = None + _messages = [] + + @property + def chat_provider(self): + if self._chat_provider is None: + self._chat_provider = self.settings["chat_provider"] + return self._chat_provider + + @property + def messages(self): + self._messages = self.chat_provider.memory.chat_memory.messages or [] + return self._messages + + @tornado.web.authenticated + async def get(self): + messages = [] + for message in self.messages: + messages.append(message) + history = ChatHistory(messages=messages) + + self.finish(history.json(models_as_dict=False)) + + @tornado.web.authenticated + async def delete(self): + self.chat_provider.memory.chat_memory.clear() + self.messages = [] + self.set_status(204) + self.finish() + + +class ChatHandler( + JupyterHandler, + websocket.WebSocketHandler +): + """ + A websocket handler for chat. + """ + + _chat_provider = None + _chat_message_queue = None + _messages = [] + + @property + def chat_provider(self): + if self._chat_provider is None: + self._chat_provider = self.settings["chat_provider"] + return self._chat_provider + + @property + def chat_message_queue(self): + if self._chat_message_queue is None: + self._chat_message_queue = self.settings["chat_message_queue"] + return self._chat_message_queue + + @property + def messages(self): + self._messages = self.chat_provider.memory.chat_memory.messages or [] + return self._messages + + def add_chat_client(self, username): + self.settings["chat_clients"][username] = self + self.log.debug("Clients are : %s", self.settings["chat_clients"].keys()) + + def remove_chat_client(self, username): + self.settings["chat_clients"][username] = None + self.log.debug("Chat clients: %s", self.settings['chat_clients'].keys()) + + def initialize(self): + self.log.debug("Initializing websocket connection %s", self.request.path) + + def pre_get(self): + """Handles authentication/authorization. + """ + # authenticate the request before opening the websocket + user = self.current_user + if user is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + # authorize the user. + if not self.authorizer.is_authorized(self, user, "execute", "events"): + raise web.HTTPError(403) + + async def get(self, *args, **kwargs): + """Get an event socket.""" + self.pre_get() + res = super().get(*args, **kwargs) + await res + + def open(self): + self.log.debug("Client with user %s connected...", self.current_user.username) + self.add_chat_client(self.current_user.username) + + def broadcast_message(self, message: any, exclude_current_user: Optional[bool] = False): + """Broadcasts message to all connected clients, + optionally excluding the current user + """ + + self.log.debug("Broadcasting message: %s to all clients...", message) + client_names = self.settings["chat_clients"].keys() + if exclude_current_user: + client_names = client_names - [self.current_user.username] + + for username in client_names: + client = self.settings["chat_clients"][username] + if client: + client.write_message(message) + + def on_message(self, message): + self.log.debug("Message recieved: %s", message) + + try: + message = json.loads(message) + chat_request = ChatRequest(**message) + except ValidationError as e: + self.log.error(e) + return + + message = HumanMessage( + content=chat_request.prompt, + additional_kwargs=dict(user=asdict(self.current_user)) + ) + data = json.dumps(_message_to_dict(message)) + # broadcast the message to other clients + self.broadcast_message(message=data, exclude_current_user=True) + + # process the message + response = self.chat_provider.predict(input=message.content) + + response = AIMessage( + content=response + ) + # broadcast to all clients + self.broadcast_message(message=json.dumps(_message_to_dict(response))) + + + def on_close(self): + self.log.debug("Disconnecting client with user %s", self.current_user.username) + self.remove_chat_client(self.current_user.username) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ca283d6b7..e4ca60cfe 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,11 +1,16 @@ -from pydantic import BaseModel -from typing import Dict, List +from pydantic import BaseModel, validator +from typing import Dict, List, Literal + +from langchain.schema import BaseMessage, _message_to_dict class PromptRequest(BaseModel): task_id: str engine_id: str prompt_variables: Dict[str, str] +class ChatRequest(BaseModel): + prompt: str + class ListEnginesEntry(BaseModel): id: str name: str @@ -22,3 +27,12 @@ class DescribeTaskResponse(BaseModel): insertion_mode: str prompt_template: str engines: List[ListEnginesEntry] + +class ChatHistory(BaseModel): + """History of chat messages""" + messages: List[BaseMessage] + + class Config: + json_encoders = { + BaseMessage: lambda v: _message_to_dict(v) + } \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/providers.py b/packages/jupyter-ai/jupyter_ai/providers.py index 8ea30d288..2fbb5410e 100644 --- a/packages/jupyter-ai/jupyter_ai/providers.py +++ b/packages/jupyter-ai/jupyter_ai/providers.py @@ -7,10 +7,12 @@ Cohere, HuggingFaceHub, OpenAI, - OpenAIChat, SagemakerEndpoint ) + from pydantic import BaseModel, Extra +from langchain.chat_models import ChatOpenAI + class EnvAuthStrategy(BaseModel): """Require one auth token via an environment variable.""" @@ -153,7 +155,7 @@ class OpenAIProvider(BaseProvider, OpenAI): pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") -class ChatOpenAIProvider(BaseProvider, OpenAIChat): +class ChatOpenAIProvider(BaseProvider, ChatOpenAI): id = "openai-chat" name = "OpenAI" models = [ @@ -168,6 +170,21 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat): pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def append_exchange(self, prompt: str, output: str): + """Appends a conversational exchange between user and an OpenAI Chat + model to a transcript that will be included in future exchanges.""" + self.prefix_messages.append({ + "role": "user", + "content": prompt + }) + self.prefix_messages.append({ + "role": "assistant", + "content": output + }) + class SmEndpointProvider(BaseProvider, SagemakerEndpoint): id = "sagemaker-endpoint" name = "Sagemaker Endpoint" diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts new file mode 100644 index 000000000..4aa4bc311 --- /dev/null +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -0,0 +1,111 @@ +import { IDisposable } from '@lumino/disposable'; +import { ServerConnection } from '@jupyterlab/services'; +import { URLExt } from '@jupyterlab/coreutils'; +import {Poll} from '@lumino/polling'; +import { AiService, requestAPI } from './handler'; + + +const CHAT_SERVICE_URL = "api/ai/chats" + +export class ChatHandler implements IDisposable{ + /** + * Create a new chat handler. + */ + constructor(options: AiService.IOptions = {}) { + this.serverSettings = + options.serverSettings ?? ServerConnection.makeSettings(); + + this._poll = new Poll({ factory: () => this._subscribe() }); + this._poll.start(); + } + + /** + * The server settings used to make API requests. + */ + readonly serverSettings: ServerConnection.ISettings; + + /** + * Whether the chat handler is disposed. + */ + get isDisposed(): boolean { + return this._isDisposed; + } + + /** + * Dispose the chat handler. + */ + dispose(): void { + if (this.isDisposed) { + return; + } + this._isDisposed = true; + + // Clean up poll. + this._poll.dispose(); + + this._listeners = [] + + // Clean up socket. + const socket = this._socket; + if (socket) { + this._socket = null; + socket.onopen = () => undefined; + socket.onerror = () => undefined; + socket.onmessage = () => undefined; + socket.onclose = () => undefined; + socket.close(); + } + } + + public addListener(handler: (message: AiService.ChatMessage) => void): void { + this._listeners.push(handler); + } + + public removeListener(handler: (message: AiService.ChatMessage) => void): void { + const index = this._listeners.indexOf(handler) + if(index > -1) { + this._listeners.splice(index, 1) + } + } + + public sendMessage(message: AiService.ChatRequest): void { + this._socket?.send(JSON.stringify(message)) + } + + public async getHistory(): Promise { + let data: AiService.ChatHistory = {messages: []} + try { + data = await requestAPI('chats/history', { + method: 'GET' + }); + } catch (e) { + return Promise.reject(e); + } + return data; + } + + private _onMessage(message: AiService.ChatMessage): void { + this._listeners.forEach(listener => listener(message)); + } + + private _subscribe(): Promise { + return new Promise((_, reject) => { + if (this.isDisposed) { + return; + } + const { token, WebSocket, wsUrl } = this.serverSettings; + const url = + URLExt.join(wsUrl, CHAT_SERVICE_URL) + + (token ? `?token=${encodeURIComponent(token)}` : ''); + const socket = (this._socket = new WebSocket(url)); + + socket.onclose = () => reject(new Error('ChatHandler socket closed')); + socket.onmessage = msg => msg.data && this._onMessage(JSON.parse(msg.data)); + }); + } + + private _isDisposed = false; + private _poll: Poll; + private _socket: WebSocket | null = null; + private _listeners: ((msg: any) => void)[] = []; +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 38ed70878..f016c9a22 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -60,6 +60,28 @@ export namespace AiService { }; } + export type ChatRequest = { + prompt: string; + }; + + export type ChatResponse = { + output: string; + }; + + export type ChatMessageData = { + content: string + additional_kwargs: {[key: string]: any} + } + + export type ChatMessage = { + type: string + data: ChatMessageData + } + + export type ChatHistory = { + messages: ChatMessage[] + } + export interface IPromptResponse { output: string; insertion_mode: 'above' | 'below' | 'replace'; @@ -81,6 +103,20 @@ export namespace AiService { return data as IPromptResponse; } + export async function sendChat(request: ChatRequest): Promise { + let data; + + try { + data = await requestAPI('chat', { + method: 'POST', + body: JSON.stringify(request) + }); + } catch (e) { + return Promise.reject(e); + } + return data as IPromptResponse; + } + export type ListTasksEntry = { id: string; name: string;