Skip to content

Commit

Permalink
feat(query): Add custom SQL query generation for pgvector search (#478)
Browse files Browse the repository at this point in the history
Adds support for custom retrieval queries with the sqlx query builder for PGVector. Puts down the fundamentals for custom query building for any retriever.

---------

Signed-off-by: shamb0 <r.raajey@gmail.com>
Co-authored-by: Swabbie (Bosun) <155570396+SwabbieBosun@users.noreply.github.com>
  • Loading branch information
shamb0 and SwabbieBosun authored Dec 30, 2024
1 parent b55bf0b commit 584695e
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 8 deletions.
118 changes: 118 additions & 0 deletions swiftide-core/src/search_strategies/custom_strategy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//! Generic vector search strategy framework for customizable query generation.
//!
//! Provides core abstractions for vector similarity search:
//! - Generic query type parameter for retriever-specific implementations
//! - Flexible query generation through closure-based configuration
//!
//! This module implements a strategy pattern for vector similarity search,
//! allowing different retrieval backends to provide their own query generation
//! logic while maintaining a consistent interface. The framework emphasizes
//! composition over inheritance, enabling configuration through closures
//! rather than struct fields.
use crate::querying::{self, states, Query};
use anyhow::{anyhow, Result};
use std::marker::PhantomData;
use std::sync::Arc;

/// A type alias for query generation functions.
///
/// The query generator takes a pending query state and produces a
/// retriever-specific query type. All configuration parameters should
/// be captured in the closure's environment.
type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;

/// `CustomStrategy` provides a flexible way to generate retriever-specific search queries.
///
/// This struct implements a strategy pattern for vector similarity search, allowing
/// different retrieval backends to provide their own query generation logic. Configuration
/// is managed through the query generation closure, promoting a more flexible and
/// composable design.
///
/// # Type Parameters
/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
///
/// # Examples
/// ```rust
/// // Define search configuration
/// const MAX_SEARCH_RESULTS: i64 = 5;
///
/// // Create a custom search strategy
/// let strategy = CustomStrategy::from_query(|query_node| {
/// let mut builder = QueryBuilder::new();
///
/// // Configure search parameters within the closure
/// builder.push(" LIMIT ");
/// builder.push_bind(MAX_SEARCH_RESULTS);
///
/// Ok(builder)
/// });
/// ```
///
/// # Implementation Notes
/// - Search configuration (like result limits and vector fields) should be defined
/// in the closure's scope
/// - Implementers are responsible for validating configuration values
/// - The query generator has access to the full query state for maximum flexibility
pub struct CustomStrategy<Q> {
/// The query generation function now returns a `Q`
query: Option<QueryGenerator<Q>>,

/// `PhantomData` to handle the generic parameter
_marker: PhantomData<Q>,
}

impl<Q: Send + Sync + 'static> querying::SearchStrategy for CustomStrategy<Q> {}

impl<Q> Default for CustomStrategy<Q> {
fn default() -> Self {
Self {
query: None,
_marker: PhantomData,
}
}
}

// Manual Clone implementation instead of derive
impl<Q> Clone for CustomStrategy<Q> {
fn clone(&self) -> Self {
Self {
query: self.query.clone(), // Arc clone is fine
_marker: PhantomData,
}
}
}

impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
/// Creates a new `CustomStrategy` with a query generation function.
///
/// The provided closure should contain all necessary configuration for
/// query generation. This design allows for more flexible configuration
/// management compared to struct-level fields.
///
/// # Parameters
/// * `query` - A closure that generates retriever-specific queries
pub fn from_query(
query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
) -> Self {
Self {
query: Some(Arc::new(query)),
_marker: PhantomData,
}
}

/// Gets the query builder, which can then be used to build the actual query
///
/// # Errors
/// This function will return an error if:
/// - No query function has been set (use `from_query` to set a query function).
/// - The query function fails while processing the provided `query_node`.
pub fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match &self.query {
Some(query_fn) => Ok(query_fn(query_node)?),
None => Err(anyhow!(
"No query function has been set. Use from_query() to set a query function."
)),
}
}
}
2 changes: 2 additions & 0 deletions swiftide-core/src/search_strategies/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
//!
//! The strategy is also yielded to the Retriever and can contain addition configuration
mod custom_strategy;
mod hybrid_search;
mod similarity_single_embedding;

