Skip to content

Commit

Permalink
Add support for HybridSearch strategy
Browse files Browse the repository at this point in the history
Signed-off-by: shamb0 <r.raajey@gmail.com>
  • Loading branch information
shamb0 committed Nov 25, 2024
1 parent a9e2a9e commit 266a62f
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 81 deletions.
7 changes: 1 addition & 6 deletions swiftide-core/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! `states::Answered`: The query has been answered
use derive_builder::Builder;

use crate::{util::debug_long_utf8, AdvanceEmbedding, Embedding, SparseEmbedding};
use crate::{util::debug_long_utf8, Embedding, SparseEmbedding};

type Document = String;

Expand All @@ -34,9 +34,6 @@ pub struct Query<STATE: QueryState> {

#[builder(default)]
pub sparse_embedding: Option<SparseEmbedding>,

#[builder(default)]
pub adv_embedding: Option<AdvanceEmbedding>,
}

impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
Expand All @@ -47,7 +44,6 @@ impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
.field("state", &self.state)
.field("transformation_history", &self.transformation_history)
.field("embedding", &self.embedding.is_some())
.field("adv_embedding", &self.adv_embedding.is_some())
.finish()
}
}
Expand Down Expand Up @@ -75,7 +71,6 @@ impl<STATE: Clone + QueryState> Query<STATE> {
transformation_history: self.transformation_history,
embedding: self.embedding,
sparse_embedding: self.sparse_embedding,
adv_embedding: self.adv_embedding,
}
}

Expand Down
18 changes: 0 additions & 18 deletions swiftide-core/src/type_aliases.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![cfg_attr(coverage_nightly, coverage(off))]

use crate::indexing::EmbeddedField;
use serde::{Deserialize, Serialize};

pub type Embedding = Vec<f32>;
Expand All @@ -21,20 +20,3 @@ impl std::fmt::Debug for SparseEmbedding {
.finish()
}
}

#[derive(Serialize, Deserialize, Clone, PartialEq)]
pub struct AdvanceEmbedding {
pub embedded_field: EmbeddedField,
pub field_value: Vec<f32>,
}
pub type AdvanceEmbeddings = Vec<AdvanceEmbedding>;

