Skip to content

Commit

Permalink
workflow: Run each split query in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
tontinton committed Dec 6, 2024
1 parent 4e083f3 commit cbb4a11
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions src/workflow/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::sync::Arc;

use async_stream::stream;
use color_eyre::eyre::{Context, Result};
use futures_util::{stream::FuturesUnordered, StreamExt};
use color_eyre::eyre::{bail, Context, Result};
use futures_util::{future::try_join_all, stream::FuturesUnordered, StreamExt};
use kinded::Kinded;
use serde_json::to_string;
use summarize::{summarize_stream, Summarize};
Expand Down Expand Up @@ -109,12 +109,12 @@ impl Workflow {
let (mut tx, mut next_rx) = mpsc::channel(1);
let mut rx: Option<mpsc::Receiver<Log>> = None;

let mut handles = FuturesUnordered::new();
let mut tasks = FuturesUnordered::new();

for step in self.steps {
debug!("Spawning step: {:?}", step);

let handle = spawn({
let task = spawn({
let tx = tx.clone();
let rx = rx.take();

Expand All @@ -126,18 +126,48 @@ impl Workflow {
splits,
handle,
}) => {
let mut split_tasks = Vec::new();

for (i, split) in splits.into_iter().enumerate() {
let response = connector
.query(&collection, split.as_ref(), handle.as_ref())
.await?;
match response {
QueryResponse::Logs(stream) => {
stream_to_tx(stream, tx.clone(), &format!("scan({i})"))
.await?
let collection = collection.clone();
let connector = connector.clone();
let handle = handle.clone();
let tx = tx.clone();

split_tasks.push(spawn(async move {
let response = connector
.query(&collection, split.as_ref(), handle.as_ref())
.await?;

match response {
QueryResponse::Logs(stream) => {
stream_to_tx(stream, tx, &format!("scan({i})")).await?
}
QueryResponse::Count(count) => return Ok(Some(count)),
}

Ok::<Option<u64>, color_eyre::eyre::Error>(None)
}));
}

let join_results = try_join_all(split_tasks).await?;

let mut count = None;
for join_result in join_results {
if let Some(split_count) = join_result? {
if let Some(ref mut inner) = count {
*inner += split_count;
} else {
count = Some(split_count);
}
QueryResponse::Count(count) => println!("{}", count),
} else if count.is_some() {
bail!("some queries responded with count and some with logs");
}
}

if let Some(inner) = count {
println!("{}", inner);
}
}
WorkflowStep::Filter(ast) => {
let stream = filter_stream(&ast, rx_stream(rx.unwrap()))?;
Expand Down Expand Up @@ -178,13 +208,13 @@ impl Workflow {
}
});

handles.push(handle);
tasks.push(task);

rx = Some(next_rx);
(tx, next_rx) = mpsc::channel(1);
}

handles.push(spawn(async move {
tasks.push(spawn(async move {
let mut rx = rx.unwrap();
while let Some(log) = rx.recv().await {
println!("{}", to_string(&log).context("log to string")?);
Expand All @@ -194,10 +224,10 @@ impl Workflow {

debug!("Starting to print logs");

while let Some(join_result) = handles.next().await {
while let Some(join_result) = tasks.next().await {
let result = join_result?;
if let Err(e) = result {
for handle in &handles {
for handle in &tasks {
handle.abort();
}
return Err(e.wrap_err("failed one of the workflow steps"));
Expand Down

0 comments on commit cbb4a11

Please sign in to comment.