From 56e2dd1e2824024d6b144a09c4ea511ff4bd6f8a Mon Sep 17 00:00:00 2001 From: Matt Green Date: Thu, 8 Aug 2024 15:45:38 -0700 Subject: [PATCH] add dataframe passthrough methods to datastream --- Cargo.lock | 12 +++++++++++ crates/core/Cargo.toml | 1 + crates/core/src/datastream.rs | 19 ++++++++++++++++- examples/examples/simple_aggregation.rs | 27 ++++++++++++++++--------- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a1ce9e..6aac1f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1064,6 +1064,17 @@ dependencies = [ "strum 0.26.3", ] +[[package]] +name = "delegate" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e018fccbeeb50ff26562ece792ed06659b9c2dae79ece77c4456bb10d9bf79b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "df-streams-core" version = "0.1.0" @@ -1079,6 +1090,7 @@ dependencies = [ "bincode", "chrono", "datafusion", + "delegate", "futures", "half", "itertools 0.13.0", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index f031124..44d23e9 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -27,3 +27,4 @@ serde.workspace = true rocksdb = "0.22.0" bincode = "1.3.3" half = "2.4.1" +delegate = "0.12.0" diff --git a/crates/core/src/datastream.rs b/crates/core/src/datastream.rs index 067acdf..1ff117b 100644 --- a/crates/core/src/datastream.rs +++ b/crates/core/src/datastream.rs @@ -19,6 +19,20 @@ pub struct DataStream { } impl DataStream { + pub fn filter(&self, predicate: Expr) -> Result { + let (session_state, plan) = self.df.as_ref().clone().into_parts(); + + let plan = LogicalPlanBuilder::from(plan).filter(predicate)?.build()?; + + Ok(Self { + df: Arc::new(DataFrame::new(session_state, plan)), + context: self.context.clone(), + }) + } + + // drop_columns, sync, columns: &[&str] + // count + pub fn streaming_window( &self, group_expr: Vec, @@ -46,7 +60,10 @@ impl DataStream { if batch.num_rows() > 0 { println!( "{}", - arrow::util::pretty::pretty_format_batches(&[batch]).unwrap() + datafusion::common::arrow::util::pretty::pretty_format_batches(&[ + batch + ]) + .unwrap() ); } } diff --git a/examples/examples/simple_aggregation.rs b/examples/examples/simple_aggregation.rs index 257b9fb..7d8736a 100644 --- a/examples/examples/simple_aggregation.rs +++ b/examples/examples/simple_aggregation.rs @@ -2,6 +2,7 @@ use std::time::Duration; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; +use datafusion::logical_expr::lit; use datafusion::logical_expr::{col, max, min}; use df_streams_core::context::Context; @@ -29,16 +30,22 @@ async fn main() -> Result<()> { ])) .await?; - let ds = ctx.from_topic(source_topic).await?.streaming_window( - vec![], - vec![ - min(col("temperature")).alias("min"), - max(col("temperature")).alias("max"), - avg(col("temperature")).alias("average"), - ], - Duration::from_millis(1_000), // 5 second window - None, - )?; + let ds = ctx + .from_topic(source_topic) + .await? + .streaming_window( + vec![], + vec![ + min(col("temperature")).alias("min"), + max(col("temperature")).alias("max"), + avg(col("temperature")).alias("average"), + ], + Duration::from_millis(1_000), // 5 second window + None, + )? + .filter(col("max").gt(lit(114)))?; + + println!("{}", ds.df.logical_plan().display_indent()); ds.clone().print_stream().await?;