Skip to content

Commit

Permalink
feat!: project struct array fields in arrow conversion (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrobbel authored Oct 1, 2024
1 parent 60f078c commit 22608e3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
7 changes: 7 additions & 0 deletions narrow-derive/src/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ impl Struct<'_> {
let field_name = self.ident.to_string();
let tokens = if matches!(self.fields, Fields::Unit) {
quote!(impl #impl_generics #narrow::arrow::StructArrayTypeFields for #ident #ty_generics #where_clause {
const NAMES: &'static [&'static str] = &[#field_name];
fn fields() -> ::arrow_schema::Fields {
::arrow_schema::Fields::from([
::std::sync::Arc::new(::arrow_schema::Field::new(#field_name, ::arrow_schema::DataType::Null, true)),
Expand All @@ -281,6 +282,7 @@ impl Struct<'_> {
} else {
// Fields
let field_ident = self.field_idents().map(|ident| ident.to_string());
let field_name = field_ident.clone();
let field_ty = self.field_types();
let field_ty_drop = self.field_types_drop_option();
let fields = quote!(
Expand All @@ -290,6 +292,11 @@ impl Struct<'_> {
);
quote! {
impl #impl_generics #narrow::arrow::StructArrayTypeFields for #ident #ty_generics #where_clause {
const NAMES: &'static [&'static str] = &[
#(
#field_name,
)*
];
fn fields() -> ::arrow_schema::Fields {
::arrow_schema::Fields::from([
#fields
Expand Down
1 change: 1 addition & 0 deletions src/arrow/array/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ mod tests {
const INPUT: [(); 4] = [(), (), (), ()];

#[test]
#[cfg(feature = "derive")]
fn derive() {
#[derive(crate::ArrayType, Copy, Clone, Debug, Default)]
struct Unit;
Expand Down
81 changes: 75 additions & 6 deletions src/arrow/array/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use crate::{

/// Arrow schema interop trait for the fields of a struct array type.
pub trait StructArrayTypeFields {
/// The names of the fields.
const NAMES: &'static [&'static str];

/// Returns the fields of this struct array.
fn fields() -> Fields;
}
Expand Down Expand Up @@ -102,26 +105,48 @@ where
impl<T: StructArrayType, Buffer: BufferType> From<arrow_array::StructArray>
for StructArray<T, false, Buffer>
where
<T as StructArrayType>::Array<Buffer>: From<Vec<Arc<dyn arrow_array::Array>>>,
<T as StructArrayType>::Array<Buffer>:
From<Vec<Arc<dyn arrow_array::Array>>> + StructArrayTypeFields,
{
fn from(value: arrow_array::StructArray) -> Self {
let (_fields, arrays, nulls_opt) = value.into_parts();
let (fields, arrays, nulls_opt) = value.into_parts();
// Project
let projected = <<T as StructArrayType>::Array<Buffer> as StructArrayTypeFields>::NAMES
.iter()
.map(|field| {
fields
.find(field)
.unwrap_or_else(|| panic!("expected struct array with field: {field}"))
})
.map(|(idx, _)| Arc::clone(&arrays[idx]))
.collect::<Vec<_>>();
match nulls_opt {
Some(_) => panic!("expected array without a null buffer"),
None => StructArray(arrays.into()),
None => StructArray(projected.into()),
}
}
}

impl<T: StructArrayType, Buffer: BufferType> From<arrow_array::StructArray>
for StructArray<T, true, Buffer>
where
<T as StructArrayType>::Array<Buffer>: From<Vec<Arc<dyn arrow_array::Array>>> + Length,
<T as StructArrayType>::Array<Buffer>:
From<Vec<Arc<dyn arrow_array::Array>>> + Length + StructArrayTypeFields,
Bitmap<Buffer>: From<NullBuffer> + FromIterator<bool>,
{
fn from(value: arrow_array::StructArray) -> Self {
let (_fields, arrays, nulls_opt) = value.into_parts();
let data = arrays.into();
let (fields, arrays, nulls_opt) = value.into_parts();
// Project
let projected = <<T as StructArrayType>::Array<Buffer> as StructArrayTypeFields>::NAMES
.iter()
.map(|field| {
fields
.find(field)
.unwrap_or_else(|| panic!("expected struct array with field: {field}"))
})
.map(|(idx, _)| Arc::clone(&arrays[idx]))
.collect::<Vec<_>>();
let data = projected.into();
match nulls_opt {
Some(null_buffer) => StructArray(Nullable {
data,
Expand Down Expand Up @@ -264,6 +289,7 @@ mod tests {
type Array<Buffer: BufferType> = FooArray<Buffer>;
}
impl<Buffer: BufferType> StructArrayTypeFields for FooArray<Buffer> {
const NAMES: &'static [&'static str] = &["a"];
fn fields() -> Fields {
Fields::from(vec![Field::new("a", DataType::UInt32, false)])
}
Expand Down Expand Up @@ -437,4 +463,47 @@ mod tests {
)))
);
}

#[test]
#[should_panic(expected = "expected struct array with field: c")]
#[cfg(feature = "derive")]
fn projected() {
#[derive(narrow_derive::ArrayType)]
struct Foo {
a: u32,
b: bool,
c: u64,
}

#[derive(narrow_derive::ArrayType, Debug, PartialEq)]
struct Bar {
b: bool,
a: u32,
}

let foo_array = [
Foo {
a: 1,
b: false,
c: 2,
},
Foo {
a: 2,
b: true,
c: 3,
},
]
.into_iter()
.collect::<StructArray<Foo>>();

let arrow_array = arrow_array::StructArray::from(foo_array);
let bar_array = StructArray::<Bar>::from(arrow_array);
assert_eq!(
bar_array.clone().into_iter().collect::<Vec<_>>(),
[Bar { b: false, a: 1 }, Bar { b: true, a: 2 }]
);

let bar_arrow_array = arrow_array::StructArray::from(bar_array);
let _ = StructArray::<Foo>::from(bar_arrow_array);
}
}

0 comments on commit 22608e3

Please sign in to comment.