Skip to content

Commit

Permalink
implement better chat history UI (jupyterlab#65)
Browse files Browse the repository at this point in the history
* implement better chat history UI

* remove console log

* add ScrollContainer component
  • Loading branch information
dlqqq authored and Marchlak committed Oct 28, 2024
1 parent ee90e89 commit e65e64f
Show file tree
Hide file tree
Showing 11 changed files with 956 additions and 209 deletions.
6 changes: 4 additions & 2 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List
import tornado
import uuid
import time

from tornado.web import HTTPError
from pydantic import ValidationError
Expand Down Expand Up @@ -191,10 +192,9 @@ def open(self):
from `self.client_id`."""

client_id = self.generate_client_id()
chat_client_kwargs = {k: v for k, v in asdict(self.current_user).items() if k != "username"}

self.chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**chat_client_kwargs, id=client_id)
self.chat_clients[client_id] = ChatClient(**asdict(self.current_user), id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())

Expand Down Expand Up @@ -235,6 +235,7 @@ async def on_message(self, message):
chat_message_id = str(uuid.uuid4())
chat_message = HumanChatMessage(
id=chat_message_id,
time=time.time(),
body=chat_request.prompt,
client=self.chat_client,
)
Expand All @@ -246,6 +247,7 @@ async def on_message(self, message):
response = await ensure_async(self.chat_provider.apredict(input=message.content))
agent_message = AgentChatMessage(
id=str(uuid.uuid4()),
time=time.time(),
body=response,
reply_to=chat_message_id
)
Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ class ChatRequest(BaseModel):
prompt: str

class ChatClient(BaseModel):
# Client ID assigned by us. Necessary because different JupyterLab clients
# on the same device (i.e. running on multiple tabs/windows) may have the
# same user ID assigned to them by IdentityProvider.
id: str
# User ID assigned by IdentityProvider.
username: str
initials: str
name: str
display_name: str
Expand All @@ -21,13 +26,15 @@ class ChatClient(BaseModel):
class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
id: str
time: float
body: str
# message ID of the HumanChatMessage it is replying to
reply_to: str

class HumanChatMessage(BaseModel):
type: Literal["human"] = "human"
id: str
time: float
body: str
client: ChatClient

Expand Down
19 changes: 11 additions & 8 deletions packages/jupyter-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@
"dependencies": {
"@emotion/react": "^11.10.5",
"@emotion/styled": "^11.10.5",
"@jupyterlab/application": "^3.1.0",
"@jupyterlab/cells": "^3.4.2",
"@jupyterlab/coreutils": "^5.1.0",
"@jupyterlab/fileeditor": "^3.5.1",
"@jupyterlab/notebook": "^3.4.2",
"@jupyterlab/services": "^6.1.0",
"@jupyterlab/ui-components": "^3.4.2",
"@jupyterlab/application": "^3.6.3",
"@jupyterlab/cells": "^3.6.3",
"@jupyterlab/coreutils": "^5.6.3",
"@jupyterlab/fileeditor": "^3.6.3",
"@jupyterlab/notebook": "^3.6.3",
"@jupyterlab/services": "^6.6.3",
"@jupyterlab/ui-components": "^3.6.3",
"@jupyterlab/collaboration": "^3.6.3",
"@mui/icons-material": "^5.11.0",
"@mui/material": "^5.11.0",
"date-fns": "^2.29.3",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.6",
Expand Down Expand Up @@ -99,7 +101,8 @@
"stylelint-config-standard": "~24.0.0",
"stylelint-prettier": "^2.0.0",
"ts-jest": "^26.0.0",
"typescript": "~4.1.3"
"typescript": "~4.1.3",
"y-protocols": "^1.0.5"
},
"sideEffects": [
"style/*.css",
Expand Down
180 changes: 121 additions & 59 deletions packages/jupyter-ai/src/components/chat-messages.tsx
Original file line number Diff line number Diff line change
@@ -1,84 +1,146 @@
import React from 'react';
import React, { useState, useEffect } from 'react';

import { Avatar, Box, useTheme } from '@mui/material';
import { Avatar, Box, Typography } from '@mui/material';
import type { SxProps, Theme } from '@mui/material';
import PsychologyIcon from '@mui/icons-material/Psychology';
import { formatDistanceToNowStrict, fromUnixTime } from 'date-fns';
import ReactMarkdown from 'react-markdown';
import remarkMath from 'remark-math';
import rehypeKatex from 'rehype-katex';
import 'katex/dist/katex.min.css';

import { ChatCodeView } from './chat-code-view';
import { AiService } from '../handler';
import { useCollaboratorsContext } from '../contexts/collaborators-context';

type ChatMessagesProps = {
sender: 'self' | 'ai' | string;
messages: string[];
messages: AiService.ChatMessage[];
};

function getAvatar(sender: 'self' | 'ai' | string) {
type ChatMessageHeaderProps = {
message: AiService.ChatMessage;
timestamp: string;
sx?: SxProps<Theme>;
};

export function ChatMessageHeader(props: ChatMessageHeaderProps) {
const collaborators = useCollaboratorsContext();

const sharedStyles: SxProps<Theme> = {
height: '2em',
width: '2em'
height: '24px',
width: '24px'
};

switch (sender) {
case 'self':
return <Avatar src="" sx={{ ...sharedStyles }} />;
case 'ai':
return (
<Avatar sx={{ ...sharedStyles }}>
<PsychologyIcon />
</Avatar>
);
default:
return <Avatar src="?" sx={{ ...sharedStyles }} />;
let avatar: JSX.Element;
if (props.message.type === 'human') {
const bgcolor = collaborators?.[props.message.client.username]?.color;
avatar = (
<Avatar
sx={{
...sharedStyles,
...(bgcolor && { bgcolor })
}}
>
<Typography
sx={{
fontSize: 'var(--jp-ui-font-size1)',
color: 'var(--jp-ui-font-color1)'
}}
>
{props.message.client.initials}
</Typography>
</Avatar>
);
} else {
avatar = (
<Avatar sx={{ ...sharedStyles, bgcolor: 'var(--jp-jupyter-icon-color)' }}>
<PsychologyIcon />
</Avatar>
);
}

const name =
props.message.type === 'human'
? props.message.client.display_name
: 'Jupyter AI';

return (
<Box
sx={{
display: 'flex',
alignItems: 'center',
'& > :not(:last-child)': {
marginRight: 3
},
...props.sx
}}
>
{avatar}
<Box
sx={{
display: 'flex',
flexGrow: 1,
flexWrap: 'wrap',
justifyContent: 'space-between',
alignItems: 'center'
}}
>
<Typography sx={{ fontWeight: 700 }}>{name}</Typography>
<Typography sx={{ fontSize: '0.8em', fontWeight: 300 }}>
{props.timestamp}
</Typography>
</Box>
</Box>
);
}

export function ChatMessages(props: ChatMessagesProps) {
const theme = useTheme();
const radius = theme.spacing(2);
const [timestamps, setTimestamps] = useState<Record<string, string>>({});

/**
* Effect: update cached timestamp strings upon receiving a new message and
* every 5 seconds after that.
*/
useEffect(() => {
function updateTimestamps() {
const newTimestamps: Record<string, string> = {};
for (const message of props.messages) {
newTimestamps[message.id] =
formatDistanceToNowStrict(fromUnixTime(message.time)) + ' ago';
}
setTimestamps(newTimestamps);
}

updateTimestamps();
const intervalId = setInterval(updateTimestamps, 5000);
return () => {
clearInterval(intervalId);
};
}, [props.messages]);

return (
<Box sx={{ display: 'flex', ' > :not(:last-child)': { marginRight: 2 } }}>
{getAvatar(props.sender)}
<Box sx={{ flexGrow: 1, minWidth: 0 }}>
{props.messages.map((message, i) => (
// extra div needed to ensure each bubble is on a new line
<Box key={i}>
<Box
sx={{
display: 'inline-block',
padding: theme.spacing(1, 2),
borderRadius: radius,
marginBottom: 1,
wordBreak: 'break-word',
textAlign: 'left',
maxWidth: '100%',
boxSizing: 'border-box',
...(props.sender === 'self'
? {
backgroundColor: theme.palette.primary.main,
color: theme.palette.common.white
}
: {
backgroundColor: theme.palette.grey[100]
})
}}
>
<ReactMarkdown
components={{
code: ChatCodeView
}}
remarkPlugins={[remarkMath]}
rehypePlugins={[rehypeKatex]}
>
{message}
</ReactMarkdown>
</Box>
</Box>
))}
</Box>
<Box
sx={{ '& > :not(:last-child)': { borderBottom: '1px solid lightgrey' } }}
>
{props.messages.map((message, i) => (
// extra div needed to ensure each bubble is on a new line
<Box key={i} sx={{ padding: 2 }}>
<ChatMessageHeader
message={message}
timestamp={timestamps[message.id]}
sx={{ marginBottom: '12px' }}
/>
<ReactMarkdown
components={{
code: ChatCodeView
}}
remarkPlugins={[remarkMath]}
rehypePlugins={[rehypeKatex]}
>
{message.body}
</ReactMarkdown>
</Box>
))}
</Box>
);
}
Loading

0 comments on commit e65e64f

Please sign in to comment.