diff --git a/src/array/null.rs b/src/array/null.rs index 258d7e21..7368a00a 100644 --- a/src/array/null.rs +++ b/src/array/null.rs @@ -1,10 +1,36 @@ use crate::{Array, ArrayIndex, ArrayType}; -use std::iter::{self, Repeat, Take}; +use std::{ + iter::{self, Repeat, Take}, + marker::PhantomData, +}; /// A sequence of nulls. -#[derive(Debug)] -pub struct NullArray { +/// +/// This array type is also used as [ArrayType] when deriving [Array] for types +/// without fields (unit types). The generic `T` is used to provide iterator +/// implementations for array of these unit types. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct NullArray { len: usize, + _ty: PhantomData T>, +} + +impl NullArray { + /// Returns a new NullArray with the given length. + /// + /// This never allocates. + pub fn with_len(len: usize) -> Self { + Self { + len, + _ty: PhantomData, + } + } + + /// Returns the number of elements in the array, also referred to as its + /// length. + pub fn len(&self) -> usize { + self.len + } } impl Array for NullArray { @@ -73,8 +99,11 @@ impl Array for NullArray { } } -impl ArrayIndex for NullArray { - type Output = (); +impl ArrayIndex for NullArray +where + T: Default, +{ + type Output = T; fn index(&self, index: usize) -> Self::Output { #[cold] @@ -87,6 +116,8 @@ impl ArrayIndex for NullArray { if index >= len { assert_failed(index, len); } + + T::default() } } @@ -94,22 +125,73 @@ impl ArrayType for () { type Array = NullArray; } -impl FromIterator<()> for NullArray { +impl FromIterator for NullArray { fn from_iter(iter: I) -> Self where - I: IntoIterator, + I: IntoIterator, { Self { len: iter.into_iter().count(), + _ty: PhantomData, } } } -impl<'a> IntoIterator for &'a NullArray { - type Item = (); - type IntoIter = Take>; +impl<'a, T> IntoIterator for &'a NullArray +where + T: Clone + Default, +{ + type Item = T; + type IntoIter = Take>; fn into_iter(self) -> Self::IntoIter { - iter::repeat(()).take(self.len) + iter::repeat(T::default()).take(self.len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_iter() { + let vec = vec![(); 100]; + let array = vec.into_iter().collect::(); + assert_eq!(array.len(), 100); + assert!(array.is_null(0)); + } + + #[test] + fn into_iter() { + let vec = vec![(); 100]; + let array = vec.iter().copied().collect::(); + assert_eq!(vec, array.into_iter().collect::>()); + } + + #[test] + fn unit_type() { + #[derive(Clone, Default, Debug, PartialEq)] + struct UnitStruct; + + let vec = vec![UnitStruct; 100]; + let array = vec.iter().cloned().collect::>(); + assert_eq!(array.len(), 100); + assert_eq!(vec, array.into_iter().collect::>()); + + #[derive(Clone, Debug, PartialEq)] + enum UnitEnum { + Unit, + } + + impl Default for UnitEnum { + fn default() -> Self { + UnitEnum::Unit + } + } + + let vec = vec![UnitEnum::default(); 100]; + let array = vec.iter().cloned().collect::>(); + assert_eq!(array.len(), 100); + assert_eq!(vec, array.into_iter().collect::>()); } }