Skip to content

Commit

Permalink
feat: add diff format reflected message support (#16)
Browse files Browse the repository at this point in the history
* feat: add confirm test

* fix: when error occurs, it can't chat again

* feat: reflected message support

* fix: don't add space if there is one when insert mention.
  • Loading branch information
lee88688 authored Nov 21, 2024
1 parent 174fbff commit 746f46f
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 47 deletions.
131 changes: 93 additions & 38 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, List, Optional, Literal
from typing import Dict, Iterator, List, Optional, Literal, Any
from flask import Flask, jsonify, request, Response
from aider.models import Model
from aider.coders import Coder
from aider.io import InputOutput
from dataclasses import dataclass, asdict
import os
import json
from threading import Event

@dataclass
class ChatSetting:
Expand Down Expand Up @@ -109,11 +110,25 @@ class ChatSessionData:

ChatModeType = Literal['ask', 'code']

@dataclass
class ChatChunkData:
# event: data, usage, write, end, error, reflected, log
# data: yield chunk message
# usage: yield usage report
# write: yield write files
# end: end of chat
# error: yield error message
# reflected: yield reflected message
# log: yield log message
event: str
data: Optional[dict] = None

class ChatSessionManager:
chat_type: ChatModeType
diff_format: str
reference_list: List[ChatSessionReference]
setting: Optional[ChatSetting] = None
confirm_ask_result: Optional[Any] = None

def __init__(self):
model = Model('gpt-4o')
Expand Down Expand Up @@ -141,6 +156,8 @@ def __init__(self):
self.diff_format = 'diff'
self.reference_list = []

self.confirm_ask_event = Event()

def update_model(self, setting: ChatSetting):
if self.setting != setting:
self.setting = setting
Expand All @@ -165,7 +182,7 @@ def update_coder(self):
read_only_fnames=(item.fs_path for item in self.reference_list if item.readonly),
)

def chat(self, data: ChatSessionData):
def chat(self, data: ChatSessionData) -> Iterator[ChatChunkData]:
need_update_coder = False
data.reference_list.sort(key=lambda x: x.fs_path)

Expand All @@ -180,8 +197,63 @@ def chat(self, data: ChatSessionData):
if need_update_coder:
self.update_coder()

yield from self.coder.run_stream(data.message)
try:
self.coder.init_before_message()
message = data.message
while message:
self.coder.reflected_message = None
for msg in self.coder.run_stream(message):
data = {
"chunk": msg,
}
yield ChatChunkData(event='data', data=data)

if manager.coder.usage_report:
yield ChatChunkData(event='usage', data=manager.coder.usage_report)

if not self.coder.reflected_message:
break

if self.coder.num_reflections >= self.coder.max_reflections:
self.coder.io.tool_warning(f"Only {self.coder.max_reflections} reflections allowed, stopping.")
return

self.coder.num_reflections += 1
message = self.coder.reflected_message

yield ChatChunkData(event='reflected', data={"message": message})

error_lines = self.coder.io.get_captured_error_lines()
if error_lines:
if not message:
raise Exception('\n'.join(error_lines))
else:
yield ChatChunkData(event='log', data={"message": '\n'.join(error_lines)})

# get write files
write_files = manager.io.get_captured_write_files()
if write_files:
data = {
"write": write_files,
}
yield ChatChunkData(event='write', data=data)

except Exception as e:
# send error to client
error_data = {
"error": str(e)
}
yield ChatChunkData(event='error', data=error_data)
finally:
# send end event to client
yield ChatChunkData(event='end')

def confirm_ask(self):
self.confirm_ask_event.clear()
self.confirm_ask_event.wait()

def confirm_ask_reply(self):
self.confirm_ask_event.set()

class CORS:
def __init__(self, app):
Expand Down Expand Up @@ -214,41 +286,12 @@ def sse():
chat_session_data = ChatSessionData(**data)

def generate():
try:
for msg in manager.chat(chat_session_data):
data = {
"chunk": msg,
}
yield f"event: data\n"
yield f"data: {json.dumps(data)}\n\n"

error_lines = manager.io.get_captured_error_lines()
if error_lines:
raise Exception('\n'.join(error_lines))

if manager.coder.usage_report:
yield f"event: usage\n"
yield f"data: {json.dumps({'usage': manager.coder.usage_report })}\n\n"

# get write files
write_files = manager.io.get_captured_write_files()
if write_files:
data = {
"write": write_files,
}
yield f"event: write\n"
yield f"data: {json.dumps(data)}\n\n"

except Exception as e:
# send error to client
error_data = {
"error": str(e)
}
yield f"event: error\n"
yield f"data: {json.dumps(error_data)}\n\n"
finally:
# send end event to client
yield f"event: end\n\n"
for msg in manager.chat(chat_session_data):
if msg.data:
yield f"event: {msg.event}\n"
yield f"data: {json.dumps(msg.data)}\n\n"
else:
yield f"event: {msg.event}\n\n"

response = Response(generate(), mimetype='text/event-stream')
return response
Expand All @@ -274,5 +317,17 @@ def update_setting():
manager.update_model(setting)
return jsonify({})

@app.route('/api/chat/confirm/ask', methods=['POST'])
def confirm_ask():
manager.confirm_ask()
return jsonify(manager.confirm_ask_result)

@app.route('/api/chat/confirm/reply', methods=['POST'])
def confirm_reply():
data = request.json
manager.confirm_ask_result = data
manager.confirm_ask_reply()
return jsonify({})

