From 235780b941a0805b69541f0f4c55c3404091baa8 Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Tue, 31 Dec 2024 12:39:30 +0100 Subject: [PATCH] feat(query): Documents as first class citizens (#504) For simple RAG, just adding the content of a retrieved document might be enough. However, in more complex use cases, you might want to add metadata as well, as is or for conditional formatting. For instance, when dealing with large amounts of chunked code, providing the path goes a long way. If generated metadata is good enough, could be useful as well. With this retrieved Documents are treated as first class citizens, including any metadata as well. Additionally, this also paves the way for multi retrieval (and multi modal). --- swiftide-core/src/document.rs | 168 ++++++++++++++++++ swiftide-core/src/lib.rs | 1 + swiftide-core/src/metadata.rs | 6 +- swiftide-core/src/query.rs | 128 ++++++------- swiftide-integrations/src/lancedb/mod.rs | 2 + swiftide-integrations/src/lancedb/retrieve.rs | 56 ++++-- swiftide-integrations/src/pgvector/mod.rs | 36 +++- swiftide-integrations/src/pgvector/persist.rs | 2 + .../src/pgvector/retrieve.rs | 102 ++++++++++- swiftide-integrations/src/qdrant/retrieve.rs | 55 +++--- swiftide-query/src/answers/simple.rs | 8 +- swiftide-query/src/evaluators/ragas.rs | 36 ++-- .../src/response_transformers/summary.rs | 6 +- swiftide/tests/lancedb.rs | 22 ++- swiftide/tests/pgvector.rs | 52 +++--- 15 files changed, 491 insertions(+), 189 deletions(-) create mode 100644 swiftide-core/src/document.rs diff --git a/swiftide-core/src/document.rs b/swiftide-core/src/document.rs new file mode 100644 index 00000000..fcee2285 --- /dev/null +++ b/swiftide-core/src/document.rs @@ -0,0 +1,168 @@ +//! Documents are the main data structure that is retrieved via the query pipeline +//! +//! Retrievers are expected to eagerly set any configured metadata on the document, with the same +//! field name used during indexing if applicable. +use std::fmt; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::{metadata::Metadata, util::debug_long_utf8}; + +/// A document represents a single unit of retrieved text +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Builder)] +#[builder(setter(into))] +pub struct Document { + #[builder(default)] + metadata: Metadata, + content: String, +} + +impl From for serde_json::Value { + fn from(document: Document) -> Self { + serde_json::json!({ + "metadata": document.metadata, + "content": document.content, + }) + } +} + +impl From<&Document> for serde_json::Value { + fn from(document: &Document) -> Self { + serde_json::json!({ + "metadata": document.metadata, + "content": document.content, + }) + } +} + +impl PartialOrd for Document { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.content.cmp(&other.content)) + } +} + +impl Ord for Document { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.content.cmp(&other.content) + } +} + +impl fmt::Debug for Document { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Document") + .field("metadata", &self.metadata) + .field("content", &debug_long_utf8(&self.content, 100)) + .finish() + } +} + +impl> From for Document { + fn from(value: T) -> Self { + Document::new(value.as_ref(), None) + } +} + +impl Document { + pub fn new(content: impl Into, metadata: Option) -> Self { + Self { + metadata: metadata.unwrap_or_default(), + content: content.into(), + } + } + + pub fn builder() -> DocumentBuilder { + DocumentBuilder::default() + } + + pub fn content(&self) -> &str { + &self.content + } + + pub fn metadata(&self) -> &Metadata { + &self.metadata + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::Metadata; + + #[test] + fn test_document_creation() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &metadata); + } + + #[test] + fn test_document_default_metadata() { + let content = "Test content"; + let document = Document::new(content, None); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &Metadata::default()); + } + + #[test] + fn test_document_from_str() { + let content = "Test content"; + let document: Document = content.into(); + + assert_eq!(document.content(), content); + assert_eq!(document.metadata(), &Metadata::default()); + } + + #[test] + fn test_document_partial_ord() { + let doc1 = Document::new("A", None); + let doc2 = Document::new("B", None); + + assert!(doc1 < doc2); + } + + #[test] + fn test_document_ord() { + let doc1 = Document::new("A", None); + let doc2 = Document::new("B", None); + + assert!(doc1.cmp(&doc2) == std::cmp::Ordering::Less); + } + + #[test] + fn test_document_debug() { + let content = "Test content"; + let document = Document::new(content, None); + let debug_str = format!("{document:?}"); + + assert!(debug_str.contains("Document")); + assert!(debug_str.contains("metadata")); + assert!(debug_str.contains("content")); + } + + #[test] + fn test_document_to_json() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + let json_value: serde_json::Value = document.into(); + + assert_eq!(json_value["content"], content); + assert_eq!(json_value["metadata"], serde_json::json!(metadata)); + } + + #[test] + fn test_document_ref_to_json() { + let content = "Test content"; + let metadata = Metadata::from([("some", "metadata")]); + let document = Document::new(content, Some(metadata.clone())); + let json_value: serde_json::Value = (&document).into(); + + assert_eq!(json_value["content"], content); + assert_eq!(json_value["metadata"], serde_json::json!(metadata)); + } +} diff --git a/swiftide-core/src/lib.rs b/swiftide-core/src/lib.rs index 07da7194..9d4536ee 100644 --- a/swiftide-core/src/lib.rs +++ b/swiftide-core/src/lib.rs @@ -12,6 +12,7 @@ pub mod query_traits; mod search_strategies; pub mod type_aliases; +pub mod document; pub mod prompt; pub use type_aliases::*; diff --git a/swiftide-core/src/metadata.rs b/swiftide-core/src/metadata.rs index a9de7755..b13d64fb 100644 --- a/swiftide-core/src/metadata.rs +++ b/swiftide-core/src/metadata.rs @@ -9,7 +9,7 @@ use serde::Deserializer; use crate::util::debug_long_utf8; -#[derive(Clone, Default, PartialEq)] +#[derive(Clone, Default, PartialEq, Eq)] pub struct Metadata { inner: BTreeMap, } @@ -53,6 +53,10 @@ impl Metadata { pub fn into_values(self) -> IntoValues { self.inner.into_values() } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } } impl Extend<(K, V)> for Metadata diff --git a/swiftide-core/src/query.rs b/swiftide-core/src/query.rs index 5e2e0b9a..2e7244ed 100644 --- a/swiftide-core/src/query.rs +++ b/swiftide-core/src/query.rs @@ -7,9 +7,7 @@ //! `states::Answered`: The query has been answered use derive_builder::Builder; -use crate::{util::debug_long_utf8, Embedding, SparseEmbedding}; - -type Document = String; +use crate::{document::Document, util::debug_long_utf8, Embedding, SparseEmbedding}; /// A query is the main object going through a query pipeline /// @@ -24,6 +22,7 @@ pub struct Query { original: String, #[builder(default = "self.original.clone().unwrap_or_default()")] current: String, + #[builder(default = STATE::default())] state: STATE, #[builder(default)] transformation_history: Vec, @@ -34,6 +33,12 @@ pub struct Query { #[builder(default)] pub sparse_embedding: Option, + + /// Documents the query will operate on + /// + /// A query can retrieve multiple times, accumulating documents + #[builder(default)] + documents: Vec, } impl std::fmt::Debug for Query { @@ -71,6 +76,7 @@ impl Query { transformation_history: self.transformation_history, embedding: self.embedding, sparse_embedding: self.sparse_embedding, + documents: self.documents, } } @@ -78,6 +84,34 @@ impl Query { pub fn history(&self) -> &Vec { &self.transformation_history } + + /// Returns the current documents that will be used as context for answer generation + pub fn documents(&self) -> &[Document] { + &self.documents + } + + /// Returns the current documents as mutable + pub fn documents_mut(&mut self) -> &mut Vec { + &mut self.documents + } +} + +impl Query { + /// Add retrieved documents and transition to `states::Retrieved` + pub fn retrieved_documents(mut self, documents: Vec) -> Query { + self.documents.extend(documents.clone()); + self.transformation_history + .push(TransformationEvent::Retrieved { + before: self.current.clone(), + after: String::new(), + documents, + }); + + let state = states::Retrieved; + + self.current.clear(); + self.transition_to(state) + } } impl Query { @@ -100,21 +134,6 @@ impl Query { self.current = new_query; } - - /// Add retrieved documents and transition to `states::Retrieved` - pub fn retrieved_documents(mut self, documents: Vec) -> Query { - self.transformation_history - .push(TransformationEvent::Retrieved { - before: self.current.clone(), - after: String::new(), - documents: documents.clone(), - }); - - let state = states::Retrieved { documents }; - - self.current.clear(); - self.transition_to(state) - } } impl Query { @@ -135,17 +154,11 @@ impl Query { self.current = new_response; } - /// Returns the last retrieved documents - pub fn documents(&self) -> &[Document] { - &self.state.documents - } - /// Transition the query to `states::Answered` #[must_use] - pub fn answered(self, answer: impl Into) -> Query { - let state = states::Answered { - answer: answer.into(), - }; + pub fn answered(mut self, answer: impl Into) -> Query { + self.current = answer.into(); + let state = states::Answered; self.transition_to(state) } } @@ -157,66 +170,37 @@ impl Query { /// Returns the answer of the query pub fn answer(&self) -> &str { - &self.state.answer + &self.current } } /// Marker trait for query states -pub trait QueryState: Send + Sync {} +pub trait QueryState: Send + Sync + Default {} +/// Marker trait for query states that can still retrieve +pub trait CanRetrieve: QueryState {} /// States of a query pub mod states { - use crate::util::debug_long_utf8; - - use super::Builder; - use super::Document; - use super::QueryState; + use super::{CanRetrieve, QueryState}; - #[derive(Debug, Default, Clone)] + #[derive(Debug, Default, Clone, PartialEq)] /// The query is pending and has not been used pub struct Pending; - #[derive(Default, Clone, Builder, PartialEq)] - #[builder(setter(into))] + #[derive(Debug, Default, Clone, PartialEq)] /// Documents have been retrieved - pub struct Retrieved { - pub(crate) documents: Vec, - } - - impl std::fmt::Debug for Retrieved { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Retrieved") - .field("num_documents", &self.documents.len()) - .field( - "documents", - &self - .documents - .iter() - .map(|d| debug_long_utf8(d, 100)) - .collect::>(), - ) - .finish() - } - } + pub struct Retrieved; - #[derive(Default, Clone, Builder, PartialEq)] - #[builder(setter(into))] + #[derive(Debug, Default, Clone, PartialEq)] /// The query has been answered - pub struct Answered { - pub(crate) answer: String, - } - - impl std::fmt::Debug for Answered { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Answered") - .field("answer", &debug_long_utf8(&self.answer, 100)) - .finish() - } - } + pub struct Answered; impl QueryState for Pending {} impl QueryState for Retrieved {} impl QueryState for Answered {} + + impl CanRetrieve for Pending {} + impl CanRetrieve for Retrieved {} } impl> From for Query { @@ -301,7 +285,7 @@ mod tests { #[test] fn test_query_retrieved_documents() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents: Vec = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents.clone()); assert_eq!(query.documents(), &documents); assert_eq!(query.history().len(), 1); @@ -323,7 +307,7 @@ mod tests { #[test] fn test_query_transformed_response() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents = vec!["doc1".into(), "doc2".into()]; let mut query = query.retrieved_documents(documents.clone()); query.transformed_response("new response"); @@ -342,7 +326,7 @@ mod tests { #[test] fn test_query_answered() { let query = Query::::from("test query"); - let documents = vec!["doc1".to_string(), "doc2".to_string()]; + let documents = vec!["doc1".into(), "doc2".into()]; let query = query.retrieved_documents(documents); let query = query.answered("the answer"); diff --git a/swiftide-integrations/src/lancedb/mod.rs b/swiftide-integrations/src/lancedb/mod.rs index 9eb60657..04130355 100644 --- a/swiftide-integrations/src/lancedb/mod.rs +++ b/swiftide-integrations/src/lancedb/mod.rs @@ -21,6 +21,8 @@ See examples for more information. Implements `Persist` and `Retrieve`. +If you want to store / retrieve metadata in Lance, the columns can be defined with `with_metadata`. + Note: For querying large tables you manually need to create an index. You can get an active connection via `get_connection`. diff --git a/swiftide-integrations/src/lancedb/retrieve.rs b/swiftide-integrations/src/lancedb/retrieve.rs index 1a65fc99..0de3b9b6 100644 --- a/swiftide-integrations/src/lancedb/retrieve.rs +++ b/swiftide-integrations/src/lancedb/retrieve.rs @@ -1,10 +1,13 @@ -use anyhow::{Context as _, Result}; +use anyhow::Result; +use arrow::datatypes::SchemaRef; use arrow_array::StringArray; use async_trait::async_trait; use futures_util::TryStreamExt; use itertools::Itertools; use lancedb::query::{ExecutableQuery, QueryBase}; use swiftide_core::{ + document::Document, + indexing::Metadata, querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, Retrieve, }; @@ -57,24 +60,49 @@ impl Retrieve> for LanceDB { query_builder = query_builder.only_if(filter); } - let result = query_builder + let batches = query_builder .execute() .await? .try_collect::>() .await?; - let Some(recordbatch) = result.first() else { - return Ok(query.retrieved_documents(vec![])); - }; - - let documents: Vec = recordbatch - .column_by_name("chunk") - .and_then(|raw_array| raw_array.as_any().downcast_ref::()) - .context("Could not cast documents to strings")? - .into_iter() - .flatten() - .map_into() - .collect(); + let mut documents = vec![]; + + for batch in batches { + let schema: SchemaRef = batch.schema(); + + for row_idx in 0..batch.num_rows() { + let mut metadata = Metadata::default(); + let mut content = String::new(); + + for (col_idx, field) in schema.fields().iter().enumerate() { + let column = batch.column(col_idx); + + if let Some(array) = column.as_any().downcast_ref::() { + if field.name() == "chunk" { + // Extract the "content" field + content = array.value(row_idx).to_string(); + } else { + // Assume other fields are part of the metadata + let value = array.value(row_idx).to_string(); + metadata.insert(field.name().clone(), value); + } + } else { + // Handle other array types as necessary + // TODO: Can't we just downcast to serde::Value or fail? + } + } + + documents.push(Document::new( + content, + if metadata.is_empty() { + None + } else { + Some(metadata) + }, + )); + } + } Ok(query.retrieved_documents(documents)) } diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index ed96232c..68ac0be7 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -6,6 +6,7 @@ //! - Efficient vector storage and indexing //! - Connection pooling with automatic retries //! - Batch operations for optimized performance +//! - Metadata included in retrieval //! //! The functionality is primarily used through the [`PgVector`] client, which implements //! the [`Persist`] trait for seamless integration with indexing and query pipelines. @@ -192,6 +193,7 @@ mod tests { use futures_util::TryStreamExt; use std::collections::HashSet; use swiftide_core::{ + document::Document, indexing::{self, EmbedMode, EmbeddedField}, querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, Persist, Retrieve, @@ -247,8 +249,13 @@ mod tests { assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content2".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"content1")); + assert!(contents.contains(&"content2")); // Additional test with priority filter let search_strategy = @@ -260,8 +267,13 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"content1".to_string())); - assert!(result.documents().contains(&"content3".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"content1")); + assert!(contents.contains(&"content3")); } #[test_log::test(tokio::test)] @@ -317,8 +329,13 @@ mod tests { // Verify that similar vectors are retrieved first assert_eq!(result.documents().len(), 2); - assert!(result.documents().contains(&"base_content".to_string())); - assert!(result.documents().contains(&"similar_content".to_string())); + let contents = result + .documents() + .iter() + .map(Document::content) + .collect::>(); + assert!(contents.contains(&"base_content")); + assert!(contents.contains(&"similar_content")); } #[test_case( @@ -443,7 +460,12 @@ mod tests { if test_case.expected_in_results { assert!( - result.documents().contains(&test_case.chunk.to_string()), + result + .documents() + .iter() + .map(Document::content) + .collect::>() + .contains(&test_case.chunk), "Document should be found in results for field {field}", ); } diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs index 6b9973ae..ab634a83 100644 --- a/swiftide-integrations/src/pgvector/persist.rs +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -5,6 +5,8 @@ //! - Single-node storage operations //! - Optimized batch storage with configurable batch sizes //! +//! NOTE: Persisting and retrieving metadata is not supported at the moment. +//! //! The implementation ensures thread-safe concurrent access and handles //! connection management automatically. use crate::pgvector::PgVector; diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs index ef55b68d..a67349d7 100644 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -1,9 +1,11 @@ -use crate::pgvector::{PgVector, PgVectorBuilder}; +use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use pgvector::Vector; -use sqlx::{prelude::FromRow, types::Uuid}; +use sqlx::{prelude::FromRow, types::Uuid, Column, Row}; use swiftide_core::{ + document::Document, + indexing::Metadata, querying::{ search_strategies::{CustomStrategy, SimilaritySingleEmbedding}, states, Query, @@ -12,10 +14,46 @@ use swiftide_core::{ }; #[allow(dead_code)] -#[derive(Debug, Clone, FromRow)] +#[derive(Debug, Clone)] struct VectorSearchResult { id: Uuid, chunk: String, + metadata: Metadata, +} + +impl From for Document { + fn from(val: VectorSearchResult) -> Self { + Document::new(val.chunk, Some(val.metadata)) + } +} + +impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult { + fn from_row(row: &sqlx::postgres::PgRow) -> Result { + let mut metadata = Metadata::default(); + + // Metadata fields are stored each as prefixed meta_ fields. Perhaps we should add a single + // metadata field instead of multiple fields. + for column in row.columns() { + if column.name().starts_with("meta_") { + row.try_get::(column.name())? + .as_object() + .and_then(|object| { + object.keys().collect::>().first().map(|key| { + metadata.insert( + key.to_owned(), + object.get(key.as_str()).expect("infallible").clone(), + ); + }) + }); + } + } + + Ok(VectorSearchResult { + id: row.try_get("id")?, + chunk: row.try_get("chunk")?, + metadata, + }) + } } #[allow(clippy::redundant_closure_for_method_calls)] @@ -40,6 +78,12 @@ impl Retrieve> for PgVector { let default_columns: Vec<_> = PgVectorBuilder::default_fields() .iter() .map(|f| f.field_name().to_string()) + .chain( + self.fields + .iter() + .filter(|f| matches!(f, FieldConfig::Metadata(_))) + .map(|f| f.field_name().to_string()), + ) .collect(); // Start building the SQL query @@ -89,7 +133,7 @@ impl Retrieve> for PgVector { .fetch_all(pool) .await?; - let docs = data.into_iter().map(|r| r.chunk).collect(); + let docs = data.into_iter().map(Into::into).collect(); Ok(query_state.retrieved_documents(docs)) } @@ -132,7 +176,7 @@ impl Retrieve>> for P .map_err(|e| anyhow!("Failed to execute search query: {}", e))?; // Transform results into documents - let documents = results.into_iter().map(|r| r.chunk).collect(); + let documents = results.into_iter().map(Into::into).collect(); // Update query state with retrieved documents Ok(query.retrieved_documents(documents)) @@ -212,4 +256,52 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 0); } + + #[test_log::test(tokio::test)] + async fn test_retrieve_docs_with_metadata() { + let test_context = TestContext::setup_with_cfg( + vec!["other", "text"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + let nodes = vec![indexing::Node::new("test_query1") + .with_metadata([ + ("other", serde_json::Value::from(10)), + ("text", serde_json::Value::from("some text")), + ]) + .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) + .to_owned()]; + + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = SimilaritySingleEmbedding::<()>::default(); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 1); + + let doc = result.documents().first().unwrap(); + assert_eq!( + doc.metadata().get("other"), + Some(&serde_json::Value::from(10)) + ); + assert_eq!( + doc.metadata().get("text"), + Some(&serde_json::Value::from("some text")) + ); + } } diff --git a/swiftide-integrations/src/qdrant/retrieve.rs b/swiftide-integrations/src/qdrant/retrieve.rs index dc193868..890ee0f8 100644 --- a/swiftide-integrations/src/qdrant/retrieve.rs +++ b/swiftide-integrations/src/qdrant/retrieve.rs @@ -1,6 +1,7 @@ -use qdrant_client::qdrant::{self, PrefetchQueryBuilder, SearchPointsBuilder}; +use qdrant_client::qdrant::{self, PrefetchQueryBuilder, ScoredPoint, SearchPointsBuilder}; use swiftide_core::{ - indexing::EmbeddedField, + document::Document, + indexing::{EmbeddedField, Metadata}, prelude::{Result, *}, querying::{ search_strategies::{HybridSearch, SimilaritySingleEmbedding}, @@ -53,13 +54,7 @@ impl Retrieve> for Qdrant { let documents = result .into_iter() - .map(|scored_point| { - Ok(scored_point - .payload - .get("content") - .context("Expected document in qdrant payload")? - .to_string()) - }) + .map(scored_point_into_document) .collect::>>()?; Ok(query.retrieved_documents(documents)) @@ -133,22 +128,30 @@ impl Retrieve for Qdrant { let documents = result .into_iter() - .map(|scored_point| { - let value = scored_point - .payload - .get("content") - .context("Expected document in qdrant payload")?; - - Ok(value - .as_str() - .map_or_else(|| value.to_string(), ToString::to_string)) - }) + .map(scored_point_into_document) .collect::>>()?; Ok(query.retrieved_documents(documents)) } } +fn scored_point_into_document(scored_point: ScoredPoint) -> Result { + let content = scored_point + .payload + .get("content") + .context("Expected document in qdrant payload")? + .to_string(); + + let metadata: Metadata = scored_point + .payload + .into_iter() + .filter(|(k, _)| *k != "content") + .collect::>() + .into(); + + Ok(Document::new(content, Some(metadata))) +} + #[cfg(test)] mod tests { use itertools::Itertools as _; @@ -218,7 +221,12 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 3); assert_eq!( - result.documents().iter().sorted().collect_vec(), + result + .documents() + .iter() + .sorted() + .map(Document::content) + .collect_vec(), // FIXME: The extra quotes should be removed by serde (via qdrant::Value), but they are // not ["\"test_query1\"", "\"test_query2\"", "\"test_query3\""] @@ -236,7 +244,12 @@ mod tests { .unwrap(); assert_eq!(result.documents().len(), 2); assert_eq!( - result.documents().iter().sorted().collect_vec(), + result + .documents() + .iter() + .sorted() + .map(Document::content) + .collect_vec(), ["\"test_query1\"", "\"test_query2\""] .into_iter() .sorted() diff --git a/swiftide-query/src/answers/simple.rs b/swiftide-query/src/answers/simple.rs index 26709a4b..67bbe767 100644 --- a/swiftide-query/src/answers/simple.rs +++ b/swiftide-query/src/answers/simple.rs @@ -7,6 +7,7 @@ //! as context instead. use std::sync::Arc; use swiftide_core::{ + document::Document, indexing::SimplePrompt, prelude::*, prompt::PromptTemplate, @@ -77,7 +78,12 @@ impl Answer for Simple { #[tracing::instrument(skip_all)] async fn answer(&self, query: Query) -> Result> { let context = if query.current().is_empty() { - &query.documents().join("\n---\n") + &query + .documents() + .iter() + .map(Document::content) + .collect::>() + .join("\n---\n") } else { query.current() }; diff --git a/swiftide-query/src/evaluators/ragas.rs b/swiftide-query/src/evaluators/ragas.rs index b5c5dfa1..3e2787a6 100644 --- a/swiftide-query/src/evaluators/ragas.rs +++ b/swiftide-query/src/evaluators/ragas.rs @@ -120,7 +120,11 @@ impl EvaluationDataSet { .get_mut(question) .ok_or_else(|| anyhow::anyhow!("Question not found"))?; - data.contexts = query.documents().to_vec(); + data.contexts = query + .documents() + .iter() + .map(|d| d.content().to_string()) + .collect::>(); Ok(()) } @@ -236,7 +240,7 @@ impl FromStr for EvaluationDataSet { mod tests { use super::*; use std::sync::Arc; - use swiftide_core::querying::{states, Query, QueryEvaluation}; + use swiftide_core::querying::{Query, QueryEvaluation}; use tokio::sync::RwLock; #[tokio::test] @@ -299,12 +303,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::RetrievedBuilder::default() - .documents(vec!["Rust is a language".to_string()]) - .build() - .unwrap(), - ) + .documents(vec!["Rust is a language".into()]) .build() .unwrap(); let evaluation = QueryEvaluation::RetrieveDocuments(query.clone()); @@ -325,12 +324,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::AnsweredBuilder::default() - .answer("A systems programming language") - .build() - .unwrap(), - ) + .current("A systems programming language") .build() .unwrap(); let evaluation = QueryEvaluation::AnswerQuery(query.clone()); @@ -372,12 +366,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::RetrievedBuilder::default() - .documents(vec!["Rust is a language".to_string()]) - .build() - .unwrap(), - ) + .documents(vec!["Rust is a language".into()]) .build() .unwrap(); dataset @@ -394,12 +383,7 @@ mod tests { let query = Query::builder() .original("What is Rust?") - .state( - states::AnsweredBuilder::default() - .answer("A systems programming language") - .build() - .unwrap(), - ) + .current("A systems programming language") .build() .unwrap(); dataset diff --git a/swiftide-query/src/response_transformers/summary.rs b/swiftide-query/src/response_transformers/summary.rs index 7e5bdc94..621c6bc3 100644 --- a/swiftide-query/src/response_transformers/summary.rs +++ b/swiftide-query/src/response_transformers/summary.rs @@ -60,7 +60,7 @@ fn default_prompt() -> PromptTemplate { {% for document in documents -%} --- - {{ document }} + {{ document.content }} --- {% endfor -%} " @@ -91,7 +91,9 @@ impl TransformResponse for Summary { #[cfg(test)] mod test { + use swiftide_core::document::Document; + use super::*; - assert_default_prompt_snapshot!("documents" => vec!["First document", "Second Document"]); + assert_default_prompt_snapshot!("documents" => vec![Document::from("First document"), Document::from("Second Document")]); } diff --git a/swiftide/tests/lancedb.rs b/swiftide/tests/lancedb.rs index 5202c5d0..f2f1b583 100644 --- a/swiftide/tests/lancedb.rs +++ b/swiftide/tests/lancedb.rs @@ -3,7 +3,7 @@ use swiftide::indexing::{ transformers::{metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode}, EmbeddedField, }; -use swiftide::query::{self, states, Query, TransformationEvent}; +use swiftide::query::{self, states, Query}; use swiftide_indexing::{loaders, transformers, Pipeline}; use swiftide_integrations::{fastembed::FastEmbed, lancedb::LanceDB}; use swiftide_query::{answers, query_transformers, response_transformers}; @@ -33,6 +33,7 @@ async fn test_lancedb() { .with_vector(EmbeddedField::Combined) .with_metadata(METADATA_QA_CODE_NAME) .with_metadata("filter") + .with_metadata("path") .table_name("swiftide_test") .build() .unwrap(); @@ -41,8 +42,10 @@ async fn test_lancedb() { .then_chunk(ChunkCode::try_for_language("rust").unwrap()) .then(MetadataQACode::new(openai_client.clone())) .then(|mut node: indexing::Node| { + // Add path to metadata, by default, storage will store all metadata fields node.metadata - .insert("filter".to_string(), "true".to_string()); + .insert("path", node.path.display().to_string()); + node.metadata.insert("filter", "true"); Ok(node) }) .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) @@ -75,17 +78,12 @@ async fn test_lancedb() { result.answer(), "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; + + let retrieved_document = result.documents().first().unwrap(); + assert_eq!(retrieved_document.content(), code); assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" + retrieved_document.metadata().get("path").unwrap(), + codefile.to_str().unwrap() ); } diff --git a/swiftide/tests/pgvector.rs b/swiftide/tests/pgvector.rs index d9cf6b35..391a151d 100644 --- a/swiftide/tests/pgvector.rs +++ b/swiftide/tests/pgvector.rs @@ -2,6 +2,8 @@ //! The tests validate the functionality of the pipeline, ensuring that data is correctly indexed //! and processed from temporary files, database configurations, and simulated environments. +use swiftide_core::document::Document; +use swiftide_integrations::treesitter::metadata_qa_code; use temp_dir::TempDir; use anyhow::{anyhow, Result}; @@ -18,10 +20,7 @@ use swiftide::{ self, pgvector::{FieldConfig, PgVector, PgVectorBuilder, VectorConfig}, }, - query::{ - self, answers, query_transformers, response_transformers, states, Query, - TransformationEvent, - }, + query::{self, answers, query_transformers, response_transformers, states, Query}, }; use swiftide_test_utils::{mock_chat_completions, openai_client}; use wiremock::MockServer; @@ -218,19 +217,21 @@ async fn test_pgvector_retrieve() { result.answer(), "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; - assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" - ); + let first_document = result.documents().first().unwrap(); + + let expected = Document::builder() + .content("fn main() { println!(\"Hello, World!\"); }") + .metadata([ + ( + metadata_qa_code::NAME, + "\n\nHello there, how may I assist you today?", + ), + ("filter", "true"), + ]) + .build() + .unwrap(); + assert_eq!(first_document, &expected); } /// Tests the dynamic vector similarity search functionality using PostgreSQL. @@ -393,17 +394,12 @@ async fn test_pgvector_retrieve_dynamic_search() { "\n\nHello there, how may I assist you today?" ); - let TransformationEvent::Retrieved { documents, .. } = result - .history() - .iter() - .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) - .unwrap() - else { - panic!("No documents found") - }; + let first_document = result.documents().first().unwrap(); - assert_eq!( - documents.first().unwrap(), - "fn main() { println!(\"Hello, World!\"); }" - ); + // The custom query explicitly skipped metadata + let expected = Document::builder() + .content("fn main() { println!(\"Hello, World!\"); }") + .build() + .unwrap(); + assert_eq!(first_document, &expected); }