Skip to content

Commit

Permalink
Merge branch 'master' into feat/setup-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
timonv authored Jun 13, 2024
2 parents 8854767 + 4d79d27 commit f19ec5f
Show file tree
Hide file tree
Showing 15 changed files with 1,090 additions and 68 deletions.
707 changes: 701 additions & 6 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions swiftide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ tree-sitter = [
"dep:tree-sitter-javascript",
]
openai = ["dep:async-openai"]

[dev-dependencies]
test-log = "0.2.16"
testcontainers = "0.17.0"
mockall = "0.12.1"
temp-dir = "0.1.13"
wiremock = "0.6.0"

[lints.clippy]
blocks_in_conditions = "allow"
68 changes: 68 additions & 0 deletions swiftide/src/ingestion/ingestion_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,71 @@ impl IngestionPipeline {
Ok(())
}
}

#[cfg(test)]
mod tests {

use super::*;
use crate::traits::*;
use futures_util::stream;
use mockall::Sequence;

#[test_log::test(tokio::test)]
async fn test_simple_run() {
let mut loader = MockLoader::new();
let mut transformer = MockTransformer::new();
let mut batch_transformer = MockBatchableTransformer::new();
let mut chunker = MockChunkerTransformer::new();
let mut storage = MockStorage::new();

let mut seq = Sequence::new();

loader
.expect_into_stream()
.times(1)
.in_sequence(&mut seq)
.returning(|| Box::pin(stream::iter(vec![Ok(IngestionNode::default())])));

transformer.expect_transform_node().returning(|mut node| {
node.chunk = "transformed".to_string();
Ok(node)
});

batch_transformer
.expect_batch_transform()
.times(1)
.in_sequence(&mut seq)
.returning(|nodes| Box::pin(stream::iter(nodes.into_iter().map(Ok))));

chunker
.expect_transform_node()
.times(1)
.in_sequence(&mut seq)
.returning(|node| {
let mut nodes = vec![];
for i in 0..3 {
let mut node = node.clone();
node.chunk = format!("transformed_chunk_{}", i);
nodes.push(Ok(node));
}
Box::pin(stream::iter(nodes))
});

storage.expect_setup().returning(|| Ok(()));
storage.expect_batch_size().returning(|| None);
storage
.expect_store()
.times(3)
.in_sequence(&mut seq)
.withf(|node| node.chunk.starts_with("transformed_chunk_"))
.returning(|_| Ok(()));

let pipeline = IngestionPipeline::from_loader(loader)
.then(transformer)
.then_in_batch(1, batch_transformer)
.then_chunk(chunker)
.store_with(storage);

pipeline.run().await.unwrap();
}
}
10 changes: 8 additions & 2 deletions swiftide/src/integrations/openai/embed.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Context as _, Result};
use async_openai::types::CreateEmbeddingRequestArgs;
use async_trait::async_trait;