pub(crate) const DEFAULT_TOP_K: u64 = 10;
pub(crate) const DEFAULT_TOP_N: u64 = 10;

pub use custom_strategy::*;
pub use hybrid_search::*;
pub use similarity_single_embedding::*;

Expand Down
8 changes: 6 additions & 2 deletions swiftide-integrations/src/pgvector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::sync::Arc;
use std::sync::OnceLock;
use tokio::time::Duration;

use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig};
pub use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig};

/// Default maximum connections for the database connection pool.
const DB_POOL_CONN_MAX: u32 = 10;
Expand Down Expand Up @@ -135,6 +135,10 @@ impl PgVector {
pub async fn get_pool(&self) -> Result<&PgPool> {
self.pool_get_or_initialize().await
}

pub fn get_table_name(&self) -> &str {
&self.table_name
}
}

impl PgVectorBuilder {
Expand Down Expand Up @@ -177,7 +181,7 @@ impl PgVectorBuilder {
self
}

fn default_fields() -> Vec<FieldConfig> {
pub fn default_fields() -> Vec<FieldConfig> {
vec![FieldConfig::ID, FieldConfig::Chunk]
}
}
Expand Down
6 changes: 3 additions & 3 deletions swiftide-integrations/src/pgvector/pgv_table_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tokio::time::sleep;
#[derive(Clone, Debug)]
pub struct VectorConfig {
embedded_field: EmbeddedField,
pub(crate) field: String,
pub field: String,
}

impl VectorConfig {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<T: AsRef<str>> From<T> for MetadataConfig {
/// Represents different field types that can be configured in the table schema,
/// including vector embeddings, metadata, and system fields.
#[derive(Clone, Debug)]
pub(crate) enum FieldConfig {
pub enum FieldConfig {
/// `Vector` - Vector embedding field configuration
Vector(VectorConfig),
/// `Metadata` - Metadata field configuration
Expand Down Expand Up @@ -441,7 +441,7 @@ impl PgVector {
}

impl PgVector {
pub(crate) fn normalize_field_name(field: &str) -> String {
pub fn normalize_field_name(field: &str) -> String {
// Define the special characters as an array
let special_chars: [char; 4] = ['(', '[', '{', '<'];

Expand Down
33 changes: 32 additions & 1 deletion swiftide-integrations/src/pgvector/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{prelude::FromRow, types::Uuid};
use swiftide_core::{
querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
querying::{
search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
states, Query,
},
Retrieve,
};

Expand Down Expand Up @@ -108,6 +111,34 @@ impl Retrieve<SimilaritySingleEmbedding> for PgVector {
}
}

#[async_trait]
impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for PgVector {
async fn retrieve(
&self,
search_strategy: &CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>,
query: Query<states::Pending>,
) -> Result<Query<states::Retrieved>> {
// Get the database pool
let pool = self.get_pool().await?;

// Build the custom query using both strategy and query state
let mut query_builder = search_strategy.build_query(&query)?;

// Execute the query using the builder's built-in methods
let results = query_builder
.build_query_as::<VectorSearchResult>() // Convert to a typed query
.fetch_all(pool) // Execute and get all results
.await
.map_err(|e| anyhow!("Failed to execute search query: {}", e))?;

// Transform results into documents
let documents = results.into_iter().map(|r| r.chunk).collect();

// Update query state with retrieved documents
Ok(query.retrieved_documents(documents))
}
}

#[cfg(test)]
mod tests {
use crate::pgvector::fixtures::TestContext;
Expand Down
Loading

0 comments on commit 584695e

Please sign in to comment.