Skip to content

Commit

Permalink
feat: nulls first kernels (#3789)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Feb 12, 2025
1 parent e5ff39f commit ca36593
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 89 deletions.
6 changes: 5 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,7 @@ def sort(
self,
by: Union[ColumnInputType, List[ColumnInputType]],
desc: Union[bool, List[bool]] = False,
nulls_first: Optional[Union[bool, List[bool]]] = None,
) -> "DataFrame":
"""Sorts DataFrame globally.
Expand Down Expand Up @@ -1768,9 +1769,12 @@ def sort(
by,
]

if nulls_first is None:
nulls_first = desc

sort_by = self.__column_input_to_expression(by)

builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=desc)
builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=nulls_first)
return DataFrame(builder)

@DataframePublicAPI
Expand Down
6 changes: 4 additions & 2 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,15 @@ class ReduceMergeAndSort(ReduceInstruction):
sort_by: ExpressionsProjection
descending: list[bool]
bounds: MicroPartition
nulls_first: list[bool] | None = None

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._reduce_merge_and_sort(inputs)

def _reduce_merge_and_sort(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
partition = MicroPartition.concat(inputs).sort(self.sort_by, descending=self.descending)
partition = MicroPartition.concat(inputs).sort(
self.sort_by, descending=self.descending, nulls_first=self.nulls_first
)
return [partition]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
Expand Down Expand Up @@ -969,7 +972,6 @@ def _reduce_to_quantiles(self, inputs: list[MicroPartition]) -> list[MicroPartit
merged = MicroPartition.concat(inputs)

nulls_first = self.nulls_first if self.nulls_first is not None else self.descending

# Skip evaluation of expressions by converting to Column Expression, since evaluation was done in Sample
merged_sorted = merged.sort(
self.sort_by.to_column_expressions(), descending=self.descending, nulls_first=nulls_first
Expand Down
1 change: 1 addition & 0 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,7 @@ def sort(
execution_step.ReduceMergeAndSort(
sort_by=sort_by,
descending=descending,
nulls_first=nulls_first,
bounds=per_part_boundaries,
)
for per_part_boundaries in per_partition_bounds
Expand Down
87 changes: 48 additions & 39 deletions src/daft-core/src/array/ops/arrow2/sort/primitive/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ pub fn idx_sort<I, F>(
cmp: F,
length: usize,
descending: bool,
nulls_first: bool,
) -> PrimitiveArray<I>
where
I: Index,
F: Fn(&I, &I) -> std::cmp::Ordering,
{
let (mut indices, start_idx, end_idx) =
generate_initial_indices::<I>(validity, length, descending);
generate_initial_indices::<I>(validity, length, descending, nulls_first);
let indices_slice = &mut indices.as_mut_slice()[start_idx..end_idx];

if !descending {
Expand All @@ -33,13 +34,18 @@ pub fn multi_column_idx_sort<I, F>(
others_cmp: &DynComparator,
length: usize,
first_col_desc: bool,
first_col_nulls_first: bool,
) -> PrimitiveArray<I>
where
I: Index,
F: Fn(&I, &I) -> std::cmp::Ordering,
{
let (mut indices, start_idx, end_idx) =
generate_initial_indices::<I>(first_col_validity, length, first_col_desc);
let (mut indices, start_idx, end_idx) = generate_initial_indices::<I>(
first_col_validity,
length,
first_col_desc,
first_col_nulls_first,
);
let indices_slice = &mut indices.as_mut_slice()[start_idx..end_idx];

indices_slice.sort_unstable_by(|a, b| overall_cmp(a, b));
Expand All @@ -60,55 +66,58 @@ fn generate_initial_indices<I>(
validity: Option<&Bitmap>,
length: usize,
descending: bool,
nulls_first: bool,
) -> (Vec<I>, usize, usize)
where
I: Index,
{
let mut start_idx: usize = 0;
let mut end_idx: usize = length;

if let Some(validity) = validity {
// number of null values
let n_nulls = validity.unset_bits();
// number of non null values
let n_valid = length.saturating_sub(n_nulls);
let mut indices = vec![I::default(); length];
if descending {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(I::range(0, length).unwrap())
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.unset_bits() + valids] = index;
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(I::range(0, length).unwrap())
.for_each(|(is_not_null, index)| {
match (is_not_null, nulls_first) {
// value && nulls first
(true, true) => {
indices[n_nulls + valids] = index;
valids += 1;
} else {
indices[nulls] = index;
nulls += 1;
}
});
start_idx = validity.unset_bits();
} else {
let last_valid_index = length.saturating_sub(validity.unset_bits());
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(I::range(0, length).unwrap())
.for_each(|(x, index)| {
if x {
// value && nulls last
(true, false) => {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
}
// null && nulls first
(false, true) => {
indices[nulls] = index;
nulls += 1;
}
// null && nulls last
(false, false) => {
indices[n_valid + nulls] = index;
nulls += 1;
}
});
end_idx = last_valid_index;
}
}
});

// either `descending` or `nulls_first` means that nulls come first
let (start_idx, end_idx) = if descending || nulls_first {
// since nulls come first, our valid values start at the end of the nulls
(n_nulls, length)
} else {
// since nulls come last, our valid values start at the beginning of the array
(0, n_valid)
};

(indices, start_idx, end_idx)
} else {
(
I::range(0, length).unwrap().collect::<Vec<_>>(),
start_idx,
end_idx,
)
(I::range(0, length).unwrap().collect::<Vec<_>>(), 0, length)
}
}
2 changes: 2 additions & 0 deletions src/daft-core/src/array/ops/arrow2/sort/primitive/indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub fn indices_sorted_unstable_by<I, T, F>(
array: &PrimitiveArray<T>,
cmp: F,
descending: bool,
nulls_first: bool,
) -> PrimitiveArray<I>
where
I: Index,
Expand All @@ -29,6 +30,7 @@ where
},
array.len(),
descending,
nulls_first,
)
}
}
Loading

0 comments on commit ca36593

Please sign in to comment.