Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce UnsafeFlag to manage disabling ArrayData validation #7027

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::mem;
use std::ops::Range;
use std::sync::Arc;

use crate::data::private::UnsafeFlag;
use crate::{equal, validate_binary_view, validate_string_view};

#[inline]
Expand Down Expand Up @@ -255,6 +256,10 @@ impl ArrayData {
buffers: Vec<Buffer>,
child_data: Vec<ArrayData>,
) -> Self {
let mut skip_validation = UnsafeFlag::new();
// SAFETY: caller responsible for ensuring data is valid
skip_validation.set(true);

ArrayDataBuilder {
data_type,
len,
Expand All @@ -265,8 +270,7 @@ impl ArrayData {
buffers,
child_data,
align_buffers: false,
// SAFETY: caller responsible for ensuring data is valid
skip_validation: true,
skip_validation,
}
.build()
.unwrap()
Expand Down Expand Up @@ -1779,6 +1783,34 @@ impl PartialEq for ArrayData {
}
}

mod private {
/// A boolean flag that cannot be mutated outside of unsafe code.
///
/// Defaults to a value of false.
///
/// This structure is used to enforce safety in the [`ArrayDataBuilder`]
#[derive(Debug)]
pub struct UnsafeFlag(bool);

impl UnsafeFlag {
/// Creates a new `UnsafeFlag` with the value set to `false`
#[inline]
pub const fn new() -> Self {
Self(false)
}

#[inline]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the point of structure is that the compiler enforces it can only ever be set to true via unsafe code

pub unsafe fn set(&mut self, val: bool) {
self.0 = val;
}

#[inline]
pub fn get(&self) -> bool {
self.0
}
}
}

/// Builder for [`ArrayData`] type
#[derive(Debug)]
pub struct ArrayDataBuilder {
Expand All @@ -1803,7 +1835,7 @@ pub struct ArrayDataBuilder {
/// This flag can only be set to true using `unsafe` APIs. However, once true
/// subsequent calls to `build()` may result in undefined behavior if the data
/// is not valid.
skip_validation: bool,
skip_validation: UnsafeFlag,
}

impl ArrayDataBuilder {
Expand All @@ -1820,7 +1852,7 @@ impl ArrayDataBuilder {
buffers: vec![],
child_data: vec![],
align_buffers: false,
skip_validation: false,
skip_validation: UnsafeFlag::new(),
}
}

Expand Down Expand Up @@ -1957,7 +1989,7 @@ impl ArrayDataBuilder {
}

// SAFETY: `skip_validation` is only set to true using `unsafe` APIs
if !skip_validation || cfg!(feature = "force_validate") {
if !skip_validation.get() || cfg!(feature = "force_validate") {
data.validate_data()?;
}
Ok(data)
Expand Down Expand Up @@ -2003,7 +2035,7 @@ impl ArrayDataBuilder {
/// If validation is skipped, the buffers must form a valid Arrow array,
/// otherwise undefined behavior will result
pub unsafe fn skip_validation(mut self, skip_validation: bool) -> Self {
self.skip_validation = skip_validation;
self.skip_validation.set(skip_validation);
self
}
}
Expand All @@ -2020,7 +2052,7 @@ impl From<ArrayData> for ArrayDataBuilder {
null_bit_buffer: None,
null_count: None,
align_buffers: false,
skip_validation: false,
skip_validation: UnsafeFlag::new(),
}
}
}
Expand Down
Loading