diff --git a/rust/benches/array_from_vec.rs b/rust/benches/array_from_vec.rs index 509628d172c66..669b88eaa40d9 100644 --- a/rust/benches/array_from_vec.rs +++ b/rust/benches/array_from_vec.rs @@ -35,7 +35,7 @@ fn array_from_vec(n: usize) { let arr_data = ArrayDataBuilder::new(DataType::Int32) .add_buffer(Buffer::from(v)) .build(); - criterion::black_box(PrimitiveArray::::from(arr_data)); + criterion::black_box(Int32Array::from(arr_data)); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/rust/benches/builder.rs b/rust/benches/builder.rs index 5edc344d2ac9f..04f8a33b5bd55 100644 --- a/rust/benches/builder.rs +++ b/rust/benches/builder.rs @@ -35,9 +35,9 @@ fn bench_primitive(c: &mut Criterion) { "bench_primitive", Benchmark::new("bench_primitive", move |b| { b.iter(|| { - let mut builder = PrimitiveArrayBuilder::::new(64); + let mut builder = Int64Builder::new(64); for _ in 0..NUM_BATCHES { - black_box(builder.push_slice(&data[..])); + let _ = black_box(builder.push_slice(&data[..])); } black_box(builder.finish()); }) @@ -58,9 +58,9 @@ fn bench_bool(c: &mut Criterion) { "bench_bool", Benchmark::new("bench_bool", move |b| { b.iter(|| { - let mut builder = PrimitiveArrayBuilder::::new(64); + let mut builder = BooleanBuilder::new(64); for _ in 0..NUM_BATCHES { - black_box(builder.push_slice(&data[..])); + let _ = black_box(builder.push_slice(&data[..])); } black_box(builder.finish()); }) diff --git a/rust/examples/builders.rs b/rust/examples/builders.rs index d88370b8e7d85..5273558d966e0 100644 --- a/rust/examples/builders.rs +++ b/rust/examples/builders.rs @@ -18,7 +18,7 @@ ///! Many builders are available to easily create different types of arrow arrays extern crate arrow; -use arrow::builder::*; +use arrow::builder::{ArrayBuilder, Int32Builder}; fn main() { // Primitive Arrays @@ -27,7 +27,7 @@ fn main() { // i32, i64, f32, f64) // Create a new builder with a capacity of 100 - let mut primitive_array_builder = PrimitiveArrayBuilder::::new(100); + let mut primitive_array_builder = Int32Builder::new(100); // Push an individual primitive value primitive_array_builder.push(55).unwrap(); diff --git a/rust/examples/dynamic_types.rs b/rust/examples/dynamic_types.rs index 678564e3eccd7..8e6bb5d41c01b 100644 --- a/rust/examples/dynamic_types.rs +++ b/rust/examples/dynamic_types.rs @@ -40,7 +40,7 @@ fn main() { ]); // create some data - let id = PrimitiveArray::from(vec![1, 2, 3, 4, 5]); + let id = Int32Array::from(vec![1, 2, 3, 4, 5]); let nested = StructArray::from(vec![ ( @@ -49,11 +49,11 @@ fn main() { ), ( Field::new("b", DataType::Float64, false), - Arc::new(PrimitiveArray::from(vec![1.1, 2.2, 3.3, 4.4, 5.5])), + Arc::new(Float64Array::from(vec![1.1, 2.2, 3.3, 4.4, 5.5])), ), ( Field::new("c", DataType::Float64, false), - Arc::new(PrimitiveArray::from(vec![2.2, 3.3, 4.4, 5.5, 6.6])), + Arc::new(Float64Array::from(vec![2.2, 3.3, 4.4, 5.5, 6.6])), ), ]); @@ -75,12 +75,12 @@ fn process(batch: &RecordBatch) { let _nested_b = nested .column(1) .as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap(); - let nested_c: &PrimitiveArray = nested + let nested_c: &Float64Array = nested .column(2) .as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap(); let projected_schema = Schema::new(vec![ @@ -91,8 +91,8 @@ fn process(batch: &RecordBatch) { let _ = RecordBatch::new( Arc::new(projected_schema), vec![ - id.clone(), //NOTE: this is cloning the Arc not the array data - Arc::new(PrimitiveArray::::from(nested_c.data())), + id.clone(), // NOTE: this is cloning the Arc not the array data + Arc::new(Float64Array::from(nested_c.data())), ], ); } diff --git a/rust/examples/read_csv.rs b/rust/examples/read_csv.rs index a12cafb46c29a..df66a8112e5f2 100644 --- a/rust/examples/read_csv.rs +++ b/rust/examples/read_csv.rs @@ -17,7 +17,7 @@ extern crate arrow; -use arrow::array::{BinaryArray, PrimitiveArray}; +use arrow::array::{BinaryArray, Float64Array}; use arrow::csv; use arrow::datatypes::{DataType, Field, Schema}; use std::fs::File; @@ -49,12 +49,12 @@ fn main() { let lat = batch .column(1) .as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap(); let lng = batch .column(2) .as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap(); for i in 0..batch.num_rows() { diff --git a/rust/src/array.rs b/rust/src/array.rs index ab44dc07e6804..264aa50121f6c 100644 --- a/rust/src/array.rs +++ b/rust/src/array.rs @@ -22,9 +22,9 @@ use std::io::Write; use std::mem; use std::sync::Arc; -use array_data::*; -use buffer::*; -use builder::PrimitiveArrayBuilder; +use array_data::{ArrayData, ArrayDataRef}; +use buffer::{Buffer, MutableBuffer}; +use builder::*; use datatypes::*; use memory; use util::bit_util; @@ -80,17 +80,17 @@ fn make_array(data: ArrayDataRef) -> ArrayRef { // TODO: here data_type() needs to clone the type - maybe add a type tag enum to // avoid the cloning. match data.data_type().clone() { - DataType::Boolean => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Int8 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Int16 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Int32 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Int64 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::UInt8 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::UInt16 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::UInt32 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::UInt64 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Float32 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, - DataType::Float64 => Arc::new(PrimitiveArray::::from(data)) as ArrayRef, + DataType::Boolean => Arc::new(BooleanArray::from(data)) as ArrayRef, + DataType::Int8 => Arc::new(Int8Array::from(data)) as ArrayRef, + DataType::Int16 => Arc::new(Int16Array::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(Int32Array::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(Int64Array::from(data)) as ArrayRef, + DataType::UInt8 => Arc::new(UInt8Array::from(data)) as ArrayRef, + DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef, + DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef, + DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef, + DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef, + DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, DataType::Utf8 => Arc::new(BinaryArray::from(data)) as ArrayRef, DataType::List(_) => Arc::new(ListArray::from(data)) as ArrayRef, DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, @@ -126,188 +126,121 @@ pub struct PrimitiveArray { /// Also note that boolean arrays are bit-packed, so although the underlying pointer is of type /// bool it should be cast back to u8 before being used. /// i.e. `self.raw_values.get() as *const u8` - raw_values: RawPtrBox, + raw_values: RawPtrBox, } -/// Macro to define primitive arrays for different data types and native types. -/// Boolean arrays are bit-packed and so are not defined by this macro -macro_rules! def_primitive_array { - ($data_ty:path, $native_ty:ident) => { - impl PrimitiveArray<$native_ty> { - pub fn new(length: i64, values: Buffer, null_count: i64, offset: i64) -> Self { - let array_data = ArrayData::builder($data_ty) - .len(length) - .add_buffer(values) - .null_count(null_count) - .offset(offset) - .build(); - PrimitiveArray::from(array_data) - } - - /// Returns a `Buffer` holds all the values of this array. - /// - /// Note this doesn't take account into the offset of this array. - pub fn values(&self) -> Buffer { - self.data.buffers()[0].clone() - } +pub type BooleanArray = PrimitiveArray; +pub type Int8Array = PrimitiveArray; +pub type Int16Array = PrimitiveArray; +pub type Int32Array = PrimitiveArray; +pub type Int64Array = PrimitiveArray; +pub type UInt8Array = PrimitiveArray; +pub type UInt16Array = PrimitiveArray; +pub type UInt32Array = PrimitiveArray; +pub type UInt64Array = PrimitiveArray; +pub type Float32Array = PrimitiveArray; +pub type Float64Array = PrimitiveArray; - /// Returns the length of this array - pub fn len(&self) -> i64 { - self.data.len() - } - - /// Returns a raw pointer to the values of this array. - pub fn raw_values(&self) -> *const $native_ty { - unsafe { mem::transmute(self.raw_values.get().offset(self.data.offset() as isize)) } - } +impl Array for PrimitiveArray { + fn as_any(&self) -> &Any { + self + } - /// Returns the primitive value at index `i`. - /// - /// Note this doesn't do any bound checking, for performance reason. - pub fn value(&self, i: i64) -> $native_ty { - unsafe { *(self.raw_values().offset(i as isize)) } - } + fn data(&self) -> ArrayDataRef { + self.data.clone() + } - /// Returns a slice for the given offset and length - /// - /// Note this doesn't do any bound checking, for performance reason. - pub fn value_slice(&self, offset: i64, len: i64) -> &[$native_ty] { - let raw = - unsafe { std::slice::from_raw_parts(self.raw_values(), self.len() as usize) }; - &raw[offset as usize..offset as usize + len as usize] - } + fn data_ref(&self) -> &ArrayDataRef { + &self.data + } +} - /// Returns the minimum value in the array, according to the natural order. - pub fn min(&self) -> Option<$native_ty> { - self.min_max_helper(|a, b| a < b) - } +/// Implementation for primitive arrays with numeric types. +/// Boolean arrays are bit-packed and so implemented separately. +impl PrimitiveArray { + pub fn new(length: i64, values: Buffer, null_count: i64, offset: i64) -> Self { + let array_data = ArrayData::builder(T::get_data_type()) + .len(length) + .add_buffer(values) + .null_count(null_count) + .offset(offset) + .build(); + PrimitiveArray::from(array_data) + } - /// Returns the maximum value in the array, according to the natural order. - pub fn max(&self) -> Option<$native_ty> { - self.min_max_helper(|a, b| a > b) - } + /// Returns a `Buffer` holds all the values of this array. + /// + /// Note this doesn't take account into the offset of this array. + pub fn values(&self) -> Buffer { + self.data.buffers()[0].clone() + } - fn min_max_helper(&self, cmp: F) -> Option<$native_ty> - where - F: Fn($native_ty, $native_ty) -> bool, - { - let mut n: Option<$native_ty> = None; - let data = self.data(); - for i in 0..data.len() { - if data.is_null(i) { - continue; - } - let m = self.value(i as i64); - match n { - None => n = Some(m), - Some(nn) => { - if cmp(m, nn) { - n = Some(m) - } - } - } - } - n - } + /// Returns the length of this array + pub fn len(&self) -> i64 { + self.data.len() + } - // Returns a new primitive array builder - pub fn builder(capacity: i64) -> PrimitiveArrayBuilder<$native_ty> { - PrimitiveArrayBuilder::<$native_ty>::new(capacity) - } - } + /// Returns a raw pointer to the values of this array. + pub fn raw_values(&self) -> *const T::Native { + unsafe { mem::transmute(self.raw_values.get().offset(self.data.offset() as isize)) } + } - /// Constructs a primitive array from a vector. Should only be used for testing. - impl From> for PrimitiveArray<$native_ty> { - fn from(data: Vec<$native_ty>) -> Self { - let array_data = ArrayData::builder($data_ty) - .len(data.len() as i64) - .add_buffer(Buffer::from(data.to_byte_slice())) - .build(); - PrimitiveArray::from(array_data) - } - } + /// Returns the primitive value at index `i`. + /// + /// Note this doesn't do any bound checking, for performance reason. + pub fn value(&self, i: i64) -> T::Native { + unsafe { *(self.raw_values().offset(i as isize)) } + } - impl From>> for PrimitiveArray<$native_ty> { - fn from(data: Vec>) -> Self { - const TY_SIZE: usize = mem::size_of::<$native_ty>(); - const NULL: [u8; TY_SIZE] = [0; TY_SIZE]; + /// Returns a slice for the given offset and length + /// + /// Note this doesn't do any bound checking, for performance reason. + pub fn value_slice(&self, offset: i64, len: i64) -> &[T::Native] { + let raw = unsafe { std::slice::from_raw_parts(self.raw_values(), self.len() as usize) }; + &raw[offset as usize..offset as usize + len as usize] + } - let data_len = data.len(); - let mut null_buf = MutableBuffer::new(data_len).with_bitset(data_len, false); - let mut val_buf = MutableBuffer::new(data_len * TY_SIZE); + /// Returns the minimum value in the array, according to the natural order. + pub fn min(&self) -> Option { + self.min_max_helper(|a, b| a < b) + } - { - let null_slice = null_buf.data_mut(); - for (i, v) in data.iter().enumerate() { - if let Some(n) = v { - bit_util::set_bit(null_slice, i); - // unwrap() in the following should be safe here since we've - // made sure enough space is allocated for the values. - val_buf.write(&n.to_byte_slice()).unwrap(); - } else { - val_buf.write(&NULL).unwrap(); - } - } - } + /// Returns the maximum value in the array, according to the natural order. + pub fn max(&self) -> Option { + self.min_max_helper(|a, b| a > b) + } - let array_data = ArrayData::builder($data_ty) - .len(data_len as i64) - .add_buffer(val_buf.freeze()) - .null_bit_buffer(null_buf.freeze()) - .build(); - PrimitiveArray::from(array_data) + fn min_max_helper(&self, cmp: F) -> Option + where + F: Fn(T::Native, T::Native) -> bool, + { + let mut n: Option = None; + let data = self.data(); + for i in 0..data.len() { + if data.is_null(i) { + continue; } - } - - /// Constructs a `PrimitiveArray` from an array data reference. - impl From for PrimitiveArray<$native_ty> { - fn from(data: ArrayDataRef) -> Self { - assert_eq!( - data.buffers().len(), - 1, - "PrimitiveArray data should contain a single buffer only (values buffer)" - ); - let raw_values = data.buffers()[0].raw_data(); - assert!( - memory::is_aligned::(raw_values, mem::align_of::<$native_ty>()), - "memory is not aligned" - ); - Self { - data, - raw_values: RawPtrBox::new(raw_values as *const $native_ty), + let m = self.value(i as i64); + match n { + None => n = Some(m), + Some(nn) => { + if cmp(m, nn) { + n = Some(m) + } } } } - }; -} - -impl Array for PrimitiveArray { - fn as_any(&self) -> &Any { - self + n } - fn data(&self) -> ArrayDataRef { - self.data.clone() - } - - fn data_ref(&self) -> &ArrayDataRef { - &self.data + // Returns a new primitive array builder + pub fn builder(capacity: i64) -> PrimitiveArrayBuilder { + PrimitiveArrayBuilder::::new(capacity) } } -def_primitive_array!(DataType::UInt8, u8); -def_primitive_array!(DataType::UInt16, u16); -def_primitive_array!(DataType::UInt32, u32); -def_primitive_array!(DataType::UInt64, u64); -def_primitive_array!(DataType::Int8, i8); -def_primitive_array!(DataType::Int16, i16); -def_primitive_array!(DataType::Int32, i32); -def_primitive_array!(DataType::Int64, i64); -def_primitive_array!(DataType::Float32, f32); -def_primitive_array!(DataType::Float64, f64); - /// Specific implementation for Boolean arrays due to bit-packing -impl PrimitiveArray { +impl PrimitiveArray { pub fn new(length: i64, values: Buffer, null_count: i64, offset: i64) -> Self { let array_data = ArrayData::builder(DataType::Boolean) .len(length) @@ -315,7 +248,7 @@ impl PrimitiveArray { .null_count(null_count) .offset(offset) .build(); - PrimitiveArray::from(array_data) + BooleanArray::from(array_data) } /// Returns a `Buffer` holds all the values of this array. @@ -333,13 +266,73 @@ impl PrimitiveArray { } // Returns a new primitive array builder - pub fn builder(capacity: i64) -> PrimitiveArrayBuilder { - PrimitiveArrayBuilder::::new(capacity) + pub fn builder(capacity: i64) -> BooleanBuilder { + BooleanBuilder::new(capacity) } } +// TODO: the macro is needed here because we'd get "conflicting implementations" error +// otherwise with both `From>` and `From>>`. +// We should revisit this in future. +macro_rules! def_numeric_from_vec { + ( $ty:ident, $native_ty:ident, $ty_id:path ) => { + impl From> for PrimitiveArray<$ty> { + fn from(data: Vec<$native_ty>) -> Self { + let array_data = ArrayData::builder($ty_id) + .len(data.len() as i64) + .add_buffer(Buffer::from(data.to_byte_slice())) + .build(); + PrimitiveArray::from(array_data) + } + } + + // Constructs a primitive array from a vector. Should only be used for testing. + impl From>> for PrimitiveArray<$ty> { + fn from(data: Vec>) -> Self { + let data_len = data.len(); + let num_bytes = bit_util::ceil(data_len as i64, 8) as usize; + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let mut val_buf = MutableBuffer::new(data_len * mem::size_of::<$native_ty>()); + + { + let null = vec![0; mem::size_of::<$native_ty>()]; + let null_slice = null_buf.data_mut(); + for (i, v) in data.iter().enumerate() { + if let Some(n) = v { + bit_util::set_bit(null_slice, i); + // unwrap() in the following should be safe here since we've + // made sure enough space is allocated for the values. + val_buf.write(&n.to_byte_slice()).unwrap(); + } else { + val_buf.write(&null).unwrap(); + } + } + } + + let array_data = ArrayData::builder($ty_id) + .len(data_len as i64) + .add_buffer(val_buf.freeze()) + .null_bit_buffer(null_buf.freeze()) + .build(); + PrimitiveArray::from(array_data) + } + } + }; +} + +def_numeric_from_vec!(Int8Type, i8, DataType::Int8); +def_numeric_from_vec!(Int16Type, i16, DataType::Int16); +def_numeric_from_vec!(Int32Type, i32, DataType::Int32); +def_numeric_from_vec!(Int64Type, i64, DataType::Int64); +def_numeric_from_vec!(UInt8Type, u8, DataType::UInt8); +def_numeric_from_vec!(UInt16Type, u16, DataType::UInt16); +def_numeric_from_vec!(UInt32Type, u32, DataType::UInt32); +def_numeric_from_vec!(UInt64Type, u64, DataType::UInt64); +def_numeric_from_vec!(Float32Type, f32, DataType::Float32); +def_numeric_from_vec!(Float64Type, f64, DataType::Float64); + /// Constructs a boolean array from a vector. Should only be used for testing. -impl From> for PrimitiveArray { +impl From> for BooleanArray { fn from(data: Vec) -> Self { let num_byte = bit_util::ceil(data.len() as i64, 8) as usize; let mut mut_buf = MutableBuffer::new(num_byte).with_bitset(num_byte, false); @@ -355,11 +348,11 @@ impl From> for PrimitiveArray { .len(data.len() as i64) .add_buffer(mut_buf.freeze()) .build(); - PrimitiveArray::from(array_data) + BooleanArray::from(array_data) } } -impl From>> for PrimitiveArray { +impl From>> for BooleanArray { fn from(data: Vec>) -> Self { let data_len = data.len() as i64; let num_byte = bit_util::ceil(data_len, 8) as usize; @@ -385,13 +378,13 @@ impl From>> for PrimitiveArray { .add_buffer(val_buf.freeze()) .null_bit_buffer(null_buf.freeze()) .build(); - PrimitiveArray::from(array_data) + BooleanArray::from(array_data) } } -/// Constructs a `PrimitiveArray` from an array data reference. -impl From for PrimitiveArray { - fn from(data: ArrayDataRef) -> Self { +/// Constructs a `PrimitiveArray` from an array data reference. +impl From for PrimitiveArray { + default fn from(data: ArrayDataRef) -> Self { assert_eq!( data.buffers().len(), 1, @@ -399,18 +392,16 @@ impl From for PrimitiveArray { ); let raw_values = data.buffers()[0].raw_data(); assert!( - memory::is_aligned::(raw_values, mem::align_of::()), + memory::is_aligned::(raw_values, mem::align_of::()), "memory is not aligned" ); Self { data, - raw_values: RawPtrBox::new(raw_values as *const bool), + raw_values: RawPtrBox::new(raw_values as *const T::Native), } } } -pub type BooleanArray = PrimitiveArray; - /// A list array where each element is a variable-sized sequence of values with the same /// type. pub struct ListArray { @@ -714,7 +705,7 @@ mod tests { fn test_primitive_array_from_vec() { let buf = Buffer::from(&[0, 1, 2, 3, 4].to_byte_slice()); let buf2 = buf.clone(); - let arr = PrimitiveArray::::new(5, buf, 0, 0); + let arr = Int32Array::new(5, buf, 0, 0); let slice = unsafe { ::std::slice::from_raw_parts(arr.raw_values(), 5) }; assert_eq!(buf2, arr.values()); assert_eq!(&[0, 1, 2, 3, 4], slice); @@ -731,7 +722,7 @@ mod tests { #[test] fn test_primitive_array_from_vec_option() { // Test building a primitive array with null values - let arr = PrimitiveArray::::from(vec![Some(0), None, Some(2), None, Some(4)]); + let arr = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]); assert_eq!(5, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(2, arr.null_count()); @@ -757,7 +748,7 @@ mod tests { .offset(2) .add_buffer(buf) .build(); - let arr = PrimitiveArray::::from(data); + let arr = Int32Array::from(data); assert_eq!(buf2, arr.values()); assert_eq!(5, arr.len()); assert_eq!(0, arr.null_count()); @@ -772,7 +763,7 @@ mod tests { )] fn test_primitive_array_invalid_buffer_len() { let data = ArrayData::builder(DataType::Int32).len(5).build(); - PrimitiveArray::::from(data); + Int32Array::from(data); } #[test] @@ -780,7 +771,7 @@ mod tests { // 00000010 01001000 let buf = Buffer::from([72_u8, 2_u8]); let buf2 = buf.clone(); - let arr = PrimitiveArray::::new(10, buf, 0, 0); + let arr = BooleanArray::new(10, buf, 0, 0); assert_eq!(buf2, arr.values()); assert_eq!(10, arr.len()); assert_eq!(0, arr.offset()); @@ -795,7 +786,7 @@ mod tests { #[test] fn test_boolean_array_from_vec() { let buf = Buffer::from([10_u8]); - let arr = PrimitiveArray::::from(vec![false, true, false, true]); + let arr = BooleanArray::from(vec![false, true, false, true]); assert_eq!(buf, arr.values()); assert_eq!(4, arr.len()); assert_eq!(0, arr.offset()); @@ -810,7 +801,7 @@ mod tests { #[test] fn test_boolean_array_from_vec_option() { let buf = Buffer::from([10_u8]); - let arr = PrimitiveArray::::from(vec![Some(false), Some(true), None, Some(true)]); + let arr = BooleanArray::from(vec![Some(false), Some(true), None, Some(true)]); assert_eq!(buf, arr.values()); assert_eq!(4, arr.len()); assert_eq!(0, arr.offset()); @@ -838,7 +829,7 @@ mod tests { .offset(2) .add_buffer(buf) .build(); - let arr = PrimitiveArray::::from(data); + let arr = BooleanArray::from(data); assert_eq!(buf2, arr.values()); assert_eq!(5, arr.len()); assert_eq!(2, arr.offset()); @@ -854,7 +845,7 @@ mod tests { )] fn test_boolean_array_invalid_buffer_len() { let data = ArrayData::builder(DataType::Boolean).len(5).build(); - PrimitiveArray::::from(data); + BooleanArray::from(data); } #[test] @@ -1159,11 +1150,11 @@ mod tests { let struct_array = StructArray::from(vec![ ( Field::new("b", DataType::Boolean, false), - Arc::new(PrimitiveArray::from(vec![false, false, true, true])) as Arc, + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, ), ( Field::new("c", DataType::Int32, false), - Arc::new(PrimitiveArray::from(vec![42, 28, 19, 31])), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), ), ]); assert_eq!(boolean_data, struct_array.column(0).data()); @@ -1175,12 +1166,12 @@ mod tests { fn test_invalid_struct_child_array_lengths() { StructArray::from(vec![ ( - Field::new("b", DataType::Float64, false), - Arc::new(PrimitiveArray::from(vec![1.1])) as Arc, + Field::new("b", DataType::Float32, false), + Arc::new(Float32Array::from(vec![1.1])) as Arc, ), ( Field::new("c", DataType::Float64, false), - Arc::new(PrimitiveArray::from(vec![2.2, 3.3])), + Arc::new(Float64Array::from(vec![2.2, 3.3])), ), ]); } @@ -1192,7 +1183,7 @@ mod tests { let buf = Buffer::from_raw_parts(ptr, 8); let buf2 = buf.slice(1); let array_data = ArrayData::builder(DataType::Int32).add_buffer(buf2).build(); - PrimitiveArray::::from(array_data); + Int32Array::from(array_data); } #[test] @@ -1233,22 +1224,21 @@ mod tests { #[test] fn test_buffer_array_min_max() { - let a = PrimitiveArray::::from(vec![5, 6, 7, 8, 9]); + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); assert_eq!(5, a.min().unwrap()); assert_eq!(9, a.max().unwrap()); } #[test] fn test_buffer_array_min_max_with_nulls() { - let a = PrimitiveArray::::from(vec![Some(5), None, None, Some(8), Some(9)]); + let a = Int32Array::from(vec![Some(5), None, None, Some(8), Some(9)]); assert_eq!(5, a.min().unwrap()); assert_eq!(9, a.max().unwrap()); } #[test] fn test_access_array_concurrently() { - let a = PrimitiveArray::::from(vec![5, 6, 7, 8, 9]); - + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); let ret = thread::spawn(move || a.value(3)).join(); assert!(ret.is_ok()); diff --git a/rust/src/builder.rs b/rust/src/builder.rs index fefbd4b053039..df6b645312e23 100644 --- a/rust/src/builder.rs +++ b/rust/src/builder.rs @@ -23,114 +23,124 @@ use std::io::Write; use std::marker::PhantomData; use std::mem; -use array::{Array, BinaryArray, ListArray, PrimitiveArray}; +use array::*; use array_data::ArrayData; use buffer::{Buffer, MutableBuffer}; -use datatypes::{ArrowPrimitiveType, DataType, ToByteSlice}; +use datatypes::*; use error::{ArrowError, Result}; use util::bit_util; /// Buffer builder with zero-copy build method -pub struct BufferBuilder -where - T: ArrowPrimitiveType, -{ +pub struct BufferBuilder { buffer: MutableBuffer, len: i64, _marker: PhantomData, } -macro_rules! impl_buffer_builder { - ($native_ty:ident) => { - impl BufferBuilder<$native_ty> { - /// Creates a builder with a fixed initial capacity - pub fn new(capacity: i64) -> Self { - let buffer = MutableBuffer::new(capacity as usize * mem::size_of::<$native_ty>()); - Self { - buffer, - len: 0, - _marker: PhantomData, - } - } +pub type BooleanBufferBuilder = BufferBuilder; +pub type Int8BufferBuilder = BufferBuilder; +pub type Int16BufferBuilder = BufferBuilder; +pub type Int32BufferBuilder = BufferBuilder; +pub type Int64BufferBuilder = BufferBuilder; +pub type UInt8BufferBuilder = BufferBuilder; +pub type UInt16BufferBuilder = BufferBuilder; +pub type UInt32BufferBuilder = BufferBuilder; +pub type UInt64BufferBuilder = BufferBuilder; +pub type Float32BufferBuilder = BufferBuilder; +pub type Float64BufferBuilder = BufferBuilder; + +// Trait for buffer builder. This is used mainly to offer separate implementations for +// numeric types and boolean types, while still be able to call methods on buffer builder +// with generic primitive type. +pub trait BufferBuilderTrait { + fn new(capacity: i64) -> Self; + fn len(&self) -> i64; + fn capacity(&self) -> i64; + fn advance(&mut self, i: i64) -> Result<()>; + fn reserve(&mut self, n: i64) -> Result<()>; + fn push(&mut self, v: T::Native) -> Result<()>; + fn push_slice(&mut self, slice: &[T::Native]) -> Result<()>; + fn finish(self) -> Buffer; +} - /// Returns the number of array elements (slots) in the builder - pub fn len(&self) -> i64 { - self.len - } +impl BufferBuilderTrait for BufferBuilder { + /// Creates a builder with a fixed initial capacity + default fn new(capacity: i64) -> Self { + let buffer = MutableBuffer::new(capacity as usize * mem::size_of::()); + Self { + buffer, + len: 0, + _marker: PhantomData, + } + } - // Advances the `len` of the underlying `Buffer` by `i` slots of type T - fn advance(&mut self, i: i64) -> Result<()> { - let new_buffer_len = (self.len + i) as usize * mem::size_of::<$native_ty>(); - self.buffer.resize(new_buffer_len)?; - self.len += i; - Ok(()) - } + /// Returns the number of array elements (slots) in the builder + fn len(&self) -> i64 { + self.len + } - /// Returns the current capacity of the builder (number of elements) - pub fn capacity(&self) -> i64 { - let byte_capacity = self.buffer.capacity(); - (byte_capacity / mem::size_of::<$native_ty>()) as i64 - } + /// Returns the current capacity of the builder (number of elements) + fn capacity(&self) -> i64 { + let bit_capacity = self.buffer.capacity() * 8; + (bit_capacity / T::get_bit_width()) as i64 + } - /// Pushes a value into the builder, growing the internal buffer as needed. - pub fn push(&mut self, v: $native_ty) -> Result<()> { - self.reserve(1)?; - self.write_bytes(v.to_byte_slice(), 1) - } + // Advances the `len` of the underlying `Buffer` by `i` slots of type T + default fn advance(&mut self, i: i64) -> Result<()> { + let new_buffer_len = (self.len + i) as usize * mem::size_of::(); + self.buffer.resize(new_buffer_len)?; + self.len += i; + Ok(()) + } - /// Pushes a slice of type `T`, growing the internal buffer as needed. - pub fn push_slice(&mut self, slice: &[$native_ty]) -> Result<()> { - let array_slots = slice.len() as i64; - self.reserve(array_slots)?; - self.write_bytes(slice.to_byte_slice(), array_slots) - } + /// Reserves memory for `n` elements of type `T`. + default fn reserve(&mut self, n: i64) -> Result<()> { + let new_capacity = self.len + n; + let byte_capacity = mem::size_of::() * new_capacity as usize; + self.buffer.reserve(byte_capacity)?; + Ok(()) + } - /// Reserves memory for `n` elements of type `T`. - pub fn reserve(&mut self, n: i64) -> Result<()> { - let new_capacity = self.len + n; - let byte_capacity = mem::size_of::<$native_ty>() * new_capacity as usize; - self.buffer.reserve(byte_capacity)?; - Ok(()) - } + /// Pushes a value into the builder, growing the internal buffer as needed. + default fn push(&mut self, v: T::Native) -> Result<()> { + self.reserve(1)?; + self.write_bytes(v.to_byte_slice(), 1) + } - /// Consumes this builder and returns an immutable `Buffer`. - pub fn finish(self) -> Buffer { - self.buffer.freeze() - } + /// Pushes a slice of type `T`, growing the internal buffer as needed. + default fn push_slice(&mut self, slice: &[T::Native]) -> Result<()> { + let array_slots = slice.len() as i64; + self.reserve(array_slots)?; + self.write_bytes(slice.to_byte_slice(), array_slots) + } - /// Writes a byte slice to the underlying buffer and updates the `len`, i.e. the number array - /// elements in the builder. Also, converts the `io::Result` required by the `Write` trait - /// to the Arrow `Result` type. - fn write_bytes(&mut self, bytes: &[u8], len_added: i64) -> Result<()> { - let write_result = self.buffer.write(bytes); - // `io::Result` has many options one of which we use, so pattern matching is overkill here - if write_result.is_err() { - Err(ArrowError::MemoryError( - "Could not write to Buffer, not big enough".to_string(), - )) - } else { - self.len += len_added; - Ok(()) - } - } + /// Consumes this builder and returns an immutable `Buffer`. + default fn finish(self) -> Buffer { + self.buffer.freeze() + } +} + +impl BufferBuilder { + /// Writes a byte slice to the underlying buffer and updates the `len`, i.e. the number array + /// elements in the builder. Also, converts the `io::Result` required by the `Write` trait + /// to the Arrow `Result` type. + fn write_bytes(&mut self, bytes: &[u8], len_added: i64) -> Result<()> { + let write_result = self.buffer.write(bytes); + // `io::Result` has many options one of which we use, so pattern matching is overkill here + if write_result.is_err() { + Err(ArrowError::MemoryError( + "Could not write to Buffer, not big enough".to_string(), + )) + } else { + self.len += len_added; + Ok(()) } - }; + } } -impl_buffer_builder!(u8); -impl_buffer_builder!(u16); -impl_buffer_builder!(u32); -impl_buffer_builder!(u64); -impl_buffer_builder!(i8); -impl_buffer_builder!(i16); -impl_buffer_builder!(i32); -impl_buffer_builder!(i64); -impl_buffer_builder!(f32); -impl_buffer_builder!(f64); - -impl BufferBuilder { +impl BufferBuilderTrait for BufferBuilder { /// Creates a builder with a fixed initial capacity. - pub fn new(capacity: i64) -> Self { + fn new(capacity: i64) -> Self { let byte_capacity = bit_util::ceil(capacity, 8); let actual_capacity = bit_util::round_upto_multiple_of_64(byte_capacity) as usize; let mut buffer = MutableBuffer::new(actual_capacity); @@ -142,27 +152,16 @@ impl BufferBuilder { } } - /// Returns the number of array elements (slots) in the builder. - pub fn len(&self) -> i64 { - self.len - } - // Advances the `len` of the underlying `Buffer` by `i` slots of type T - pub fn advance(&mut self, i: i64) -> Result<()> { + fn advance(&mut self, i: i64) -> Result<()> { let new_buffer_len = bit_util::ceil(self.len + i, 8); self.buffer.resize(new_buffer_len as usize)?; self.len += i; Ok(()) } - /// Returns the current capacity of the builder (number of elements) - pub fn capacity(&self) -> i64 { - let byte_capacity = self.buffer.capacity() as i64; - byte_capacity * 8 - } - /// Pushes a value into the builder, growing the internal buffer as needed. - pub fn push(&mut self, v: bool) -> Result<()> { + fn push(&mut self, v: bool) -> Result<()> { self.reserve(1)?; if v { // For performance the `len` of the buffer is not updated on each push but @@ -176,7 +175,7 @@ impl BufferBuilder { } /// Pushes a slice of type `T`, growing the internal buffer as needed. - pub fn push_slice(&mut self, slice: &[bool]) -> Result<()> { + fn push_slice(&mut self, slice: &[bool]) -> Result<()> { let array_slots = slice.len(); for i in 0..array_slots { self.push(slice[i])?; @@ -185,7 +184,7 @@ impl BufferBuilder { } /// Reserves memory for `n` elements of type `T`. - pub fn reserve(&mut self, n: i64) -> Result<()> { + fn reserve(&mut self, n: i64) -> Result<()> { let new_capacity = self.len + n; if new_capacity > self.capacity() { let new_byte_capacity = bit_util::ceil(new_capacity, 8) as usize; @@ -198,7 +197,7 @@ impl BufferBuilder { } /// Consumes this and returns an immutable `Buffer`. - pub fn finish(mut self) -> Buffer { + fn finish(mut self) -> Buffer { // `push` does not update the buffer's `len` so do it before `freeze` is called. let new_buffer_len = bit_util::ceil(self.len, 8) as usize; debug_assert!(new_buffer_len >= self.buffer.len()); @@ -224,110 +223,100 @@ pub trait ArrayBuilder { } /// Array builder for fixed-width primitive types -pub struct PrimitiveArrayBuilder -where - T: ArrowPrimitiveType, -{ +pub struct PrimitiveArrayBuilder { values_builder: BufferBuilder, - bitmap_builder: BufferBuilder, + bitmap_builder: BooleanBufferBuilder, } -macro_rules! impl_primitive_array_builder { - ($data_ty:path, $native_ty:ident) => { - impl ArrayBuilder for PrimitiveArrayBuilder<$native_ty> { - type ArrayType = PrimitiveArray<$native_ty>; +pub type BooleanBuilder = PrimitiveArrayBuilder; +pub type Int8Builder = PrimitiveArrayBuilder; +pub type Int16Builder = PrimitiveArrayBuilder; +pub type Int32Builder = PrimitiveArrayBuilder; +pub type Int64Builder = PrimitiveArrayBuilder; +pub type UInt8Builder = PrimitiveArrayBuilder; +pub type UInt16Builder = PrimitiveArrayBuilder; +pub type UInt32Builder = PrimitiveArrayBuilder; +pub type UInt64Builder = PrimitiveArrayBuilder; +pub type Float32Builder = PrimitiveArrayBuilder; +pub type Float64Builder = PrimitiveArrayBuilder; + +impl ArrayBuilder for PrimitiveArrayBuilder { + type ArrayType = PrimitiveArray; - /// Returns the builder as an owned `Any` type so that it can be `downcast` to a specific - /// implementation before calling it's `finish` method - fn into_any(self) -> Box { - Box::new(self) - } + /// Returns the builder as an owned `Any` type so that it can be `downcast` to a specific + /// implementation before calling it's `finish` method + fn into_any(self) -> Box { + Box::new(self) + } - /// Returns the number of array slots in the builder - fn len(&self) -> i64 { - self.values_builder.len - } + /// Returns the number of array slots in the builder + fn len(&self) -> i64 { + self.values_builder.len + } - /// Builds the PrimitiveArray - fn finish(self) -> PrimitiveArray<$native_ty> { - let len = self.len(); - let null_bit_buffer = self.bitmap_builder.finish(); - let data = ArrayData::builder($data_ty) - .len(len) - .null_count(len - bit_util::count_set_bits(null_bit_buffer.data())) - .add_buffer(self.values_builder.finish()) - .null_bit_buffer(null_bit_buffer) - .build(); - PrimitiveArray::<$native_ty>::from(data) - } - } + /// Builds the PrimitiveArray + fn finish(self) -> PrimitiveArray { + let len = self.len(); + let null_bit_buffer = self.bitmap_builder.finish(); + let data = ArrayData::builder(T::get_data_type()) + .len(len) + .null_count(len - bit_util::count_set_bits(null_bit_buffer.data())) + .add_buffer(self.values_builder.finish()) + .null_bit_buffer(null_bit_buffer) + .build(); + PrimitiveArray::::from(data) + } +} - impl PrimitiveArrayBuilder<$native_ty> { - /// Creates a new primitive array builder - pub fn new(capacity: i64) -> Self { - Self { - values_builder: BufferBuilder::<$native_ty>::new(capacity), - bitmap_builder: BufferBuilder::::new(capacity), - } - } +impl PrimitiveArrayBuilder { + /// Creates a new primitive array builder + pub fn new(capacity: i64) -> Self { + Self { + values_builder: BufferBuilder::::new(capacity), + bitmap_builder: BooleanBufferBuilder::new(capacity), + } + } - /// Returns the capacity of this builder measured in slots of type `T` - pub fn capacity(&self) -> i64 { - self.values_builder.capacity() - } + /// Returns the capacity of this builder measured in slots of type `T` + pub fn capacity(&self) -> i64 { + self.values_builder.capacity() + } - /// Pushes a value of type `T` into the builder - pub fn push(&mut self, v: $native_ty) -> Result<()> { - self.bitmap_builder.push(true)?; - self.values_builder.push(v)?; - Ok(()) - } + /// Pushes a value of type `T` into the builder + pub fn push(&mut self, v: T::Native) -> Result<()> { + self.bitmap_builder.push(true)?; + self.values_builder.push(v)?; + Ok(()) + } - /// Pushes a null slot into the builder - pub fn push_null(&mut self) -> Result<()> { - self.bitmap_builder.push(false)?; - self.values_builder.advance(1)?; - Ok(()) - } + /// Pushes a null slot into the builder + pub fn push_null(&mut self) -> Result<()> { + self.bitmap_builder.push(false)?; + self.values_builder.advance(1)?; + Ok(()) + } - /// Pushes an `Option` into the builder - pub fn push_option(&mut self, v: Option<$native_ty>) -> Result<()> { - match v { - None => self.push_null()?, - Some(v) => self.push(v)?, - }; - Ok(()) - } + /// Pushes an `Option` into the builder + pub fn push_option(&mut self, v: Option) -> Result<()> { + match v { + None => self.push_null()?, + Some(v) => self.push(v)?, + }; + Ok(()) + } - /// Pushes a slice of type `T` into the builder - pub fn push_slice(&mut self, v: &[$native_ty]) -> Result<()> { - self.bitmap_builder.push_slice(&vec![true; v.len()][..])?; - self.values_builder.push_slice(v)?; - Ok(()) - } - } - }; + /// Pushes a slice of type `T` into the builder + pub fn push_slice(&mut self, v: &[T::Native]) -> Result<()> { + self.bitmap_builder.push_slice(&vec![true; v.len()][..])?; + self.values_builder.push_slice(v)?; + Ok(()) + } } -impl_primitive_array_builder!(DataType::Boolean, bool); -impl_primitive_array_builder!(DataType::UInt8, u8); -impl_primitive_array_builder!(DataType::UInt16, u16); -impl_primitive_array_builder!(DataType::UInt32, u32); -impl_primitive_array_builder!(DataType::UInt64, u64); -impl_primitive_array_builder!(DataType::Int8, i8); -impl_primitive_array_builder!(DataType::Int16, i16); -impl_primitive_array_builder!(DataType::Int32, i32); -impl_primitive_array_builder!(DataType::Int64, i64); -impl_primitive_array_builder!(DataType::Float32, f32); -impl_primitive_array_builder!(DataType::Float64, f64); - /// Array builder for `ListArray` -pub struct ListArrayBuilder -where - T: ArrayBuilder, -{ - offsets_builder: BufferBuilder, - bitmap_builder: BufferBuilder, +pub struct ListArrayBuilder { + offsets_builder: Int32BufferBuilder, + bitmap_builder: BooleanBufferBuilder, values_builder: T, len: i64, } @@ -335,11 +324,11 @@ where impl ListArrayBuilder { /// Creates a new `ListArrayBuilder` from a given values array builder pub fn new(values_builder: T) -> Self { - let mut offsets_builder = BufferBuilder::::new(values_builder.len() + 1); + let mut offsets_builder = Int32BufferBuilder::new(values_builder.len() + 1); offsets_builder.push(0).unwrap(); Self { offsets_builder, - bitmap_builder: BufferBuilder::::new(values_builder.len()), + bitmap_builder: BooleanBufferBuilder::new(values_builder.len()), values_builder, len: 0, } @@ -408,32 +397,32 @@ macro_rules! impl_list_array_builder { }; } -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(PrimitiveArrayBuilder); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); -impl_list_array_builder!(ListArrayBuilder>); +impl_list_array_builder!(BooleanBuilder); +impl_list_array_builder!(UInt8Builder); +impl_list_array_builder!(UInt16Builder); +impl_list_array_builder!(UInt32Builder); +impl_list_array_builder!(UInt64Builder); +impl_list_array_builder!(Int8Builder); +impl_list_array_builder!(Int16Builder); +impl_list_array_builder!(Int32Builder); +impl_list_array_builder!(Int64Builder); +impl_list_array_builder!(Float32Builder); +impl_list_array_builder!(Float64Builder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); +impl_list_array_builder!(ListArrayBuilder); /// Array builder for `BinaryArray` pub struct BinaryArrayBuilder { - builder: ListArrayBuilder>, + builder: ListArrayBuilder, } impl ArrayBuilder for BinaryArrayBuilder { @@ -459,7 +448,7 @@ impl ArrayBuilder for BinaryArrayBuilder { impl BinaryArrayBuilder { /// Creates a new `BinaryArrayBuilder`, `capacity` is the number of bytes in the values array pub fn new(capacity: i64) -> Self { - let values_builder = PrimitiveArrayBuilder::::new(capacity); + let values_builder = UInt8Builder::new(capacity); Self { builder: ListArrayBuilder::new(values_builder), } @@ -492,14 +481,13 @@ impl BinaryArrayBuilder { #[cfg(test)] mod tests { - use array::Array; use super::*; #[test] fn test_builder_i32_empty() { - let b = BufferBuilder::::new(5); + let b = Int32BufferBuilder::new(5); assert_eq!(0, b.len()); assert_eq!(16, b.capacity()); let a = b.finish(); @@ -508,7 +496,7 @@ mod tests { #[test] fn test_builder_i32_alloc_zero_bytes() { - let mut b = BufferBuilder::::new(0); + let mut b = Int32BufferBuilder::new(0); b.push(123).unwrap(); let a = b.finish(); assert_eq!(4, a.len()); @@ -516,7 +504,7 @@ mod tests { #[test] fn test_builder_i32() { - let mut b = BufferBuilder::::new(5); + let mut b = Int32BufferBuilder::new(5); for i in 0..5 { b.push(i).unwrap(); } @@ -527,7 +515,7 @@ mod tests { #[test] fn test_builder_i32_grow_buffer() { - let mut b = BufferBuilder::::new(2); + let mut b = Int32BufferBuilder::new(2); assert_eq!(16, b.capacity()); for i in 0..20 { b.push(i).unwrap(); @@ -539,14 +527,14 @@ mod tests { #[test] fn test_reserve() { - let mut b = BufferBuilder::::new(2); + let mut b = UInt8BufferBuilder::new(2); assert_eq!(64, b.capacity()); b.reserve(64).unwrap(); assert_eq!(64, b.capacity()); b.reserve(65).unwrap(); assert_eq!(128, b.capacity()); - let mut b = BufferBuilder::::new(2); + let mut b = Int32BufferBuilder::new(2); assert_eq!(16, b.capacity()); b.reserve(16).unwrap(); assert_eq!(16, b.capacity()); @@ -556,13 +544,13 @@ mod tests { #[test] fn test_push_slice() { - let mut b = BufferBuilder::::new(0); + let mut b = UInt8BufferBuilder::new(0); b.push_slice("Hello, ".as_bytes()).unwrap(); b.push_slice("World!".as_bytes()).unwrap(); let buffer = b.finish(); assert_eq!(13, buffer.len()); - let mut b = BufferBuilder::::new(0); + let mut b = Int32BufferBuilder::new(0); b.push_slice(&[32, 54]).unwrap(); let buffer = b.finish(); assert_eq!(8, buffer.len()); @@ -570,7 +558,7 @@ mod tests { #[test] fn test_write_bytes() { - let mut b = BufferBuilder::::new(4); + let mut b = BooleanBufferBuilder::new(4); b.push(false).unwrap(); b.push(true).unwrap(); b.push(false).unwrap(); @@ -580,7 +568,7 @@ mod tests { let buffer = b.finish(); assert_eq!(1, buffer.len()); - let mut b = BufferBuilder::::new(4); + let mut b = BooleanBufferBuilder::new(4); b.push_slice(&[false, true, false, true]).unwrap(); assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); @@ -590,7 +578,7 @@ mod tests { #[test] fn test_write_bytes_i32() { - let mut b = BufferBuilder::::new(4); + let mut b = Int32BufferBuilder::new(4); let bytes = [8, 16, 32, 64].to_byte_slice(); b.write_bytes(bytes, 4).unwrap(); assert_eq!(4, b.len()); @@ -602,7 +590,7 @@ mod tests { #[test] #[should_panic(expected = "Could not write to Buffer, not big enough")] fn test_write_too_many_bytes() { - let mut b = BufferBuilder::::new(0); + let mut b = Int32BufferBuilder::new(0); let bytes = [8, 16, 32, 64].to_byte_slice(); b.write_bytes(bytes, 4).unwrap(); } @@ -611,7 +599,7 @@ mod tests { fn test_boolean_builder_increases_buffer_len() { // 00000010 01001000 let buf = Buffer::from([72_u8, 2_u8]); - let mut builder = BufferBuilder::::new(8); + let mut builder = BooleanBufferBuilder::new(8); for i in 0..10 { if i == 3 || i == 6 || i == 9 { @@ -628,7 +616,7 @@ mod tests { #[test] fn test_primitive_array_builder_i32() { - let mut builder = PrimitiveArray::::builder(5); + let mut builder = Int32Array::builder(5); for i in 0..5 { builder.push(i).unwrap(); } @@ -647,7 +635,7 @@ mod tests { fn test_primitive_array_builder_bool() { // 00000010 01001000 let buf = Buffer::from([72_u8, 2_u8]); - let mut builder = PrimitiveArray::::builder(10); + let mut builder = BooleanArray::builder(10); for i in 0..10 { if i == 3 || i == 6 || i == 9 { builder.push(true).unwrap(); @@ -670,9 +658,9 @@ mod tests { #[test] fn test_primitive_array_builder_push_option() { - let arr1 = PrimitiveArray::::from(vec![Some(0), None, Some(2), None, Some(4)]); + let arr1 = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]); - let mut builder = PrimitiveArray::::builder(5); + let mut builder = Int32Array::builder(5); builder.push_option(Some(0)).unwrap(); builder.push_option(None).unwrap(); builder.push_option(Some(2)).unwrap(); @@ -694,9 +682,9 @@ mod tests { #[test] fn test_primitive_array_builder_push_null() { - let arr1 = PrimitiveArray::::from(vec![Some(0), Some(2), None, None, Some(4)]); + let arr1 = Int32Array::from(vec![Some(0), Some(2), None, None, Some(4)]); - let mut builder = PrimitiveArray::::builder(5); + let mut builder = Int32Array::builder(5); builder.push(0).unwrap(); builder.push(2).unwrap(); builder.push_null().unwrap(); @@ -718,9 +706,9 @@ mod tests { #[test] fn test_primitive_array_builder_push_slice() { - let arr1 = PrimitiveArray::::from(vec![Some(0), Some(2), None, None, Some(4)]); + let arr1 = Int32Array::from(vec![Some(0), Some(2), None, None, Some(4)]); - let mut builder = PrimitiveArray::::builder(5); + let mut builder = Int32Array::builder(5); builder.push_slice(&[0, 2]).unwrap(); builder.push_null().unwrap(); builder.push_null().unwrap(); @@ -741,7 +729,7 @@ mod tests { #[test] fn test_list_array_builder() { - let values_builder = PrimitiveArrayBuilder::::new(10); + let values_builder = Int32Builder::new(10); let mut builder = ListArrayBuilder::new(values_builder); // [[0, 1, 2], [3, 4, 5], [6, 7]] @@ -780,7 +768,7 @@ mod tests { #[test] fn test_list_array_builder_nulls() { - let values_builder = PrimitiveArrayBuilder::::new(10); + let values_builder = Int32Builder::new(10); let mut builder = ListArrayBuilder::new(values_builder); // [[0, 1, 2], null, [3, null, 5], [6, 7]] @@ -807,7 +795,7 @@ mod tests { #[test] fn test_list_list_array_builder() { - let primitive_builder = PrimitiveArrayBuilder::::new(10); + let primitive_builder = Int32Builder::new(10); let values_builder = ListArrayBuilder::new(primitive_builder); let mut builder = ListArrayBuilder::new(values_builder); diff --git a/rust/src/csv/reader.rs b/rust/src/csv/reader.rs index cbe53bb076107..dcb35958c5d89 100644 --- a/rust/src/csv/reader.rs +++ b/rust/src/csv/reader.rs @@ -45,8 +45,8 @@ use std::io::BufReader; use std::sync::Arc; use array::{ArrayRef, BinaryArray}; -use builder::{ArrayBuilder, ListArrayBuilder, PrimitiveArrayBuilder}; -use datatypes::{DataType, Schema}; +use builder::*; +use datatypes::*; use error::{ArrowError, Result}; use record_batch::RecordBatch; @@ -87,17 +87,27 @@ impl Reader { } } -macro_rules! build_primitive_array { - ($ROWS:expr, $COL_INDEX:expr, $TY:ty) => {{ - let mut builder = PrimitiveArrayBuilder::<$TY>::new($ROWS.len() as i64); - for row_index in 0..$ROWS.len() { - match $ROWS[row_index].get(*$COL_INDEX) { - Some(s) if s.len() > 0 => builder.push(s.parse::<$TY>().unwrap()).unwrap(), - _ => builder.push_null().unwrap(), - } +fn build_primitive_array( + rows: &[StringRecord], + col_idx: &usize, +) -> Result { + let mut builder = PrimitiveArrayBuilder::::new(rows.len() as i64); + for row_index in 0..rows.len() { + match rows[row_index].get(*col_idx) { + Some(s) if s.len() > 0 => match s.parse::() { + Ok(v) => builder.push(v)?, + Err(_) => { + // TODO: we should surface the underlying error here. + return Err(ArrowError::ParseError(format!( + "Error while parsing value {}", + s + ))); + } + }, + _ => builder.push_null().unwrap(), } - Ok(Arc::new(builder.finish()) as ArrayRef) - }}; + } + Ok(Arc::new(builder.finish()) as ArrayRef) } impl Reader { @@ -133,26 +143,25 @@ impl Reader { .collect(), }; + let rows = &rows[..]; let arrays: Result> = projection .iter() .map(|i| { let field = self.schema.field(*i); - match field.data_type() { - &DataType::Boolean => build_primitive_array!(rows, i, bool), - &DataType::Int8 => build_primitive_array!(rows, i, i8), - &DataType::Int16 => build_primitive_array!(rows, i, i16), - &DataType::Int32 => build_primitive_array!(rows, i, i32), - &DataType::Int64 => build_primitive_array!(rows, i, i64), - &DataType::UInt8 => build_primitive_array!(rows, i, u8), - &DataType::UInt16 => build_primitive_array!(rows, i, u16), - &DataType::UInt32 => build_primitive_array!(rows, i, u32), - &DataType::UInt64 => build_primitive_array!(rows, i, u64), - &DataType::Float32 => build_primitive_array!(rows, i, f32), - &DataType::Float64 => build_primitive_array!(rows, i, f64), + &DataType::Boolean => build_primitive_array::(rows, i), + &DataType::Int8 => build_primitive_array::(rows, i), + &DataType::Int16 => build_primitive_array::(rows, i), + &DataType::Int32 => build_primitive_array::(rows, i), + &DataType::Int64 => build_primitive_array::(rows, i), + &DataType::UInt8 => build_primitive_array::(rows, i), + &DataType::UInt16 => build_primitive_array::(rows, i), + &DataType::UInt32 => build_primitive_array::(rows, i), + &DataType::UInt64 => build_primitive_array::(rows, i), + &DataType::Float32 => build_primitive_array::(rows, i), + &DataType::Float64 => build_primitive_array::(rows, i), &DataType::Utf8 => { - let mut values_builder: PrimitiveArrayBuilder = - PrimitiveArrayBuilder::::new(rows.len() as i64); + let mut values_builder: UInt8Builder = UInt8Builder::new(rows.len() as i64); let mut list_builder = ListArrayBuilder::new(values_builder); for row_index in 0..rows.len() { match rows[row_index].get(*i) { @@ -186,7 +195,7 @@ impl Reader { mod tests { use super::*; - use array::PrimitiveArray; + use array::*; use datatypes::Field; #[test] @@ -208,7 +217,7 @@ mod tests { let lat = batch .column(1) .as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap(); assert_eq!(57.653484, lat.value(0)); diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index a9bd855f059fe..fdb9351e61abc 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -24,6 +24,7 @@ use std::fmt; use std::mem::size_of; use std::slice::from_raw_parts; +use std::str::FromStr; use error::{ArrowError, Result}; use serde_json::Value; @@ -68,23 +69,65 @@ pub struct Field { nullable: bool, } +pub trait ArrowNativeType: Send + Sync + Copy + PartialOrd + FromStr + 'static {} + /// Trait indicating a primitive fixed-width type (bool, ints and floats). -/// -/// This trait is a marker trait to indicate a primitive type, i.e. a type that occupies a fixed -/// size in memory as indicated in bit or byte width. -pub trait ArrowPrimitiveType: Send + Sync + Copy + PartialOrd + 'static {} - -impl ArrowPrimitiveType for bool {} -impl ArrowPrimitiveType for u8 {} -impl ArrowPrimitiveType for u16 {} -impl ArrowPrimitiveType for u32 {} -impl ArrowPrimitiveType for u64 {} -impl ArrowPrimitiveType for i8 {} -impl ArrowPrimitiveType for i16 {} -impl ArrowPrimitiveType for i32 {} -impl ArrowPrimitiveType for i64 {} -impl ArrowPrimitiveType for f32 {} -impl ArrowPrimitiveType for f64 {} +pub trait ArrowPrimitiveType: 'static { + /// Corresponding Rust native type for the primitive type. + type Native: ArrowNativeType; + + /// Returns the corresponding Arrow data type of this primitive type. + fn get_data_type() -> DataType; + + /// Returns the bit width of this primitive type. + fn get_bit_width() -> usize; +} + +macro_rules! make_type { + ($name:ident, $native_ty:ty, $data_ty:path, $bit_width:expr) => { + impl ArrowNativeType for $native_ty {} + + pub struct $name {} + + impl ArrowPrimitiveType for $name { + type Native = $native_ty; + + fn get_data_type() -> DataType { + $data_ty + } + + fn get_bit_width() -> usize { + $bit_width + } + } + }; +} + +make_type!(BooleanType, bool, DataType::Boolean, 1); +make_type!(Int8Type, i8, DataType::Int8, 8); +make_type!(Int16Type, i16, DataType::Int16, 16); +make_type!(Int32Type, i32, DataType::Int32, 32); +make_type!(Int64Type, i64, DataType::Int64, 64); +make_type!(UInt8Type, u8, DataType::UInt8, 8); +make_type!(UInt16Type, u16, DataType::UInt16, 16); +make_type!(UInt32Type, u32, DataType::UInt32, 32); +make_type!(UInt64Type, u64, DataType::UInt64, 64); +make_type!(Float32Type, f32, DataType::Float32, 32); +make_type!(Float64Type, f64, DataType::Float64, 64); + +/// A subtype of primitive type that represents numeric values. +pub trait ArrowNumericType: ArrowPrimitiveType {} + +impl ArrowNumericType for Int8Type {} +impl ArrowNumericType for Int16Type {} +impl ArrowNumericType for Int32Type {} +impl ArrowNumericType for Int64Type {} +impl ArrowNumericType for UInt8Type {} +impl ArrowNumericType for UInt16Type {} +impl ArrowNumericType for UInt32Type {} +impl ArrowNumericType for UInt64Type {} +impl ArrowNumericType for Float32Type {} +impl ArrowNumericType for Float64Type {} /// Allows conversion from supported Arrow types to a byte slice. pub trait ToByteSlice { @@ -92,20 +135,14 @@ pub trait ToByteSlice { fn to_byte_slice(&self) -> &[u8]; } -impl ToByteSlice for [T] -where - T: ArrowPrimitiveType, -{ +impl ToByteSlice for [T] { fn to_byte_slice(&self) -> &[u8] { let raw_ptr = self.as_ptr() as *const T as *const u8; unsafe { from_raw_parts(raw_ptr, self.len() * size_of::()) } } } -impl ToByteSlice for T -where - T: ArrowPrimitiveType, -{ +impl ToByteSlice for T { fn to_byte_slice(&self) -> &[u8] { let raw_ptr = self as *const T as *const u8; unsafe { from_raw_parts(raw_ptr, size_of::()) } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cc6d3ff73f1ab..b2db090cf7c87 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#![feature(specialization)] + extern crate bytes; extern crate csv as csv_crate; extern crate libc; diff --git a/rust/src/record_batch.rs b/rust/src/record_batch.rs index fe1f39f700d4a..cde1122aadc0a 100644 --- a/rust/src/record_batch.rs +++ b/rust/src/record_batch.rs @@ -82,7 +82,7 @@ mod tests { .len(5) .add_buffer(Buffer::from(v.to_byte_slice())) .build(); - let a = PrimitiveArray::::from(array_data); + let a = Int32Array::from(array_data); let v = vec![b'a', b'b', b'c', b'd', b'e']; let offset_data = vec![0, 1, 2, 3, 4, 5, 6]; diff --git a/rust/src/tensor.rs b/rust/src/tensor.rs index a9f933ffb6eec..e50a3136d2ba1 100644 --- a/rust/src/tensor.rs +++ b/rust/src/tensor.rs @@ -20,14 +20,11 @@ use std::marker::PhantomData; use std::mem; use buffer::Buffer; -use datatypes::{ArrowPrimitiveType, DataType}; +use datatypes::*; /// Computes the strides required assuming a row major memory layout -fn compute_row_major_strides(shape: &Vec) -> Vec -where - T: ArrowPrimitiveType, -{ - let mut remaining_bytes = mem::size_of::(); +fn compute_row_major_strides(shape: &Vec) -> Vec { + let mut remaining_bytes = mem::size_of::(); for i in shape { remaining_bytes = remaining_bytes .checked_mul(*i as usize) @@ -43,11 +40,8 @@ where } /// Computes the strides required assuming a column major memory layout -fn compute_column_major_strides(shape: &Vec) -> Vec -where - T: ArrowPrimitiveType, -{ - let mut remaining_bytes = mem::size_of::(); +fn compute_column_major_strides(shape: &Vec) -> Vec { + let mut remaining_bytes = mem::size_of::(); let mut strides = Vec::::new(); for i in shape { strides.push(remaining_bytes as i64); @@ -59,10 +53,7 @@ where } /// Tensor of primitive types -pub struct Tensor<'a, T> -where - T: ArrowPrimitiveType, -{ +pub struct Tensor<'a, T: ArrowPrimitiveType> { data_type: DataType, buffer: Buffer, shape: Option>, @@ -71,203 +62,199 @@ where _marker: PhantomData, } -macro_rules! impl_tensor { - ($data_ty:path, $native_ty:ident) => { - impl<'a> Tensor<'a, $native_ty> { - /// Creates a new `Tensor` - pub fn new( - buffer: Buffer, - shape: Option>, - strides: Option>, - names: Option>, - ) -> Self { - match &shape { - None => { +pub type BooleanTensor<'a> = Tensor<'a, BooleanType>; +pub type Int8Tensor<'a> = Tensor<'a, Int8Type>; +pub type Int16Tensor<'a> = Tensor<'a, Int16Type>; +pub type Int32Tensor<'a> = Tensor<'a, Int32Type>; +pub type Int64Tensor<'a> = Tensor<'a, Int64Type>; +pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>; +pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>; +pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>; +pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>; +pub type Float32Tensor<'a> = Tensor<'a, Float32Type>; +pub type Float64Tensor<'a> = Tensor<'a, Float64Type>; + +impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> { + /// Creates a new `Tensor` + pub fn new( + buffer: Buffer, + shape: Option>, + strides: Option>, + names: Option>, + ) -> Self { + match &shape { + None => { + assert_eq!( + buffer.len(), + mem::size_of::(), + "underlying buffer should only contain a single tensor element" + ); + assert_eq!(None, strides); + assert_eq!(None, names); + } + Some(ref s) => { + strides + .iter() + .map(|i| assert_eq!(s.len(), i.len(), "shape and stride dimensions differ")) + .next(); + names + .iter() + .map(|i| { assert_eq!( - buffer.len(), - mem::size_of::<$native_ty>(), - "underlying buffer should only contain a single tensor element" - ); - assert_eq!(None, strides); - assert_eq!(None, names); - } - Some(ref s) => { - strides - .iter() - .map(|i| { - assert_eq!(s.len(), i.len(), "shape and stride dimensions differ") - }) - .next(); - names - .iter() - .map(|i| { - assert_eq!( - s.len(), - i.len(), - "number of dimensions and number of dimension names differ" - ) - }) - .next(); - } - }; - Self { - data_type: $data_ty, - buffer, - shape, - strides, - names, - _marker: PhantomData, - } + s.len(), + i.len(), + "number of dimensions and number of dimension names differ" + ) + }) + .next(); } + }; + Self { + data_type: T::get_data_type(), + buffer, + shape, + strides, + names, + _marker: PhantomData, + } + } - /// Creates a new Tensor using row major memory layout - pub fn new_row_major( - buffer: Buffer, - shape: Option>, - names: Option>, - ) -> Self { - let strides = match &shape { - None => None, - Some(ref s) => Some(compute_row_major_strides::<$native_ty>(&s)), - }; - Self::new(buffer, shape, strides, names) - } + /// Creates a new Tensor using row major memory layout + pub fn new_row_major( + buffer: Buffer, + shape: Option>, + names: Option>, + ) -> Self { + let strides = match &shape { + None => None, + Some(ref s) => Some(compute_row_major_strides::(&s)), + }; + Self::new(buffer, shape, strides, names) + } - /// Creates a new Tensor using column major memory layout - pub fn new_column_major( - buffer: Buffer, - shape: Option>, - names: Option>, - ) -> Self { - let strides = match &shape { - None => None, - Some(ref s) => Some(compute_column_major_strides::<$native_ty>(&s)), - }; - Self::new(buffer, shape, strides, names) - } + /// Creates a new Tensor using column major memory layout + pub fn new_column_major( + buffer: Buffer, + shape: Option>, + names: Option>, + ) -> Self { + let strides = match &shape { + None => None, + Some(ref s) => Some(compute_column_major_strides::(&s)), + }; + Self::new(buffer, shape, strides, names) + } - /// The data type of the `Tensor` - pub fn data_type(&self) -> &DataType { - &self.data_type - } + /// The data type of the `Tensor` + pub fn data_type(&self) -> &DataType { + &self.data_type + } - /// The sizes of the dimensions - pub fn shape(&self) -> Option<&Vec> { - self.shape.as_ref() - } + /// The sizes of the dimensions + pub fn shape(&self) -> Option<&Vec> { + self.shape.as_ref() + } - /// Returns a reference to the underlying `Buffer` - pub fn data(&self) -> &Buffer { - &self.buffer - } + /// Returns a reference to the underlying `Buffer` + pub fn data(&self) -> &Buffer { + &self.buffer + } - /// The number of bytes between elements in each dimension - pub fn strides(&self) -> Option<&Vec> { - self.strides.as_ref() - } + /// The number of bytes between elements in each dimension + pub fn strides(&self) -> Option<&Vec> { + self.strides.as_ref() + } - /// The names of the dimensions - pub fn names(&self) -> Option<&Vec<&'a str>> { - self.names.as_ref() - } + /// The names of the dimensions + pub fn names(&self) -> Option<&Vec<&'a str>> { + self.names.as_ref() + } - /// The number of dimensions - pub fn ndim(&self) -> i64 { - match &self.shape { - None => 0, - Some(v) => v.len() as i64, - } - } + /// The number of dimensions + pub fn ndim(&self) -> i64 { + match &self.shape { + None => 0, + Some(v) => v.len() as i64, + } + } - /// The name of dimension i - pub fn dim_name(&self, i: i64) -> Option<&'a str> { - match &self.names { - None => None, - Some(ref names) => Some(&names[i as usize]), - } - } + /// The name of dimension i + pub fn dim_name(&self, i: i64) -> Option<&'a str> { + match &self.names { + None => None, + Some(ref names) => Some(&names[i as usize]), + } + } - /// The total number of elements in the `Tensor` - pub fn size(&self) -> i64 { - (self.buffer.len() / mem::size_of::<$native_ty>()) as i64 - } + /// The total number of elements in the `Tensor` + pub fn size(&self) -> i64 { + (self.buffer.len() / mem::size_of::()) as i64 + } - /// Indicates if the data is laid out contiguously in memory - pub fn is_contiguous(&self) -> bool { - self.is_row_major() || self.is_column_major() - } + /// Indicates if the data is laid out contiguously in memory + pub fn is_contiguous(&self) -> bool { + self.is_row_major() || self.is_column_major() + } - /// Indicates if the memory layout row major - pub fn is_row_major(&self) -> bool { - match self.shape { - None => false, - Some(ref s) => Some(compute_row_major_strides::<$native_ty>(s)) == self.strides, - } - } + /// Indicates if the memory layout row major + pub fn is_row_major(&self) -> bool { + match self.shape { + None => false, + Some(ref s) => Some(compute_row_major_strides::(s)) == self.strides, + } + } - /// Indicates if the memory layout column major - pub fn is_column_major(&self) -> bool { - match self.shape { - None => false, - Some(ref s) => { - Some(compute_column_major_strides::<$native_ty>(s)) == self.strides - } - } - } + /// Indicates if the memory layout column major + pub fn is_column_major(&self) -> bool { + match self.shape { + None => false, + Some(ref s) => Some(compute_column_major_strides::(s)) == self.strides, } - }; + } } -impl_tensor!(DataType::UInt8, u8); -impl_tensor!(DataType::UInt16, u16); -impl_tensor!(DataType::UInt32, u32); -impl_tensor!(DataType::UInt64, u64); -impl_tensor!(DataType::Int8, i8); -impl_tensor!(DataType::Int16, i16); -impl_tensor!(DataType::Int32, i32); -impl_tensor!(DataType::Int64, i64); -impl_tensor!(DataType::Float32, f32); -impl_tensor!(DataType::Float64, f64); - #[cfg(test)] mod tests { use super::*; use buffer::Buffer; - use builder::BufferBuilder; + use builder::*; #[test] fn test_compute_row_major_strides() { assert_eq!( vec![48, 8], - compute_row_major_strides::(&vec![4_i64, 6]) + compute_row_major_strides::(&vec![4_i64, 6]) ); assert_eq!( vec![24, 4], - compute_row_major_strides::(&vec![4_i64, 6]) + compute_row_major_strides::(&vec![4_i64, 6]) + ); + assert_eq!( + vec![6, 1], + compute_row_major_strides::(&vec![4_i64, 6]) ); - assert_eq!(vec![6, 1], compute_row_major_strides::(&vec![4_i64, 6])); } #[test] fn test_compute_column_major_strides() { assert_eq!( vec![8, 32], - compute_column_major_strides::(&vec![4_i64, 6]) + compute_column_major_strides::(&vec![4_i64, 6]) ); assert_eq!( vec![4, 16], - compute_column_major_strides::(&vec![4_i64, 6]) + compute_column_major_strides::(&vec![4_i64, 6]) ); assert_eq!( vec![1, 4], - compute_column_major_strides::(&vec![4_i64, 6]) + compute_column_major_strides::(&vec![4_i64, 6]) ); } #[test] fn test_zero_dim() { let buf = Buffer::from(&[1]); - let tensor = Tensor::::new(buf, None, None, None); + let tensor = UInt8Tensor::new(buf, None, None, None); assert_eq!(1, tensor.size()); assert_eq!(None, tensor.shape()); assert_eq!(None, tensor.names()); @@ -277,7 +264,7 @@ mod tests { assert_eq!(false, tensor.is_contiguous()); let buf = Buffer::from(&[1, 2, 2, 2]); - let tensor = Tensor::::new(buf, None, None, None); + let tensor = Int32Tensor::new(buf, None, None, None); assert_eq!(1, tensor.size()); assert_eq!(None, tensor.shape()); assert_eq!(None, tensor.names()); @@ -289,12 +276,12 @@ mod tests { #[test] fn test_tensor() { - let mut builder = BufferBuilder::::new(16); + let mut builder = Int32BufferBuilder::new(16); for i in 0..16 { builder.push(i).unwrap(); } let buf = builder.finish(); - let tensor = Tensor::::new(buf, Some(vec![2, 8]), None, None); + let tensor = Int32Tensor::new(buf, Some(vec![2, 8]), None, None); assert_eq!(16, tensor.size()); assert_eq!(Some(vec![2_i64, 8]).as_ref(), tensor.shape()); assert_eq!(None, tensor.strides()); @@ -304,12 +291,12 @@ mod tests { #[test] fn test_new_row_major() { - let mut builder = BufferBuilder::::new(16); + let mut builder = Int32BufferBuilder::new(16); for i in 0..16 { builder.push(i).unwrap(); } let buf = builder.finish(); - let tensor = Tensor::::new_row_major(buf, Some(vec![2, 8]), None); + let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None); assert_eq!(16, tensor.size()); assert_eq!(Some(vec![2_i64, 8]).as_ref(), tensor.shape()); assert_eq!(Some(vec![32_i64, 4]).as_ref(), tensor.strides()); @@ -322,12 +309,12 @@ mod tests { #[test] fn test_new_column_major() { - let mut builder = BufferBuilder::::new(16); + let mut builder = Int32BufferBuilder::new(16); for i in 0..16 { builder.push(i).unwrap(); } let buf = builder.finish(); - let tensor = Tensor::::new_column_major(buf, Some(vec![2, 8]), None); + let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None); assert_eq!(16, tensor.size()); assert_eq!(Some(vec![2_i64, 8]).as_ref(), tensor.shape()); assert_eq!(Some(vec![4_i64, 8]).as_ref(), tensor.strides()); @@ -340,13 +327,13 @@ mod tests { #[test] fn test_with_names() { - let mut builder = BufferBuilder::::new(8); + let mut builder = Int64BufferBuilder::new(8); for i in 0..8 { builder.push(i).unwrap(); } let buf = builder.finish(); let names = vec!["Dim 1", "Dim 2"]; - let tensor = Tensor::::new_column_major(buf, Some(vec![2, 4]), Some(names)); + let tensor = Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)); assert_eq!(8, tensor.size()); assert_eq!(Some(vec![2_i64, 4]).as_ref(), tensor.shape()); assert_eq!(Some(vec![8_i64, 16]).as_ref(), tensor.strides()); @@ -361,23 +348,23 @@ mod tests { #[test] #[should_panic(expected = "shape and stride dimensions differ")] fn test_inconsistent_strides() { - let mut builder = BufferBuilder::::new(16); + let mut builder = Int32BufferBuilder::new(16); for i in 0..16 { builder.push(i).unwrap(); } let buf = builder.finish(); - Tensor::::new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None); + Int32Tensor::new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None); } #[test] #[should_panic(expected = "number of dimensions and number of dimension names differ")] fn test_inconsistent_names() { - let mut builder = BufferBuilder::::new(16); + let mut builder = Int32BufferBuilder::new(16); for i in 0..16 { builder.push(i).unwrap(); } let buf = builder.finish(); - Tensor::::new( + Int32Tensor::new( buf, Some(vec![2, 8]), Some(vec![4, 8]),