From 3863577ae31084cc7afbd72d846d16a076b88832 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Fri, 29 Dec 2023 22:02:58 +1100 Subject: [PATCH] Introduce ProjectionMask (#51) --- src/arrow_reader.rs | 9 ++- src/arrow_reader/column.rs | 6 +- src/async_arrow_reader.rs | 2 +- src/lib.rs | 1 + src/projection.rs | 61 +++++++++++++++ src/schema.rs | 147 ++++++++++++++++++++++++++----------- 6 files changed, 177 insertions(+), 49 deletions(-) create mode 100644 src/projection.rs diff --git a/src/arrow_reader.rs b/src/arrow_reader.rs index 94eb1b69..b488ed41 100644 --- a/src/arrow_reader.rs +++ b/src/arrow_reader.rs @@ -38,6 +38,7 @@ use crate::arrow_reader::column::timestamp::new_timestamp_iter; use crate::arrow_reader::column::NullableIterator; use crate::builder::BoxedArrayBuilder; use crate::error::{self, InvalidColumnSnafu, IoSnafu, Result}; +use crate::projection::ProjectionMask; use crate::proto::stream::Kind; use crate::proto::StripeFooter; use crate::reader::decompress::{Compression, Decompressor}; @@ -816,7 +817,8 @@ pub struct Cursor { impl Cursor { pub fn new>(mut reader: R, fields: &[T]) -> Result { let file_metadata = Arc::new(read_metadata(&mut reader)?); - let projected_data_type = file_metadata.root_data_type().project(fields); + let mask = ProjectionMask::named_roots(file_metadata.root_data_type(), fields); + let projected_data_type = file_metadata.root_data_type().project(&mask); Ok(Self { reader, file_metadata, @@ -840,7 +842,8 @@ impl Cursor { impl Cursor { pub async fn new_async>(mut reader: R, fields: &[T]) -> Result { let file_metadata = Arc::new(read_metadata_async(&mut reader).await?); - let projected_data_type = file_metadata.root_data_type().project(fields); + let mask = ProjectionMask::named_roots(file_metadata.root_data_type(), fields); + let projected_data_type = file_metadata.root_data_type().project(&mask); Ok(Self { reader, file_metadata, @@ -914,7 +917,7 @@ impl Stripe { let columns = projected_data_type .children() .iter() - .map(|(name, data_type)| Column::new(name, data_type, &footer, info.number_of_rows())) + .map(|col| Column::new(col.name(), col.data_type(), &footer, info.number_of_rows())) .collect(); let mut stream_map = HashMap::new(); diff --git a/src/arrow_reader/column.rs b/src/arrow_reader/column.rs index f08e4525..d8b54ef5 100644 --- a/src/arrow_reader/column.rs +++ b/src/arrow_reader/column.rs @@ -105,11 +105,11 @@ impl Column { | DataType::Date { .. } => vec![], DataType::Struct { children, .. } => children .iter() - .map(|(name, data_type)| Column { + .map(|col| Column { number_of_rows: self.number_of_rows, footer: self.footer.clone(), - name: name.clone(), - data_type: data_type.clone(), + name: col.name().to_string(), + data_type: col.data_type().clone(), }) .collect(), DataType::List { child, .. } => { diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index 6d254d63..db120055 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -214,7 +214,7 @@ impl Stripe { let columns = projected_data_type .children() .iter() - .map(|(name, data_type)| Column::new(name, data_type, &footer, info.number_of_rows())) + .map(|col| Column::new(col.name(), col.data_type(), &footer, info.number_of_rows())) .collect(); let mut stream_map = HashMap::new(); diff --git a/src/lib.rs b/src/lib.rs index 07f2134f..c2f6c1b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod arrow_reader; pub mod async_arrow_reader; pub(crate) mod builder; pub mod error; +pub mod projection; pub mod proto; pub mod reader; pub mod schema; diff --git a/src/projection.rs b/src/projection.rs new file mode 100644 index 00000000..c1ed0112 --- /dev/null +++ b/src/projection.rs @@ -0,0 +1,61 @@ +use crate::schema::RootDataType; + +// TODO: be able to nest project (project columns within struct type) + +/// Specifies which column indices to project from an ORC type. +#[derive(Debug, Clone)] +pub struct ProjectionMask { + /// Indices of column in ORC type, can refer to nested types + /// (not only root level columns) + indices: Option>, +} + +impl ProjectionMask { + /// Project all columns. + pub fn all() -> Self { + Self { indices: None } + } + + /// Project only specific columns from the root type by column index. + pub fn roots(root_data_type: &RootDataType, indices: impl IntoIterator) -> Self { + // TODO: return error if column index not found? + let input_indices = indices.into_iter().collect::>(); + // By default always project root + let mut indices = vec![0]; + root_data_type + .children() + .iter() + .filter(|col| input_indices.contains(&col.data_type().column_index())) + .for_each(|col| indices.extend(col.data_type().all_indices())); + Self { + indices: Some(indices), + } + } + + /// Project only specific columns from the root type by column name. + pub fn named_roots(root_data_type: &RootDataType, names: &[T]) -> Self + where + T: AsRef, + { + // TODO: return error if column name not found? + // By default always project root + let mut indices = vec![0]; + let names = names.iter().map(AsRef::as_ref).collect::>(); + root_data_type + .children() + .iter() + .filter(|col| names.contains(&col.name())) + .for_each(|col| indices.extend(col.data_type().all_indices())); + Self { + indices: Some(indices), + } + } + + /// Check if ORC column should is projected or not, by index. + pub fn is_index_projected(&self, index: usize) -> bool { + match &self.indices { + Some(indices) => indices.contains(&index), + None => true, + } + } +} diff --git a/src/schema.rs b/src/schema.rs index 8d6c96c8..10dc358b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use snafu::{ensure, OptionExt}; use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu}; +use crate::projection::ProjectionMask; use crate::proto; use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode}; @@ -22,7 +23,7 @@ use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, Union /// See: #[derive(Debug, Clone)] pub struct RootDataType { - children: Vec<(String, DataType)>, + children: Vec, } impl RootDataType { @@ -32,7 +33,7 @@ impl RootDataType { } /// Base columns of the file. - pub fn children(&self) -> &[(String, DataType)] { + pub fn children(&self) -> &[NamedColumn] { &self.children } @@ -41,24 +42,22 @@ impl RootDataType { let fields = self .children .iter() - .map(|(name, dt)| { - let dt = dt.to_arrow_data_type(); - Field::new(name, dt, true) + .map(|col| { + let dt = col.data_type().to_arrow_data_type(); + Field::new(col.name(), dt, true) }) .collect::>(); Schema::new_with_metadata(fields, user_metadata.clone()) } - /// Project only specific columns from the root type by column name. - pub fn project>(&self, fields: &[T]) -> Self { - // TODO: change project to accept project mask (vec of bools) instead of relying on col names? - // TODO: be able to nest project? (i.e. project child struct data type) unsure if actually desirable - let fields = fields.iter().map(AsRef::as_ref).collect::>(); + /// Create new root data type based on mask of columns to project. + pub fn project(&self, mask: &ProjectionMask) -> Self { + // TODO: fix logic here to account for nested projection let children = self .children .iter() - .filter(|c| fields.contains(&c.0.as_str())) - .map(|c| c.to_owned()) + .filter(|col| mask.is_index_projected(col.data_type().column_index())) + .map(|col| col.to_owned()) .collect::>(); Self { children } } @@ -75,18 +74,40 @@ impl Display for RootDataType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "ROOT")?; for child in &self.children { - write!(f, "\n {} {}", child.0, child.1)?; + write!(f, "\n {child}")?; } Ok(()) } } +#[derive(Debug, Clone)] +pub struct NamedColumn { + name: String, + data_type: DataType, +} + +impl NamedColumn { + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } +} + +impl Display for NamedColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.name(), self.data_type()) + } +} + /// Helper function since this is duplicated for [`RootDataType`] and [`DataType::Struct`] /// parsing from proto. fn parse_struct_children_from_proto( types: &[proto::Type], column_index: usize, -) -> Result> { +) -> Result> { // These pre-conditions should always be upheld, especially as this is a private function assert!(column_index < types.len()); let ty = &types[column_index]; @@ -107,8 +128,8 @@ fn parse_struct_children_from_proto( .map(|(&index, name)| { let index = index as usize; let name = name.to_owned(); - let dt = DataType::from_proto(types, index)?; - Ok((name, dt)) + let data_type = DataType::from_proto(types, index)?; + Ok(NamedColumn { name, data_type }) }) .collect::>>()?; Ok(children) @@ -174,7 +195,7 @@ pub enum DataType { /// collection of children types. Struct { column_index: usize, - children: Vec<(String, DataType)>, + children: Vec, }, /// Compound type where each value in the column is a list of values /// of another type, specified by the child type. @@ -227,22 +248,66 @@ impl DataType { } } + /// All children column indices. + pub fn children_indices(&self) -> Vec { + match self { + DataType::Boolean { .. } + | DataType::Byte { .. } + | DataType::Short { .. } + | DataType::Int { .. } + | DataType::Long { .. } + | DataType::Float { .. } + | DataType::Double { .. } + | DataType::String { .. } + | DataType::Varchar { .. } + | DataType::Char { .. } + | DataType::Binary { .. } + | DataType::Decimal { .. } + | DataType::Timestamp { .. } + | DataType::TimestampWithLocalTimezone { .. } + | DataType::Date { .. } => vec![], + DataType::Struct { children, .. } => children + .iter() + .flat_map(|col| col.data_type().children_indices()) + .collect(), + DataType::List { child, .. } => child.all_indices(), + DataType::Map { key, value, .. } => { + let mut indices = key.children_indices(); + indices.extend(value.children_indices()); + indices + } + DataType::Union { variants, .. } => variants + .iter() + .flat_map(|dt| dt.children_indices()) + .collect(), + } + } + + /// Includes self index and all children column indices. + pub fn all_indices(&self) -> Vec { + let mut indices = vec![self.column_index()]; + indices.extend(self.children_indices()); + indices + } + fn from_proto(types: &[proto::Type], column_index: usize) -> Result { + use proto::r#type::Kind; + let ty = types.get(column_index).context(UnexpectedSnafu { msg: format!("Column index out of bounds: {column_index}"), })?; let dt = match ty.kind() { - proto::r#type::Kind::Boolean => Self::Boolean { column_index }, - proto::r#type::Kind::Byte => Self::Byte { column_index }, - proto::r#type::Kind::Short => Self::Short { column_index }, - proto::r#type::Kind::Int => Self::Int { column_index }, - proto::r#type::Kind::Long => Self::Long { column_index }, - proto::r#type::Kind::Float => Self::Float { column_index }, - proto::r#type::Kind::Double => Self::Double { column_index }, - proto::r#type::Kind::String => Self::String { column_index }, - proto::r#type::Kind::Binary => Self::Binary { column_index }, - proto::r#type::Kind::Timestamp => Self::Timestamp { column_index }, - proto::r#type::Kind::List => { + Kind::Boolean => Self::Boolean { column_index }, + Kind::Byte => Self::Byte { column_index }, + Kind::Short => Self::Short { column_index }, + Kind::Int => Self::Int { column_index }, + Kind::Long => Self::Long { column_index }, + Kind::Float => Self::Float { column_index }, + Kind::Double => Self::Double { column_index }, + Kind::String => Self::String { column_index }, + Kind::Binary => Self::Binary { column_index }, + Kind::Timestamp => Self::Timestamp { column_index }, + Kind::List => { ensure!( ty.subtypes.len() == 1, UnexpectedSnafu { @@ -260,7 +325,7 @@ impl DataType { child, } } - proto::r#type::Kind::Map => { + Kind::Map => { ensure!( ty.subtypes.len() == 2, UnexpectedSnafu { @@ -281,14 +346,14 @@ impl DataType { value, } } - proto::r#type::Kind::Struct => { + Kind::Struct => { let children = parse_struct_children_from_proto(types, column_index)?; Self::Struct { column_index, children, } } - proto::r#type::Kind::Union => { + Kind::Union => { ensure!( ty.subtypes.len() <= 256, UnexpectedSnafu { @@ -312,23 +377,21 @@ impl DataType { variants, } } - proto::r#type::Kind::Decimal => Self::Decimal { + Kind::Decimal => Self::Decimal { column_index, precision: ty.precision(), scale: ty.scale(), }, - proto::r#type::Kind::Date => Self::Date { column_index }, - proto::r#type::Kind::Varchar => Self::Varchar { + Kind::Date => Self::Date { column_index }, + Kind::Varchar => Self::Varchar { column_index, max_length: ty.maximum_length(), }, - proto::r#type::Kind::Char => Self::Char { + Kind::Char => Self::Char { column_index, max_length: ty.maximum_length(), }, - proto::r#type::Kind::TimestampInstant => { - Self::TimestampWithLocalTimezone { column_index } - } + Kind::TimestampInstant => Self::TimestampWithLocalTimezone { column_index }, }; Ok(dt) } @@ -358,9 +421,9 @@ impl DataType { DataType::Struct { children, .. } => { let children = children .iter() - .map(|(name, dt)| { - let dt = dt.to_arrow_data_type(); - Field::new(name, dt, true) + .map(|col| { + let dt = col.data_type().to_arrow_data_type(); + Field::new(col.name(), dt, true) }) .collect(); ArrowDataType::Struct(children) @@ -434,7 +497,7 @@ impl Display for DataType { } => { write!(f, "STRUCT")?; for child in children { - write!(f, "\n {} {}", child.0, child.1)?; + write!(f, "\n {child}")?; } Ok(()) }