Expand All @@ -9,8 +9,14 @@ use super::OpenAI;
#[async_trait]
impl Embed for OpenAI {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
let model = self
.default_options
.embed_model
.as_ref()
.context("Model not set")?;

let request = CreateEmbeddingRequestArgs::default()
.model(&self.embed_model)
.model(model)
.input(input)
.build()?;
tracing::debug!(
Expand Down
44 changes: 40 additions & 4 deletions swiftide/src/integrations/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
use std::sync::Arc;

use derive_builder::Builder;

mod embed;
mod simple_prompt;

#[derive(Debug)]
#[derive(Debug, Builder, Clone)]
pub struct OpenAI {
client: async_openai::Client<async_openai::config::OpenAIConfig>,
embed_model: String,
prompt_model: String,
#[builder(default = "Arc::new(async_openai::Client::new())", setter(custom))]
client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
#[builder(default)]
default_options: Options,
}

#[derive(Debug, Default, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct Options {
#[builder(default)]
pub embed_model: Option<String>,
#[builder(default)]
pub prompt_model: Option<String>,
}

impl Options {
pub fn builder() -> OptionsBuilder {
OptionsBuilder::default()
}
}

impl OpenAI {
pub fn builder() -> OpenAIBuilder {
OpenAIBuilder::default()
}
}

impl OpenAIBuilder {
pub fn client(
&mut self,
client: async_openai::Client<async_openai::config::OpenAIConfig>,
) -> &mut Self {
self.client = Some(Arc::new(client));
self
}
}
8 changes: 7 additions & 1 deletion swiftide/src/integrations/openai/simple_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ use anyhow::{Context as _, Result};
impl SimplePrompt for OpenAI {
#[tracing::instrument(skip(self), err)]
async fn prompt(&self, prompt: &str) -> Result<String> {
let model = self
.default_options
.prompt_model
.as_ref()
.context("Model not set")?;

let request = CreateChatCompletionRequestArgs::default()
.model(&self.prompt_model)
.model(model)
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.build()?
Expand Down
30 changes: 12 additions & 18 deletions swiftide/src/integrations/qdrant/mod.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,33 @@
mod ingestion_node;
mod persist;

use anyhow::Result;
use derive_builder::Builder;
use qdrant_client::client::QdrantClient;
use qdrant_client::prelude::*;
use qdrant_client::qdrant::vectors_config::Config;
use qdrant_client::qdrant::{VectorParams, VectorsConfig};

const DEFAULT_COLLECTION_NAME: &str = "swiftide";

#[derive(Builder)]
#[builder(pattern = "owned")]
pub struct Qdrant {
client: QdrantClient,
#[builder(default = "DEFAULT_COLLECTION_NAME.to_string()")]
collection_name: String,
vector_size: usize,
#[builder(default)]
batch_size: Option<usize>,
}

impl Qdrant {
pub fn from_client(client: QdrantClient, collection_name: impl Into<String>) -> Self {
Qdrant {
client,
collection_name: collection_name.into(),
vector_size: 1536,
batch_size: None,
}
}

pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
pub fn builder() -> QdrantBuilder {
QdrantBuilder::default()
}

/// The size (dimensions) of the embedding vectors being stored
///
/// I.e. for small openai embeddings this is 1536
pub fn with_vector_size(mut self, vector_size: usize) -> Self {
self.vector_size = vector_size;
self
pub fn try_from_url(url: &str) -> Result<QdrantBuilder> {
Ok(QdrantBuilder::default().client(QdrantClient::from_url(url).build()?))
}

pub async fn create_index_if_not_exists(&self) -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions swiftide/src/integrations/qdrant/persist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ impl Storage for Qdrant {

#[tracing::instrument(skip_all, err)]
async fn setup(&self) -> Result<()> {
tracing::debug!("Setting up Qdrant storage");
self.create_index_if_not_exists().await
}

Expand Down
20 changes: 16 additions & 4 deletions swiftide/src/integrations/redis/node_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,23 @@ impl NodeCache for RedisNodeCache {
mod tests {
use super::*;
use std::collections::HashMap;
#[tokio::test]
use testcontainers::runners::AsyncRunner;

#[test_log::test(tokio::test)]
async fn test_redis_cache() {
let redis_url = std::env::var("REDIS_URL").expect("REDIS_URL not set");
let cache =
RedisNodeCache::try_from_url(&redis_url, "test").expect("Could not build redis client");
let redis = testcontainers::GenericImage::new("redis", "7.2.4")
.with_exposed_port(6379)
.with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
"Ready to accept connections",
))
.start()
.await
.expect("Redis started");

let host = redis.get_host().await.unwrap();
let port = redis.get_host_port_ipv4(6379).await.unwrap();
let cache = RedisNodeCache::try_from_url(&format!("redis://{host}:{port}"), "test")
.expect("Could not build redis client");
cache.reset_cache().await;

let node = IngestionNode {
Expand Down
15 changes: 10 additions & 5 deletions swiftide/src/integrations/treesitter/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const DEFAULT_MAX_BYTES: usize = 1500;
/// Splits code files into meaningful chunks
///
/// Supports splitting code files into chunks based on a maximum size or a range of bytes.
#[builder(setter(into), build_fn(error = "anyhow::Error"))]
pub struct CodeSplitter {
/// Maximum size of a chunk in bytes or a range of bytes
#[builder(default, setter(into))]
Expand All @@ -22,7 +23,7 @@ pub struct CodeSplitter {
}

impl CodeSplitterBuilder {
pub fn language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> {
pub fn try_language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> {
self.language = Some(
// For some reason there's a trait conflict, wth
language
Expand Down Expand Up @@ -170,7 +171,8 @@ mod test {
#[test]
fn test_max_bytes_limit() {
let splitter = CodeSplitter::builder()
.language(SupportedLanguages::Rust)?
.try_language(SupportedLanguages::Rust)
.unwrap()
.chunk_size(50)
.build()
.unwrap();
Expand All @@ -197,7 +199,8 @@ mod test {
#[test]
fn test_empty_text() {
let splitter = CodeSplitter::builder()
.language(SupportedLanguages::Rust)?
.try_language(SupportedLanguages::Rust)
.unwrap()
.chunk_size(50)
.build()
.unwrap();
Expand All @@ -212,7 +215,8 @@ mod test {
#[test]
fn test_range_max() {
let splitter = CodeSplitter::builder()
.language(SupportedLanguages::Rust)?
.try_language(SupportedLanguages::Rust)
.unwrap()
.chunk_size(0..50)
.build()
.unwrap();
Expand All @@ -237,7 +241,8 @@ mod test {
#[test]
fn test_range_min_and_max() {
let splitter = CodeSplitter::builder()
.language(SupportedLanguages::Rust)?
.try_language(SupportedLanguages::Rust)
.unwrap()
.chunk_size(20..50)
.build()
.unwrap();
Expand Down
5 changes: 3 additions & 2 deletions swiftide/src/integrations/treesitter/supported_languages.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// pub use std::str::FromStr as _;
#[allow(unused_imports)]
pub use std::str::FromStr as _;

#[derive(Debug, PartialEq, Clone, Copy, strum_macros::EnumString, strum_macros::Display)]
#[strum(ascii_case_insensitive)]
pub enum SupportedLanguages {
Rust,
Typescript,
Expand Down Expand Up @@ -42,7 +44,6 @@ impl From<SupportedLanguages> for tree_sitter::Language {
#[cfg(test)]
mod test {
use super::*;
use std::str::FromStr;

#[test]
fn test_supported_languages_from_str() {
Expand Down
10 changes: 10 additions & 0 deletions swiftide/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ use crate::{ingestion::IngestionNode, ingestion::IngestionStream, Embeddings};
use anyhow::Result;
use async_trait::async_trait;

/// All traits are easilly mockable under tests
#[cfg(test)]
use mockall::{automock, predicate::*};

#[cfg_attr(test, automock)]
#[async_trait]
/// Transforms single nodes into single nodes
pub trait Transformer: Send + Sync + Debug {
async fn transform_node(&self, node: IngestionNode) -> Result<IngestionNode>;
}

#[cfg_attr(test, automock)]
#[async_trait]
/// Transforms batched single nodes into streams of nodes
pub trait BatchableTransformer: Send + Sync + Debug {
Expand All @@ -20,16 +26,19 @@ pub trait BatchableTransformer: Send + Sync + Debug {
}

/// Starting point of a stream
#[cfg_attr(test, automock)]
pub trait Loader {
fn into_stream(self) -> IngestionStream;
}

#[cfg_attr(test, automock)]
#[async_trait]
/// Turns one node into many nodes
pub trait ChunkerTransformer: Send + Sync + Debug {
async fn transform_node(&self, node: IngestionNode) -> IngestionStream;
}

#[cfg_attr(test, automock)]
#[async_trait]
/// Caches nodes, typically by their path and hash
/// Recommended to namespace on the storage
Expand All @@ -51,6 +60,7 @@ pub trait SimplePrompt: Debug + Send + Sync {
async fn prompt(&self, prompt: &str) -> Result<String>;
}

#[cfg_attr(test, automock)]
#[async_trait]
/// Persists nodes
pub trait Storage: Send + Sync {
Expand Down
Loading

0 comments on commit f19ec5f

Please sign in to comment.