From 3751f49201c71398144a8913a4443f452534def2 Mon Sep 17 00:00:00 2001 From: RK Date: Wed, 4 Dec 2024 14:46:55 +0530 Subject: [PATCH] feat(query): Add support for single embedding retrieval with PGVector (#406) --- Cargo.lock | 1 + Cargo.toml | 6 +- examples/index_md_into_pgvector.rs | 65 +++- .../src/pgvector/fixtures.rs | 1 + swiftide-integrations/src/pgvector/mod.rs | 289 +++++++++--------- .../src/pgvector/pgv_table_types.rs | 16 +- .../src/pgvector/retrieve.rs | 184 +++++++++++ swiftide/Cargo.toml | 1 + swiftide/tests/pgvector.rs | 232 ++++++++++++++ 9 files changed, 647 insertions(+), 148 deletions(-) create mode 100644 swiftide-integrations/src/pgvector/retrieve.rs create mode 100644 swiftide/tests/pgvector.rs diff --git a/Cargo.lock b/Cargo.lock index 00c70c52..46ad881c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8398,6 +8398,7 @@ dependencies = [ "qdrant-client", "serde", "serde_json", + "sqlx", "swiftide-core", "swiftide-indexing", "swiftide-integrations", diff --git a/Cargo.toml b/Cargo.toml index c5ceec9e..485613c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,10 +53,7 @@ arrow-array = { version = "52.2", default-features = false } arrow = { version = "52.2", default-features = false } parquet = { version = "52.2", default-features = false, features = ["async"] } redb = { version = "2.2" } -sqlx = { version = "0.8.2", features = [ - "postgres", - "uuid", -], default-features = false } +sqlx = { version = "0.8.2", features = ["postgres", "uuid"] } aws-config = "1.5" pgvector = { version = "0.4.0", features = ["sqlx"], default-features = false } aws-credential-types = "1.2" @@ -87,6 +84,7 @@ tree-sitter-ruby = "0.23" tree-sitter-rust = "0.23" tree-sitter-typescript = "0.23" + # Testing test-log = "0.2.16" testcontainers = { version = "0.23.0", features = ["http_wait"] } diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs index a8dfc5c1..498d42b3 100644 --- a/examples/index_md_into_pgvector.rs +++ b/examples/index_md_into_pgvector.rs @@ -11,9 +11,46 @@ use swiftide::{ }, EmbeddedField, }, - integrations::{self, pgvector::PgVector}, + integrations::{self, fastembed::FastEmbed, pgvector::PgVector}, + query::{self, answers, query_transformers, response_transformers}, + traits::SimplePrompt, }; +async fn ask_query( + llm_client: impl SimplePrompt + Clone + 'static, + embed: FastEmbed, + vector_store: PgVector, + questions: Vec, +) -> Result, Box> { + // By default the search strategy is SimilaritySingleEmbedding + // which takes the latest query, embeds it, and does a similarity search + // + // Pgvector will return an error if multiple embeddings are set + // + // The pipeline generates subquestions to increase semantic coverage, embeds these in a single + // embedding, retrieves the default top_k documents, summarizes them and uses that as context + // for the final answer. + let pipeline = query::Pipeline::default() + .then_transform_query(query_transformers::GenerateSubquestions::from_client( + llm_client.clone(), + )) + .then_transform_query(query_transformers::Embed::from_client(embed)) + .then_retrieve(vector_store.clone()) + .then_transform_response(response_transformers::Summary::from_client( + llm_client.clone(), + )) + .then_answer(answers::Simple::from_client(llm_client.clone())); + + let results: Vec = pipeline + .query_all(questions) + .await? + .iter() + .map(|result| result.answer().to_string()) + .collect(); + + Ok(results) +} + #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); @@ -62,6 +99,7 @@ async fn main() -> Result<(), Box> { } tracing::info!("Starting indexing pipeline"); + indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"])) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(llm_client.clone())) @@ -70,6 +108,29 @@ async fn main() -> Result<(), Box> { .run() .await?; - tracing::info!("PgVector Indexing test completed successfully"); + tracing::info!("PgVector Indexing completed successfully"); + + let questions: Vec = vec![ + "What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.".into(), + "How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.".into(), + ]; + + ask_query( + llm_client.clone(), + fastembed.clone(), + pgv_storage.clone(), + questions, + ) + .await? + .iter() + .enumerate() + .for_each(|(i, result)| { + tracing::info!("*** Answer Q{} ***", i + 1); + tracing::info!("{}", result); + tracing::info!("===X==="); + }); + + tracing::info!("PgVector Indexing & retrieval test completed successfully"); + Ok(()) } diff --git a/swiftide-integrations/src/pgvector/fixtures.rs b/swiftide-integrations/src/pgvector/fixtures.rs index 22c31e4b..65331d4e 100644 --- a/swiftide-integrations/src/pgvector/fixtures.rs +++ b/swiftide-integrations/src/pgvector/fixtures.rs @@ -77,6 +77,7 @@ pub(crate) struct PgVectorTestData<'a> { pub metadata: Option, /// Vector embeddings with their corresponding fields pub vectors: Vec<(indexing::EmbeddedField, Vec)>, + pub expected_in_results: bool, } impl PgVectorTestData<'_> { diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index 51ffcf8c..e1882934 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -27,6 +27,7 @@ mod fixtures; mod persist; mod pgv_table_types; +mod retrieve; use anyhow::Result; use derive_builder::Builder; use sqlx::PgPool; @@ -188,10 +189,134 @@ mod tests { use std::collections::HashSet; use swiftide_core::{ indexing::{self, EmbedMode, EmbeddedField}, - Persist, + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Persist, Retrieve, }; use test_case::test_case; + #[test_log::test(tokio::test)] + async fn test_metadata_filter_with_vector_search() { + let test_context = TestContext::setup_with_cfg( + vec!["category", "priority"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + // Create nodes with different metadata and vectors + let nodes = vec![ + indexing::Node::new("content1") + .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "1")]), + indexing::Node::new("content2") + .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])]) + .with_metadata(vec![("category", "A"), ("priority", "2")]), + indexing::Node::new("content3") + .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])]) + .with_metadata(vec![("category", "B"), ("priority", "1")]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Test combined metadata and vector search + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content2".to_string())); + + // Additional test with priority filter + let search_strategy = + SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string()); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"content1".to_string())); + assert!(result.documents().contains(&"content3".to_string())); + } + + #[test_log::test(tokio::test)] + async fn test_vector_similarity_search_accuracy() { + let test_context = TestContext::setup_with_cfg( + vec!["category", "priority"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + // Create nodes with known vector relationships + let base_vector = vec![1.0; 384]; + let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::>(); + let dissimilar_vector = vec![-1.0; 384]; + + let nodes = vec![ + indexing::Node::new("base_content") + .with_vectors([(EmbeddedField::Combined, base_vector)]) + .with_metadata(vec![("category", "A"), ("priority", "1")]), + indexing::Node::new("similar_content") + .with_vectors([(EmbeddedField::Combined, similar_vector)]) + .with_metadata(vec![("category", "A"), ("priority", "2")]), + indexing::Node::new("dissimilar_content") + .with_vectors([(EmbeddedField::Combined, dissimilar_vector)]) + .with_metadata(vec![("category", "B"), ("priority", "1")]), + ] + .into_iter() + .map(|node| node.to_owned()) + .collect(); + + // Store all nodes + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + // Search with base vector + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(2); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .unwrap(); + + // Verify that similar vectors are retrieved first + assert_eq!(result.documents().len(), 2); + assert!(result.documents().contains(&"base_content".to_string())); + assert!(result.documents().contains(&"similar_content".to_string())); + } + #[test_case( // SingleWithMetadata - No Metadata vec![ @@ -200,12 +325,14 @@ mod tests { chunk: "single_no_meta_1", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], + expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "single_no_meta_2", metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], + expected_in_results: true, } ], HashSet::from([EmbeddedField::Combined]) @@ -221,6 +348,7 @@ mod tests { ("priority", "high") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], + expected_in_results: true, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, @@ -230,144 +358,11 @@ mod tests { ("priority", "low") ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], + expected_in_results: true, } ], HashSet::from([EmbeddedField::Combined]) ; "SingleWithMetadata mode with metadata")] - #[test_case( - // PerField - No Metadata - vec![ - PgVectorTestData { - embed_mode: EmbedMode::PerField, - chunk: "per_field_no_meta_1", - metadata: None, - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), - ], - }, - PgVectorTestData { - embed_mode: EmbedMode::PerField, - chunk: "per_field_no_meta_2", - metadata: None, - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), - ], - } - ], - HashSet::from([ - EmbeddedField::Chunk, - EmbeddedField::Metadata("category".into()), - EmbeddedField::Metadata("priority".into()), - ]) - ; "PerField mode without metadata")] - #[test_case( - // PerField - With Metadata - vec![ - PgVectorTestData { - embed_mode: EmbedMode::PerField, - chunk: "single_with_meta_1", - metadata: Some(vec![ - ("category", "A"), - ("priority", "high") - ].into()), - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), - ], - }, - PgVectorTestData { - embed_mode: EmbedMode::PerField, - chunk: "single_with_meta_2", - metadata: Some(vec![ - ("category", "B"), - ("priority", "low") - ].into()), - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), - ], - } - ], - HashSet::from([ - EmbeddedField::Chunk, - EmbeddedField::Metadata("category".into()), - EmbeddedField::Metadata("priority".into()), - ]) - ; "PerField mode with metadata")] - #[test_case( - // Both - No Metadata - vec![ - PgVectorTestData { - embed_mode: EmbedMode::Both, - chunk: "both_no_meta_1", - metadata: None, - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.0), - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1) - ], - }, - PgVectorTestData { - embed_mode: EmbedMode::Both, - chunk: "both_no_meta_2", - metadata: None, - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.2), - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3) - ], - } - ], - HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk]) - ; "Both mode without metadata")] - #[test_case( - // Both - With Metadata - vec![ - PgVectorTestData { - embed_mode: EmbedMode::Both, - chunk: "both_with_meta_1", - metadata: Some(vec![ - ("category", "P"), - ("priority", "urgent"), - ("tag", "test1") - ].into()), - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.4), - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.5), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 3.6), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) - ], - }, - PgVectorTestData { - embed_mode: EmbedMode::Both, - chunk: "both_with_meta_2", - metadata: Some(vec![ - ("category", "Q"), - ("priority", "low"), - ("tag", "test2") - ].into()), - vectors: vec![ - PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.9), - PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 4.0), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 4.1), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), - PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) - ], - } - ], - HashSet::from([ - EmbeddedField::Combined, - EmbeddedField::Chunk, - EmbeddedField::Metadata("category".into()), - EmbeddedField::Metadata("priority".into()), - EmbeddedField::Metadata("tag".into()), - ]) - ; "Both mode with metadata")] #[test_log::test(tokio::test)] async fn test_persist_nodes( test_cases: Vec>, @@ -405,7 +400,7 @@ mod tests { "All nodes should be stored" ); - // Verify storage for each test case + // Verify storage and retrieval for each test case for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { // 1. Verify basic node properties assert_eq!( @@ -427,6 +422,28 @@ mod tests { test_case.vectors.len(), "Vector count should match" ); + + // 3. Test vector similarity search + for (field, vector) in &test_case.vectors { + let mut query = Query::::new("test_query"); + query.embedding = Some(vector.clone()); + + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(nodes.len() as u64); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .expect("Retrieval should succeed"); + + if test_case.expected_in_results { + assert!( + result.documents().contains(&test_case.chunk.to_string()), + "Document should be found in results for field {field}", + ); + } + } } } } diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs index f49f9211..c681d6c6 100644 --- a/swiftide-integrations/src/pgvector/pgv_table_types.rs +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -13,6 +13,7 @@ use regex::Regex; use sqlx::postgres::PgArguments; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; +use std::collections::BTreeMap; use swiftide_core::indexing::{EmbeddedField, Node}; use tokio::time::sleep; @@ -23,7 +24,7 @@ use tokio::time::sleep; #[derive(Clone, Debug)] pub struct VectorConfig { embedded_field: EmbeddedField, - field: String, + pub(crate) field: String, } impl VectorConfig { @@ -75,7 +76,7 @@ impl> From for MetadataConfig { /// Represents different field types that can be configured in the table schema, /// including vector embeddings, metadata, and system fields. #[derive(Clone, Debug)] -pub enum FieldConfig { +pub(crate) enum FieldConfig { /// `Vector` - Vector embedding field configuration Vector(VectorConfig), /// `Metadata` - Metadata field configuration @@ -258,8 +259,6 @@ impl PgVector { .get() .ok_or_else(|| anyhow!("SQL bulk insert statement not set"))?; - tracing::info!("Sql statement :: {:#?}", sql); - let query = self.bind_bulk_data_to_query(sqlx::query(sql), &bulk_data)?; query @@ -293,7 +292,10 @@ impl PgVector { .get(&config.original_field) .ok_or_else(|| anyhow!("Missing metadata field"))?; - bulk_data.metadata_fields[idx].push(value.clone()); + let mut metadata_map = BTreeMap::new(); + metadata_map.insert(config.original_field.clone(), value.clone()); + + bulk_data.metadata_fields[idx].push(serde_json::to_value(metadata_map)?); } FieldConfig::Vector(config) => { let idx = bulk_data @@ -431,7 +433,9 @@ impl PgVector { match vector_fields.as_slice() { [field] => Ok(field.field_name().to_string()), [] => Err(anyhow!("No vector field configured in schema")), - _ => Err(anyhow!("Multiple vector fields configured in schema")), + _ => Err(anyhow!( + "Search strategy for multiple vector fields in the schema is not yet implemented" + )), } } } diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs new file mode 100644 index 00000000..f2ac3d60 --- /dev/null +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -0,0 +1,184 @@ +use crate::pgvector::{PgVector, PgVectorBuilder}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use pgvector::Vector; +use sqlx::{prelude::FromRow, types::Uuid}; +use swiftide_core::{ + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Retrieve, +}; + +#[allow(dead_code)] +#[derive(Debug, Clone, FromRow)] +struct VectorSearchResult { + id: Uuid, + chunk: String, +} + +#[allow(clippy::redundant_closure_for_method_calls)] +#[async_trait] +impl Retrieve> for PgVector { + #[tracing::instrument] + async fn retrieve( + &self, + search_strategy: &SimilaritySingleEmbedding, + query_state: Query, + ) -> Result> { + let embedding = if let Some(embedding) = query_state.embedding.as_ref() { + Vector::from(embedding.clone()) + } else { + return Err(anyhow::Error::msg("Missing embedding in query state")); + }; + + let vector_column_name = self.get_vector_column_name()?; + + let pool = self.pool_get_or_initialize().await?; + + let default_columns: Vec<_> = PgVectorBuilder::default_fields() + .iter() + .map(|f| f.field_name().to_string()) + .collect(); + + // Start building the SQL query + let mut sql = format!( + "SELECT {} FROM {}", + default_columns.join(", "), + self.table_name + ); + + if let Some(filter) = search_strategy.filter() { + let filter_parts: Vec<&str> = filter.split('=').collect(); + if filter_parts.len() == 2 { + let key = filter_parts[0].trim(); + let value = filter_parts[1].trim().trim_matches('"'); + tracing::debug!( + "Filter being applied: key = {:#?}, value = {:#?}", + key, + value + ); + + let sql_filter = format!( + " WHERE meta_{}->>'{}' = '{}'", + PgVector::normalize_field_name(key), + key, + value + ); + sql.push_str(&sql_filter); + } else { + return Err(anyhow!("Invalid filter format")); + } + } + + // Add the ORDER BY clause for vector similarity search + sql.push_str(&format!( + " ORDER BY {} <=> $1 LIMIT $2", + &vector_column_name + )); + + tracing::debug!("Running retrieve with SQL: {}", sql); + + let top_k = i32::try_from(search_strategy.top_k()) + .map_err(|_| anyhow!("Failed to convert top_k to i32"))?; + + let data: Vec = sqlx::query_as(&sql) + .bind(embedding) + .bind(top_k) + .fetch_all(pool) + .await?; + + let docs = data.into_iter().map(|r| r.chunk).collect(); + + Ok(query_state.retrieved_documents(docs)) + } +} + +#[async_trait] +impl Retrieve for PgVector { + async fn retrieve( + &self, + search_strategy: &SimilaritySingleEmbedding, + query: Query, + ) -> Result> { + Retrieve::>::retrieve( + self, + &search_strategy.into_concrete_filter::(), + query, + ) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::fixtures::TestContext; + use futures_util::TryStreamExt; + use std::collections::HashSet; + use swiftide_core::{indexing, indexing::EmbeddedField, Persist}; + use swiftide_core::{ + querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + Retrieve, + }; + + #[test_log::test(tokio::test)] + async fn test_retrieve_multiple_docs_and_filter() { + let test_context = TestContext::setup_with_cfg( + vec!["filter"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + let nodes = vec![ + indexing::Node::new("test_query1").with_metadata(("filter", "true")), + indexing::Node::new("test_query2").with_metadata(("filter", "true")), + indexing::Node::new("test_query3").with_metadata(("filter", "false")), + ] + .into_iter() + .map(|node| { + node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); + node.to_owned() + }) + .collect(); + + test_context + .pgv_storage + .batch_store(nodes) + .await + .try_collect::>() + .await + .unwrap(); + + let mut query = Query::::new("test_query"); + query.embedding = Some(vec![1.0; 384]); + + let search_strategy = SimilaritySingleEmbedding::<()>::default(); + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 3); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + + assert_eq!(result.documents().len(), 2); + + let search_strategy = + SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string()); + + let result = test_context + .pgv_storage + .retrieve(&search_strategy, query.clone()) + .await + .unwrap(); + assert_eq!(result.documents().len(), 0); + } +} diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index 766ee35d..242a15f2 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -109,6 +109,7 @@ serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } arrow-array = { workspace = true } +sqlx = { workspace = true } [lints] workspace = true diff --git a/swiftide/tests/pgvector.rs b/swiftide/tests/pgvector.rs new file mode 100644 index 00000000..2efee6a3 --- /dev/null +++ b/swiftide/tests/pgvector.rs @@ -0,0 +1,232 @@ +//! This module contains tests for the `PgVector` indexing pipeline in the Swiftide project. +//! The tests validate the functionality of the pipeline, ensuring that data is correctly indexed +//! and processed from temporary files, database configurations, and simulated environments. + +use temp_dir::TempDir; + +use sqlx::{prelude::FromRow, types::Uuid}; +use swiftide::{ + indexing::{ + self, loaders, + transformers::{ + self, metadata_qa_code::NAME as METADATA_QA_CODE_NAME, ChunkCode, MetadataQACode, + }, + EmbeddedField, Pipeline, + }, + integrations::{self, pgvector::PgVector}, + query::{ + self, answers, query_transformers, response_transformers, states, Query, + TransformationEvent, + }, +}; +use swiftide_test_utils::{mock_chat_completions, openai_client}; +use wiremock::MockServer; + +#[allow(dead_code)] +#[derive(Debug, Clone, FromRow)] +struct VectorSearchResult { + id: Uuid, + chunk: String, +} + +/// Test case for verifying the PgVector indexing pipeline functionality. +/// +/// This test: +/// - Sets up a temporary file and Postgres database for testing. +/// - Configures a PgVector instance with a vector size of 384. +/// - Executes an indexing pipeline for Rust code chunks with embedded vector metadata. +/// - Performs a similarity-based vector search on the database and validates the retrieved results. +/// +/// Ensures correctness of end-to-end data flow, including table management, vector storage, and query execution. +#[test_log::test(tokio::test)] +async fn test_pgvector_indexing() { + // Setup temporary directory and file for testing + let tempdir = TempDir::new().unwrap(); + let codefile = tempdir.child("main.rs"); + let code = "fn main() { println!(\"Hello, World!\"); }"; + std::fs::write(&codefile, code).unwrap(); + + let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + // Setup mock servers to simulate API responses + let mock_server = MockServer::start().await; + mock_chat_completions(&mock_server).await; + + // Configure Pgvector with a default vector size, a single embedding + // and in addition to embedding the text metadata, also store it in a field + let pgv_storage = PgVector::builder() + .db_url(pgv_db_url) + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .table_name("swiftide_test") + .build() + .unwrap(); + + // Drop the existing test table before running the test + println!("Dropping existing test table & index if it exists"); + let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test"; + let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx"; + + if let Ok(pool) = pgv_storage.get_pool().await { + sqlx::query(drop_table_sql) + .execute(pool) + .await + .expect("Failed to execute SQL query for dropping the table"); + sqlx::query(drop_index_sql) + .execute(pool) + .await + .expect("Failed to execute SQL query for dropping the index"); + } else { + panic!("Unable to acquire database connection pool"); + } + + let result = + Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) + .then_chunk(ChunkCode::try_for_language("rust").unwrap()) + .then(|mut node: indexing::Node| { + node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]); + Ok(node) + }) + .then_store_with(pgv_storage.clone()) + .run() + .await; + + result.expect("PgVector Named vectors test indexing pipeline failed"); + + let pool = pgv_storage + .get_pool() + .await + .expect("Unable to acquire database connection pool"); + + // Start building the SQL query + let sql_vector_query = + "SELECT id, chunk FROM swiftide_test ORDER BY vector_combined <=> $1::VECTOR LIMIT $2"; + + println!("Running retrieve with SQL: {sql_vector_query}"); + + let top_k: i32 = 10; + let embedding = vec![1.0; 384]; + + let data: Vec = sqlx::query_as(sql_vector_query) + .bind(embedding) + .bind(top_k) + .fetch_all(pool) + .await + .expect("Sql named vector query failed"); + + let docs: Vec<_> = data.into_iter().map(|r| r.chunk).collect(); + + println!("Retrieved documents for debugging: {docs:#?}"); + + assert_eq!(docs[0], "fn main() { println!(\"Hello, World!\"); }"); +} + +/// Test the retrieval functionality of `PgVector` integration. +/// +/// This test verifies that a Rust code snippet can be embedded, +/// stored in a PostgreSQL database using `PgVector`, and accurately +/// retrieved using a single similarity-based query pipeline. It sets up +/// a mock OpenAI client, configures `PgVector`, and executes a query +/// to ensure the pipeline retrieves the correct data and generates +/// an expected response. +#[test_log::test(tokio::test)] +async fn test_pgvector_retrieve() { + // Setup temporary directory and file for testing + let tempdir = TempDir::new().unwrap(); + let codefile = tempdir.child("main.rs"); + let code = "fn main() { println!(\"Hello, World!\"); }"; + std::fs::write(&codefile, code).unwrap(); + + let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + // Setup mock servers to simulate API responses + let mock_server = MockServer::start().await; + mock_chat_completions(&mock_server).await; + + let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o"); + + let fastembed = + integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); + + // Configure Pgvector with a default vector size, a single embedding + // and in addition to embedding the text metadata, also store it in a field + let pgv_storage = PgVector::builder() + .db_url(pgv_db_url) + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata(METADATA_QA_CODE_NAME) + .with_metadata("filter") + .table_name("swiftide_test") + .build() + .unwrap(); + + // Drop the existing test table before running the test + println!("Dropping existing test table & index if it exists"); + let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test"; + let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx"; + + if let Ok(pool) = pgv_storage.get_pool().await { + sqlx::query(drop_table_sql) + .execute(pool) + .await + .expect("Failed to execute SQL query for dropping the table"); + sqlx::query(drop_index_sql) + .execute(pool) + .await + .expect("Failed to execute SQL query for dropping the index"); + } else { + panic!("Unable to acquire database connection pool"); + } + + Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) + .then_chunk(ChunkCode::try_for_language("rust").unwrap()) + .then(MetadataQACode::new(openai_client.clone())) + .then(|mut node: indexing::Node| { + node.metadata + .insert("filter".to_string(), "true".to_string()); + Ok(node) + }) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) + .log_nodes() + .then_store_with(pgv_storage.clone()) + .run() + .await + .unwrap(); + + let strategy = query::search_strategies::SimilaritySingleEmbedding::from_filter( + "filter = \"true\"".to_string(), + ); + + let query_pipeline = query::Pipeline::from_search_strategy(strategy) + .then_transform_query(query_transformers::GenerateSubquestions::from_client( + openai_client.clone(), + )) + .then_transform_query(query_transformers::Embed::from_client(fastembed.clone())) + .then_retrieve(pgv_storage.clone()) + .then_transform_response(response_transformers::Summary::from_client( + openai_client.clone(), + )) + .then_answer(answers::Simple::from_client(openai_client.clone())); + + let result: Query = query_pipeline.query("What is swiftide?").await.unwrap(); + + println!("{:#?}", &result); + + assert_eq!( + result.answer(), + "\n\nHello there, how may I assist you today?" + ); + let TransformationEvent::Retrieved { documents, .. } = result + .history() + .iter() + .find(|e| matches!(e, TransformationEvent::Retrieved { .. })) + .unwrap() + else { + panic!("No documents found") + }; + + assert_eq!( + documents.first().unwrap(), + "fn main() { println!(\"Hello, World!\"); }" + ); +}