diff --git a/Cargo.lock b/Cargo.lock index 1cd0c77d..c6aa38c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2922,6 +2922,7 @@ dependencies = [ "itertools 0.13.0", "mockall", "num_cpus", + "pin-project-lite", "qdrant-client", "redis", "serde", diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index 9813f8aa..b4f7cdfd 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -28,6 +28,7 @@ tracing = { version = "0.1.40", features = ["log"] } strum = "0.26.2" strum_macros = "0.26.4" num_cpus = "1.16.0" +pin-project-lite = "0.2" # Integrations async-openai = { version = "0.23.2", optional = true } diff --git a/swiftide/src/ingestion/ingestion_pipeline.rs b/swiftide/src/ingestion/ingestion_pipeline.rs index de679fa2..8d08c66f 100644 --- a/swiftide/src/ingestion/ingestion_pipeline.rs +++ b/swiftide/src/ingestion/ingestion_pipeline.rs @@ -27,7 +27,7 @@ impl Default for IngestionPipeline { /// Creates a default `IngestionPipeline` with an empty stream, no storage, and a concurrency level equal to the number of CPUs. fn default() -> Self { Self { - stream: Box::pin(futures_util::stream::empty()), + stream: IngestionStream::empty(), storage: Default::default(), concurrency: num_cpus::get(), } @@ -47,7 +47,7 @@ impl IngestionPipeline { pub fn from_loader(loader: impl Loader + 'static) -> Self { let stream = loader.into_stream(); Self { - stream: stream.boxed(), + stream, ..Default::default() } } @@ -95,7 +95,8 @@ impl IngestionPipeline { } .instrument(span) }) - .boxed(); + .boxed() + .into(); self } @@ -120,7 +121,8 @@ impl IngestionPipeline { async move { transformer.transform_node(node).await }.instrument(span) }) .try_buffer_unordered(concurrency) - .boxed(); + .boxed() + .into(); self } @@ -154,7 +156,7 @@ impl IngestionPipeline { }) .try_buffer_unordered(concurrency) // First get the streams from each future .try_flatten_unordered(concurrency) // Then flatten all the streams back into one - .boxed(); + .boxed().into(); self } @@ -180,7 +182,8 @@ impl IngestionPipeline { }) .try_buffer_unordered(concurrency) .try_flatten_unordered(concurrency) - .boxed(); + .boxed() + .into(); self } @@ -210,7 +213,7 @@ impl IngestionPipeline { }) .try_buffer_unordered(self.concurrency) .try_flatten_unordered(self.concurrency) - .boxed(); + .boxed().into(); } else { self.stream = self .stream @@ -222,7 +225,8 @@ impl IngestionPipeline { async move { storage.store(node).await }.instrument(span) }) .try_buffer_unordered(self.concurrency) - .boxed(); + .boxed() + .into(); } self @@ -232,7 +236,9 @@ impl IngestionPipeline { /// /// Useful for rate limiting the ingestion pipeline. Uses tokio_stream::StreamExt::throttle internally which has a granualarity of 1ms. pub fn throttle(mut self, duration: impl Into) -> Self { - self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into()).boxed(); + self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into()) + .boxed() + .into(); self } @@ -249,7 +255,8 @@ impl IngestionPipeline { Err(_e) => None, } }) - .boxed(); + .boxed() + .into(); self } @@ -260,7 +267,8 @@ impl IngestionPipeline { self.stream = self .stream .inspect(|result| tracing::debug!("Processing result: {:?}", result)) - .boxed(); + .boxed() + .into(); self } @@ -271,7 +279,8 @@ impl IngestionPipeline { self.stream = self .stream .inspect_err(|e| tracing::error!("Error processing node: {:?}", e)) - .boxed(); + .boxed() + .into(); self } @@ -282,7 +291,8 @@ impl IngestionPipeline { self.stream = self .stream .inspect_ok(|node| tracing::debug!("Processed node: {:?}", node)) - .boxed(); + .boxed() + .into(); self } @@ -333,7 +343,6 @@ mod tests { use super::*; use crate::ingestion::IngestionNode; use crate::traits::*; - use futures_util::stream; use mockall::Sequence; /// Tests a simple run of the ingestion pipeline. @@ -351,7 +360,7 @@ mod tests { .expect_into_stream() .times(1) .in_sequence(&mut seq) - .returning(|| Box::pin(stream::iter(vec![Ok(IngestionNode::default())]))); + .returning(|| vec![Ok(IngestionNode::default())].into()); transformer.expect_transform_node().returning(|mut node| { node.chunk = "transformed".to_string(); @@ -363,7 +372,7 @@ mod tests { .expect_batch_transform() .times(1) .in_sequence(&mut seq) - .returning(|nodes| Box::pin(stream::iter(nodes.into_iter().map(Ok)))); + .returning(|nodes| IngestionStream::iter(nodes.into_iter().map(Ok))); batch_transformer.expect_concurrency().returning(|| None); chunker @@ -377,7 +386,7 @@ mod tests { node.chunk = format!("transformed_chunk_{}", i); nodes.push(Ok(node)); } - Box::pin(stream::iter(nodes)) + nodes.into() }); chunker.expect_concurrency().returning(|| None); @@ -409,7 +418,7 @@ mod tests { .expect_into_stream() .times(1) .in_sequence(&mut seq) - .returning(|| Box::pin(stream::iter(vec![Ok(IngestionNode::default())]))); + .returning(|| vec![Ok(IngestionNode::default())].into()); transformer .expect_transform_node() .returning(|_node| Err(anyhow::anyhow!("Error transforming node"))); @@ -435,11 +444,12 @@ mod tests { .times(1) .in_sequence(&mut seq) .returning(|| { - Box::pin(stream::iter(vec![ + vec![ Ok(IngestionNode::default()), Ok(IngestionNode::default()), Ok(IngestionNode::default()), - ])) + ] + .into() }); transformer .expect_transform_node() diff --git a/swiftide/src/ingestion/ingestion_stream.rs b/swiftide/src/ingestion/ingestion_stream.rs index 4dcc3d35..1feefa15 100644 --- a/swiftide/src/ingestion/ingestion_stream.rs +++ b/swiftide/src/ingestion/ingestion_stream.rs @@ -1,44 +1,68 @@ +#![allow(clippy::from_over_into)] +#![cfg(not(tarpaulin_include))] //! This module defines the `IngestionStream` type, which is used for handling asynchronous streams of `IngestionNode` items in the ingestion pipeline. -//! -//! The `IngestionStream` type is a pinned, boxed, dynamically-dispatched stream that yields `Result` items. This type is essential for managing -//! and processing large volumes of data asynchronously, ensuring efficient and scalable ingestion workflows. use anyhow::Result; -use futures_util::stream::Stream; +use futures_util::stream::{self, Stream}; +use pin_project_lite::pin_project; use std::pin::Pin; use super::IngestionNode; -/// A type alias for a pinned, boxed, dynamically-dispatched stream of `IngestionNode` items. -/// -/// This type is used in the ingestion pipeline to handle asynchronous streams of data. Each item in the stream is a `Result`, -/// allowing for error handling during the ingestion process. The `Send` trait is implemented to ensure that the stream can be safely sent -/// across threads, enabling concurrent processing. -/// -/// # Type Definition -/// - `Pin> + Send>>` -/// -/// # Components -/// - `Pin`: Ensures that the memory location of the stream is fixed, which is necessary for certain asynchronous operations. -/// - `Box>>`: A heap-allocated, dynamically-dispatched stream that yields `Result` items. -/// - `Send`: Ensures that the stream can be sent across thread boundaries, facilitating concurrent processing. -/// -/// # Usage -/// The `IngestionStream` type is typically used in the ingestion pipeline to process data asynchronously. It allows for efficient handling -/// of large volumes of data by leveraging Rust's asynchronous capabilities. -/// -/// # Error Handling -/// Each item in the stream is a `Result`, which means that errors can be propagated and handled during the ingestion process. -/// This design allows for robust error handling and ensures that the ingestion pipeline can gracefully handle failures. -/// -/// # Performance Considerations -/// The use of `Pin` and `Box` ensures that the stream's memory location is fixed and heap-allocated, respectively. This design choice is -/// crucial for asynchronous operations that require stable memory addresses. Additionally, the `Send` trait enables concurrent processing, -/// which can significantly improve performance in multi-threaded environments. -/// -/// # Edge Cases -/// - The stream may yield errors (`Err` variants) instead of valid `IngestionNode` items. These errors should be handled appropriately -/// to ensure the robustness of the ingestion pipeline. -/// - The stream must be pinned to ensure that its memory location remains fixed, which is necessary for certain asynchronous operations. - -pub type IngestionStream = Pin> + Send>>; +pub use futures_util::{StreamExt, TryStreamExt}; + +// We need to inform the compiler that `inner` is pinned as well +pin_project! { + /// An asynchronous stream of `IngestionNode` items. + /// + /// Wraps an internal stream of `Result` items. + /// + /// Streams, iterators and vectors of `Result` can be converted into an `IngestionStream`. + pub struct IngestionStream { + #[pin] + inner: Pin> + Send>>, + } +} + +impl Stream for IngestionStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + this.inner.poll_next(cx) + } +} + +impl Into for Vec> { + fn into(self) -> IngestionStream { + IngestionStream::iter(self) + } +} + +impl Into for Pin> + Send>> { + fn into(self) -> IngestionStream { + IngestionStream { inner: self } + } +} + +impl IngestionStream { + pub fn empty() -> Self { + IngestionStream { + inner: stream::empty().boxed(), + } + } + + // NOTE: Can we really guarantee that the iterator will outlive the stream? + pub fn iter(iter: I) -> Self + where + I: IntoIterator> + Send + 'static, + ::IntoIter: Send, + { + IngestionStream { + inner: stream::iter(iter).boxed(), + } + } +} diff --git a/swiftide/src/integrations/qdrant/persist.rs b/swiftide/src/integrations/qdrant/persist.rs index cce6d861..db9b8e8d 100644 --- a/swiftide/src/integrations/qdrant/persist.rs +++ b/swiftide/src/integrations/qdrant/persist.rs @@ -4,7 +4,6 @@ use anyhow::Result; use async_trait::async_trait; -use futures_util::{stream, StreamExt}; use crate::{ ingestion::{IngestionNode, IngestionStream}, @@ -82,7 +81,7 @@ impl Persist for Qdrant { .collect::>>(); if points.is_err() { - return stream::iter(vec![Err(points.unwrap_err())]).boxed(); + return vec![Err(points.unwrap_err())].into(); } let points = points.unwrap(); @@ -93,9 +92,9 @@ impl Persist for Qdrant { .await; if result.is_ok() { - stream::iter(nodes.into_iter().map(Ok)).boxed() + IngestionStream::iter(nodes.into_iter().map(Ok)) } else { - stream::iter(vec![Err(result.unwrap_err())]).boxed() + vec![Err(result.unwrap_err())].into() } } } diff --git a/swiftide/src/integrations/redis/persist.rs b/swiftide/src/integrations/redis/persist.rs index ef2bf9de..86eb20d1 100644 --- a/swiftide/src/integrations/redis/persist.rs +++ b/swiftide/src/integrations/redis/persist.rs @@ -1,6 +1,5 @@ use anyhow::{Context as _, Result}; use async_trait::async_trait; -use futures_util::{stream, StreamExt}; use crate::{ ingestion::{IngestionNode, IngestionStream}, @@ -58,7 +57,7 @@ impl Persist for Redis { .collect::>>(); if args.is_err() { - return stream::iter(vec![Err(args.unwrap_err())]).boxed(); + return vec![Err(args.unwrap_err())].into(); } let args = args.unwrap(); @@ -70,12 +69,12 @@ impl Persist for Redis { .context("Error persisting to redis"); if result.is_ok() { - stream::iter(nodes.into_iter().map(Ok)).boxed() + IngestionStream::iter(nodes.into_iter().map(Ok)) } else { - stream::iter(vec![Err(result.unwrap_err())]).boxed() + IngestionStream::iter([Err(result.unwrap_err())]) } } else { - stream::iter(vec![Err(anyhow::anyhow!("Failed to connect to Redis"))]).boxed() + IngestionStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))]) } } } diff --git a/swiftide/src/loaders/file_loader.rs b/swiftide/src/loaders/file_loader.rs index e5386c37..7d355e84 100644 --- a/swiftide/src/loaders/file_loader.rs +++ b/swiftide/src/loaders/file_loader.rs @@ -1,5 +1,4 @@ use crate::{ingestion::IngestionNode, ingestion::IngestionStream, Loader}; -use futures_util::{stream, StreamExt}; use std::path::PathBuf; /// The `FileLoader` struct is responsible for loading files from a specified directory, @@ -103,7 +102,7 @@ impl Loader for FileLoader { }) }); - stream::iter(file_paths).boxed() + IngestionStream::iter(file_paths) } } diff --git a/swiftide/src/transformers/chunk_code.rs b/swiftide/src/transformers/chunk_code.rs index a5a15f09..dc52344e 100644 --- a/swiftide/src/transformers/chunk_code.rs +++ b/swiftide/src/transformers/chunk_code.rs @@ -1,7 +1,6 @@ use anyhow::Result; use async_trait::async_trait; use derive_builder::Builder; -use futures_util::{stream, StreamExt}; use crate::{ ingestion::{IngestionNode, IngestionStream}, @@ -90,16 +89,15 @@ impl ChunkerTransformer for ChunkCode { let split_result = self.chunker.split(&node.chunk); if let Ok(split) = split_result { - return stream::iter(split.into_iter().map(move |chunk| { + IngestionStream::iter(split.into_iter().map(move |chunk| { Ok(IngestionNode { chunk, ..node.clone() }) })) - .boxed(); } else { // Send the error downstream - return stream::iter(vec![Err(split_result.unwrap_err())]).boxed(); + IngestionStream::iter(vec![Err(split_result.unwrap_err())]) } } diff --git a/swiftide/src/transformers/chunk_markdown.rs b/swiftide/src/transformers/chunk_markdown.rs index a2fa4777..94a12004 100644 --- a/swiftide/src/transformers/chunk_markdown.rs +++ b/swiftide/src/transformers/chunk_markdown.rs @@ -1,7 +1,6 @@ use crate::{ingestion::IngestionNode, ingestion::IngestionStream, ChunkerTransformer}; use async_trait::async_trait; use derive_builder::Builder; -use futures_util::{stream, StreamExt}; use text_splitter::{Characters, MarkdownSplitter}; #[derive(Debug, Builder)] @@ -42,13 +41,12 @@ impl ChunkerTransformer for ChunkMarkdown { .map(|chunk| chunk.to_string()) .collect::>(); - stream::iter(chunks.into_iter().map(move |chunk| { + IngestionStream::iter(chunks.into_iter().map(move |chunk| { Ok(IngestionNode { chunk, ..node.clone() }) })) - .boxed() } fn concurrency(&self) -> Option { diff --git a/swiftide/src/transformers/embed.rs b/swiftide/src/transformers/embed.rs index 4470ec2a..321d4c42 100644 --- a/swiftide/src/transformers/embed.rs +++ b/swiftide/src/transformers/embed.rs @@ -6,7 +6,6 @@ use crate::{ }; use anyhow::Result; use async_trait::async_trait; -use futures_util::{stream, StreamExt}; /// A transformer that can generate embeddings for an `IngestionNode` /// @@ -67,23 +66,21 @@ impl BatchableTransformer for Embed { // TODO: We should drop chunks that go over the token limit of the EmbedModel let chunks_to_embed: Vec = nodes.iter().map(|n| n.as_embeddable()).collect(); - stream::iter( - self.embed_model - .embed(chunks_to_embed) - .await - .map(|embeddings| { - nodes - .into_iter() - .zip(embeddings) - .map(|(mut n, v)| { - n.vector = Some(v); - Ok(n) - }) - .collect::>>() - }) - .unwrap_or_else(|e| vec![Err(e)]), - ) - .boxed() + self.embed_model + .embed(chunks_to_embed) + .await + .map(|embeddings| { + nodes + .into_iter() + .zip(embeddings) + .map(|(mut n, v)| { + n.vector = Some(v); + Ok(n) + }) + .collect::>>() + }) + .unwrap_or_else(|e| vec![Err(e)]) + .into() } fn concurrency(&self) -> Option {