Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Make list.rs non generic & simplify the code #1118

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 29 additions & 59 deletions native/spark-expr/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::as_list_array;
use arrow::{
array::{as_primitive_array, Capacities, MutableArrayData},
buffer::{NullBuffer, OffsetBuffer},
datatypes::ArrowNativeType,
record_batch::RecordBatch,
};
use arrow_array::{
make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray,
make_array, Array, ArrayRef, GenericListArray, Int32Array, ListArray, StructArray,
};
use arrow_schema::{DataType, Field, FieldRef, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{
cast::{as_int32_array, as_large_list_array, as_list_array},
internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
cast::as_int32_array, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
};
use datafusion_physical_expr::PhysicalExpr;
use std::{
Expand Down Expand Up @@ -72,7 +72,7 @@ impl ListExtract {

fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
match self.child.data_type(input_schema)? {
DataType::List(field) | DataType::LargeList(field) => Ok(field),
DataType::List(field) => Ok(field),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected data type in ListExtract: {:?}",
data_type
Expand Down Expand Up @@ -127,19 +127,7 @@ impl PhysicalExpr for ListExtract {

match child_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&child_value)?;
let index_array = as_int32_array(&ordinal_value)?;

list_extract(
list_array,
index_array,
&default_value,
self.fail_on_error,
adjust_index,
)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&child_value)?;
let list_array = as_list_array(&child_value);
let index_array = as_int32_array(&ordinal_value)?;

list_extract(
Expand Down Expand Up @@ -220,8 +208,8 @@ fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
}
}

fn list_extract<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
fn list_extract(
list_array: &ListArray,
index_array: &Int32Array,
default_value: &ScalarValue,
fail_on_error: bool,
Expand Down Expand Up @@ -329,7 +317,6 @@ impl PhysicalExpr for GetArrayStructFields {
let struct_field = self.child_field(input_schema)?;
match self.child.data_type(input_schema)? {
DataType::List(_) => Ok(DataType::List(struct_field)),
DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected data type in GetArrayStructFields: {:?}",
data_type
Expand All @@ -347,12 +334,7 @@ impl PhysicalExpr for GetArrayStructFields {

match child_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&child_value)?;

get_array_struct_fields(list_array, self.ordinal)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&child_value)?;
let list_array = as_list_array(&child_value);

get_array_struct_fields(list_array, self.ordinal)
}
Expand Down Expand Up @@ -388,8 +370,8 @@ impl PhysicalExpr for GetArrayStructFields {
}
}

fn get_array_struct_fields<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
fn get_array_struct_fields(
list_array: &ListArray,
ordinal: usize,
) -> DataFusionResult<ColumnarValue> {
let values = list_array
Expand Down Expand Up @@ -452,7 +434,6 @@ impl ArrayInsert {
pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
match data_type {
DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected src array type in ArrayInsert: {:?}",
data_type
Expand Down Expand Up @@ -497,7 +478,6 @@ impl PhysicalExpr for ArrayInsert {

let src_element_type = match self.array_type(src_value.data_type())? {
DataType::List(field) => &field.data_type().clone(),
DataType::LargeList(field) => &field.data_type().clone(),
_ => unreachable!(),
};

Expand All @@ -514,27 +494,13 @@ impl PhysicalExpr for ArrayInsert {
)));
}

match src_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&src_value)?;
array_insert(
list_array,
&item_value,
&pos_value,
self.legacy_negative_index,
)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&src_value)?;
array_insert(
list_array,
&item_value,
&pos_value,
self.legacy_negative_index,
)
}
_ => unreachable!(), // This case is checked already
}
let list_array = as_list_array(&src_value);
array_insert(
list_array,
&item_value,
&pos_value,
self.legacy_negative_index,
)
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -566,8 +532,8 @@ impl PhysicalExpr for ArrayInsert {
}
}

fn array_insert<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
fn array_insert(
list_array: &ListArray,
items_array: &ArrayRef,
pos_array: &ArrayRef,
legacy_mode: bool,
Expand All @@ -587,7 +553,7 @@ fn array_insert<O: OffsetSizeTrait>(
let mut mutable_values =
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);

let mut new_offsets = vec![O::usize_as(0)];
let mut new_offsets: Vec<i32> = vec![0];
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());

let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
Expand All @@ -601,7 +567,7 @@ fn array_insert<O: OffsetSizeTrait>(
if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_offsets.push(new_offsets[row_index] + 1);
new_nulls.push(false);
continue;
}
Expand Down Expand Up @@ -630,14 +596,17 @@ fn array_insert<O: OffsetSizeTrait>(
mutable_values.extend(0, start, start + corrected_pos);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected_pos, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
// new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value
new_offsets.push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap());
} else {
mutable_values.extend(0, start, end);
mutable_values.extend_nulls(new_array_len - (end - start));
mutable_values.extend(1, row_index, row_index + 1);
// In that case spark actualy makes array longer than expected;
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
new_offsets
.push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap() + 1);
// new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value
}
} else {
// This comment is takes from the Apache Spark source code as is:
Expand All @@ -655,7 +624,8 @@ fn array_insert<O: OffsetSizeTrait>(
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend_nulls(new_array_len - (end - start + 1));
mutable_values.extend(0, start, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
// new_array_len is less than MAX_ROUNDED_ARRAY_LENGTH that is less than i32 max value
new_offsets.push(new_offsets[row_index] + i32::from_usize(new_array_len).unwrap());
}
if is_item_null {
if (start == end) || (values.is_null(row_index)) {
Expand All @@ -674,7 +644,7 @@ fn array_insert<O: OffsetSizeTrait>(
DataType::LargeList(field) => field.data_type(),
_ => unreachable!(),
};
let new_array = GenericListArray::<O>::try_new(
let new_array = ListArray::try_new(
Arc::new(Field::new("item", data_type.clone(), true)),
OffsetBuffer::new(new_offsets.into()),
data,
Expand Down
Loading