if __name__ == '__main__':
app.run()
50 changes: 49 additions & 1 deletion ui/src/stores/useChatStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ export const useChatStore = create(
}
});

eventSource.addEventListener('end', () => {
const end = () => {
if (!get().current) {
return;
}
set((state) => {
const history = state.current
? [...state.history, state.current]
Expand All @@ -305,6 +308,50 @@ export const useChatStore = create(
current: undefined,
};
});
};

eventSource.addEventListener('reflected', (event: { data: string }) => {
const reflectedMessage = JSON.parse(event.data) as {
message: string;
};
// reflected message is a user message
set((state) => {
const history = state.current
? [...state.history, state.current]
: [...state.history];
history.push({
type: 'user',
text: reflectedMessage.message,
displayText: reflectedMessage.message,
id: nanoid(),
reflected: true,
referenceList: state.chatReferenceList,
});

const id = state.id;
useChatSessionStore.getState().addSession(id, history);

// create new assistant message for next round
return {
...state,
history,
current: {
id: nanoid(),
text: '',
type: 'assistant',
},
};
});
logToOutput('info', `reflected message: ${reflectedMessage.message}`);
});

eventSource.addEventListener('log', (event: { data: string }) => {
const logMessage = JSON.parse(event.data) as { message: string };
logToOutput('info', `server log: ${logMessage.message}`);
});

eventSource.addEventListener('end', () => {
end();
eventSource.close();
});

Expand All @@ -323,6 +370,7 @@ export const useChatStore = create(
} else {
console.error('EventSource error:', event);
}
end();
eventSource.close();
};
},
Expand Down
1 change: 1 addition & 0 deletions ui/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface ChatUserMessage {
text: string;
displayText: string;
referenceList: ChatReferenceItem[];
reflected?: boolean;
}

export interface ChatAssistantMessage {
Expand Down
52 changes: 45 additions & 7 deletions ui/src/views/chat/chatMessageList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ import './codeTheme.scss';
const messageItemStyle = css({
marginBottom: '16px',
// whiteSpace: 'pre-wrap',
'& h1': {
fontSize: '1.25rem',
fontWeight: 'bold',
lineHeight: '1.25',
margin: '0.5rem 0',
},
'& h2': {
fontSize: '1.25rem',
fontWeight: 'bold',
lineHeight: '1.25',
margin: '0.5rem 0',
},
'& pre': {
overflow: 'auto hidden',
width: '100%',
Expand All @@ -32,13 +44,17 @@ const messageItemStyle = css({
fontFamily: 'var(--vscode-editor-font-family)',
},
},
'& .hljs': {
backgroundColor: 'transparent',
},
});

function ChatUserMessageItem(props: { message: ChatUserMessage }) {
const { message } = props;

return (
<div
className={messageItemStyle}
style={{
backgroundColor: 'var(--vscode-input-background)',
borderRadius: '4px',
Expand All @@ -47,7 +63,11 @@ function ChatUserMessageItem(props: { message: ChatUserMessage }) {
whiteSpace: 'pre-wrap',
}}
>
<Markdown>{message.displayText}</Markdown>
{message.reflected ? (
message.displayText
) : (
<Markdown>{message.displayText}</Markdown>
)}
</div>
);
}
Expand All @@ -73,12 +93,17 @@ const code: MarkdownComponents['code'] = (props) => {
);
};

function ChatAssistantMessageItem(props: { message: ChatAssistantMessage }) {
const { message } = props;
function ChatAssistantMessageItem(props: {
message: ChatAssistantMessage;
useComponents?: boolean;
}) {
const { message, useComponents = true } = props;

return (
<div className={messageItemStyle}>
<Markdown components={{ code }}>{message.text}</Markdown>
<Markdown components={useComponents ? { code } : undefined}>
{message.text}
</Markdown>
{message.usage && (
<div
style={{
Expand All @@ -97,13 +122,19 @@ function ChatAssistantMessageItem(props: { message: ChatAssistantMessage }) {

const ChatMessageItem = memo(function ChatMessageItem(props: {
message: ChatMessage;
useComponents?: boolean;
}) {
const { message } = props;
const { message, useComponents } = props;

if (message.type === 'user') {
return <ChatUserMessageItem message={message} />;
} else if (message.type === 'assistant') {
return <ChatAssistantMessageItem message={message} />;
return (
<ChatAssistantMessageItem
message={message}
useComponents={useComponents}
/>
);
}

return null;
Expand All @@ -125,7 +156,14 @@ export default function ChatMessageList() {
let currentItem: React.ReactNode;
if (current) {
if (current.text) {
currentItem = <ChatMessageItem key={current.id} message={current} />;
currentItem = (
// current may change very fast, so use components may cause performance issue
<ChatMessageItem
key={current.id}
message={current}
useComponents={false}
/>
);
} else {
currentItem = (
<div>
Expand Down
14 changes: 13 additions & 1 deletion ui/src/views/chat/chatTextArea.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,19 @@ const insertMention = (editor: Editor, reference: ChatReferenceItem) => {
const point = Editor.after(editor, editor.selection!);
if (point) {
Transforms.setSelection(editor, { anchor: point, focus: point });
editor.insertText(' ');

// Check if there's already a space after the mention
const after = Editor.after(editor, point);
if (after) {
const range = Editor.range(editor, point, after);
const text = Editor.string(editor, range);
if (text !== ' ') {
editor.insertText(' ');
}
} else {
// If we're at the end of the document, add a space
editor.insertText(' ');
}
}

ReactEditor.focus(editor);
Expand Down

0 comments on commit 746f46f

Please sign in to comment.