Skip to content

Commit

Permalink
Adding config option for checkpointing (#50)
Browse files Browse the repository at this point in the history
* Adding config option for checkpointing

* Add maturin build step for ci (#52)

* fix: correct python module name

* Fixing the streaming join example (#54)

* Fixing the streaming join example

* format

* add drop_columns

* update python internal package name

---------

Co-authored-by: Matt Green <emgeee@users.noreply.github.com>

* merge with main

* Adding config option for checkpointing

* merge with main

* Cargo fmt

---------

Co-authored-by: Matt Green <emgeee@users.noreply.github.com>
  • Loading branch information
ameyc and emgeee authored Nov 7, 2024
1 parent 134344c commit ebcde97
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 60 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ Details about developing the python bindings can be found in [py-denormalized/RE

### Checkpointing

We use SlateDB for state backend. Initialize your Job Context to a path to local directory -
We use SlateDB for state backend. Initialize your Job Context with a custom config and a path for SlateDB backend to store state -

```
let ctx = Context::new()?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
let config = Context::default_config().set_bool("denormalized_config.checkpoint", true);
let ctx = Context::with_config(config)?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg/job1"))
.await;
```

Expand Down
24 changes: 17 additions & 7 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use datafusion::execution::{
session_state::SessionStateBuilder,
};

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::datasource::kafka::TopicReader;
use crate::datastream::DataStream;
use crate::physical_optimizer::EnsureHashPartititionOnGroupByForStreamingAggregates;
Expand All @@ -17,12 +18,13 @@ use denormalized_common::error::{DenormalizedError, Result};

#[derive(Clone)]
pub struct Context {
pub session_conext: Arc<SessionContext>,
pub session_context: Arc<SessionContext>,
}

impl Context {
pub fn new() -> Result<Self, DenormalizedError> {
let config = SessionConfig::new()
pub fn default_config() -> SessionConfig {
let ext_config = DenormalizedConfig::default();
let mut config = SessionConfig::new()
.set(
"datafusion.execution.batch_size",
&datafusion::common::ScalarValue::UInt64(Some(32)),
Expand All @@ -34,8 +36,16 @@ impl Context {
&datafusion::common::ScalarValue::Boolean(Some(false)),
);

let runtime = Arc::new(RuntimeEnv::default());
let _ = config.options_mut().extensions.insert(ext_config);
config
}

pub fn new() -> Result<Self, DenormalizedError> {
Context::with_config(Context::default_config())
}

pub fn with_config(config: SessionConfig) -> Result<Self, DenormalizedError> {
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
Expand All @@ -48,15 +58,15 @@ impl Context {
.build();

Ok(Self {
session_conext: Arc::new(SessionContext::new_with_state(state)),
session_context: Arc::new(SessionContext::new_with_state(state)),
})
}

pub async fn from_topic(&self, topic: TopicReader) -> Result<DataStream, DenormalizedError> {
let topic_name = topic.0.topic.clone();
self.register_table(topic_name.clone(), Arc::new(topic))
.await?;
let df = self.session_conext.table(topic_name.as_str()).await?;
let df = self.session_context.table(topic_name.as_str()).await?;
let ds = DataStream::new(Arc::new(df), Arc::new(self.clone()));
Ok(ds)
}
Expand All @@ -66,7 +76,7 @@ impl Context {
name: String,
table: Arc<impl TableProvider + 'static>,
) -> Result<(), DenormalizedError> {
self.session_conext
self.session_context
.register_table(name.as_str(), table.clone())?;

Ok(())
Expand Down
27 changes: 16 additions & 11 deletions crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use arrow_array::{Array, ArrayRef, PrimitiveArray, RecordBatch, StringArray, Str
use arrow_schema::{DataType, Field, SchemaRef, TimeUnit};
use crossbeam::channel;
use denormalized_orchestrator::channel_manager::{create_channel, get_sender, take_receiver};
use denormalized_orchestrator::orchestrator::{self, OrchestrationMessage};
use denormalized_orchestrator::orchestrator::OrchestrationMessage;
use futures::executor::block_on;
use log::{debug, error};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -83,13 +83,13 @@ impl PartitionStream for KafkaStreamRead {
}

fn execute(&self, ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let _config_options = ctx
let config_options = ctx
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let mut should_checkpoint = false; //config_options.map_or(false, |c| c.checkpoint);
let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

let node_id = self.exec_node_id.unwrap();
let partition_tag = self
Expand All @@ -101,13 +101,16 @@ impl PartitionStream for KafkaStreamRead {

let channel_tag = format!("{}_{}", node_id, partition_tag);
let mut serialized_state: Option<Vec<u8>> = None;
let state_backend = get_global_slatedb().unwrap();
let mut state_backend = None;

let mut starting_offsets: HashMap<i32, i64> = HashMap::new();
if orchestrator::SHOULD_CHECKPOINT {

if should_checkpoint {
create_channel(channel_tag.as_str(), 10);
let backend = get_global_slatedb().unwrap();
debug!("checking for last checkpointed offsets");
serialized_state = block_on(state_backend.clone().get(channel_tag.as_bytes().to_vec()));
serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec()));
state_backend = Some(backend);
}

if let Some(serialized_state) = serialized_state {
Expand Down Expand Up @@ -151,25 +154,26 @@ impl PartitionStream for KafkaStreamRead {
builder.spawn(async move {
let mut epoch = 0;
let mut receiver: Option<channel::Receiver<OrchestrationMessage>> = None;
if orchestrator::SHOULD_CHECKPOINT {
if should_checkpoint {
let orchestrator_sender = get_sender("orchestrator");
let msg: OrchestrationMessage =
OrchestrationMessage::RegisterStream(channel_tag.clone());
orchestrator_sender.as_ref().unwrap().send(msg).unwrap();
receiver = take_receiver(channel_tag.as_str());
}
let mut checkpoint_batch = false;

loop {
//let mut checkpoint_barrier: Option<String> = None;
let mut _checkpoint_barrier: Option<i64> = None;

if orchestrator::SHOULD_CHECKPOINT {
if should_checkpoint {
let r = receiver.as_ref().unwrap();
for message in r.try_iter() {
debug!("received checkpoint barrier for {:?}", message);
if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message {
epoch = epoch_ts;
should_checkpoint = true;
checkpoint_batch = true;
}
}
}
Expand Down Expand Up @@ -245,7 +249,7 @@ impl PartitionStream for KafkaStreamRead {
let tx_result = tx.send(Ok(timestamped_record_batch)).await;
match tx_result {
Ok(_) => {
if should_checkpoint {
if checkpoint_batch {
debug!("about to checkpoint offsets");
let off = BatchReadMetadata {
epoch,
Expand All @@ -255,9 +259,10 @@ impl PartitionStream for KafkaStreamRead {
};
state_backend
.as_ref()
.unwrap()
.put(channel_tag.as_bytes().to_vec(), off.to_bytes().unwrap());
debug!("checkpointed offsets {:?}", off);
should_checkpoint = false;
checkpoint_batch = false;
}
}
Err(err) => error!("result err {:?}. shutdown signal detected.", err),
Expand Down
19 changes: 13 additions & 6 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use datafusion::common::runtime::SpawnedTask;
use datafusion::logical_expr::LogicalPlan;
use datafusion::physical_plan::ExecutionPlanProperties;
use denormalized_orchestrator::orchestrator;
use futures::StreamExt;
use log::debug;
use log::info;
Expand All @@ -18,6 +17,7 @@ use datafusion::logical_expr::{
};
use datafusion::physical_plan::display::DisplayableExecutionPlan;

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::context::Context;
use crate::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};
use crate::logical_plan::StreamingLogicalPlanBuilder;
Expand Down Expand Up @@ -240,7 +240,12 @@ impl DataStream {

let mut maybe_orchestrator_handle = None;

if orchestrator::SHOULD_CHECKPOINT {
let config = self.context.session_context.copied_config();
let config_options = config.options().extensions.get::<DenormalizedConfig>();

let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

if should_checkpoint {
let mut orchestrator = Orchestrator::default();
let cloned_shutdown_rx = self.shutdown_rx.clone();
let orchestrator_handle =
Expand Down Expand Up @@ -286,10 +291,12 @@ impl DataStream {

log::info!("Stream processing stopped. Cleaning up...");

let state_backend = get_global_slatedb();
if let Ok(db) = state_backend {
log::info!("Closing the state backend (slatedb)...");
db.close().await.unwrap();
if should_checkpoint {
let state_backend = get_global_slatedb();
if let Ok(db) = state_backend {
log::info!("Closing the state backend (slatedb)...");
db.close().await.unwrap();
}
}

// Join the orchestrator handle if it exists, ensuring it is joined and awaited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ use datafusion::{
};

use denormalized_orchestrator::{
channel_manager::take_receiver,
orchestrator::{self, OrchestrationMessage},
channel_manager::take_receiver, orchestrator::OrchestrationMessage,
};
use futures::{executor::block_on, Stream, StreamExt};
use log::debug;
use serde::{Deserialize, Serialize};

use crate::{
config_extensions::denormalized_config::DenormalizedConfig,
physical_plan::utils::time::RecordBatchWatermark,
state_backend::slatedb::{get_global_slatedb, SlateDBWrapper},
utils::serialization::ArrayContainer,
Expand Down Expand Up @@ -73,11 +73,11 @@ pub struct GroupedWindowAggStream {
group_by: PhysicalGroupBy,
group_schema: Arc<Schema>,
context: Arc<TaskContext>,
epoch: i64,
checkpoint: bool,
partition: usize,
channel_tag: String,
receiver: Option<Receiver<OrchestrationMessage>>,
state_backend: Arc<SlateDBWrapper>,
state_backend: Option<Arc<SlateDBWrapper>>,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -147,11 +147,23 @@ impl GroupedWindowAggStream {
.and_then(|tag| take_receiver(tag.as_str()));

let channel_tag: String = channel_tag.unwrap_or(String::from(""));
let state_backend = get_global_slatedb().unwrap();

let serialized_state = block_on(state_backend.get(channel_tag.as_bytes().to_vec()));
let config_options = context
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let checkpoint = config_options.map_or(false, |c| c.checkpoint);

let mut serialized_state: Option<Vec<u8>> = None;
let mut state_backend = None;
if checkpoint {
let backend = get_global_slatedb().unwrap();
serialized_state = block_on(backend.get(channel_tag.as_bytes().to_vec()));
state_backend = Some(backend);
}

//let window_frames: BTreeMap<SystemTime, GroupedAggWindowFrame> = BTreeMap::new();
let mut stream = Self {
schema: agg_schema,
input,
Expand All @@ -166,7 +178,7 @@ impl GroupedWindowAggStream {
group_by,
group_schema,
context,
epoch: 0,
checkpoint,
partition,
channel_tag,
receiver,
Expand Down Expand Up @@ -340,19 +352,19 @@ impl GroupedWindowAggStream {
return Poll::Pending;
}
};
self.epoch += 1;

if orchestrator::SHOULD_CHECKPOINT {
let mut checkpoint_batch = false;

if self.checkpoint {
let r = self.receiver.as_ref().unwrap();
let mut epoch: u128 = 0;
for message in r.try_iter() {
debug!("received checkpoint barrier for {:?}", message);
if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message {
epoch = epoch_ts;
if let OrchestrationMessage::CheckpointBarrier(_epoch_ts) = message {
checkpoint_batch = true;
}
}

if epoch != 0 {
if checkpoint_batch {
// Prepare data for checkpointing

// Clone or extract necessary data
Expand Down Expand Up @@ -400,7 +412,7 @@ impl GroupedWindowAggStream {
let key = self.channel_tag.as_bytes().to_vec();

// Clone or use `Arc` for `state_backend`
let state_backend = self.state_backend.clone();
let state_backend = self.state_backend.clone().unwrap();

state_backend.put(key, serialized_checkpoint);
}
Expand Down
25 changes: 18 additions & 7 deletions crates/core/src/physical_plan/continuous/streaming_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ use datafusion::{
};
use denormalized_orchestrator::{
channel_manager::{create_channel, get_sender},
orchestrator::{self, OrchestrationMessage},
orchestrator::OrchestrationMessage,
};
use futures::{Stream, StreamExt};
use tracing::debug;

use crate::physical_plan::{
continuous::grouped_window_agg_stream::GroupedWindowAggStream,
utils::{
accumulators::{create_accumulators, AccumulatorItem},
time::{system_time_from_epoch, RecordBatchWatermark},
use crate::{
config_extensions::denormalized_config::DenormalizedConfig,
physical_plan::{
continuous::grouped_window_agg_stream::GroupedWindowAggStream,
utils::{
accumulators::{create_accumulators, AccumulatorItem},
time::{system_time_from_epoch, RecordBatchWatermark},
},
},
};

Expand Down Expand Up @@ -427,7 +430,15 @@ impl ExecutionPlan for StreamingWindowExec {
.node_id()
.expect("expected node id to be set.");

let channel_tag = if orchestrator::SHOULD_CHECKPOINT {
let config_options = context
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let checkpoint = config_options.map_or(false, |c| c.checkpoint);

let channel_tag = if checkpoint {
let tag = format!("{}_{}", node_id, partition);
create_channel(tag.as_str(), 10);
let orchestrator_sender = get_sender("orchestrator");
Expand Down
2 changes: 0 additions & 2 deletions crates/orchestrator/src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ pub struct Orchestrator {
senders: HashMap<String, channel::Sender<OrchestrationMessage>>,
}

pub const SHOULD_CHECKPOINT: bool = false; // THIS WILL BE MOVED INTO CONFIG

/**
* 1. Keep track of checkpoint per source.
* 2. Tell each downstream which checkpoints it needs to know.
Expand Down
Loading

0 comments on commit ebcde97

Please sign in to comment.