Skip to content

Commit

Permalink
refactor: add chat module
Browse files Browse the repository at this point in the history
  • Loading branch information
pythops committed Jan 24, 2024
1 parent 03bfcaf commit a3f911d
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 107 deletions.
22 changes: 3 additions & 19 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::help::Help;
use crate::history::History;
use crate::prompt::Prompt;
use crate::{chat::Chat, help::Help};
use std;
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;
Expand All @@ -9,7 +9,7 @@ use crate::notification::Notification;
use crate::spinner::Spinner;
use crate::{config::Config, formatter::Formatter};
use arboard::Clipboard;
use ratatui::text::{Line, Text};
use ratatui::text::Line;

use std::sync::Arc;

Expand All @@ -24,27 +24,12 @@ pub enum FocusedBlock {
Help,
}

#[derive(Debug, Default)]
pub struct Chat<'a> {
pub messages: Vec<String>,
pub formatted_chat: Text<'a>,
pub scroll: u16,
pub length: u16,
}

#[derive(Debug, Default)]
pub struct Answer<'a> {
pub answer: String,
pub formatted_answer: Text<'a>,
}

pub struct App<'a> {
pub running: bool,
pub prompt: Prompt<'a>,
pub chat: Chat<'a>,
pub focused_block: FocusedBlock,
pub llm_messages: Vec<HashMap<String, String>>,
pub answer: Answer<'a>,
pub history: History<'a>,
pub notifications: Vec<Notification>,
pub spinner: Spinner,
Expand All @@ -60,10 +45,9 @@ impl<'a> App<'a> {
Self {
running: true,
prompt: Prompt::default(),
chat: Chat::default(),
chat: Chat::new(),
focused_block: FocusedBlock::Prompt,
llm_messages: Vec::new(),
answer: Answer::default(),
history: History::new(),
notifications: Vec::new(),
spinner: Spinner::default(),
Expand Down
136 changes: 136 additions & 0 deletions src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use std::{rc::Rc, sync::atomic::AtomicBool};

use ratatui::{
layout::Rect,
style::Style,
text::Text,
widgets::{Block, BorderType, Borders, Paragraph, Wrap},
Frame,
};

use crate::{app::FocusedBlock, formatter::Formatter, llm::LLMAnswer};

#[derive(Debug, Clone, Default)]
pub struct Answer<'a> {
pub plain_answer: String,
pub formatted_answer: Text<'a>,
}

#[derive(Debug, Clone)]
pub struct Chat<'a> {
pub plain_chat: Vec<String>,
pub formatted_chat: Text<'a>,
pub answer: Answer<'a>,
pub scroll: u16,
area_height: u16,
area_width: u16,
pub automatic_scroll: Rc<AtomicBool>,
pub block: Block<'a>,
}

impl Default for Chat<'_> {
fn default() -> Self {
let block = Block::default()
.border_type(BorderType::default())
.borders(Borders::ALL)
.style(Style::default());

Self {
plain_chat: Vec::new(),
formatted_chat: Text::raw(""),
answer: Answer::default(),
scroll: 0,
area_height: 0,
area_width: 0,
automatic_scroll: Rc::new(AtomicBool::new(true)),
block,
}
}
}

impl Chat<'_> {
pub fn new() -> Self {
Self::default()
}

pub fn handle_answer(&mut self, event: LLMAnswer, formatter: &Formatter) {
match event {
LLMAnswer::StartAnswer => {
self.formatted_chat.lines.pop();
}

LLMAnswer::Answer(answer) => {
self.answer.plain_answer.push_str(answer.as_str());

self.answer.formatted_answer =
formatter.format(format!("🤖: {}", &self.answer.plain_answer).as_str());
// self.move_to_bottom();
}

LLMAnswer::EndAnswer => {
self.formatted_chat
.extend(self.answer.formatted_answer.clone());

self.formatted_chat.extend(Text::raw("\n"));

self.plain_chat
.push(format!("🤖: {}", self.answer.plain_answer));

self.answer = Answer::default();
}
}
}

