From 4f96c6388ee94a828e32deaf9ab18213590d33a9 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 3 Apr 2024 07:58:37 -0700 Subject: [PATCH 1/5] improve support for custom providers adds 3 new provider class attributes: 1. `manages_history` 2. `unsupported_slash_commands` 3. `persona` --- .../jupyter_ai_magics/__init__.py | 3 ++ .../jupyter_ai_magics/models/persona.py | 39 +++++++++++++++++++ .../jupyter_ai_magics/providers.py | 26 +++++++++++++ .../jupyter_ai_magics/static/jupyternaut.svg | 9 +++++ .../jupyter_ai/chat_handlers/base.py | 21 +++++++++- .../jupyter_ai/chat_handlers/default.py | 16 +++++--- .../jupyter_ai/chat_handlers/help.py | 21 +++++++--- .../jupyter-ai/jupyter_ai/config_manager.py | 12 ++++++ packages/jupyter-ai/jupyter_ai/extension.py | 37 ++++++++++++++++-- packages/jupyter-ai/jupyter_ai/models.py | 20 +++++++++- .../src/components/chat-messages.tsx | 13 ++++--- packages/jupyter-ai/src/handler.ts | 6 +++ 12 files changed, 200 insertions(+), 23 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index f81ea11c8..616ce6093 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -26,6 +26,9 @@ TogetherAIProvider, ) +# expose JupyternautPersona on the package root +# required by `jupyter-ai`. +from .models.persona import JupyternautPersona, Persona def load_ipython_extension(ipython): ipython.register_magics(AiMagics) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py new file mode 100644 index 000000000..15468f6ff --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -0,0 +1,39 @@ + +from typing import ClassVar +from langchain.pydantic_v1 import BaseModel +import os + +JUPYTERNAUT_AVATAR_PATH = str(os.path.join(os.path.dirname(__file__), '..', 'static', 'jupyternaut.svg')) +JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg" + +class Persona(BaseModel): + """ + Model of an **agent persona**, a struct that includes the name & avatar + shown on agent replies in the chat UI. + + Each persona is specific to a single provider, set on the `persona` field. + """ + + name: ClassVar[str] = ... + """ + Name of the persona, e.g. "Jupyternaut". This is used to render the name + shown on agent replies in the chat UI. + """ + + avatar_route: ClassVar[str] = ... + """ + The server route that should be used the avatar of this persona. This is + used to render the avatar shown on agent replies in the chat UI. + """ + + avatar_path: ClassVar[str] = ... + """ + The path to the avatar SVG file on the server filesystem. The server should + serve the file at this path on the route specified by `avatar_route`. + """ + +class JupyternautPersona(Persona): + name: ClassVar[str] = "Jupyternaut" + avatar_route: ClassVar[str] = JUPYTERNAUT_AVATAR_ROUTE + avatar_path: ClassVar[str] = JUPYTERNAUT_AVATAR_PATH + \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 3fcdf9abc..55c7cc812 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -46,6 +46,8 @@ except: from pydantic.main import ModelMetaclass +from .models.persona import Persona + CHAT_SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. @@ -214,6 +216,30 @@ class Config: """User inputs expected by this provider when initializing it. Each `Field` `f` should be passed in the constructor as a keyword argument, keyed by `f.key`.""" + manages_history: ClassVar[bool] = False + """Whether this provider manages its own conversation history upstream. If + set to `True`, Jupyter AI will not pass the chat history to this provider + when invoked.""" + + persona: ClassVar[Optional[Persona]] = None + """ + The **persona** of this provider, a struct that defines the name and avatar + shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will + choose a default persona when rendering agent messages by this provider. + + Because this field is set to `None` by default, `jupyter-ai` will render a + default persona for all providers that are included natively with the + `jupyter-ai` package. This field is reserved for Jupyter AI modules that + serve a custom provider and want to distinguish it in the chat UI. + """ + + unsupported_slash_commands: ClassVar[set] = {} + """ + A set of slash commands unsupported by this provider. Unsupported slash + commands are not shown in the help message, and cannot be used while this + provider is selected. + """ + # # instance attrs # diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg b/packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg new file mode 100644 index 000000000..dd800d538 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg @@ -0,0 +1,9 @@ + + + + + + diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index bcbba00ba..0cd1222a6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,7 +16,7 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage, PersonaDescription from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -94,10 +94,17 @@ async def on_message(self, message: HumanChatMessage): `self.handle_exc()` when an exception is raised. This method is called by RootChatHandler when it routes a human message to this chat handler. """ + lm_provider_klass = self.config_manager.lm_provider + + # ensure the current slash command is supported + if self.routing_type.routing_method == "slash_command": + slash_command = "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" + if slash_command in lm_provider_klass.unsupported_slash_commands: + self.reply("Sorry, the selected language model does not support this slash command.") + return # check whether the configured LLM can support a request at this time. if self.uses_llm and BaseChatHandler._requests_count > 0: - lm_provider_klass = self.config_manager.lm_provider lm_provider_params = self.config_manager.lm_provider_params lm_provider = lm_provider_klass(**lm_provider_params) @@ -159,11 +166,21 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): self.reply(response, message) def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): + """ + Sends an agent message, usually in response to a received + `HumanChatMessage`. + """ + persona = self.config_manager.persona + agent_msg = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, reply_to=human_msg.id if human_msg else "", + persona=PersonaDescription( + name=persona.name, + avatar_route=persona.avatar_route + ) ) for handler in self._root_chat_handlers.values(): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 8352a8f8d..59652af72 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -2,7 +2,7 @@ from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import ConversationChain +from langchain.chains import ConversationChain, LLMChain from langchain.memory import ConversationBufferWindowMemory from .base import BaseChatHandler, SlashCommandRoutingType @@ -30,14 +30,20 @@ def create_llm_chain( llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() + self.llm = llm self.memory = ConversationBufferWindowMemory( return_messages=llm.is_chat_provider, k=2 ) - self.llm = llm - self.llm_chain = ConversationChain( - llm=llm, prompt=prompt_template, verbose=True, memory=self.memory - ) + if llm.manages_history: + self.llm_chain = LLMChain( + llm=llm, prompt=prompt_template, verbose=True + ) + + else: + self.llm_chain = ConversationChain( + llm=llm, prompt=prompt_template, verbose=True, memory=self.memory + ) def clear_memory(self): # clear chain memory diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index ebb8f0383..cd153003e 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -2,11 +2,12 @@ from typing import Dict from uuid import uuid4 -from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, HumanChatMessage, PersonaDescription +from jupyter_ai_magics import Persona from .base import BaseChatHandler, SlashCommandRoutingType -HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant. +HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant. You can ask me a question using the text box below. You can also use these commands: {commands} @@ -15,7 +16,11 @@ """ -def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]): +def _format_help_message(chat_handlers: Dict[str, BaseChatHandler], persona: Persona, unsupported_slash_commands: set): + if unsupported_slash_commands: + keys = set(chat_handlers.keys()) - unsupported_slash_commands + chat_handlers = { key: chat_handlers[key] for key in keys } + commands = "\n".join( [ f"* `{command_name}` — {handler.help}" @@ -23,15 +28,19 @@ def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]): if command_name != "default" ] ) - return HELP_MESSAGE.format(commands=commands) + return HELP_MESSAGE.format(commands=commands, persona_name=persona.name) -def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]): +def build_help_message(chat_handlers: Dict[str, BaseChatHandler], persona: Persona, unsupported_slash_commands: set): return AgentChatMessage( id=uuid4().hex, time=time.time(), - body=_format_help_message(chat_handlers), + body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), reply_to="", + persona=PersonaDescription( + name=persona.name, + avatar_route=persona.avatar_route + ) ) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 01d3fe766..ad6150130 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -15,6 +15,7 @@ get_em_provider, get_lm_provider, ) +from jupyter_ai_magics import Persona, JupyternautPersona from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -452,3 +453,14 @@ def em_provider_params(self): "model_id": em_lid, **authn_fields, } + + @property + def persona(self) -> Persona: + """ + The current agent persona, set by the selected LM provider. If the + selected LM provider is `None`, this property returns + `JupyternautPersona` by default. + """ + lm_provider = self.lm_provider + persona = getattr(lm_provider, 'persona', None) or JupyternautPersona + return persona diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 245a1c957..4c412309a 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -5,9 +5,11 @@ from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai_magics import JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from traitlets import Dict, List, Unicode +from tornado.web import StaticFileHandler from .chat_handlers import ( AskChatHandler, @@ -18,7 +20,7 @@ HelpChatHandler, LearnChatHandler, ) -from .chat_handlers.help import HelpMessage +from .chat_handlers.help import build_help_message from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager from .handlers import ( @@ -30,6 +32,9 @@ RootChatHandler, ) +JUPYTERNAUT_AVATAR_PATH = JupyternautPersona.avatar_path +JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route + class AiExtension(ExtensionApp): name = "jupyter_ai" @@ -41,6 +46,11 @@ class AiExtension(ExtensionApp): (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), + + # serve the default persona avatar at this path. + # the `()` at the end of the URL denotes an empty regex capture group, + # required by Tornado. + (fr"{JUPYTERNAUT_AVATAR_ROUTE}()", StaticFileHandler, {"path": JUPYTERNAUT_AVATAR_PATH}), ] allowed_providers = List( @@ -303,13 +313,32 @@ def initialize_settings(self): # Make help always appear as the last command jai_chat_handlers["/help"] = help_chat_handler - self.settings["chat_history"].append( - HelpMessage(chat_handlers=jai_chat_handlers) - ) + # bind chat handlers to settings self.settings["jai_chat_handlers"] = jai_chat_handlers + # show help message at server start + self._show_help_message() + latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + + def _show_help_message(self): + """ + Method that ensures a dynamically-generated help message is included in + the chat history shown to users. + """ + chat_handlers = self.settings["jai_chat_handlers"] + config_manager: ConfigManager = self.settings["jai_config_manager"] + lm_provider = config_manager.lm_provider + + if not lm_provider: + return + + persona = config_manager.persona + unsupported_slash_commands = lm_provider.unsupported_slash_commands if lm_provider else set() + help_message = build_help_message(chat_handlers, persona, unsupported_slash_commands) + self.settings["chat_history"].append(help_message) + async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 41509a74d..156c5e85c 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -29,13 +29,31 @@ class ChatClient(ChatUser): id: str +class PersonaDescription(BaseModel): + """ + Description of a persona to a chat client. + """ + name: str + avatar_route: str + + class AgentChatMessage(BaseModel): type: Literal["agent"] = "agent" id: str time: float body: str - # message ID of the HumanChatMessage it is replying to + reply_to: str + """ + Message ID of the HumanChatMessage being replied to. This is set to an empty + string if not applicable. + """ + + persona: PersonaDescription + """ + The persona of the selected provider. If the selected provider is `None`, + this defaults to a description of `JupyternautPersona`. + """ class HumanChatMessage(BaseModel): diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index dd889cf78..8a2e4b658 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -2,10 +2,11 @@ import React, { useState, useEffect } from 'react'; import { Avatar, Box, Typography } from '@mui/material'; import type { SxProps, Theme } from '@mui/material'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { ServerConnection } from '@jupyterlab/services'; +// TODO: delete jupyternaut from frontend package import { AiService } from '../handler'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { Jupyternaut } from '../icons'; import { RendermimeMarkdown } from './rendermime-markdown'; import { useCollaboratorsContext } from '../contexts/collaborators-context'; @@ -49,9 +50,11 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { ); } else { + const baseUrl = ServerConnection.makeSettings().baseUrl; + const avatar_url = baseUrl + props.message.persona.avatar_route; avatar = ( - - + + ); } @@ -59,7 +62,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const name = props.message.type === 'human' ? props.message.client.display_name - : 'Jupyternaut'; + : props.message.persona.name; return ( Date: Wed, 3 Apr 2024 08:01:12 -0700 Subject: [PATCH 2/5] pre-commit --- .../jupyter_ai_magics/__init__.py | 7 +++--- .../jupyter_ai_magics/models/persona.py | 11 ++++++---- .../jupyter_ai_magics/providers.py | 1 - .../jupyter_ai/chat_handlers/base.py | 20 ++++++++++++----- .../jupyter_ai/chat_handlers/default.py | 4 +--- .../jupyter_ai/chat_handlers/help.py | 21 ++++++++++++------ .../jupyter-ai/jupyter_ai/config_manager.py | 6 ++--- packages/jupyter-ai/jupyter_ai/extension.py | 22 ++++++++++++------- packages/jupyter-ai/jupyter_ai/models.py | 1 + 9 files changed, 58 insertions(+), 35 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 616ce6093..7c609a606 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -12,6 +12,10 @@ from .exception import store_exception from .magics import AiMagics +# expose JupyternautPersona on the package root +# required by `jupyter-ai`. +from .models.persona import JupyternautPersona, Persona + # expose model providers on the package root from .providers import ( AI21Provider, @@ -26,9 +30,6 @@ TogetherAIProvider, ) -# expose JupyternautPersona on the package root -# required by `jupyter-ai`. -from .models.persona import JupyternautPersona, Persona def load_ipython_extension(ipython): ipython.register_magics(AiMagics) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py index 15468f6ff..f33d7d7cd 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -1,11 +1,14 @@ - +import os from typing import ClassVar + from langchain.pydantic_v1 import BaseModel -import os -JUPYTERNAUT_AVATAR_PATH = str(os.path.join(os.path.dirname(__file__), '..', 'static', 'jupyternaut.svg')) +JUPYTERNAUT_AVATAR_PATH = str( + os.path.join(os.path.dirname(__file__), "..", "static", "jupyternaut.svg") +) JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg" + class Persona(BaseModel): """ Model of an **agent persona**, a struct that includes the name & avatar @@ -32,8 +35,8 @@ class Persona(BaseModel): serve the file at this path on the route specified by `avatar_route`. """ + class JupyternautPersona(Persona): name: ClassVar[str] = "Jupyternaut" avatar_route: ClassVar[str] = JUPYTERNAUT_AVATAR_ROUTE avatar_path: ClassVar[str] = JUPYTERNAUT_AVATAR_PATH - \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 55c7cc812..c07061b1d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -48,7 +48,6 @@ from .models.persona import Persona - CHAT_SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 0cd1222a6..418791b65 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,7 +16,12 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage, PersonaDescription +from jupyter_ai.models import ( + AgentChatMessage, + ChatMessage, + HumanChatMessage, + PersonaDescription, +) from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -98,9 +103,13 @@ async def on_message(self, message: HumanChatMessage): # ensure the current slash command is supported if self.routing_type.routing_method == "slash_command": - slash_command = "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" + slash_command = ( + "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" + ) if slash_command in lm_provider_klass.unsupported_slash_commands: - self.reply("Sorry, the selected language model does not support this slash command.") + self.reply( + "Sorry, the selected language model does not support this slash command." + ) return # check whether the configured LLM can support a request at this time. @@ -178,9 +187,8 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): body=response, reply_to=human_msg.id if human_msg else "", persona=PersonaDescription( - name=persona.name, - avatar_route=persona.avatar_route - ) + name=persona.name, avatar_route=persona.avatar_route + ), ) for handler in self._root_chat_handlers.values(): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 59652af72..df288d409 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -36,9 +36,7 @@ def create_llm_chain( ) if llm.manages_history: - self.llm_chain = LLMChain( - llm=llm, prompt=prompt_template, verbose=True - ) + self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True) else: self.llm_chain = ConversationChain( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index cd153003e..b9a5e4460 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -16,11 +16,15 @@ """ -def _format_help_message(chat_handlers: Dict[str, BaseChatHandler], persona: Persona, unsupported_slash_commands: set): +def _format_help_message( + chat_handlers: Dict[str, BaseChatHandler], + persona: Persona, + unsupported_slash_commands: set, +): if unsupported_slash_commands: keys = set(chat_handlers.keys()) - unsupported_slash_commands - chat_handlers = { key: chat_handlers[key] for key in keys } - + chat_handlers = {key: chat_handlers[key] for key in keys} + commands = "\n".join( [ f"* `{command_name}` — {handler.help}" @@ -31,16 +35,19 @@ def _format_help_message(chat_handlers: Dict[str, BaseChatHandler], persona: Per return HELP_MESSAGE.format(commands=commands, persona_name=persona.name) -def build_help_message(chat_handlers: Dict[str, BaseChatHandler], persona: Persona, unsupported_slash_commands: set): +def build_help_message( + chat_handlers: Dict[str, BaseChatHandler], + persona: Persona, + unsupported_slash_commands: set, +): return AgentChatMessage( id=uuid4().hex, time=time.time(), body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), reply_to="", persona=PersonaDescription( - name=persona.name, - avatar_route=persona.avatar_route - ) + name=persona.name, avatar_route=persona.avatar_route + ), ) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index ad6150130..392e44601 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -8,6 +8,7 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest +from jupyter_ai_magics import JupyternautPersona, Persona from jupyter_ai_magics.utils import ( AnyProvider, EmProvidersDict, @@ -15,7 +16,6 @@ get_em_provider, get_lm_provider, ) -from jupyter_ai_magics import Persona, JupyternautPersona from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -453,7 +453,7 @@ def em_provider_params(self): "model_id": em_lid, **authn_fields, } - + @property def persona(self) -> Persona: """ @@ -462,5 +462,5 @@ def persona(self) -> Persona: `JupyternautPersona` by default. """ lm_provider = self.lm_provider - persona = getattr(lm_provider, 'persona', None) or JupyternautPersona + persona = getattr(lm_provider, "persona", None) or JupyternautPersona return persona diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 4c412309a..12199b676 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -8,8 +8,8 @@ from jupyter_ai_magics import JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp -from traitlets import Dict, List, Unicode from tornado.web import StaticFileHandler +from traitlets import Dict, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -46,11 +46,14 @@ class AiExtension(ExtensionApp): (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), - # serve the default persona avatar at this path. # the `()` at the end of the URL denotes an empty regex capture group, # required by Tornado. - (fr"{JUPYTERNAUT_AVATAR_ROUTE}()", StaticFileHandler, {"path": JUPYTERNAUT_AVATAR_PATH}), + ( + rf"{JUPYTERNAUT_AVATAR_ROUTE}()", + StaticFileHandler, + {"path": JUPYTERNAUT_AVATAR_PATH}, + ), ] allowed_providers = List( @@ -321,25 +324,28 @@ def initialize_settings(self): latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") - + def _show_help_message(self): """ Method that ensures a dynamically-generated help message is included in the chat history shown to users. """ chat_handlers = self.settings["jai_chat_handlers"] - config_manager: ConfigManager = self.settings["jai_config_manager"] + config_manager: ConfigManager = self.settings["jai_config_manager"] lm_provider = config_manager.lm_provider if not lm_provider: return persona = config_manager.persona - unsupported_slash_commands = lm_provider.unsupported_slash_commands if lm_provider else set() - help_message = build_help_message(chat_handlers, persona, unsupported_slash_commands) + unsupported_slash_commands = ( + lm_provider.unsupported_slash_commands if lm_provider else set() + ) + help_message = build_help_message( + chat_handlers, persona, unsupported_slash_commands + ) self.settings["chat_history"].append(help_message) - async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 156c5e85c..11999a09c 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -33,6 +33,7 @@ class PersonaDescription(BaseModel): """ Description of a persona to a chat client. """ + name: str avatar_route: str From 0ed6644feb0453a1da976fe0d1ec5c23c87fc3dd Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 3 Apr 2024 08:14:58 -0700 Subject: [PATCH 3/5] add comment about jupyternaut icon in frontend --- packages/jupyter-ai/src/icons.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/jupyter-ai/src/icons.ts b/packages/jupyter-ai/src/icons.ts index a510c17a7..30d40a165 100644 --- a/packages/jupyter-ai/src/icons.ts +++ b/packages/jupyter-ai/src/icons.ts @@ -15,4 +15,7 @@ export const jupyternautIcon = new LabIcon({ svgstr: jupyternautSvg }); +// this icon is only used in the status bar. +// to configure the icon shown on agent replies in the chat UI, please specify a +// custom `Persona`. export const Jupyternaut = jupyternautIcon.react; From 5ecce2e337a8e85241571b9944189facc2b33b50 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 4 Apr 2024 10:31:53 -0700 Subject: [PATCH 4/5] remove 'avatar_path' from 'Persona', drop 'PersonaDescription' --- .../jupyter_ai_magics/models/persona.py | 25 +++---------------- .../jupyter_ai/chat_handlers/base.py | 4 +-- .../jupyter_ai/chat_handlers/help.py | 4 +-- packages/jupyter-ai/jupyter_ai/extension.py | 8 +++--- packages/jupyter-ai/jupyter_ai/models.py | 12 ++------- .../jupyter_ai}/static/jupyternaut.svg | 0 packages/jupyter-ai/src/handler.ts | 4 +-- 7 files changed, 17 insertions(+), 40 deletions(-) rename packages/{jupyter-ai-magics/jupyter_ai_magics => jupyter-ai/jupyter_ai}/static/jupyternaut.svg (100%) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py index f33d7d7cd..54d45b7ce 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -1,14 +1,5 @@ -import os -from typing import ClassVar - from langchain.pydantic_v1 import BaseModel -JUPYTERNAUT_AVATAR_PATH = str( - os.path.join(os.path.dirname(__file__), "..", "static", "jupyternaut.svg") -) -JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg" - - class Persona(BaseModel): """ Model of an **agent persona**, a struct that includes the name & avatar @@ -17,26 +8,18 @@ class Persona(BaseModel): Each persona is specific to a single provider, set on the `persona` field. """ - name: ClassVar[str] = ... + name: str = ... """ Name of the persona, e.g. "Jupyternaut". This is used to render the name shown on agent replies in the chat UI. """ - avatar_route: ClassVar[str] = ... + avatar_route: str = ... """ The server route that should be used the avatar of this persona. This is used to render the avatar shown on agent replies in the chat UI. """ - avatar_path: ClassVar[str] = ... - """ - The path to the avatar SVG file on the server filesystem. The server should - serve the file at this path on the route specified by `avatar_route`. - """ - -class JupyternautPersona(Persona): - name: ClassVar[str] = "Jupyternaut" - avatar_route: ClassVar[str] = JUPYTERNAUT_AVATAR_ROUTE - avatar_path: ClassVar[str] = JUPYTERNAUT_AVATAR_PATH +JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg" +JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 418791b65..973972c67 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -20,8 +20,8 @@ AgentChatMessage, ChatMessage, HumanChatMessage, - PersonaDescription, ) +from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -186,7 +186,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): time=time.time(), body=response, reply_to=human_msg.id if human_msg else "", - persona=PersonaDescription( + persona=Persona( name=persona.name, avatar_route=persona.avatar_route ), ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index b9a5e4460..d79c32fd7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -2,7 +2,7 @@ from typing import Dict from uuid import uuid4 -from jupyter_ai.models import AgentChatMessage, HumanChatMessage, PersonaDescription +from jupyter_ai.models import AgentChatMessage, HumanChatMessage from jupyter_ai_magics import Persona from .base import BaseChatHandler, SlashCommandRoutingType @@ -45,7 +45,7 @@ def build_help_message( time=time.time(), body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), reply_to="", - persona=PersonaDescription( + persona=Persona( name=persona.name, avatar_route=persona.avatar_route ), ) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 12199b676..bbf29263a 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,6 @@ -import logging import re import time +import os from dask.distributed import Client as DaskClient from importlib_metadata import entry_points @@ -32,9 +32,11 @@ RootChatHandler, ) -JUPYTERNAUT_AVATAR_PATH = JupyternautPersona.avatar_path -JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route +JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route +JUPYTERNAUT_AVATAR_PATH = str( + os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg") +) class AiExtension(ExtensionApp): name = "jupyter_ai" diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 11999a09c..ff711a9cc 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from jupyter_ai_magics.providers import AuthStrategy, Field +from jupyter_ai_magics import Persona from langchain.pydantic_v1 import BaseModel, validator DEFAULT_CHUNK_SIZE = 2000 @@ -29,15 +30,6 @@ class ChatClient(ChatUser): id: str -class PersonaDescription(BaseModel): - """ - Description of a persona to a chat client. - """ - - name: str - avatar_route: str - - class AgentChatMessage(BaseModel): type: Literal["agent"] = "agent" id: str @@ -50,7 +42,7 @@ class AgentChatMessage(BaseModel): string if not applicable. """ - persona: PersonaDescription + persona: Persona """ The persona of the selected provider. If the selected provider is `None`, this defaults to a description of `JupyternautPersona`. diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg b/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg similarity index 100% rename from packages/jupyter-ai-magics/jupyter_ai_magics/static/jupyternaut.svg rename to packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index a554c07ad..7848dc20e 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -68,7 +68,7 @@ export namespace AiService { id: string; }; - export type PersonaDescription = { + export type Persona = { name: string; avatar_route: string; }; @@ -79,7 +79,7 @@ export namespace AiService { time: number; body: string; reply_to: string; - persona: PersonaDescription; + persona: Persona; }; export type HumanChatMessage = { From 568adcf647a6543db0f460176798607f4cc6eb63 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 4 Apr 2024 10:32:32 -0700 Subject: [PATCH 5/5] pre-commit --- .../jupyter_ai_magics/models/persona.py | 1 + packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 10 ++-------- packages/jupyter-ai/jupyter_ai/chat_handlers/help.py | 4 +--- packages/jupyter-ai/jupyter_ai/extension.py | 4 ++-- packages/jupyter-ai/jupyter_ai/models.py | 2 +- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py index 54d45b7ce..fe25397b0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -1,5 +1,6 @@ from langchain.pydantic_v1 import BaseModel + class Persona(BaseModel): """ Model of an **agent persona**, a struct that includes the name & avatar diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 973972c67..1ae80c5c5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,11 +16,7 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import ( - AgentChatMessage, - ChatMessage, - HumanChatMessage, -) +from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -186,9 +182,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): time=time.time(), body=response, reply_to=human_msg.id if human_msg else "", - persona=Persona( - name=persona.name, avatar_route=persona.avatar_route - ), + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ) for handler in self._root_chat_handlers.values(): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index d79c32fd7..e46038da5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -45,9 +45,7 @@ def build_help_message( time=time.time(), body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), reply_to="", - persona=Persona( - name=persona.name, avatar_route=persona.avatar_route - ), + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index bbf29263a..0a66a8b1b 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,6 @@ +import os import re import time -import os from dask.distributed import Client as DaskClient from importlib_metadata import entry_points @@ -32,12 +32,12 @@ RootChatHandler, ) - JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route JUPYTERNAUT_AVATAR_PATH = str( os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg") ) + class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ff711a9cc..32353a694 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Literal, Optional, Union -from jupyter_ai_magics.providers import AuthStrategy, Field from jupyter_ai_magics import Persona +from jupyter_ai_magics.providers import AuthStrategy, Field from langchain.pydantic_v1 import BaseModel, validator DEFAULT_CHUNK_SIZE = 2000