Skip to content

Commit

Permalink
remove flag
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Oct 28, 2024
1 parent 8adf14e commit 8d6c0a6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 68 deletions.
13 changes: 3 additions & 10 deletions datafusion/physical-plan/benches/spm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};

fn generate_spm_for_round_robin_tie_breaker(
has_same_value: bool,
enable_round_robin_repartition: bool,
batch_count: usize,
partition_count: usize,
) -> SortPreservingMergeExec {
Expand Down Expand Up @@ -83,13 +82,11 @@ fn generate_spm_for_round_robin_tie_breaker(

let exec = MemoryExec::try_new(&partitiones, schema, None).unwrap();
SortPreservingMergeExec::new(sort, Arc::new(exec))
.with_round_robin_repartition(enable_round_robin_repartition)
}

fn run_bench(
c: &mut Criterion,
has_same_value: bool,
enable_round_robin_repartition: bool,
batch_count: usize,
partition_count: usize,
description: &str,
Expand All @@ -99,7 +96,6 @@ fn run_bench(

let spm = Arc::new(generate_spm_for_round_robin_tie_breaker(
has_same_value,
enable_round_robin_repartition,
batch_count,
partition_count,
)) as Arc<dyn ExecutionPlan>;
Expand All @@ -112,16 +108,14 @@ fn run_bench(

fn criterion_benchmark(c: &mut Criterion) {
let params = [
(true, false, "low_card_without_tiebreaker"), // low cardinality, no tie breaker
(true, true, "low_card_with_tiebreaker"), // low cardinality, with tie breaker
(false, false, "high_card_without_tiebreaker"), // high cardinality, no tie breaker
(false, true, "high_card_with_tiebreaker"), // high cardinality, with tie breaker
(true, "low_card"), // low cardinality, with tie breaker
(false, "high_card"), // high cardinality, with tie breaker
];

let batch_counts = [1, 25, 625];
let partition_counts = [2, 8, 32];

for &(has_same_value, enable_round_robin_repartition, cardinality_label) in &params {
for &(has_same_value, cardinality_label) in &params {
for &batch_count in &batch_counts {
for &partition_count in &partition_counts {
let description = format!(
Expand All @@ -131,7 +125,6 @@ fn criterion_benchmark(c: &mut Criterion) {
run_bench(
c,
has_same_value,
enable_round_robin_repartition,
batch_count,
partition_count,
&description,
Expand Down
9 changes: 2 additions & 7 deletions datafusion/physical-plan/src/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
/// Cursors for each input partition. `None` means the input is exhausted
cursors: Vec<Option<Cursor<C>>>,

/// Configuration parameter to enable round-robin selection of tied winners of loser tree.
enable_round_robin_tie_breaker: bool,

/// Flag indicating whether we are in the mode of round-robin
/// tie breaker for the loser tree winners.
round_robin_tie_breaker_mode: bool,
Expand Down Expand Up @@ -138,7 +135,6 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
batch_size: usize,
fetch: Option<usize>,
reservation: MemoryReservation,
enable_round_robin_tie_breaker: bool,
) -> Self {
let stream_count = streams.partitions();

Expand All @@ -159,7 +155,6 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
fetch,
produced: 0,
uninitiated_partitions: (0..stream_count).collect(),
enable_round_robin_tie_breaker,
}
}

Expand Down Expand Up @@ -442,7 +437,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
/// This function adjusts the tree by comparing the current winner with challengers from
/// other partitions.
///
/// If `enable_round_robin_tie_breaker` is true and a tie occurs at the final level, the
/// If a tie occurs at the final level, the
/// tie-breaker logic will be applied to ensure fair selection among equal elements.
fn update_loser_tree(&mut self) {
// Start with the current winner
Expand All @@ -455,7 +450,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
while cmp_node != 0 {
let challenger = self.loser_tree[cmp_node];
// If round-robin tie-breaker is enabled and we're at the final comparison (cmp_node == 1)
if self.enable_round_robin_tie_breaker && cmp_node == 1 {
if cmp_node == 1 {
match (&self.cursors[winner], &self.cursors[challenger]) {
(Some(ac), Some(bc)) => {
let ord = ac.cmp(bc);
Expand Down
34 changes: 5 additions & 29 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ pub struct SortPreservingMergeExec {
fetch: Option<usize>,
/// Cache holding plan properties like equivalences, output partitioning etc.
cache: PlanProperties,
/// Configuration parameter to enable round-robin selection of tied winners of loser tree.
enable_round_robin_repartition: bool,
}

impl SortPreservingMergeExec {
Expand All @@ -96,7 +94,6 @@ impl SortPreservingMergeExec {
metrics: ExecutionPlanMetricsSet::new(),
fetch: None,
cache,
enable_round_robin_repartition: true,
}
}

Expand All @@ -106,15 +103,6 @@ impl SortPreservingMergeExec {
self
}

/// Sets the selection strategy of tied winners of the loser tree algorithm
pub fn with_round_robin_repartition(
mut self,
enable_round_robin_repartition: bool,
) -> Self {
self.enable_round_robin_repartition = enable_round_robin_repartition;
self
}

/// Input schema
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
Expand Down Expand Up @@ -195,7 +183,6 @@ impl ExecutionPlan for SortPreservingMergeExec {
metrics: self.metrics.clone(),
fetch: limit,
cache: self.cache.clone(),
enable_round_robin_repartition: true,
}))
}

Expand Down Expand Up @@ -295,7 +282,6 @@ impl ExecutionPlan for SortPreservingMergeExec {
.with_batch_size(context.session_config().batch_size())
.with_fetch(self.fetch)
.with_reservation(reservation)
.with_round_robin_tie_breaker(self.enable_round_robin_repartition)
.build()?;

debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
Expand Down Expand Up @@ -374,9 +360,8 @@ mod tests {
.with_session_config(config);
Ok(Arc::new(task_ctx))
}
fn generate_spm_for_round_robin_tie_breaker(
enable_round_robin_repartition: bool,
) -> Result<Arc<SortPreservingMergeExec>> {
fn generate_spm_for_round_robin_tie_breaker() -> Result<Arc<SortPreservingMergeExec>>
{
let target_batch_size = 12500;
let row_size = 12500;
let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
Expand All @@ -403,27 +388,18 @@ mod tests {
RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(2))?;
let coalesce_batches_exec =
CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
.with_round_robin_repartition(enable_round_robin_repartition);
let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec));
Ok(Arc::new(spm))
}

#[tokio::test(flavor = "multi_thread")]
async fn test_round_robin_tie_breaker_success() -> Result<()> {
async fn test_round_robin_tie_breaker() -> Result<()> {
let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
let spm = generate_spm_for_round_robin_tie_breaker(true)?;
let spm = generate_spm_for_round_robin_tie_breaker()?;
let _collected = collect(spm, task_ctx).await.unwrap();
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_round_robin_tie_breaker_fail() -> Result<()> {
let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
let spm = generate_spm_for_round_robin_tie_breaker(false)?;
let _err = collect(spm, task_ctx).await.unwrap_err();
Ok(())
}

#[tokio::test]
async fn test_merge_interleave() {
let task_ctx = Arc::new(TaskContext::default());
Expand Down
29 changes: 7 additions & 22 deletions datafusion/physical-plan/src/sorts/streaming_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ macro_rules! primitive_merge_helper {
}

macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident, $enable_round_robin_tie_breaker:ident) => {{
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
Expand All @@ -45,7 +45,6 @@ macro_rules! merge_helper {
$batch_size,
$fetch,
$reservation,
$enable_round_robin_tie_breaker,
)));
}};
}
Expand All @@ -59,15 +58,11 @@ pub struct StreamingMergeBuilder<'a> {
batch_size: Option<usize>,
fetch: Option<usize>,
reservation: Option<MemoryReservation>,
enable_round_robin_tie_breaker: bool,
}

impl<'a> StreamingMergeBuilder<'a> {
pub fn new() -> Self {
Self {
enable_round_robin_tie_breaker: true,
..Default::default()
}
Self::default()
}

pub fn with_streams(mut self, streams: Vec<SendableRecordBatchStream>) -> Self {
Expand Down Expand Up @@ -105,14 +100,6 @@ impl<'a> StreamingMergeBuilder<'a> {
self
}

pub fn with_round_robin_tie_breaker(
mut self,
enable_round_robin_tie_breaker: bool,
) -> Self {
self.enable_round_robin_tie_breaker = enable_round_robin_tie_breaker;
self
}

pub fn build(self) -> Result<SendableRecordBatchStream> {
let Self {
streams,
Expand All @@ -122,7 +109,6 @@ impl<'a> StreamingMergeBuilder<'a> {
reservation,
fetch,
expressions,
enable_round_robin_tie_breaker,
} = self;

// Early return if streams or expressions are empty
Expand Down Expand Up @@ -155,11 +141,11 @@ impl<'a> StreamingMergeBuilder<'a> {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation)
_ => {}
}
}
Expand All @@ -177,7 +163,6 @@ impl<'a> StreamingMergeBuilder<'a> {
batch_size,
fetch,
reservation,
enable_round_robin_tie_breaker,
)))
}
}

0 comments on commit 8d6c0a6

Please sign in to comment.