impl std::fmt::Debug for AdvanceEmbedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Start the debug struct formatting
f.debug_struct("AdvanceEmbedding")
.field("embedded_field", &self.embedded_field)
.field("field_value", &self.field_value)
.finish()
}
}
2 changes: 1 addition & 1 deletion swiftide-indexing/src/loaders/file_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl Loader for FileLoader {
.filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file()))
.filter(move |entry| self.file_has_extension(entry.path()))
.map(|entry| {
tracing::info!("Reading file: {:?}", entry);
tracing::debug!("Reading file: {:?}", entry);
let content =
std::fs::read_to_string(entry.path()).context("Failed to read file")?;
let original_size = content.len();
Expand Down
2 changes: 1 addition & 1 deletion swiftide-integrations/src/pgvector/fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub(crate) struct PgVectorTestData<'a> {
/// Vector embeddings with their corresponding fields
pub vectors: Vec<(indexing::EmbeddedField, Vec<f32>)>,
pub expected_in_results: bool,
pub use_adv_embedding_query: bool,
pub use_hybrid_search: bool,
}

impl<'a> PgVectorTestData<'a> {
Expand Down
76 changes: 45 additions & 31 deletions swiftide-integrations/src/pgvector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,11 @@ mod tests {
use std::collections::HashSet;
use swiftide_core::{
indexing::{self, EmbedMode, EmbeddedField},
querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
AdvanceEmbedding, Persist, Retrieve,
querying::{
search_strategies::{HybridSearch, SimilaritySingleEmbedding},
states, Query,
},
Persist, Retrieve,
};
use test_case::test_case;

Expand Down Expand Up @@ -326,15 +329,15 @@ mod tests {
metadata: None,
vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)],
expected_in_results: true,
use_adv_embedding_query: false,
use_hybrid_search: false,
},
PgVectorTestData {
embed_mode: EmbedMode::SingleWithMetadata,
chunk: "single_no_meta_2",
metadata: None,
vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)],
expected_in_results: true,
use_adv_embedding_query: false,
use_hybrid_search: false,
}
],
HashSet::from([EmbeddedField::Combined])
Expand All @@ -351,7 +354,7 @@ mod tests {
].into()),
vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)],
expected_in_results: true,
use_adv_embedding_query: false,
use_hybrid_search: false,
},
PgVectorTestData {
embed_mode: EmbedMode::SingleWithMetadata,
Expand All @@ -362,7 +365,7 @@ mod tests {
].into()),
vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)],
expected_in_results: true,
use_adv_embedding_query: false,
use_hybrid_search: false,
}
],
HashSet::from([EmbeddedField::Combined])
Expand All @@ -380,7 +383,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2),
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
},
PgVectorTestData {
embed_mode: EmbedMode::PerField,
Expand All @@ -392,7 +395,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3),
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
}
],
HashSet::from([
Expand All @@ -417,7 +420,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2),
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
},
PgVectorTestData {
embed_mode: EmbedMode::PerField,
Expand All @@ -432,7 +435,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3),
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
}
],
HashSet::from([
Expand All @@ -453,7 +456,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1)
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
},
PgVectorTestData {
embed_mode: EmbedMode::Both,
Expand All @@ -464,7 +467,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3)
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
}
],
HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk])
Expand All @@ -488,7 +491,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8)
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
},
PgVectorTestData {
embed_mode: EmbedMode::Both,
Expand All @@ -506,7 +509,7 @@ mod tests {
PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3)
],
expected_in_results: true,
use_adv_embedding_query: true,
use_hybrid_search: true,
}
],
HashSet::from([
Expand Down Expand Up @@ -581,24 +584,35 @@ mod tests {
for (index, (field, vector)) in test_case.vectors.iter().enumerate() {
tracing::warn!("Enter :: {:#?}!", index);
let mut query = Query::<states::Pending>::new("test_query");

if test_case.use_adv_embedding_query {
query.adv_embedding = Some(AdvanceEmbedding {
embedded_field: field.clone(),
field_value: vector.clone(),
});
query.embedding = Some(vector.clone());

let result = if test_case.use_hybrid_search {
let mut search_strategy = HybridSearch::default();
search_strategy
.with_dense_vector_field(field.clone())
.with_top_n(nodes.len() as u64);

test_context
.pgv_storage
.retrieve(&search_strategy, query)
.await
.expect("Retrieval should succeed")
} else {
query.embedding = Some(vector.clone());
}

let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
search_strategy.with_top_k(nodes.len() as u64);

let result = test_context
.pgv_storage
.retrieve(&search_strategy, query)
.await
.expect("Retrieval should succeed");
let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
search_strategy.with_top_k(nodes.len() as u64);

test_context
.pgv_storage
.retrieve(&search_strategy, query)
.await
.expect("Retrieval should succeed")
};

// let result = test_context
// .pgv_storage
// .retrieve(&*search_strategy, query)
// .await
// .expect("Retrieval should succeed");

if test_case.expected_in_results {
assert!(
Expand Down
4 changes: 3 additions & 1 deletion swiftide-integrations/src/pgvector/pgv_table_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ impl PgVector {
match vector_fields.as_slice() {
[field] => Ok(field.field_name().to_string()),
[] => Err(anyhow!("No vector field configured in schema")),
_ => Err(anyhow!("Multiple vector fields configured in schema")),
_ => Err(anyhow!(
"Multiple vector fields configured in schema use HybridSearch strategy"
)),
}
}
}
Expand Down
87 changes: 64 additions & 23 deletions swiftide-integrations/src/pgvector/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{prelude::FromRow, types::Uuid};
use swiftide_core::{
querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
querying::{
search_strategies::{HybridSearch, SimilaritySingleEmbedding},
states, Query,
},
Retrieve,
};

Expand All @@ -24,30 +27,14 @@ impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
search_strategy: &SimilaritySingleEmbedding<String>,
query_state: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
let (vector_column_name, embedding) = match (
query_state.embedding.as_ref(),
query_state.adv_embedding.as_ref(),
) {
(Some(embed), None) => {
let vector_column_name = self.get_vector_column_name()?;
let embedding = Vector::from(embed.clone());
(vector_column_name, embedding)
}
(None, Some(adv_embed)) => {
let vector_column_name = VectorConfig::from(adv_embed.embedded_field.clone()).field;
let embedding = Vector::from(adv_embed.field_value.clone());
(vector_column_name, embedding)
}
(None, None) => {
return Err(anyhow!("No embedding found in query state"));
}
(Some(_), Some(_)) => {
return Err(anyhow!(
"Both regular and advanced embeddings found. Please provide only one type."
));
}
let embedding = if let Some(embedding) = query_state.embedding.as_ref() {
Vector::from(embedding.clone())
} else {
return Err(anyhow::Error::msg("Missing embedding in query state"));
};

let vector_column_name = self.get_vector_column_name()?;

let pool = self.pool_get_or_initialize().await?;

let default_columns: Vec<_> = PgVectorBuilder::default_fields()
Expand Down Expand Up @@ -124,6 +111,60 @@ impl Retrieve<SimilaritySingleEmbedding> for PgVector {
}
}

#[async_trait]
impl Retrieve<HybridSearch> for PgVector {
#[tracing::instrument]
async fn retrieve(
&self,
search_strategy: &HybridSearch,
query_state: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
let embedding = if let Some(embedding) = query_state.embedding.as_ref() {
Vector::from(embedding.clone())
} else {
return Err(anyhow::Error::msg("Missing embedding in query state"));
};

let vector_column_name =
VectorConfig::from(search_strategy.dense_vector_field().clone()).field;

let pool = self.pool_get_or_initialize().await?;

let default_columns: Vec<_> = PgVectorBuilder::default_fields()
.iter()
.map(|f| f.field_name().to_string())
.collect();

// Start building the SQL query
let mut sql = format!(
"SELECT {} FROM {}",
default_columns.join(", "),
self.table_name
);

// Add the ORDER BY clause for vector similarity search
sql.push_str(&format!(
" ORDER BY {} <=> $1 LIMIT $2",
&vector_column_name
));

tracing::debug!("Running retrieve with SQL: {}", sql);

let top_k = i32::try_from(search_strategy.top_k())
.map_err(|_| anyhow!("Failed to convert top_k to i32"))?;

let data: Vec<VectorSearchResult> = sqlx::query_as(&sql)
.bind(embedding)
.bind(top_k)
.fetch_all(pool)
.await?;

let docs = data.into_iter().map(|r| r.chunk).collect();

Ok(query_state.retrieved_documents(docs))
}
}

#[cfg(test)]
mod tests {
use crate::pgvector::fixtures::TestContext;
Expand Down

0 comments on commit 266a62f

Please sign in to comment.