@@ -6,9 +6,10 @@ use std::{collections::BTreeMap, ops::Range, pin::Pin, sync::Arc};
6
6
use crate :: dataset:: fragment:: FragReadConfig ;
7
7
use crate :: dataset:: rowids:: get_row_id_index;
8
8
use crate :: { Error , Result } ;
9
- use arrow:: { array :: as_struct_array , compute:: concat_batches, datatypes:: UInt64Type } ;
9
+ use arrow:: { compute:: concat_batches, datatypes:: UInt64Type } ;
10
10
use arrow_array:: cast:: AsArray ;
11
- use arrow_array:: { RecordBatch , StructArray , UInt64Array } ;
11
+ use arrow_array:: { Array , RecordBatch , StructArray , UInt64Array } ;
12
+ use arrow_buffer:: { ArrowNativeType , BooleanBuffer , Buffer , NullBuffer } ;
12
13
use arrow_schema:: { Field as ArrowField , Schema as ArrowSchema } ;
13
14
use datafusion:: error:: DataFusionError ;
14
15
use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
@@ -283,9 +284,13 @@ async fn do_take_rows(
283
284
// Remove the rowaddr column.
284
285
let keep_indices = ( 0 ..one_batch. num_columns ( ) - 1 ) . collect :: < Vec < _ > > ( ) ;
285
286
let one_batch = one_batch. project ( & keep_indices) ?;
287
+
288
+ // There's a bug in arrow_select::take::take, that it doesn't handle empty struct correctly,
289
+ // so we need to handle it manually here.
290
+ // TODO: remove this once the bug is fixed.
286
291
let struct_arr: StructArray = one_batch. into ( ) ;
287
- let reordered = arrow_select :: take :: take ( & struct_arr, & remapping_index, None ) ?;
288
- Ok ( as_struct_array ( & reordered) . into ( ) )
292
+ let reordered = take_struct_array ( & struct_arr, & remapping_index) ?;
293
+ Ok ( reordered. into ( ) )
289
294
} ?;
290
295
291
296
let batch = projection. project_batch ( batch) . await ?;
@@ -553,6 +558,42 @@ impl TakeBuilder {
553
558
}
554
559
}
555
560
561
+ fn take_struct_array ( array : & StructArray , indices : & UInt64Array ) -> Result < StructArray > {
562
+ let nulls = array. nulls ( ) . map ( |nulls| {
563
+ let is_valid = indices. iter ( ) . map ( |index| {
564
+ if let Some ( index) = index {
565
+ nulls. is_valid ( index. to_usize ( ) . unwrap ( ) )
566
+ } else {
567
+ false
568
+ }
569
+ } ) ;
570
+ NullBuffer :: new ( BooleanBuffer :: new (
571
+ Buffer :: from_iter ( is_valid) ,
572
+ 0 ,
573
+ indices. len ( ) ,
574
+ ) )
575
+ } ) ;
576
+
577
+ if array. fields ( ) . is_empty ( ) {
578
+ return Ok ( StructArray :: new_empty_fields ( indices. len ( ) , nulls) ) ;
579
+ }
580
+
581
+ let arrays = array
582
+ . columns ( )
583
+ . iter ( )
584
+ . map ( |array| {
585
+ let array = match array. data_type ( ) {
586
+ arrow:: datatypes:: DataType :: Struct ( _) => {
587
+ Arc :: new ( take_struct_array ( array. as_struct ( ) , indices) ?)
588
+ }
589
+ _ => arrow_select:: take:: take ( array, indices, None ) ?,
590
+ } ;
591
+ Ok ( array)
592
+ } )
593
+ . collect :: < Result < Vec < _ > > > ( ) ?;
594
+ Ok ( StructArray :: new ( array. fields ( ) . clone ( ) , arrays, nulls) )
595
+ }
596
+
556
597
#[ cfg( test) ]
557
598
mod test {
558
599
use arrow_array:: { Int32Array , RecordBatchIterator , StringArray } ;
0 commit comments