diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index be7e5f86a04..4314b550d68 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -28,6 +28,7 @@ use std::mem; use std::ops::Range; use std::sync::Arc; +use crate::data::private::UnsafeFlag; use crate::{equal, validate_binary_view, validate_string_view}; #[inline] @@ -255,6 +256,10 @@ impl ArrayData { buffers: Vec, child_data: Vec, ) -> Self { + let mut skip_validation = UnsafeFlag::new(); + // SAFETY: caller responsible for ensuring data is valid + skip_validation.set(true); + ArrayDataBuilder { data_type, len, @@ -265,8 +270,7 @@ impl ArrayData { buffers, child_data, align_buffers: false, - // SAFETY: caller responsible for ensuring data is valid - skip_validation: true, + skip_validation, } .build() .unwrap() @@ -1779,6 +1783,36 @@ impl PartialEq for ArrayData { } } +mod private { + /// A boolean flag that cannot be mutated outside of unsafe code. + /// + /// Defaults to a value of false. + /// + /// This structure is used to enforce safety in the [`ArrayDataBuilder`] + /// + /// [`ArrayDataBuilder`]: super::ArrayDataBuilder + #[derive(Debug)] + pub struct UnsafeFlag(bool); + + impl UnsafeFlag { + /// Creates a new `UnsafeFlag` with the value set to `false` + #[inline] + pub const fn new() -> Self { + Self(false) + } + + #[inline] + pub unsafe fn set(&mut self, val: bool) { + self.0 = val; + } + + #[inline] + pub fn get(&self) -> bool { + self.0 + } + } +} + /// Builder for [`ArrayData`] type #[derive(Debug)] pub struct ArrayDataBuilder { @@ -1803,7 +1837,7 @@ pub struct ArrayDataBuilder { /// This flag can only be set to true using `unsafe` APIs. However, once true /// subsequent calls to `build()` may result in undefined behavior if the data /// is not valid. - skip_validation: bool, + skip_validation: UnsafeFlag, } impl ArrayDataBuilder { @@ -1820,7 +1854,7 @@ impl ArrayDataBuilder { buffers: vec![], child_data: vec![], align_buffers: false, - skip_validation: false, + skip_validation: UnsafeFlag::new(), } } @@ -1957,7 +1991,7 @@ impl ArrayDataBuilder { } // SAFETY: `skip_validation` is only set to true using `unsafe` APIs - if !skip_validation || cfg!(feature = "force_validate") { + if !skip_validation.get() || cfg!(feature = "force_validate") { data.validate_data()?; } Ok(data) @@ -2003,7 +2037,7 @@ impl ArrayDataBuilder { /// If validation is skipped, the buffers must form a valid Arrow array, /// otherwise undefined behavior will result pub unsafe fn skip_validation(mut self, skip_validation: bool) -> Self { - self.skip_validation = skip_validation; + self.skip_validation.set(skip_validation); self } } @@ -2020,7 +2054,7 @@ impl From for ArrayDataBuilder { null_bit_buffer: None, null_count: None, align_buffers: false, - skip_validation: false, + skip_validation: UnsafeFlag::new(), } } }