Skip to content

Commit

Permalink
refactor(7181): use hashmap entry for mutable ref
Browse files Browse the repository at this point in the history
  • Loading branch information
wiedld committed Oct 29, 2023
1 parent e74ba5d commit b123748
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions datafusion/physical-plan/src/sorts/cascade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;
use futures::Stream;
use std::collections::HashMap;
use std::collections::{HashMap, hash_map::Entry};
use std::marker::Send;
use std::task::{ready, Context, Poll};
use std::{pin::Pin, sync::Arc};
Expand Down Expand Up @@ -260,7 +260,7 @@ impl<C: CursorValues + Send + Unpin + 'static> SortPreservingCascadeStream<C> {
) -> Result<RecordBatch> {
let (batch_rowsets, sort_order) = yielded_sort_order;

let mut batches_needed = Vec::with_capacity(sort_order.len());
let mut batches_to_interleave = Vec::with_capacity(sort_order.len());
let mut batches_seen: HashMap<BatchId, (usize, usize)> =
HashMap::with_capacity(sort_order.len()); // (batch_idx, max_row_idx)

Expand All @@ -274,27 +274,26 @@ impl<C: CursorValues + Send + Unpin + 'static> SortPreservingCascadeStream<C> {
let mut adjusted_sort_order = Vec::with_capacity(sort_order.len());

for (batch_id, row_idx) in sort_order.iter() {
let batch_idx = match batches_seen.get(batch_id) {
Some((batch_idx, _)) => *batch_idx,
None => {
let batch_idx = batches_seen.len();
batches_needed.push(*batch_id);
let batch_idx = match batches_seen.entry(*batch_id) {
Entry::Occupied(entry) => entry.get().0,
Entry::Vacant(entry) => {
let batch_idx = batches_to_interleave.len();
batches_to_interleave.push(*batch_id);
entry.insert((batch_idx, *row_idx));
batch_idx
}
};
adjusted_sort_order.push((batch_idx, row_idx_offsets[batch_id] + *row_idx));
batches_seen
.insert(*batch_id, (batch_idx, row_idx_offsets[batch_id] + *row_idx));
}

let batches = self
.record_batch_collector
.get_batches(batches_needed.as_slice());
.get_batches(batches_to_interleave.as_slice());

// remove record_batches (from the batch tracker) that are fully yielded
let batches_to_remove = batches
.iter()
.zip(batches_needed)
.zip(batches_to_interleave)
.filter_map(|(batch, batch_id)| {
let max_row_idx = batches_seen[&batch_id].1;
if batch.num_rows() == max_row_idx + 1 {
Expand Down

0 comments on commit b123748

Please sign in to comment.