From 1774b84f00a83fe69af4a2b6a6daf397d4d9b32d Mon Sep 17 00:00:00 2001 From: RK Date: Wed, 20 Nov 2024 15:28:12 +0530 Subject: [PATCH] feat(integrations): Add PGVector support for indexing (#392) --- .github/workflows/lint.yml | 39 ++ .github/workflows/test.yml | 26 - Cargo.lock | 470 +++++++++++++- Cargo.toml | 6 + examples/Cargo.toml | 8 + examples/index_md_into_pgvector.rs | 75 +++ swiftide-integrations/Cargo.toml | 35 +- swiftide-integrations/src/lib.rs | 2 + .../src/pgvector/fixtures.rs | 186 ++++++ swiftide-integrations/src/pgvector/mod.rs | 432 +++++++++++++ swiftide-integrations/src/pgvector/persist.rs | 93 +++ .../src/pgvector/pgv_table_types.rs | 575 ++++++++++++++++++ swiftide-test-utils/Cargo.toml | 2 +- swiftide-test-utils/src/test_utils.rs | 29 +- swiftide/Cargo.toml | 4 + 15 files changed, 1947 insertions(+), 35 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 examples/index_md_into_pgvector.rs create mode 100644 swiftide-integrations/src/pgvector/fixtures.rs create mode 100644 swiftide-integrations/src/pgvector/mod.rs create mode 100644 swiftide-integrations/src/pgvector/persist.rs create mode 100644 swiftide-integrations/src/pgvector/pgv_table_types.rs diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..99477cd2 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,39 @@ +name: CI + +on: + pull_request: + merge_group: + push: + branches: + - master + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-lint + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-Dwarnings" + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - uses: r7kamura/rust-problem-matchers@v1 + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Check typos + uses: crate-ci/typos@master + - name: "Rustfmt" + run: cargo fmt --all --check + - name: Lint dependencies + uses: EmbarkStudios/cargo-deny-action@v2 + - name: clippy + run: cargo clippy --all-targets --all-features --workspace diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bac7cdf0..88e0a552 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,37 +9,12 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-test - cancel-in-progress: true env: CARGO_TERM_COLOR: always RUSTFLAGS: "-Dwarnings" jobs: - lint: - name: Lint - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt - - uses: r7kamura/rust-problem-matchers@v1 - - name: Install Protoc - uses: arduino/setup-protoc@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - uses: r7kamura/rust-problem-matchers@v1 - - name: Check typos - uses: crate-ci/typos@master - - name: "Rustfmt" - run: cargo fmt --all --check - - name: Lint dependencies - uses: EmbarkStudios/cargo-deny-action@v2 - - name: clippy - run: cargo clippy --all-targets --all-features --workspace - test: name: Test runs-on: ubuntu-latest @@ -50,6 +25,5 @@ jobs: uses: arduino/setup-protoc@v3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - - uses: r7kamura/rust-problem-matchers@v1 - name: "Test" run: cargo test --all-features --tests diff --git a/Cargo.lock b/Cargo.lock index 2a2dfc29..06d8d0d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1253,6 +1253,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "benchmarks" version = "0.14.2" @@ -1284,6 +1290,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "bitpacking" @@ -1748,6 +1757,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -1836,6 +1851,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32c" version = "0.6.8" @@ -2453,6 +2483,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2525,6 +2566,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -2593,6 +2635,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "downcast" version = "0.11.0" @@ -2649,6 +2697,9 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -2946,6 +2997,17 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fluvio" version = "0.24.0" @@ -3359,6 +3421,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot 0.12.3", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -3626,6 +3699,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -3718,6 +3800,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -4946,6 +5037,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "lazycell" @@ -5091,6 +5185,17 @@ dependencies = [ "redox_syscall 0.5.7", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -5577,6 +5682,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -5996,6 +6118,15 @@ dependencies = [ "stfu8", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -6063,6 +6194,15 @@ dependencies = [ "indexmap 2.6.0", ] +[[package]] +name = "pgvector" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e8871b6d7ca78348c6cd29b911b94851f3429f0cd403130ca17f26c1fb91a6" +dependencies = [ + "sqlx", +] + [[package]] name = "pharos" version = "0.5.3" @@ -6260,6 +6400,27 @@ dependencies = [ "futures-io", ] +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.31" @@ -6443,7 +6604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.5.0", + "heck 0.4.1", "itertools 0.12.1", "log", "multimap", @@ -7100,6 +7261,26 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -7548,6 +7729,17 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha1_smol" version = "1.0.1" @@ -7598,6 +7790,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -7669,6 +7871,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "snafu" @@ -7771,6 +7976,19 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] [[package]] name = "spm_precompiled" @@ -7784,6 +8002,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "sqlformat" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" +dependencies = [ + "nom", + "unicode_categories", +] + [[package]] name = "sqlparser" version = "0.49.0" @@ -7805,6 +8033,208 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "sqlx" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" +dependencies = [ + "atoi", + "byteorder", + "bytes", + "chrono", + "crc", + "crossbeam-queue", + "either", + "event-listener 5.3.1", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.14.5", + "hashlink", + "hex", + "indexmap 2.6.0", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlformat", + "thiserror", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.87", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" +dependencies = [ + "dotenvy", + "either", + "heck 0.5.0", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 2.0.87", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa 1.0.11", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand 0.8.5", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa 1.0.11", + "log", + "md-5", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" +dependencies = [ + "atoi", + "chrono", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "serde_urlencoded", + "sqlx-core", + "tracing", + "url", + "uuid", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -7872,6 +8302,17 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3c3ee6129eec20fed59acf2e9cfb3ffd20d0bbe39fe334c22af0edc56dfe752" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -7987,9 +8428,12 @@ dependencies = [ "qdrant-client", "serde_json", "spider", + "sqlx", "swiftide", + "swiftide-test-utils", "temp-dir", "tokio", + "tracing", "tracing-subscriber", ] @@ -8048,6 +8492,7 @@ dependencies = [ "mockall", "ollama-rs", "parquet", + "pgvector", "qdrant-client", "redb", "redis", @@ -8057,6 +8502,7 @@ dependencies = [ "serde", "serde_json", "spider", + "sqlx", "strum", "strum_macros", "swiftide-core", @@ -9195,6 +9641,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -9390,6 +9842,12 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.95" @@ -9495,6 +9953,16 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall 0.5.7", + "wasite", +] + [[package]] name = "widestring" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3e2058ca..189704ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,12 @@ arrow-array = { version = "52.2", default-features = false } arrow = { version = "52.2", default-features = false } parquet = { version = "52.2", default-features = false, features = ["async"] } redb = { version = "2.2" } +sqlx = { version = "0.8.2", features = [ + "postgres", + "uuid", +], default-features = false } aws-config = "1.5" +pgvector = { version = "0.4.0", features = ["sqlx"], default-features = false } aws-credential-types = "1.2" aws-sdk-bedrockruntime = "1.61" criterion = { version = "0.5.1", default-features = false } @@ -91,6 +96,7 @@ wiremock = "0.6.0" test-case = "3.3.1" insta = { version = "1.41.1", features = ["yaml"] } + [workspace.lints.rust] unsafe_code = "forbid" unexpected_cfgs = { level = "warn", check-cfg = [ diff --git a/examples/Cargo.toml b/examples/Cargo.toml index f243055a..b0148ea2 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -21,6 +21,7 @@ swiftide = { path = "../swiftide/", features = [ "ollama", "fluvio", "lancedb", + "pgvector", ] } tracing-subscriber = { workspace = true } serde_json = { workspace = true } @@ -28,6 +29,9 @@ spider = { workspace = true } qdrant-client = { workspace = true } fluvio = { workspace = true } temp-dir = { workspace = true } +sqlx = { workspace = true } +swiftide-test-utils = { path = "../swiftide-test-utils" } +tracing = { workspace = true } [[example]] doc-scrape-examples = true @@ -91,3 +95,7 @@ path = "fluvio.rs" [[example]] name = "lancedb" path = "lancedb.rs" + +[[example]] +name = "index-md-pgvector" +path = "index_md_into_pgvector.rs" diff --git a/examples/index_md_into_pgvector.rs b/examples/index_md_into_pgvector.rs new file mode 100644 index 00000000..a8dfc5c1 --- /dev/null +++ b/examples/index_md_into_pgvector.rs @@ -0,0 +1,75 @@ +/** +* This example demonstrates how to index markdown into PGVector +*/ +use std::path::PathBuf; +use swiftide::{ + indexing::{ + self, + loaders::FileLoader, + transformers::{ + metadata_qa_text::NAME as METADATA_QA_TEXT_NAME, ChunkMarkdown, Embed, MetadataQAText, + }, + EmbeddedField, + }, + integrations::{self, pgvector::PgVector}, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + tracing::info!("Starting PgVector indexing test"); + + // Get the manifest directory path + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + + // Create a PathBuf to test dataset from the manifest directory + let test_dataset_path = PathBuf::from(manifest_dir).join("../README.md"); + + tracing::info!("Test Dataset path: {:?}", test_dataset_path); + + let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + + tracing::info!("pgv_db_url :: {:#?}", pgv_db_url); + + let llm_client = integrations::ollama::Ollama::default() + .with_default_prompt_model("llama3.2:latest") + .to_owned(); + + let fastembed = + integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed"); + + // Configure Pgvector with a default vector size, a single embedding + // and in addition to embedding the text metadata, also store it in a field + let pgv_storage = PgVector::builder() + .db_url(pgv_db_url) + .vector_size(384) + .with_vector(EmbeddedField::Combined) + .with_metadata(METADATA_QA_TEXT_NAME) + .table_name("swiftide_pgvector_test".to_string()) + .build() + .unwrap(); + + // Drop the existing test table before running the test + tracing::info!("Dropping existing test table & index if it exists"); + let drop_table_sql = "DROP TABLE IF EXISTS swiftide_pgvector_test"; + let drop_index_sql = "DROP INDEX IF EXISTS swiftide_pgvector_test_embedding_idx"; + + if let Ok(pool) = pgv_storage.get_pool().await { + sqlx::query(drop_table_sql).execute(pool).await?; + sqlx::query(drop_index_sql).execute(pool).await?; + } else { + return Err("Failed to get database connection pool".into()); + } + + tracing::info!("Starting indexing pipeline"); + indexing::Pipeline::from_loader(FileLoader::new(test_dataset_path).with_extensions(&["md"])) + .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) + .then(MetadataQAText::new(llm_client.clone())) + .then_in_batch(Embed::new(fastembed.clone()).with_batch_size(100)) + .then_store_with(pgv_storage.clone()) + .run() + .await?; + + tracing::info!("PgVector Indexing test completed successfully"); + Ok(()) +} diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index 30e717ce..0ceaf7cc 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -28,13 +28,24 @@ strum_macros = { workspace = true } regex = { workspace = true } futures-util = { workspace = true } - # Integrations async-openai = { workspace = true, optional = true } qdrant-client = { workspace = true, optional = true, default-features = false, features = [ "serde", ] } -redis = { workspace = true, features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"], optional = true } +sqlx = { workspace = true, optional = true, features = [ + "postgres", + "runtime-tokio", + "chrono", + "uuid", +] } +pgvector = { workspace = true, optional = true, features = ["sqlx"] } +redis = { workspace = true, features = [ + "aio", + "tokio-comp", + "connection-manager", + "tokio-rustls-comp", +], optional = true } tree-sitter = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-python = { workspace = true, optional = true } @@ -45,13 +56,22 @@ tree-sitter-java = { workspace = true, optional = true } fastembed = { workspace = true, optional = true } spider = { workspace = true, optional = true } htmd = { workspace = true, optional = true } -aws-config = { workspace = true, features = ["behavior-version-latest"], optional = true } -aws-credential-types = { workspace = true, features = ["hardcoded-credentials"], optional = true } -aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"], optional = true } +aws-config = { workspace = true, features = [ + "behavior-version-latest", +], optional = true } +aws-credential-types = { workspace = true, features = [ + "hardcoded-credentials", +], optional = true } +aws-sdk-bedrockruntime = { workspace = true, features = [ + "behavior-version-latest", +], optional = true } secrecy = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } ollama-rs = { workspace = true, optional = true } -deadpool = { workspace = true, features = ["managed", "rt_tokio_1"], optional = true } +deadpool = { workspace = true, features = [ + "managed", + "rt_tokio_1", +], optional = true } fluvio = { workspace = true, optional = true } arrow-array = { workspace = true, optional = true } lancedb = { workspace = true, optional = true } @@ -82,12 +102,15 @@ test-case = { workspace = true } indoc = { workspace = true } insta = { workspace = true } + [features] default = ["rustls"] # Ensures rustls is used rustls = ["reqwest/rustls-tls-native-roots"] # Qdrant for storage qdrant = ["dep:qdrant-client", "swiftide-core/qdrant"] +# PgVector for storage +pgvector = ["dep:sqlx", "dep:pgvector"] # Redis for caching and storage redis = ["dep:redis"] # Tree-sitter for code operations and chunking diff --git a/swiftide-integrations/src/lib.rs b/swiftide-integrations/src/lib.rs index d1e38a08..74f1c1d6 100644 --- a/swiftide-integrations/src/lib.rs +++ b/swiftide-integrations/src/lib.rs @@ -16,6 +16,8 @@ pub mod ollama; pub mod openai; #[cfg(feature = "parquet")] pub mod parquet; +#[cfg(feature = "pgvector")] +pub mod pgvector; #[cfg(feature = "qdrant")] pub mod qdrant; #[cfg(feature = "redb")] diff --git a/swiftide-integrations/src/pgvector/fixtures.rs b/swiftide-integrations/src/pgvector/fixtures.rs new file mode 100644 index 00000000..6508893a --- /dev/null +++ b/swiftide-integrations/src/pgvector/fixtures.rs @@ -0,0 +1,186 @@ +//! Test fixtures and utilities for pgvector integration testing. +//! +//! Provides test infrastructure and helper types to verify vector storage and retrieval: +//! - Mock data generation for different embedding modes +//! - Test containers for `PostgreSQL` with pgvector extension +//! - Common test scenarios and assertions +//! +//! # Examples +//! +//! ```rust +//! use swiftide_integrations::pgvector::fixtures::{TestContext, PgVectorTestData}; +//! use swiftide_core::indexing::{EmbedMode, EmbeddedField}; +//! +//! # async fn example() -> Result<(), Box> { +//! // Initialize test context with PostgreSQL container +//! let context = TestContext::setup_with_cfg( +//! Some(vec!["category", "priority"]), +//! vec![EmbeddedField::Combined].into_iter().collect() +//! ).await?; +//! +//! // Create test data for different embedding modes +//! let test_data = PgVectorTestData { +//! embed_mode: EmbedMode::SingleWithMetadata, +//! chunk: "test content", +//! metadata: None, +//! vectors: vec![PgVectorTestData::create_test_vector( +//! EmbeddedField::Combined, +//! 1.0 +//! )], +//! }; +//! # Ok(()) +//! # } +//! ``` +//! +//! The module supports testing for: +//! - Single embedding with/without metadata +//! - Per-field embeddings +//! - Combined embedding modes +//! - Different vector configurations +//! - Various metadata scenarios +use crate::pgvector::PgVector; +use std::collections::HashSet; +use swiftide_core::{ + indexing::{self, EmbeddedField}, + Persist, +}; +use testcontainers::{ContainerAsync, GenericImage}; + +/// Test data structure for pgvector integration testing. +/// +/// Provides a flexible structure to test different embedding modes and configurations, +/// including metadata handling and vector generation. +/// +/// # Examples +/// +/// ```rust +/// use swiftide_integrations::pgvector::fixtures::PgVectorTestData; +/// use swiftide_core::indexing::{EmbedMode, EmbeddedField}; +/// +/// let test_data = PgVectorTestData { +/// embed_mode: EmbedMode::SingleWithMetadata, +/// chunk: "test content", +/// metadata: None, +/// vectors: vec![PgVectorTestData::create_test_vector( +/// EmbeddedField::Combined, +/// 1.0 +/// )], +/// }; +/// ``` +#[derive(Clone)] +pub(crate) struct PgVectorTestData<'a> { + /// Embedding mode for the test case + pub embed_mode: indexing::EmbedMode, + /// Test content chunk + pub chunk: &'a str, + /// Optional metadata for testing metadata handling + pub metadata: Option, + /// Vector embeddings with their corresponding fields + pub vectors: Vec<(indexing::EmbeddedField, Vec)>, +} + +impl<'a> PgVectorTestData<'a> { + pub(crate) fn to_node(&self) -> indexing::Node { + // Create the initial builder + let mut base_builder = indexing::Node::builder(); + + // Set the required fields + let mut builder = base_builder.chunk(self.chunk).embed_mode(self.embed_mode); + + // Add metadata if it exists + if let Some(metadata) = &self.metadata { + builder = builder.metadata(metadata.clone()); + } + + // Build the node and add vectors + let mut node = builder.build().unwrap(); + node.vectors = Some(self.vectors.clone().into_iter().collect()); + node + } + + pub(crate) fn create_test_vector( + field: EmbeddedField, + base_value: f32, + ) -> (EmbeddedField, Vec) { + (field, vec![base_value; 384]) + } +} + +/// Test context managing `PostgreSQL` container and pgvector storage. +/// +/// Handles the lifecycle of test containers and provides configured storage +/// instances for testing. +/// +/// # Examples +/// +/// ```rust +/// # use swiftide_integrations::pgvector::fixtures::TestContext; +/// # use swiftide_core::indexing::EmbeddedField; +/// # async fn example() -> Result<(), Box> { +/// // Setup test context with specific configuration +/// let context = TestContext::setup_with_cfg( +/// Some(vec!["category"]), +/// vec![EmbeddedField::Combined].into_iter().collect() +/// ).await?; +/// +/// // Use context for testing +/// context.pgv_storage.setup().await?; +/// # Ok(()) +/// # } +/// ``` +pub(crate) struct TestContext { + /// Configured pgvector storage instance + pub(crate) pgv_storage: PgVector, + /// Container instance running `PostgreSQL` with pgvector + _pgv_db_container: ContainerAsync, +} + +impl TestContext { + /// Set up the test context, initializing `PostgreSQL` and `PgVector` storage + /// with configurable metadata fields + pub(crate) async fn setup_with_cfg( + metadata_fields: Option>, + vector_fields: HashSet, + ) -> Result> { + // Start `PostgreSQL` container and obtain the connection URL + let (pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await; + tracing::info!("Postgres database URL: {:#?}", pgv_db_url); + + // Initialize the connection pool outside of the builder chain + let mut connection_pool = PgVector::builder(); + + // Configure PgVector storage + let mut builder = connection_pool + .db_url(pgv_db_url) + .vector_size(384) + .table_name("swiftide_pgvector_test".to_string()); + + // Add all vector fields + for vector_field in vector_fields { + builder = builder.with_vector(vector_field); + } + + // Add all metadata fields + if let Some(metadata_fields_inner) = metadata_fields { + for field in metadata_fields_inner { + builder = builder.with_metadata(field); + } + }; + + let pgv_storage = builder.build().map_err(|err| { + tracing::error!("Failed to build PgVector: {}", err); + err + })?; + + // Set up PgVector storage (create the table if not exists) + pgv_storage.setup().await.map_err(|err| { + tracing::error!("PgVector setup failed: {}", err); + err + })?; + + Ok(Self { + pgv_storage, + _pgv_db_container: pgv_db_container, + }) + } +} diff --git a/swiftide-integrations/src/pgvector/mod.rs b/swiftide-integrations/src/pgvector/mod.rs new file mode 100644 index 00000000..51ffcf8c --- /dev/null +++ b/swiftide-integrations/src/pgvector/mod.rs @@ -0,0 +1,432 @@ +//! Integration module for `PostgreSQL` vector database (pgvector) operations. +//! +//! This module provides a client interface for vector similarity search operations using pgvector, +//! supporting: +//! - Vector collection management with configurable schemas +//! - Efficient vector storage and indexing +//! - Connection pooling with automatic retries +//! - Batch operations for optimized performance +//! +//! The functionality is primarily used through the [`PgVector`] client, which implements +//! the [`Persist`] trait for seamless integration with indexing and query pipelines. +//! +//! # Example +//! ```rust +//! # use swiftide_integrations::pgvector::PgVector; +//! # async fn example() -> anyhow::Result<()> { +//! let client = PgVector::builder() +//! .db_url("postgresql://localhost:5432/vectors") +//! .vector_size(384) +//! .build()?; +//! +//! # Ok(()) +//! # } +//! ``` +#[cfg(test)] +mod fixtures; + +mod persist; +mod pgv_table_types; +use anyhow::Result; +use derive_builder::Builder; +use sqlx::PgPool; +use std::fmt; +use std::sync::Arc; +use std::sync::OnceLock; +use tokio::time::Duration; + +use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig}; + +/// Default maximum connections for the database connection pool. +const DB_POOL_CONN_MAX: u32 = 10; + +/// Default maximum retries for database connection attempts. +const DB_POOL_CONN_RETRY_MAX: u32 = 3; + +/// Delay between connection retry attempts, in seconds. +const DB_POOL_CONN_RETRY_DELAY_SECS: u64 = 3; + +/// Default batch size for storing nodes. +const BATCH_SIZE: usize = 50; + +/// Represents a Pgvector client with configuration options. +/// +/// This struct is used to interact with the Pgvector vector database, providing methods to manage vector collections, +/// store data, and ensure efficient searches. The client can be cloned with low cost as it shares connections. +#[derive(Builder, Clone)] +#[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))] +pub struct PgVector { + /// Name of the table to store vectors. + #[builder(default = "String::from(\"swiftide_pgv_store\")")] + table_name: String, + + /// Default vector size; can be customized per configuration. + vector_size: i32, + + /// Batch size for storing nodes. + #[builder(default = "BATCH_SIZE")] + batch_size: usize, + + /// Field configurations for the `PgVector` table schema. + /// + /// Supports multiple field types (see [`FieldConfig`]). + #[builder(default)] + fields: Vec, + + /// Database connection URL. + db_url: String, + + /// Maximum connections allowed in the connection pool. + #[builder(default = "DB_POOL_CONN_MAX")] + db_max_connections: u32, + + /// Maximum retry attempts for establishing a database connection. + #[builder(default = "DB_POOL_CONN_RETRY_MAX")] + db_max_retry: u32, + + /// Delay between retry attempts for database connections. + #[builder(default = "Duration::from_secs(DB_POOL_CONN_RETRY_DELAY_SECS)")] + db_conn_retry_delay: Duration, + + /// Lazy-initialized database connection pool. + #[builder(default = "Arc::new(OnceLock::new())")] + connection_pool: Arc>, + + /// SQL statement used for executing bulk insert. + #[builder(default = "Arc::new(OnceLock::new())")] + sql_stmt_bulk_insert: Arc>, +} + +impl fmt::Debug for PgVector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgVector") + .field("table_name", &self.table_name) + .field("vector_size", &self.vector_size) + .field("batch_size", &self.batch_size) + .finish() + } +} + +impl PgVector { + /// Creates a new instance of `PgVectorBuilder` with default settings. + /// + /// # Returns + /// + /// A new `PgVectorBuilder`. + pub fn builder() -> PgVectorBuilder { + PgVectorBuilder::default() + } + + /// Retrieves a connection pool for `PostgreSQL`. + /// + /// This function returns the connection pool used for interacting with the `PostgreSQL` database. + /// It fetches the pool from the `PgDBConnectionPool` struct. + /// + /// # Returns + /// + /// A `Result` that, on success, contains the `PgPool` representing the database connection pool. + /// On failure, an error is returned. + /// + /// # Errors + /// + /// This function will return an error if it fails to retrieve the connection pool, which could occur + /// if the underlying connection to `PostgreSQL` has not been properly established. + pub async fn get_pool(&self) -> Result<&PgPool> { + self.pool_get_or_initialize().await + } +} + +impl PgVectorBuilder { + /// Adds a vector configuration to the builder. + /// + /// # Arguments + /// + /// * `config` - The vector configuration to add, which can be converted into a `VectorConfig`. + /// + /// # Returns + /// + /// A mutable reference to the builder with the new vector configuration added. + pub fn with_vector(&mut self, config: impl Into) -> &mut Self { + // Use `get_or_insert_with` to initialize `fields` if it's `None` + self.fields + .get_or_insert_with(Self::default_fields) + .push(FieldConfig::Vector(config.into())); + + self + } + + /// Sets the metadata configuration for the vector similarity search. + /// + /// This method allows you to specify metadata configurations for vector similarity search using `MetadataConfig`. + /// The provided configuration will be added as a new field in the builder. + /// + /// # Arguments + /// + /// * `config` - The metadata configuration to use. + /// + /// # Returns + /// + /// * Returns a mutable reference to `self` for method chaining. + pub fn with_metadata(&mut self, config: impl Into) -> &mut Self { + // Use `get_or_insert_with` to initialize `fields` if it's `None` + self.fields + .get_or_insert_with(Self::default_fields) + .push(FieldConfig::Metadata(config.into())); + + self + } + + fn default_fields() -> Vec { + vec![FieldConfig::ID, FieldConfig::Chunk] + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::fixtures::{PgVectorTestData, TestContext}; + use futures_util::TryStreamExt; + use std::collections::HashSet; + use swiftide_core::{ + indexing::{self, EmbedMode, EmbeddedField}, + Persist, + }; + use test_case::test_case; + + #[test_case( + // SingleWithMetadata - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_no_meta_1", + metadata: None, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)], + }, + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_no_meta_2", + metadata: None, + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)], + } + ], + HashSet::from([EmbeddedField::Combined]) + ; "SingleWithMetadata mode without metadata")] + #[test_case( + // SingleWithMetadata - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_with_meta_1", + metadata: Some(vec![ + ("category", "A"), + ("priority", "high") + ].into()), + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)], + }, + PgVectorTestData { + embed_mode: EmbedMode::SingleWithMetadata, + chunk: "single_with_meta_2", + metadata: Some(vec![ + ("category", "B"), + ("priority", "low") + ].into()), + vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)], + } + ], + HashSet::from([EmbeddedField::Combined]) + ; "SingleWithMetadata mode with metadata")] + #[test_case( + // PerField - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "per_field_no_meta_1", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "per_field_no_meta_2", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), + ], + } + ], + HashSet::from([ + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + ]) + ; "PerField mode without metadata")] + #[test_case( + // PerField - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "single_with_meta_1", + metadata: Some(vec![ + ("category", "A"), + ("priority", "high") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.2), + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::PerField, + chunk: "single_with_meta_2", + metadata: Some(vec![ + ("category", "B"), + ("priority", "low") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 1.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 2.3), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.3), + ], + } + ], + HashSet::from([ + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + ]) + ; "PerField mode with metadata")] + #[test_case( + // Both - No Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_no_meta_1", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.0), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.1) + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_no_meta_2", + metadata: None, + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.2), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.3) + ], + } + ], + HashSet::from([EmbeddedField::Combined, EmbeddedField::Chunk]) + ; "Both mode without metadata")] + #[test_case( + // Both - With Metadata + vec![ + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_with_meta_1", + metadata: Some(vec![ + ("category", "P"), + ("priority", "urgent"), + ("tag", "test1") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.4), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 3.5), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 3.6), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 3.7), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 3.8) + ], + }, + PgVectorTestData { + embed_mode: EmbedMode::Both, + chunk: "both_with_meta_2", + metadata: Some(vec![ + ("category", "Q"), + ("priority", "low"), + ("tag", "test2") + ].into()), + vectors: vec![ + PgVectorTestData::create_test_vector(EmbeddedField::Combined, 3.9), + PgVectorTestData::create_test_vector(EmbeddedField::Chunk, 4.0), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("category".into()), 4.1), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("priority".into()), 4.2), + PgVectorTestData::create_test_vector(EmbeddedField::Metadata("tag".into()), 4.3) + ], + } + ], + HashSet::from([ + EmbeddedField::Combined, + EmbeddedField::Chunk, + EmbeddedField::Metadata("category".into()), + EmbeddedField::Metadata("priority".into()), + EmbeddedField::Metadata("tag".into()), + ]) + ; "Both mode with metadata")] + #[test_log::test(tokio::test)] + async fn test_persist_nodes( + test_cases: Vec>, + vector_fields: HashSet, + ) { + // Extract all possible metadata fields from test cases + let metadata_fields: Vec<&str> = test_cases + .iter() + .filter_map(|case| case.metadata.as_ref()) + .flat_map(|metadata| metadata.iter().map(|(key, _)| key.as_str())) + .collect::>() + .into_iter() + .collect(); + + // Initialize test context with all required metadata fields + let test_context = TestContext::setup_with_cfg(Some(metadata_fields), vector_fields) + .await + .expect("Test setup failed"); + + // Convert test cases to nodes and store them + let nodes: Vec = test_cases.iter().map(PgVectorTestData::to_node).collect(); + + // Test batch storage + let stored_nodes = test_context + .pgv_storage + .batch_store(nodes.clone()) + .await + .try_collect::>() + .await + .expect("Failed to store nodes"); + + assert_eq!( + stored_nodes.len(), + nodes.len(), + "All nodes should be stored" + ); + + // Verify storage for each test case + for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) { + // 1. Verify basic node properties + assert_eq!( + stored_node.chunk, test_case.chunk, + "Stored chunk should match" + ); + assert_eq!( + stored_node.embed_mode, test_case.embed_mode, + "Embed mode should match" + ); + + // 2. Verify vectors were stored correctly + let stored_vectors = stored_node + .vectors + .as_ref() + .expect("Vectors should be present"); + assert_eq!( + stored_vectors.len(), + test_case.vectors.len(), + "Vector count should match" + ); + } + } +} diff --git a/swiftide-integrations/src/pgvector/persist.rs b/swiftide-integrations/src/pgvector/persist.rs new file mode 100644 index 00000000..6b9973ae --- /dev/null +++ b/swiftide-integrations/src/pgvector/persist.rs @@ -0,0 +1,93 @@ +//! Storage persistence implementation for vector embeddings. +//! +//! Implements the [`Persist`] trait for [`PgVector`], providing vector storage capabilities: +//! - Database schema initialization and setup +//! - Single-node storage operations +//! - Optimized batch storage with configurable batch sizes +//! +//! The implementation ensures thread-safe concurrent access and handles +//! connection management automatically. +use crate::pgvector::PgVector; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use swiftide_core::{ + indexing::{IndexingStream, Node}, + Persist, +}; + +#[async_trait] +impl Persist for PgVector { + #[tracing::instrument(skip_all)] + async fn setup(&self) -> Result<()> { + // Get or initialize the connection pool + let pool = self.pool_get_or_initialize().await?; + + if self.sql_stmt_bulk_insert.get().is_none() { + let sql = self.generate_unnest_upsert_sql()?; + + self.sql_stmt_bulk_insert + .set(sql) + .map_err(|_| anyhow!("SQL bulk store statement is already set"))?; + } + + let mut tx = pool.begin().await?; + + // Create extension + let sql = "CREATE EXTENSION IF NOT EXISTS vector"; + sqlx::query(sql).execute(&mut *tx).await?; + + // Create table + let create_table_sql = self.generate_create_table_sql()?; + sqlx::query(&create_table_sql).execute(&mut *tx).await?; + + // Create HNSW index + let index_sql = self.create_index_sql()?; + sqlx::query(&index_sql).execute(&mut *tx).await?; + + tx.commit().await?; + + Ok(()) + } + + #[tracing::instrument(skip_all)] + async fn store(&self, node: Node) -> Result { + let mut nodes = vec![node; 1]; + self.store_nodes(&nodes).await?; + + let node = nodes.swap_remove(0); + + Ok(node) + } + + #[tracing::instrument(skip_all)] + async fn batch_store(&self, nodes: Vec) -> IndexingStream { + self.store_nodes(&nodes).await.map(|()| nodes).into() + } + + fn batch_size(&self) -> Option { + Some(self.batch_size) + } +} + +#[cfg(test)] +mod tests { + use crate::pgvector::fixtures::TestContext; + use std::collections::HashSet; + use swiftide_core::{indexing::EmbeddedField, Persist}; + + #[test_log::test(tokio::test)] + async fn test_persist_setup_no_error_when_table_exists() { + let test_context = TestContext::setup_with_cfg( + vec!["filter"].into(), + HashSet::from([EmbeddedField::Combined]), + ) + .await + .expect("Test setup failed"); + + test_context + .pgv_storage + .setup() + .await + .expect("PgVector setup should not fail when the table already exists"); + } +} diff --git a/swiftide-integrations/src/pgvector/pgv_table_types.rs b/swiftide-integrations/src/pgvector/pgv_table_types.rs new file mode 100644 index 00000000..f49f9211 --- /dev/null +++ b/swiftide-integrations/src/pgvector/pgv_table_types.rs @@ -0,0 +1,575 @@ +//! `PostgreSQL` table schema and type conversion utilities for vector storage. +//! +//! Provides schema configuration and data type conversion functionality: +//! - Table schema generation with vector and metadata columns +//! - Field configuration for different vector embedding types +//! - HNSW index creation for similarity search optimization +//! - Bulk data preparation and SQL query generation +//! +use crate::pgvector::PgVector; +use anyhow::{anyhow, Result}; +use pgvector as ExtPgVector; +use regex::Regex; +use sqlx::postgres::PgArguments; +use sqlx::postgres::PgPoolOptions; +use sqlx::PgPool; +use swiftide_core::indexing::{EmbeddedField, Node}; +use tokio::time::sleep; + +/// Configuration for vector embedding columns in the `PostgreSQL` table. +/// +/// This struct defines how vector embeddings are stored and managed in the database, +/// mapping Swiftide's embedded fields to `PostgreSQL` vector columns. +#[derive(Clone, Debug)] +pub struct VectorConfig { + embedded_field: EmbeddedField, + field: String, +} + +impl VectorConfig { + pub fn new(embedded_field: &EmbeddedField) -> Self { + Self { + embedded_field: embedded_field.clone(), + field: format!( + "vector_{}", + PgVector::normalize_field_name(&embedded_field.to_string()), + ), + } + } +} + +impl From for VectorConfig { + fn from(val: EmbeddedField) -> Self { + Self::new(&val) + } +} + +/// Configuration for metadata fields in the `PostgreSQL` table. +/// +/// Handles the mapping and storage of metadata fields, ensuring proper column naming +/// and type conversion for `PostgreSQL` compatibility. +#[derive(Clone, Debug)] +pub struct MetadataConfig { + field: String, + original_field: String, +} + +impl MetadataConfig { + pub fn new>(original_field: T) -> Self { + let original = original_field.into(); + Self { + field: format!("meta_{}", PgVector::normalize_field_name(&original)), + original_field: original, + } + } +} + +impl> From for MetadataConfig { + fn from(val: T) -> Self { + Self::new(val.as_ref()) + } +} + +/// Field configuration types supported in the `PostgreSQL` table schema. +/// +/// Represents different field types that can be configured in the table schema, +/// including vector embeddings, metadata, and system fields. +#[derive(Clone, Debug)] +pub enum FieldConfig { + /// `Vector` - Vector embedding field configuration + Vector(VectorConfig), + /// `Metadata` - Metadata field configuration + Metadata(MetadataConfig), + /// `Chunk` - Text content storage field + Chunk, + /// `ID` - Primary key field + ID, +} + +impl FieldConfig { + pub fn field_name(&self) -> &str { + match self { + FieldConfig::Vector(config) => &config.field, + FieldConfig::Metadata(config) => &config.field, + FieldConfig::Chunk => "chunk", + FieldConfig::ID => "id", + } + } +} + +/// Internal structure for managing bulk upsert operations. +/// +/// Collects and organizes data for efficient bulk insertions and updates, +/// grouping related fields for UNNEST-based operations. +struct BulkUpsertData<'a> { + ids: Vec, + chunks: Vec<&'a str>, + metadata_fields: Vec>, + vector_fields: Vec>, + field_mapping: FieldMapping<'a>, +} + +struct FieldMapping<'a> { + metadata_names: Vec<&'a str>, + vector_names: Vec<&'a str>, +} + +impl<'a> BulkUpsertData<'a> { + fn new(fields: &'a [FieldConfig], size: usize) -> Self { + let (metadata_names, vector_names): (Vec<&str>, Vec<&str>) = ( + fields + .iter() + .filter_map(|field| match field { + FieldConfig::Metadata(config) => Some(config.field.as_str()), + _ => None, + }) + .collect(), + fields + .iter() + .filter_map(|field| match field { + FieldConfig::Vector(config) => Some(config.field.as_str()), + _ => None, + }) + .collect(), + ); + + Self { + ids: Vec::with_capacity(size), + chunks: Vec::with_capacity(size), + metadata_fields: vec![Vec::with_capacity(size); metadata_names.len()], + vector_fields: vec![Vec::with_capacity(size); vector_names.len()], + field_mapping: FieldMapping { + metadata_names, + vector_names, + }, + } + } + + fn get_metadata_index(&self, field: &str) -> Option { + self.field_mapping + .metadata_names + .iter() + .position(|&name| name == field) + } + + fn get_vector_index(&self, field: &str) -> Option { + self.field_mapping + .vector_names + .iter() + .position(|&name| name == field) + } +} + +impl PgVector { + /// Generates a SQL statement to create a table for storing vector embeddings. + /// + /// The table will include columns for an ID, chunk data, metadata, and a vector embedding. + /// + /// # Returns + /// + /// * The generated SQL statement. + /// + /// # Errors + /// + /// * Returns an error if the table name is invalid or if `vector_size` is not configured. + pub fn generate_create_table_sql(&self) -> Result { + // Validate table_name and field_name (e.g., check against allowed patterns) + if !Self::is_valid_identifier(&self.table_name) { + return Err(anyhow::anyhow!("Invalid table name")); + } + + let columns: Vec = self + .fields + .iter() + .map(|field| match field { + FieldConfig::ID => "id UUID NOT NULL".to_string(), + FieldConfig::Chunk => format!("{} TEXT NOT NULL", field.field_name()), + FieldConfig::Metadata(_) => format!("{} JSONB", field.field_name()), + FieldConfig::Vector(_) => { + format!("{} VECTOR({})", field.field_name(), self.vector_size) + } + }) + .chain(std::iter::once("PRIMARY KEY (id)".to_string())) + .collect(); + + let sql = format!( + "CREATE TABLE IF NOT EXISTS {} (\n {}\n)", + self.table_name, + columns.join(",\n ") + ); + + Ok(sql) + } + + /// Generates the SQL statement to create an HNSW index on the vector column. + /// + /// # Errors + /// + /// Returns an error if: + /// - No vector field is found in the table configuration. + /// - The table name or field name is invalid. + pub fn create_index_sql(&self) -> Result { + let index_name = format!("{}_embedding_idx", self.table_name); + let vector_field = self + .fields + .iter() + .find(|f| matches!(f, FieldConfig::Vector(_))) + .ok_or_else(|| anyhow::anyhow!("No vector field found in configuration"))? + .field_name(); + + // Validate table_name and field_name (e.g., check against allowed patterns) + if !Self::is_valid_identifier(&self.table_name) + || !Self::is_valid_identifier(&index_name) + || !Self::is_valid_identifier(vector_field) + { + return Err(anyhow::anyhow!("Invalid table or field name")); + } + + Ok(format!( + "CREATE INDEX IF NOT EXISTS {} ON {} USING hnsw ({} vector_cosine_ops)", + index_name, &self.table_name, vector_field + )) + } + + /// Stores a list of nodes in the database using an upsert operation. + /// + /// # Arguments + /// + /// * `nodes` - A slice of `Node` objects to be stored. + /// + /// # Returns + /// + /// * `Result<()>` - `Ok` if all nodes are successfully stored, `Err` otherwise. + /// + /// # Errors + /// + /// This function will return an error if: + /// - The database connection pool is not established. + /// - Any of the SQL queries fail to execute due to schema mismatch, constraint violations, or connectivity issues. + /// - Committing the transaction fails. + pub async fn store_nodes(&self, nodes: &[Node]) -> Result<()> { + let pool = self.pool_get_or_initialize().await?; + + let mut tx = pool.begin().await?; + let bulk_data = self.prepare_bulk_data(nodes)?; + + let sql = self + .sql_stmt_bulk_insert + .get() + .ok_or_else(|| anyhow!("SQL bulk insert statement not set"))?; + + tracing::info!("Sql statement :: {:#?}", sql); + + let query = self.bind_bulk_data_to_query(sqlx::query(sql), &bulk_data)?; + + query + .execute(&mut *tx) + .await + .map_err(|e| anyhow!("Failed to store nodes: {:?}", e))?; + + tx.commit() + .await + .map_err(|e| anyhow!("Failed to commit transaction: {:?}", e)) + } + + /// Prepares data from nodes into vectors for bulk processing. + #[allow(clippy::implicit_clone)] + fn prepare_bulk_data<'a>(&'a self, nodes: &'a [Node]) -> Result> { + let mut bulk_data = BulkUpsertData::new(&self.fields, nodes.len()); + + for node in nodes { + bulk_data.ids.push(node.id()); + bulk_data.chunks.push(node.chunk.as_str()); + + for field in &self.fields { + match field { + FieldConfig::Metadata(config) => { + let idx = bulk_data + .get_metadata_index(config.field.as_str()) + .ok_or_else(|| anyhow!("Invalid metadata field"))?; + + let value = node + .metadata + .get(&config.original_field) + .ok_or_else(|| anyhow!("Missing metadata field"))?; + + bulk_data.metadata_fields[idx].push(value.clone()); + } + FieldConfig::Vector(config) => { + let idx = bulk_data + .get_vector_index(config.field.as_str()) + .ok_or_else(|| anyhow!("Invalid vector field"))?; + + let data = node + .vectors + .as_ref() + .and_then(|v| v.get(&config.embedded_field)) + .map(|v| v.to_vec()) + .unwrap_or_default(); + + bulk_data.vector_fields[idx].push(ExtPgVector::Vector::from(data)); + } + _ => continue, + } + } + } + + Ok(bulk_data) + } + + /// Generates SQL for UNNEST-based bulk upsert. + /// + /// # Returns + /// + /// * `Result` - The generated SQL statement or an error if fields are empty. + /// + /// # Errors + /// + /// Returns an error if `self.fields` is empty, as no valid SQL can be generated. + pub(crate) fn generate_unnest_upsert_sql(&self) -> Result { + if self.fields.is_empty() { + return Err(anyhow!("Cannot generate upsert SQL with empty fields")); + } + + let mut columns = Vec::new(); + let mut unnest_params = Vec::new(); + let mut param_counter = 1; + + for field in &self.fields { + let name = field.field_name(); + columns.push(name.to_string()); + + unnest_params.push(format!( + "${param_counter}::{}", + match field { + FieldConfig::ID => "UUID[]", + FieldConfig::Chunk => "TEXT[]", + FieldConfig::Metadata(_) => "JSONB[]", + FieldConfig::Vector(_) => "VECTOR[]", + } + )); + + param_counter += 1; + } + + let update_columns = self + .fields + .iter() + .filter(|field| !matches!(field, FieldConfig::ID)) // Skip ID field in updates + .map(|field| { + let name = field.field_name(); + format!("{name} = EXCLUDED.{name}") + }) + .collect::>() + .join(", "); + + Ok(format!( + r#" + INSERT INTO {} ({}) + SELECT {} + FROM UNNEST({}) AS t({}) + ON CONFLICT (id) DO UPDATE SET {}"#, + self.table_name, + columns.join(", "), + columns.join(", "), + unnest_params.join(", "), + columns.join(", "), + update_columns + )) + } + + /// Binds bulk data to the SQL query, ensuring data arrays are matched to corresponding fields. + /// + /// # Errors + /// + /// Returns an error if any metadata or vector field is missing from the bulk data. + #[allow(clippy::implicit_clone)] + fn bind_bulk_data_to_query<'a>( + &self, + mut query: sqlx::query::Query<'a, sqlx::Postgres, PgArguments>, + bulk_data: &'a BulkUpsertData, + ) -> Result> { + for field in &self.fields { + query = match field { + FieldConfig::ID => query.bind(&bulk_data.ids), + FieldConfig::Chunk => query.bind(&bulk_data.chunks), + FieldConfig::Vector(config) => { + let idx = bulk_data + .get_vector_index(config.field.as_str()) + .ok_or_else(|| { + anyhow!("Vector field {} not found in bulk data", config.field) + })?; + query.bind(&bulk_data.vector_fields[idx]) + } + FieldConfig::Metadata(config) => { + let idx = bulk_data + .get_metadata_index(config.field.as_str()) + .ok_or_else(|| { + anyhow!("Metadata field {} not found in bulk data", config.field) + })?; + query.bind(&bulk_data.metadata_fields[idx]) + } + }; + } + Ok(query) + } + + /// Retrieves the name of the vector column configured in the schema. + /// + /// # Returns + /// * `Ok(String)` - The name of the vector column if exactly one is configured. + /// # Errors + /// * `Error::NoEmbedding` - If no vector field is configured in the schema. + /// * `Error::MultipleEmbeddings` - If multiple vector fields are configured in the schema. + pub fn get_vector_column_name(&self) -> Result { + let vector_fields: Vec<_> = self + .fields + .iter() + .filter(|field| matches!(field, FieldConfig::Vector(_))) + .collect(); + + match vector_fields.as_slice() { + [field] => Ok(field.field_name().to_string()), + [] => Err(anyhow!("No vector field configured in schema")), + _ => Err(anyhow!("Multiple vector fields configured in schema")), + } + } +} + +impl PgVector { + pub(crate) fn normalize_field_name(field: &str) -> String { + // Define the special characters as an array + let special_chars: [char; 4] = ['(', '[', '{', '<']; + + // First split by special characters and take the first part + let base_text = field + .split(|c| special_chars.contains(&c)) + .next() + .unwrap_or(field) + .trim(); + + // Split by whitespace, take up to 3 words, convert to lowercase + let normalized = base_text + .split_whitespace() + .take(3) + .collect::>() + .join("_") + .to_lowercase(); + + // Ensure the result only contains alphanumeric chars and underscores + normalized + .chars() + .filter(|c| c.is_alphanumeric() || *c == '_') + .collect() + } + + pub(crate) fn is_valid_identifier(identifier: &str) -> bool { + // PostgreSQL identifier rules: + // 1. Must start with a letter (a-z) or underscore + // 2. Subsequent characters can be letters, underscores, digits (0-9), or dollar signs + // 3. Maximum length is 63 bytes + // 4. Cannot be a reserved keyword + + // Check length + if identifier.is_empty() || identifier.len() > 63 { + return false; + } + + // Use a regular expression to check the pattern + let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_$]*$").unwrap(); + if !identifier_regex.is_match(identifier) { + return false; + } + + // Check if it's not a reserved keyword + !Self::is_reserved_keyword(identifier) + } + + pub(crate) fn is_reserved_keyword(word: &str) -> bool { + // This list is not exhaustive. You may want to expand it based on + // the PostgreSQL version you're using. + const RESERVED_KEYWORDS: &[&str] = &[ + "SELECT", "FROM", "WHERE", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "TABLE", + "INDEX", "ALTER", "ADD", "COLUMN", "AND", "OR", "NOT", "NULL", "TRUE", + "FALSE", + // Add more keywords as needed + ]; + + RESERVED_KEYWORDS.contains(&word.to_uppercase().as_str()) + } +} + +impl PgVector { + async fn create_pool(&self) -> Result { + let pool_options = PgPoolOptions::new().max_connections(self.db_max_connections); + + for attempt in 1..=self.db_max_retry { + match pool_options.clone().connect(self.db_url.as_ref()).await { + Ok(pool) => { + tracing::info!("Successfully established database connection"); + return Ok(pool); + } + Err(err) if attempt < self.db_max_retry => { + tracing::warn!( + error = %err, + attempt = attempt, + max_retries = self.db_max_retry, + "Database connection attempt failed, retrying..." + ); + sleep(self.db_conn_retry_delay).await; + } + Err(err) => { + return Err(anyhow!(err).context("Failed to establish database connection")); + } + } + } + + Err(anyhow!( + "Max connection retries ({}) exceeded", + self.db_max_retry + )) + } + + /// Returns a reference to the `PgPool` if it is already initialized, + /// or creates and initializes it if it is not. + /// + /// # Errors + /// This function will return an error if pool creation fails. + pub async fn pool_get_or_initialize(&self) -> Result<&PgPool> { + if let Some(pool) = self.connection_pool.get() { + return Ok(pool); + } + + let pool = self.create_pool().await?; + self.connection_pool + .set(pool) + .map_err(|_| anyhow!("Pool already initialized"))?; + + // Re-check if the pool was set successfully, otherwise return an error + self.connection_pool + .get() + .ok_or_else(|| anyhow!("Failed to retrieve connection pool after setting it")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_identifiers() { + assert!(PgVector::is_valid_identifier("valid_name")); + assert!(PgVector::is_valid_identifier("_valid_name")); + assert!(PgVector::is_valid_identifier("valid_name_123")); + assert!(PgVector::is_valid_identifier("validName")); + } + + #[test] + fn test_invalid_identifiers() { + assert!(!PgVector::is_valid_identifier("")); // Empty string + assert!(!PgVector::is_valid_identifier(&"a".repeat(64))); // Too long + assert!(!PgVector::is_valid_identifier("123_invalid")); // Starts with a number + assert!(!PgVector::is_valid_identifier("invalid-name")); // Contains hyphen + assert!(!PgVector::is_valid_identifier("select")); // Reserved keyword + } +} diff --git a/swiftide-test-utils/Cargo.toml b/swiftide-test-utils/Cargo.toml index f31b5744..840729f7 100644 --- a/swiftide-test-utils/Cargo.toml +++ b/swiftide-test-utils/Cargo.toml @@ -13,7 +13,7 @@ homepage.workspace = true [dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } -swiftide-integrations = { path = "../swiftide-integrations", all-features = true } +swiftide-integrations = { path = "../swiftide-integrations", features = ["openai"] } async-openai = { workspace = true } qdrant-client = { workspace = true, default-features = false, features = [ diff --git a/swiftide-test-utils/src/test_utils.rs b/swiftide-test-utils/src/test_utils.rs index 86ba416a..3ce54df6 100644 --- a/swiftide-test-utils/src/test_utils.rs +++ b/swiftide-test-utils/src/test_utils.rs @@ -3,7 +3,9 @@ use serde_json::json; use testcontainers::{ - core::wait::HttpWaitStrategy, runners::AsyncRunner as _, ContainerAsync, GenericImage, + core::{wait::HttpWaitStrategy, IntoContainerPort, WaitFor}, + runners::AsyncRunner, + ContainerAsync, GenericImage, ImageExt, }; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -70,6 +72,31 @@ pub async fn start_redis() -> (ContainerAsync, String) { (redis, redis_url) } +/// Setup Postgres container. +/// Returns container server and `server_url`. +pub async fn start_postgres() -> (ContainerAsync, String) { + let postgres = testcontainers::GenericImage::new("pgvector/pgvector", "pg17") + .with_wait_for(WaitFor::message_on_stdout( + "database system is ready to accept connections", + )) + .with_exposed_port(5432.tcp()) + .with_env_var("POSTGRES_USER", "myuser") + .with_env_var("POSTGRES_PASSWORD", "mypassword") + .with_env_var("POSTGRES_DB", "mydatabase") + .start() + .await + .expect("Failed to start Postgres container"); + + // Construct the connection URL using the dynamically assigned port + let host_port = postgres.get_host_port_ipv4(5432).await.unwrap(); + let postgres_url = format!( + "postgresql://myuser:mypassword@127.0.0.1:{}/mydatabase", + host_port + ); + + (postgres, postgres_url) +} + /// Mock embeddings creation endpoint. /// `embeddings_count` controls number of returned embedding vectors. pub async fn mock_embeddings(mock_server: &MockServer, embeddings_count: u8) { diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index 202bb88e..766ee35d 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -35,6 +35,7 @@ all = [ "aws-bedrock", "groq", "ollama", + "pgvector", ] #! ### Integrations @@ -42,6 +43,9 @@ all = [ ## Enables Qdrant for storage and retrieval qdrant = ["swiftide-integrations/qdrant", "swiftide-core/qdrant"] +## Enables PgVector for storage and retrieval +pgvector = ["swiftide-integrations/pgvector"] + ## Enables Redis as an indexing cache and storage redis = ["swiftide-integrations/redis"]