pub fn height(&self) -> usize {
let mut chat = self.formatted_chat.clone();

chat.extend(self.answer.formatted_answer.clone());

let nb_lines = chat.lines.len() + 3;
chat.lines.iter().fold(nb_lines, |acc, line| {
acc + line.width() / self.area_width as usize
})
}

pub fn move_to_bottom(&mut self) {
// self.scroll = (self.formatted_chat.height() + self.answer.formatted_answer.height())
// .saturating_sub((self.area_height - 2).into());
}

pub fn render(&mut self, frame: &mut Frame, area: Rect, focused_block: &FocusedBlock) {
let mut text = self.formatted_chat.clone();
text.extend(self.answer.formatted_answer.clone());

self.area_height = area.height;
self.area_width = area.width;

let scroll: u16 = {
if self
.automatic_scroll
.load(std::sync::atomic::Ordering::Relaxed)
{
let scroll = self.height().saturating_sub(self.area_height.into()) as u16;
self.scroll = scroll;
scroll
} else {
self.scroll
}
};

let chat = Paragraph::new(text)
.scroll((scroll, 0))
.wrap(Wrap { trim: false })
.block(
Block::default()
.borders(Borders::ALL)
.style(Style::default())
.border_type(match focused_block {
FocusedBlock::Chat => BorderType::Thick,
_ => BorderType::Rounded,
})
.border_style(Style::default()),
);

frame.render_widget(chat, area);
}
}
34 changes: 25 additions & 9 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::llm::LLMAnswer;
use crate::{app::Chat, prompt::Mode};
use crate::{chat::Chat, prompt::Mode};

