Skip to content

Commit

Permalink
Addressed review comments:
Browse files Browse the repository at this point in the history
- added Postgres test_util,
- completed unit tests for persist and retrieval

Signed-off-by: shamb0 <r.raajey@gmail.com>
  • Loading branch information
shamb0 committed Oct 22, 2024
1 parent bfa44b5 commit 81ef43b
Show file tree
Hide file tree
Showing 13 changed files with 728 additions and 257 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ temp-dir = "0.1.13"
wiremock = "0.6.0"
test-case = "3.3.1"
insta = { version = "1.39.0", features = ["yaml"] }
tempfile = "3.10.1"
portpicker = "0.1.1"

[workspace.lints.rust]
unsafe_code = "forbid"
Expand Down
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ qdrant-client = { workspace = true }
fluvio = { workspace = true }
temp-dir = { workspace = true }
sqlx = { workspace = true }
swiftide-test-utils = { path = "../swiftide-test-utils" }

[[example]]
doc-scrape-examples = true
Expand Down
81 changes: 66 additions & 15 deletions examples/index_md_into_pgvector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* This example demonstrates how to use the Pgvector integration with Swiftide
* This example demonstrates how to use the Pgvector integration with SwiftIDE
*/
use std::path::PathBuf;
use swiftide::{
Expand All @@ -11,9 +11,40 @@ use swiftide::{
},
EmbeddedField,
},
integrations::{self, pgvector::PgVector},
integrations::{self, fastembed::FastEmbed, pgvector::PgVector},
query::{self, answers, query_transformers, response_transformers},
traits::SimplePrompt,
};

async fn ask_query(
llm_client: impl SimplePrompt + Clone + 'static,
embed: FastEmbed,
vector_store: PgVector,
question: String,
) -> Result<String, Box<dyn std::error::Error>> {
// By default the search strategy is SimilaritySingleEmbedding
// which takes the latest query, embeds it, and does a similarity search
//
// Pgvector will return an error if multiple embeddings are set
//
// The pipeline generates subquestions to increase semantic coverage, embeds these in a single
// embedding, retrieves the default top_k documents, summarizes them and uses that as context
// for the final answer.
let pipeline = query::Pipeline::default()
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
llm_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(embed))
.then_retrieve(vector_store.clone())
.then_transform_response(response_transformers::Summary::from_client(
llm_client.clone(),
))
.then_answer(answers::Simple::from_client(llm_client.clone()));

let result = pipeline.query(question).await?;
Ok(result.answer().into())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
Expand All @@ -23,15 +54,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");

// Create a PathBuf to test dataset from the manifest directory
let test_dataset_path = PathBuf::from(manifest_dir).join("test_dataset");
let test_dataset_path = PathBuf::from(manifest_dir).join("../README.md");

tracing::info!("Test Dataset path: {:?}", test_dataset_path);

let pgv_db_url = std::env::var("DATABASE_URL")
.as_deref()
.unwrap_or("postgresql://myuser:mypassword@localhost:5432/mydatabase")
.to_owned();
let (_pgv_db_container, pgv_db_url, _temp_dir) = swiftide_test_utils::start_postgres().await;

tracing::info!("pgv_db_url :: {:#?}", pgv_db_url);

let ollama_client = integrations::ollama::Ollama::default()
let llm_client = integrations::ollama::Ollama::default()
.with_default_prompt_model("llama3.2:latest")
.to_owned();

Expand All @@ -41,7 +72,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Configure Pgvector with a default vector size, a single embedding
// and in addition to embedding the text metadata, also store it in a field
let pgv_storage = PgVector::builder()
.try_from_url(pgv_db_url, Some(10))
.try_connect_to_pool(pgv_db_url, Some(10))
.await
.expect("Failed to connect to postgres server")
.vector_size(384)
Expand All @@ -52,24 +83,44 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap();

// Drop the existing test table before running the test
tracing::info!("Dropping existing test table if it exists");
tracing::info!("Dropping existing test table & index if it exists");
let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test";
let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx";

if let Some(pool) = pgv_storage.get_pool() {
sqlx::query(drop_table_sql).execute(pool).await?;
if let Ok(pool) = pgv_storage.get_pool() {
sqlx::query(drop_table_sql).execute(&pool).await?;
sqlx::query(drop_index_sql).execute(&pool).await?;
} else {
return Err("Failed to get database connection pool".into());
}

tracing::info!("Starting indexing pipeline");
indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"]))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(ollama_client.clone()))
.then_in_batch(Embed::new(fastembed).with_batch_size(100))
.then(MetadataQAText::new(llm_client.clone()))
.then_in_batch(Embed::new(fastembed.clone()).with_batch_size(100))
.then_store_with(pgv_storage.clone())
.run()
.await?;

tracing::info!("Indexing test completed successfully");
for (i, question) in [
"What is SwiftIDE? Provide a clear, comprehensive summary in under 50 words.",
"How can I use SwiftIDE to connect with the Ethereum blockchain? Please provide a concise, comprehensive summary in less than 50 words.",
]
.iter()
.enumerate()
{
let result = ask_query(
llm_client.clone(),
fastembed.clone(),
pgv_storage.clone(),
question.to_string(),
).await?;
tracing::info!("*** Answer Q{} ***", i + 1);
tracing::info!("{}", result);
tracing::info!("===X===");
}

tracing::info!("PgVector Indexing & retrieval test completed successfully");
Ok(())
}
41 changes: 0 additions & 41 deletions examples/test_dataset/README.md

This file was deleted.

46 changes: 0 additions & 46 deletions scripts/docker/docker-compose-db-pg.yml

This file was deleted.

2 changes: 1 addition & 1 deletion swiftide-indexing/src/loaders/file_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl Loader for FileLoader {
.filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file()))
.filter(move |entry| self.file_has_extension(entry.path()))
.map(|entry| {
tracing::debug!("Reading file: {:?}", entry);
tracing::info!("Reading file: {:?}", entry);
let content =
std::fs::read_to_string(entry.path()).context("Failed to read file")?;
let original_size = content.len();
Expand Down
Loading

0 comments on commit 81ef43b

Please sign in to comment.