Skip to content

Commit

Permalink
chore(thread): always return save codeSourceId if it contains codeQue…
Browse files Browse the repository at this point in the history
…ry (#3847)

Signed-off-by: Wei Zhang <kweizh@tabbyml.com>
  • Loading branch information
zwpaper authored Feb 15, 2025
1 parent dccc4b3 commit dde2122
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 20 deletions.
18 changes: 16 additions & 2 deletions ee/tabby-db/src/threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,27 @@ impl DbConn {
pub async fn update_thread_message_code_attachments(
&self,
message_id: i64,
code_source_id: &str,
code_attachments: &[AttachmentCode],
) -> Result<()> {
let code_attachments = Json(code_attachments);
query!(
"UPDATE thread_messages SET attachment = JSON_SET(attachment, '$.code', JSON(?)), code_source_id = ?, updated_at = DATETIME('now') WHERE id = ?",
"UPDATE thread_messages SET attachment = JSON_SET(attachment, '$.code', JSON(?)), updated_at = DATETIME('now') WHERE id = ?",
code_attachments,
message_id
)
.execute(&self.pool)
.await?;

Ok(())
}

pub async fn update_thread_message_code_source_id(
&self,
message_id: i64,
code_source_id: &str,
) -> Result<()> {
query!(
"UPDATE thread_messages SET code_source_id = ?, updated_at = DATETIME('now') WHERE id = ?",
code_source_id,
message_id
)
Expand Down
2 changes: 0 additions & 2 deletions ee/tabby-schema/src/schema/thread/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ pub struct ThreadAssistantMessageAttachmentsCodeFileList {

#[derive(GraphQLObject)]
pub struct ThreadAssistantMessageAttachmentsCode {
#[graphql(skip)]
pub code_source_id: String,
pub hits: Vec<MessageCodeSearchHit>,
}

Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl AnswerService {
if !hits.is_empty() {
let hits = hits.into_iter().map(|x| x.into()).collect::<Vec<_>>();
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode(
ThreadAssistantMessageAttachmentsCode { code_source_id: repository.source_id, hits }
ThreadAssistantMessageAttachmentsCode { hits }
));
}
}
Expand Down
1 change: 1 addition & 0 deletions ee/tabby-webserver/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl ServerContext {
db_conn.clone(),
answer.clone(),
Some(auth.clone()),
context.clone(),
));
let page = chat.as_ref().map(|chat| {
Arc::new(page::create(
Expand Down
3 changes: 2 additions & 1 deletion ee/tabby-webserver/src/service/page.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ mod tests {
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let thread = Arc::new(thread::create(db.clone(), None, None));
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let thread = Arc::new(thread::create(db.clone(), None, None, context.clone()));
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db, chat, thread, context);

Expand Down
70 changes: 56 additions & 14 deletions ee/tabby-webserver/src/service/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ use juniper::ID;
use tabby_db::{AttachmentDoc, DbConn, ThreadMessageDAO};
use tabby_schema::{
auth::AuthenticationService,
bail, from_thread_message_attachment_document,
bail,
context::ContextService,
from_thread_message_attachment_document,
policy::AccessPolicy,
thread::{
self, CreateMessageInput, CreateThreadInput, MessageAttachment, MessageAttachmentDoc,
MessageAttachmentInput, ThreadRunItem, ThreadRunOptionsInput, ThreadRunStream,
ThreadService, UpdateMessageInput,
self, CodeQueryInput, CreateMessageInput, CreateThreadInput, MessageAttachment,
MessageAttachmentDoc, MessageAttachmentInput, ThreadRunItem, ThreadRunOptionsInput,
ThreadRunStream, ThreadService, UpdateMessageInput,
},
AsID, AsRowid, DbEnum, Result,
};
Expand All @@ -22,6 +24,7 @@ struct ThreadServiceImpl {
db: DbConn,
auth: Option<Arc<dyn AuthenticationService>>,
answer: Option<Arc<AnswerService>>,
context: Arc<dyn ContextService>,
}

impl ThreadServiceImpl {
Expand Down Expand Up @@ -108,6 +111,25 @@ impl ThreadServiceImpl {
}
output
}

async fn get_source_id(&self, policy: &AccessPolicy, input: &CodeQueryInput) -> Option<String> {
let helper = self.context.read(Some(policy)).await.ok()?.helper();

if let Some(source_id) = &input.source_id {
if helper.can_access_source_id(source_id) {
Some(source_id.clone())
} else {
None
}
} else if let Some(git_url) = &input.git_url {
helper
.allowed_code_repository()
.closest_match(git_url)
.map(|s| s.to_string())
} else {
None
}
}
}

#[async_trait]
Expand Down Expand Up @@ -187,6 +209,14 @@ impl ThreadService for ThreadServiceImpl {
)
.await?;

if let Some(code_query) = &options.code_query {
if let Some(source_id) = self.get_source_id(policy, code_query).await {
self.db
.update_thread_message_code_source_id(assistant_message_id, &source_id)
.await?;
}
}

let s = answer
.answer(policy, &messages, options, attachment_input)
.await?;
Expand Down Expand Up @@ -223,7 +253,6 @@ impl ThreadService for ThreadServiceImpl {
.collect::<Vec<_>>();
db.update_thread_message_code_attachments(
assistant_message_id,
&x.code_source_id,
&code,
).await?;
}
Expand Down Expand Up @@ -356,8 +385,14 @@ pub fn create(
db: DbConn,
answer: Option<Arc<AnswerService>>,
auth: Option<Arc<dyn AuthenticationService>>,
context: Arc<dyn ContextService>,
) -> impl ThreadService {
ThreadServiceImpl { db, answer, auth }
ThreadServiceImpl {
db,
answer,
auth,
context,
}
}

#[cfg(test)]
Expand Down Expand Up @@ -390,7 +425,8 @@ mod tests {
async fn test_create_thread() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db, None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db, None, None, context);

let input = CreateThreadInput {
user_message: CreateMessageInput {
Expand All @@ -406,7 +442,8 @@ mod tests {
async fn test_append_messages() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db, None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db, None, None, context);

let thread_id = service
.create(
Expand Down Expand Up @@ -452,7 +489,8 @@ mod tests {
async fn test_delete_thread_message_pair() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db.clone(), None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db.clone(), None, None, context);

let thread_id = service
.create(
Expand Down Expand Up @@ -541,7 +579,8 @@ mod tests {
async fn test_get_thread() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db, None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db, None, None, context);

let input = CreateThreadInput {
user_message: CreateMessageInput {
Expand All @@ -565,7 +604,8 @@ mod tests {
async fn test_delete_thread() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db.clone(), None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db.clone(), None, None, context);

let input = CreateThreadInput {
user_message: CreateMessageInput {
Expand All @@ -592,7 +632,8 @@ mod tests {
async fn test_set_persisted() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db.clone(), None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db.clone(), None, None, context);

let input = CreateThreadInput {
user_message: CreateMessageInput {
Expand Down Expand Up @@ -647,7 +688,7 @@ mod tests {
serper,
repo,
));
let service = create(db.clone(), Some(answer_service), None);
let service = create(db.clone(), Some(answer_service), None, context);

let input = CreateThreadInput {
user_message: CreateMessageInput {
Expand All @@ -672,7 +713,8 @@ mod tests {
async fn test_list_threads() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let service = create(db, None, None);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let service = create(db, None, None, context);

for i in 0..3 {
let input = CreateThreadInput {
Expand Down

0 comments on commit dde2122

Please sign in to comment.