use crate::{
app::{App, AppResult, FocusedBlock},
Expand Down Expand Up @@ -46,6 +46,9 @@ pub fn handle_key_events(
}

FocusedBlock::Chat => {
app.chat
.automatic_scroll
.store(false, std::sync::atomic::Ordering::Relaxed);
app.chat.scroll = app.chat.scroll.saturating_add(1);
}

Expand All @@ -67,6 +70,9 @@ pub fn handle_key_events(
}

FocusedBlock::Chat => {
app.chat
.automatic_scroll
.store(false, std::sync::atomic::Ordering::Relaxed);
app.chat.scroll = app.chat.scroll.saturating_sub(1);
}

Expand All @@ -79,7 +85,7 @@ pub fn handle_key_events(

// `G`: Mo to the bottom
KeyCode::Char('G') => match app.focused_block {
FocusedBlock::Chat => app.chat.scroll = app.chat.length,
FocusedBlock::Chat => app.chat.move_to_bottom(),
FocusedBlock::History => app.history.move_to_bottom(),
_ => (),
},
Expand All @@ -99,7 +105,7 @@ pub fn handle_key_events(
.text
.push(app.chat.formatted_chat.clone());

app.history.text.push(app.chat.messages.clone());
app.history.text.push(app.chat.plain_chat.clone());

app.chat = Chat::default();
app.llm_messages = Vec::new();
Expand All @@ -120,7 +126,7 @@ pub fn handle_key_events(
FocusedBlock::Chat | FocusedBlock::Prompt => {
match std::fs::write(
app.config.archive_file_name.clone(),
app.chat.messages.join(""),
app.chat.plain_chat.join(""),
) {
Ok(_) => {
let notif = Notification::new(
Expand All @@ -145,12 +151,16 @@ pub fn handle_key_events(
KeyCode::Tab => match app.focused_block {
FocusedBlock::Chat => {
app.focused_block = FocusedBlock::Prompt;

app.chat
.automatic_scroll
.store(true, std::sync::atomic::Ordering::Relaxed);

app.prompt.update(&app.focused_block);
}
FocusedBlock::Prompt => {
app.chat.scroll = (app.chat.formatted_chat.height()
+ app.answer.formatted_answer.height())
as u16;
app.chat.move_to_bottom();

app.focused_block = FocusedBlock::Chat;
app.prompt.mode = Mode::Normal;
app.prompt.update(&app.focused_block);
Expand All @@ -173,6 +183,9 @@ pub fn handle_key_events(
{
app.focused_block = FocusedBlock::Help;
app.prompt.update(&app.focused_block);
app.chat
.automatic_scroll
.store(true, std::sync::atomic::Ordering::Relaxed);
}

// Show history
Expand All @@ -183,6 +196,9 @@ pub fn handle_key_events(
{
app.focused_block = FocusedBlock::History;
app.prompt.update(&app.focused_block);
app.chat
.automatic_scroll
.store(true, std::sync::atomic::Ordering::Relaxed);
}

// Discard help & history popups
Expand All @@ -207,11 +223,11 @@ pub fn handle_key_events(

app.prompt.clear();

app.chat.messages.push(format!(" : {}\n", user_input));
app.chat.plain_chat.push(format!("👤 : {}\n", user_input));

app.chat.formatted_chat.extend(
app.formatter
.format(format!(": {}\n", user_input).as_str()),
.format(format!("👤: {}\n", user_input).as_str()),
);

let conv = HashMap::from([
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ pub mod prompt;
pub mod help;

pub mod history;

pub mod chat;
31 changes: 7 additions & 24 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ratatui::backend::CrosstermBackend;
use ratatui::Terminal;
use std::collections::HashMap;
use std::{env, io};
use tenere::app::{Answer, App, AppResult};
use tenere::app::{App, AppResult};
use tenere::cli;
use tenere::config::Config;
use tenere::event::{Event, EventHandler};
Expand All @@ -11,7 +11,6 @@ use tenere::handler::handle_key_events;
use tenere::llm::LLMAnswer;
use tenere::tui::Tui;

use ratatui::text::Text;
use tenere::llm::{LLMBackend, LLMModel};

use std::sync::Arc;
Expand Down Expand Up @@ -53,42 +52,26 @@ fn main() -> AppResult<()> {
Event::Mouse(_) => {}
Event::Resize(_, _) => {}
Event::LLMEvent(LLMAnswer::Answer(answer)) => {
if app.answer.answer.is_empty() {
app.answer
.answer
.push_str(format!("🤖: {}", answer).as_str());
}
app.answer.answer.push_str(answer.as_str());
app.answer.formatted_answer = formatter.format(&app.answer.answer);
app.chat
.handle_answer(LLMAnswer::Answer(answer), &formatter);
}
Event::LLMEvent(LLMAnswer::EndAnswer) => {
app.answer.answer = app
.answer
.answer
.strip_prefix("🤖: ")
.unwrap_or_default()
.to_string();
app.chat.handle_answer(LLMAnswer::EndAnswer, &formatter);

// TODO: factor this into llm struct or trait
let mut conv: HashMap<String, String> = HashMap::new();
conv.insert("role".to_string(), "user".to_string());
conv.insert("content".to_string(), app.answer.answer.clone());
conv.insert("content".to_string(), app.chat.answer.plain_answer.clone());
app.llm_messages.push(conv);

app.chat.formatted_chat.extend(app.answer.formatted_answer);
app.chat.formatted_chat.extend(Text::raw("\n"));

app.chat.messages.push(format!("🤖: {}", app.answer.answer));

app.answer = Answer::default();

app.terminate_response_signal
.store(false, std::sync::atomic::Ordering::Relaxed);
}
Event::LLMEvent(LLMAnswer::StartAnswer) => {
app.spinner.active = false;
app.chat.formatted_chat.lines.pop();
app.chat.handle_answer(LLMAnswer::StartAnswer, &formatter);
}

Event::Notification(notification) => {
app.notifications.push(notification);
}
Expand Down
Loading

0 comments on commit a3f911d

Please sign in to comment.