diff --git a/src/ai/constants.rs b/src/ai/constants.rs index 54325d48..2c39d169 100644 --- a/src/ai/constants.rs +++ b/src/ai/constants.rs @@ -87,18 +87,21 @@ mod test { RelatedDoc { url: "".into(), title: "".into(), + title_parent: None, content: "content1".into(), similarity: 0f64, }, RelatedDoc { url: "".into(), title: "".into(), + title_parent: None, content: "content2".into(), similarity: 0f64, }, RelatedDoc { url: "".into(), title: "".into(), + title_parent: None, content: "content3".into(), similarity: 0f64, }, diff --git a/src/ai/embeddings.rs b/src/ai/embeddings.rs index d72ee5c6..eb49b406 100644 --- a/src/ai/embeddings.rs +++ b/src/ai/embeddings.rs @@ -1,4 +1,5 @@ use async_openai::{config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client}; +use itertools::Itertools; use crate::{ ai::{constants::EMBEDDING_MODEL, error::AIError}, @@ -41,22 +42,25 @@ const MACRO_EMB_DISTANCE: f64 = 0.78; const MACRO_EMB_SEC_MIN_LENGTH: i64 = 50; const MACRO_EMB_DOC_LIMIT: i64 = 5; -const MACRO_DOCS_QUERY: &str = "select -mdn_doc_macro.mdn_url as url, -mdn_doc_macro.title, -mdn_doc_macro.markdown as content, -mdn_doc_macro.embedding <=> $1 as similarity -from mdn_doc_macro -where length(mdn_doc_macro.markdown) >= $4 -and (mdn_doc_macro.embedding <=> $1) < $2 -and mdn_doc_macro.mdn_url not like '/en-US/docs/MDN%' -order by mdn_doc_macro.embedding <=> $1 -limit $3;"; +const MACRO_DOCS_QUERY: &str = "SELECT + doc.mdn_url AS url, + doc.title, + parent.title_short AS title_parent, + doc.markdown AS content, + doc.embedding <=> $1 AS similarity +FROM mdn_doc_macro doc +LEFT JOIN mdn_doc_macro parent ON parent.mdn_url = SUBSTRING(doc.mdn_url, 1, LENGTH(doc.mdn_url) - STRPOS(REVERSE(doc.mdn_url), '/')) +WHERE LENGTH(doc.markdown) >= $4 + AND (doc.embedding <=> $1) < $2 + AND doc.mdn_url NOT LIKE '/en-US/docs/MDN%' +ORDER BY doc.embedding <=> $1 +LIMIT $3;"; #[derive(sqlx::FromRow, Debug)] pub struct RelatedDoc { pub url: String, pub title: String, + pub title_parent: Option, pub content: String, pub similarity: f64, } @@ -74,13 +78,29 @@ pub async fn get_related_macro_docs( let embedding = pgvector::Vector::from(embedding_res.data.into_iter().next().unwrap().embedding); - let docs: Vec = sqlx::query_as(MACRO_DOCS_QUERY) + + let mut docs: Vec = sqlx::query_as(MACRO_DOCS_QUERY) .bind(embedding) .bind(MACRO_EMB_DISTANCE) .bind(MACRO_EMB_DOC_LIMIT) .bind(MACRO_EMB_SEC_MIN_LENGTH) .fetch_all(pool) .await?; + + let duplicate_titles: Vec = docs + .iter() + .map(|x| x.title.to_string()) + .duplicates() + .collect(); + + docs.iter_mut().for_each(|doc| { + if let (true, Some(title_parent)) = + (duplicate_titles.contains(&doc.title), &doc.title_parent) + { + doc.title = format!("{} ({})", doc.title, title_parent); + } + }); + Ok(docs) }