Skip to content

Commit

Permalink
Add support for PGVector
Browse files Browse the repository at this point in the history
  • Loading branch information
timonv committed Dec 30, 2024
1 parent f14e082 commit c8589ac
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 5 deletions.
1 change: 1 addition & 0 deletions swiftide-integrations/src/pgvector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! - Efficient vector storage and indexing
//! - Connection pooling with automatic retries
//! - Batch operations for optimized performance
//! - Metadata included in retrieval
//!
//! The functionality is primarily used through the [`PgVector`] client, which implements
//! the [`Persist`] trait for seamless integration with indexing and query pipelines.
Expand Down
103 changes: 98 additions & 5 deletions swiftide-integrations/src/pgvector/retrieve.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::pgvector::{PgVector, PgVectorBuilder};
use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{prelude::FromRow, types::Uuid};
use sqlx::{prelude::FromRow, types::Uuid, Column, Row};
use swiftide_core::{
document::Document,
indexing::Metadata,
querying::{
search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
states, Query,
Expand All @@ -12,10 +14,47 @@ use swiftide_core::{
};

#[allow(dead_code)]
#[derive(Debug, Clone, FromRow)]
#[derive(Debug, Clone)]
struct VectorSearchResult {
id: Uuid,
chunk: String,
metadata: Metadata,
}

impl From<VectorSearchResult> for Document {
fn from(val: VectorSearchResult) -> Self {
Document::new(val.chunk, Some(val.metadata))
}
}

impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult {
fn from_row(row: &sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
dbg!(&row);
let mut metadata = Metadata::default();

// Metadata fields are stored each as prefixed meta_ fields. Perhaps we should add a single
// metadata field instead of multiple fields.
for column in row.columns() {
if column.name().starts_with("meta_") {
row.try_get::<serde_json::Value, _>(column.name())?
.as_object()
.and_then(|object| {
object.keys().collect::<Vec<_>>().first().map(|key| {
metadata.insert(
key.to_owned(),
object.get(key.as_str()).expect("infallible").clone(),
);
})
});
}
}

Ok(VectorSearchResult {
id: row.try_get("id")?,
chunk: row.try_get("chunk")?,
metadata,
})
}
}

#[allow(clippy::redundant_closure_for_method_calls)]
Expand All @@ -40,6 +79,12 @@ impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
let default_columns: Vec<_> = PgVectorBuilder::default_fields()
.iter()
.map(|f| f.field_name().to_string())
.chain(
self.fields
.iter()
.filter(|f| matches!(f, FieldConfig::Metadata(_)))
.map(|f| f.field_name().to_string()),
)
.collect();

// Start building the SQL query
Expand Down Expand Up @@ -89,7 +134,7 @@ impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
.fetch_all(pool)
.await?;

let docs = data.into_iter().map(|r| r.chunk.into()).collect();
let docs = data.into_iter().map(Into::into).collect();

Ok(query_state.retrieved_documents(docs))
}
Expand Down Expand Up @@ -132,7 +177,7 @@ impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for P
.map_err(|e| anyhow!("Failed to execute search query: {}", e))?;

// Transform results into documents
let documents = results.into_iter().map(|r| r.chunk.into()).collect();
let documents = results.into_iter().map(Into::into).collect();

// Update query state with retrieved documents
Ok(query.retrieved_documents(documents))
Expand Down Expand Up @@ -212,4 +257,52 @@ mod tests {
.unwrap();
assert_eq!(result.documents().len(), 0);
}

#[test_log::test(tokio::test)]
async fn test_retrieve_docs_with_metadata() {
let test_context = TestContext::setup_with_cfg(
vec!["other", "text"].into(),
HashSet::from([EmbeddedField::Combined]),
)
.await
.expect("Test setup failed");

let nodes = vec![indexing::Node::new("test_query1")
.with_metadata([
("other", serde_json::Value::from(10)),
("text", serde_json::Value::from("some text")),
])
.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])])
.to_owned()];

test_context
.pgv_storage
.batch_store(nodes)
.await
.try_collect::<Vec<_>>()
.await
.unwrap();

let mut query = Query::<states::Pending>::new("test_query");
query.embedding = Some(vec![1.0; 384]);

let search_strategy = SimilaritySingleEmbedding::<()>::default();
let result = test_context
.pgv_storage
.retrieve(&search_strategy, query.clone())
.await
.unwrap();

assert_eq!(result.documents().len(), 1);

let doc = result.documents().first().unwrap();
assert_eq!(
doc.metadata().get("other"),
Some(&serde_json::Value::from(10))
);
assert_eq!(
doc.metadata().get("text"),
Some(&serde_json::Value::from("some text"))
);
}
}

0 comments on commit c8589ac

Please sign in to comment.