Skip to content

Commit

Permalink
refactor: update LLM trait to append and clear chat messages
Browse files Browse the repository at this point in the history
  • Loading branch information
pythops committed Jan 30, 2024
1 parent eb26543 commit 5a09626
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 33 deletions.
3 changes: 0 additions & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::history::History;
use crate::prompt::Prompt;
use crate::{chat::Chat, help::Help};
use std;
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;

use crate::notification::Notification;
Expand Down Expand Up @@ -30,7 +29,6 @@ pub struct App<'a> {
pub prompt: Prompt<'a>,
pub chat: Chat<'a>,
pub focused_block: FocusedBlock,
pub llm_messages: Vec<HashMap<String, String>>,
pub history: History<'a>,
pub notifications: Vec<Notification>,
pub spinner: Spinner,
Expand All @@ -49,7 +47,6 @@ impl<'a> App<'a> {
prompt: Prompt::default(),
chat: Chat::new(),
focused_block: FocusedBlock::Prompt,
llm_messages: Vec::new(),
history: History::new(),
notifications: Vec::new(),
spinner: Spinner::default(),
Expand Down
16 changes: 14 additions & 2 deletions src/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct ChatGPT {
openai_api_key: String,
model: String,
url: String,
messages: Vec<HashMap<String, String>>,
}

impl ChatGPT {
Expand All @@ -44,15 +45,26 @@ You need to define one wether in the configuration file or as an environment var
openai_api_key,
model: config.model,
url: config.url,
messages: Vec::new(),
}
}
}

#[async_trait]
impl LLM for ChatGPT {
fn clear(&mut self) {
self.messages = Vec::new();
}

fn append_chat_msg(&mut self, chat: String) {
let mut conv: HashMap<String, String> = HashMap::new();
conv.insert("role".to_string(), "user".to_string());
conv.insert("content".to_string(), chat);
self.messages.push(conv);
}

async fn ask(
&self,
chat_messages: Vec<HashMap<String, String>>,
sender: UnboundedSender<Event>,
terminate_response_signal: Arc<AtomicBool>,
) -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -73,7 +85,7 @@ impl LLM for ChatGPT {
])),
];

messages.extend(chat_messages);
messages.extend(self.messages.clone());

let body: Value = json!({
"model": self.model,
Expand Down
34 changes: 17 additions & 17 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ use crate::{

use crate::llm::LLM;
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use std::collections::HashMap;

use ratatui::text::Line;

use crate::notification::{Notification, NotificationLevel};
use std::sync::Arc;
use tokio::sync::Mutex;

use tokio::sync::mpsc::UnboundedSender;

pub async fn handle_key_events(
key_event: KeyEvent,
app: &mut App<'_>,
llm: Arc<impl LLM + 'static>,
llm: Arc<Mutex<impl LLM + 'static>>,
sender: UnboundedSender<Event>,
) -> AppResult<()> {
match key_event.code {
Expand Down Expand Up @@ -120,7 +120,12 @@ pub async fn handle_key_events(
app.history.text.push(app.chat.plain_chat.clone());

app.chat = Chat::default();
app.llm_messages = Vec::new();

let llm = llm.clone();
{
let mut llm = llm.lock().await;
llm.clear();
}

app.chat.scroll = 0;
}
Expand Down Expand Up @@ -248,13 +253,11 @@ pub async fn handle_key_events(
);
}

let conv = HashMap::from([
("role".into(), "user".into()),
("content".into(), user_input.into()),
]);
app.llm_messages.push(conv);

let llm_messages = app.llm_messages.clone();
let llm = llm.clone();
{
let mut llm = llm.lock().await;
llm.append_chat_msg(user_input.into());
}

app.spinner.active = true;

Expand All @@ -267,14 +270,11 @@ pub async fn handle_key_events(

let sender = sender.clone();

let llm = llm.clone();

tokio::spawn(async move {
let res = llm
.ask(
llm_messages.to_vec(),
sender.clone(),
terminate_response_signal,
)
.await;
let llm = llm.lock().await;
let res = llm.ask(sender.clone(), terminate_response_signal).await;

if let Err(e) = res {
sender
Expand Down
7 changes: 4 additions & 3 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::config::Config;
use crate::event::Event;
use async_trait::async_trait;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;
use tokio::sync::mpsc::UnboundedSender;

Expand All @@ -13,10 +12,12 @@ use std::sync::Arc;
pub trait LLM: Send + Sync {
async fn ask(
&self,
chat_messages: Vec<HashMap<String, String>>,
sender: UnboundedSender<Event>,
terminate_response_signal: Arc<AtomicBool>,
) -> Result<(), Box<dyn std::error::Error>>;

fn append_chat_msg(&mut self, chat: String);
fn clear(&mut self);
}

#[derive(Clone, Debug)]
Expand All @@ -31,7 +32,7 @@ pub enum LLMBackend {
ChatGPT,
}

pub struct LLMModel {}
pub struct LLMModel;

impl LLMModel {
pub async fn init(model: &LLMBackend, config: Arc<Config>) -> impl LLM {
Expand Down
17 changes: 9 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use ratatui::backend::CrosstermBackend;
use ratatui::Terminal;
use std::collections::HashMap;
use std::{env, io};
use tenere::app::{App, AppResult};
use tenere::cli;
use tenere::config::Config;
use tenere::event::{Event, EventHandler};
use tenere::formatter::Formatter;
use tenere::handler::handle_key_events;
use tenere::llm::LLMAnswer;
use tenere::llm::{LLMAnswer, LLM};
use tenere::tui::Tui;

use tenere::llm::LLMModel;

use std::sync::Arc;
use tokio::sync::Mutex;

use clap::crate_version;

Expand All @@ -28,7 +28,9 @@ async fn main() -> AppResult<()> {

let mut app = App::new(config.clone(), &formatter);

let llm = Arc::new(LLMModel::init(&config.model, config.clone()).await);
let llm = Arc::new(Mutex::new(
LLMModel::init(&config.model, config.clone()).await,
));

let backend = CrosstermBackend::new(io::stderr());
let terminal = Terminal::new(backend)?;
Expand All @@ -53,11 +55,10 @@ async fn main() -> AppResult<()> {
Event::LLMEvent(LLMAnswer::EndAnswer) => {
app.chat.handle_answer(LLMAnswer::EndAnswer, &formatter);

// TODO: factor this into llm struct or trait
let mut conv: HashMap<String, String> = HashMap::new();
conv.insert("role".to_string(), "user".to_string());
conv.insert("content".to_string(), app.chat.answer.plain_answer.clone());
app.llm_messages.push(conv);
{
let mut llm = llm.lock().await;
llm.append_chat_msg(app.chat.answer.plain_answer.clone());
}

app.terminate_response_signal
.store(false, std::sync::atomic::Ordering::Relaxed);
Expand Down

0 comments on commit 5a09626

Please sign in to comment.