diff --git a/swiftide-core/src/query.rs b/swiftide-core/src/query.rs index d9c8fce2..5e2e0b9a 100644 --- a/swiftide-core/src/query.rs +++ b/swiftide-core/src/query.rs @@ -7,7 +7,7 @@ //! `states::Answered`: The query has been answered use derive_builder::Builder; -use crate::{util::debug_long_utf8, AdvanceEmbedding, Embedding, SparseEmbedding}; +use crate::{util::debug_long_utf8, Embedding, SparseEmbedding}; type Document = String; @@ -34,9 +34,6 @@ pub struct Query { #[builder(default)] pub sparse_embedding: Option, - - #[builder(default)] - pub adv_embedding: Option, } impl std::fmt::Debug for Query { @@ -47,7 +44,6 @@ impl std::fmt::Debug for Query { .field("state", &self.state) .field("transformation_history", &self.transformation_history) .field("embedding", &self.embedding.is_some()) - .field("adv_embedding", &self.adv_embedding.is_some()) .finish() } } @@ -75,7 +71,6 @@ impl Query { transformation_history: self.transformation_history, embedding: self.embedding, sparse_embedding: self.sparse_embedding, - adv_embedding: self.adv_embedding, } } diff --git a/swiftide-core/src/type_aliases.rs b/swiftide-core/src/type_aliases.rs index 6c97bc6e..197c56b3 100644 --- a/swiftide-core/src/type_aliases.rs +++ b/swiftide-core/src/type_aliases.rs @@ -1,6 +1,5 @@ #![cfg_attr(coverage_nightly, coverage(off))] -use crate::indexing::EmbeddedField; use serde::{Deserialize, Serialize}; pub type Embedding = Vec; @@ -21,20 +20,3 @@ impl std::fmt::Debug for SparseEmbedding { .finish() } } - -#[derive(Serialize, Deserialize, Clone, PartialEq)] -pub struct AdvanceEmbedding { - pub embedded_field: EmbeddedField, - pub field_value: Vec, -} -pub type AdvanceEmbeddings = Vec; - -impl std::fmt::Debug for AdvanceEmbedding { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Start the debug struct formatting - f.debug_struct("AdvanceEmbedding") - .field("embedded_field", &self.embedded_field) - .field("field_value", &self.field_value) - .finish() - } -} diff --git a/swiftide-indexing/src/loaders/file_loader.rs b/swiftide-indexing/src/loaders/file_loader.rs index fff4e800..f05d4ff1 100644 --- a/swiftide-indexing/src/loaders/file_loader.rs +++ b/swiftide-indexing/src/loaders/file_loader.rs @@ -109,7 +109,7 @@ impl Loader for FileLoader { .filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file())) .filter(move |entry| self.file_has_extension(entry.path())) .map(|entry| { - tracing::info!("Reading file: {:?}", entry); + tracing::debug!("Reading file: {:?}", entry); let content = std::fs::read_to_string(entry.path()).context("Failed to read file")?; let original_size = content.len(); diff --git a/swiftide-integrations/src/pgvector/fixtures.rs b/swiftide-integrations/src/pgvector/fixtures.rs index 9307c14d..46ea767d 100644 --- a/swiftide-integrations/src/pgvector/fixtures.rs +++ b/swiftide-integrations/src/pgvector/fixtures.rs @@ -78,7 +78,7 @@ pub(crate) struct PgVectorTestData<'a> { /// Vector embeddings with their corresponding fields pub vectors: Vec<(indexing::EmbeddedField, Vec)>, pub expected_in_results: bool, - pub use_adv_embedding_query: bool, + pub use_hybrid_search: bool, } impl<'a> PgVectorTestData<'a> { diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs index f0eba79b..211ca4a7 100644 --- a/swiftide-integrations/src/pgvector/mod.rs +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -189,8 +189,11 @@ mod tests { use std::collections::HashSet; use swiftide_core::{ indexing::{self, EmbedMode, EmbeddedField}, - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, - AdvanceEmbedding, Persist, Retrieve, + querying::{ + search_strategies::{HybridSearch, SimilaritySingleEmbedding}, + states, Query, + }, + Persist, Retrieve, }; use test_case::test_case; @@ -326,7 +329,7 @@ mod tests { metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], expected_in_results: true, - use_adv_embedding_query: false, + use_hybrid_search: false, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, @@ -334,7 +337,7 @@ mod tests { metadata: None, vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], expected_in_results: true, - use_adv_embedding_query: false, + use_hybrid_search: false, } ], HashSet::from([EmbeddedField::Combined]) @@ -351,7 +354,7 @@ mod tests { ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], expected_in_results: true, - use_adv_embedding_query: false, + use_hybrid_search: false, }, PgVectorTestData { embed_mode: EmbedMode::SingleWithMetadata, @@ -362,7 +365,7 @@ mod tests { ].into()), vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], expected_in_results: true, - use_adv_embedding_query: false, + use_hybrid_search: false, } ], HashSet::from([EmbeddedField::Combined]) @@ -380,7 +383,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, }, PgVectorTestData { embed_mode: EmbedMode::PerField, @@ -392,7 +395,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, } ], HashSet::from([ @@ -417,7 +420,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, }, PgVectorTestData { embed_mode: EmbedMode::PerField, @@ -432,7 +435,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, } ], HashSet::from([ @@ -453,7 +456,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1) ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, @@ -464,7 +467,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3) ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, } ], HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk]) @@ -488,7 +491,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, }, PgVectorTestData { embed_mode: EmbedMode::Both, @@ -506,7 +509,7 @@ mod tests { PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) ], expected_in_results: true, - use_adv_embedding_query: true, + use_hybrid_search: true, } ], HashSet::from([ @@ -581,25 +584,30 @@ mod tests { for (index, (field, vector)) in test_case.vectors.iter().enumerate() { tracing::warn!("Enter :: {:#?}!", index); let mut query = Query::::new("test_query"); - - if test_case.use_adv_embedding_query { - query.adv_embedding = Some(AdvanceEmbedding { - embedded_field: field.clone(), - field_value: vector.clone(), - }); + query.embedding = Some(vector.clone()); + + let result = if test_case.use_hybrid_search { + let mut search_strategy = HybridSearch::default(); + search_strategy + .with_dense_vector_field(field.clone()) + .with_top_n(nodes.len() as u64); + + test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .expect("Retrieval should succeed") } else { - query.embedding = Some(vector.clone()); - } - - let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); - search_strategy.with_top_k(nodes.len() as u64); - - let result = test_context - .pgv_storage - .retrieve(&search_strategy, query) - .await - .expect("Retrieval should succeed"); - + let mut search_strategy = SimilaritySingleEmbedding::<()>::default(); + search_strategy.with_top_k(nodes.len() as u64); + + test_context + .pgv_storage + .retrieve(&search_strategy, query) + .await + .expect("Retrieval should succeed") + }; + if test_case.expected_in_results { assert!( result.documents().contains(&test_case.chunk.to_string()), diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs index a66acb5b..ac598476 100644 --- a/swiftide-integrations/src/pgvector/pgv_table_types.rs +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -437,7 +437,9 @@ impl PgVector { match vector_fields.as_slice() { [field] => Ok(field.field_name().to_string()), [] => Err(anyhow!("No vector field configured in schema")), - _ => Err(anyhow!("Multiple vector fields configured in schema")), + _ => Err(anyhow!( + "Multiple vector fields configured in schema use HybridSearch strategy" + )), } } } diff --git a/swiftide-integrations/src/pgvector/retrieve.rs b/swiftide-integrations/src/pgvector/retrieve.rs index 166085cc..36bfb963 100644 --- a/swiftide-integrations/src/pgvector/retrieve.rs +++ b/swiftide-integrations/src/pgvector/retrieve.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use pgvector::Vector; use sqlx::{prelude::FromRow, types::Uuid}; use swiftide_core::{ - querying::{search_strategies::SimilaritySingleEmbedding, states, Query}, + querying::{ + search_strategies::{HybridSearch, SimilaritySingleEmbedding}, + states, Query, + }, Retrieve, }; @@ -24,30 +27,14 @@ impl Retrieve> for PgVector { search_strategy: &SimilaritySingleEmbedding, query_state: Query, ) -> Result> { - let (vector_column_name, embedding) = match ( - query_state.embedding.as_ref(), - query_state.adv_embedding.as_ref(), - ) { - (Some(embed), None) => { - let vector_column_name = self.get_vector_column_name()?; - let embedding = Vector::from(embed.clone()); - (vector_column_name, embedding) - } - (None, Some(adv_embed)) => { - let vector_column_name = VectorConfig::from(adv_embed.embedded_field.clone()).field; - let embedding = Vector::from(adv_embed.field_value.clone()); - (vector_column_name, embedding) - } - (None, None) => { - return Err(anyhow!("No embedding found in query state")); - } - (Some(_), Some(_)) => { - return Err(anyhow!( - "Both regular and advanced embeddings found. Please provide only one type." - )); - } + let embedding = if let Some(embedding) = query_state.embedding.as_ref() { + Vector::from(embedding.clone()) + } else { + return Err(anyhow::Error::msg("Missing embedding in query state")); }; + let vector_column_name = self.get_vector_column_name()?; + let pool = self.pool_get_or_initialize().await?; let default_columns: Vec<_> = PgVectorBuilder::default_fields() @@ -124,6 +111,60 @@ impl Retrieve for PgVector { } } +#[async_trait] +impl Retrieve for PgVector { + #[tracing::instrument] + async fn retrieve( + &self, + search_strategy: &HybridSearch, + query_state: Query, + ) -> Result> { + let embedding = if let Some(embedding) = query_state.embedding.as_ref() { + Vector::from(embedding.clone()) + } else { + return Err(anyhow::Error::msg("Missing embedding in query state")); + }; + + let vector_column_name = + VectorConfig::from(search_strategy.dense_vector_field().clone()).field; + + let pool = self.pool_get_or_initialize().await?; + + let default_columns: Vec<_> = PgVectorBuilder::default_fields() + .iter() + .map(|f| f.field_name().to_string()) + .collect(); + + // Start building the SQL query + let mut sql = format!( + "SELECT {} FROM {}", + default_columns.join(", "), + self.table_name + ); + + // Add the ORDER BY clause for vector similarity search + sql.push_str(&format!( + " ORDER BY {} <=> $1 LIMIT $2", + &vector_column_name + )); + + tracing::debug!("Running retrieve with SQL: {}", sql); + + let top_k = i32::try_from(search_strategy.top_k()) + .map_err(|_| anyhow!("Failed to convert top_k to i32"))?; + + let data: Vec = sqlx::query_as(&sql) + .bind(embedding) + .bind(top_k) + .fetch_all(pool) + .await?; + + let docs = data.into_iter().map(|r| r.chunk).collect(); + + Ok(query_state.retrieved_documents(docs)) + } +} + #[cfg(test)] mod tests { use crate::pgvector::fixtures::TestContext;