Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ByteAddressableBuffer: allow reading from read-only buffer #17

Merged
merged 7 commits into from
Oct 10, 2024
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### 🚨BREAKING🚨
- Signed for loops like `for _ in 0..4i32 {}` no longer compile. We recommend switching to unsigned for loops and casting back to signed integers in the meanwhile.

### Changed 🛠
- [PR#17](https://github.com/Rust-GPU/rust-gpu/pull/17) refactor ByteAddressableBuffer to allow reading from read-only buffers
- [PR#14](https://github.com/Rust-GPU/rust-gpu/pull/14) add subgroup intrinsics matching glsl's [`GL_KHR_shader_subgroup`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt)
- [PR#13](https://github.com/Rust-GPU/rust-gpu/pull/13) allow cargo features to be passed to the shader crate
- [PR#12](https://github.com/rust-gpu/rust-gpu/pull/12) updated toolchain to `nightly-2024-04-24`
Expand Down
124 changes: 85 additions & 39 deletions crates/spirv-std/src/byte_addressable_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,74 +43,120 @@ unsafe fn buffer_store_intrinsic<T>(
.write(value);
}

/// `ByteAddressableBuffer` is an untyped blob of data, allowing loads and stores of arbitrary
/// basic data types at arbitrary indices. However, all data must be aligned to size 4, each
/// element within the data (e.g. struct fields) must have a size and alignment of a multiple of 4,
/// and the `byte_index` passed to load and store must be a multiple of 4 (`byte_index` will be
/// rounded down to the nearest multiple of 4). So, it's not technically a *byte* addressable
/// buffer, but rather a *word* buffer, but this naming and behavior was inherited from HLSL (where
/// it's UB to pass in an index not a multiple of 4).
/// `ByteAddressableBuffer` is a view to an untyped blob of data, allowing
/// loads and stores of arbitrary basic data types at arbitrary indices.
///
/// # Alignment
/// All data must be aligned to size 4, each element within the data (e.g.
/// struct fields) must have a size and alignment of a multiple of 4, and the
/// `byte_index` passed to load and store must be a multiple of 4. Technically
/// it is not a *byte* addressable buffer, but rather a *word* buffer, but this
/// naming and behavior was inherited from HLSL (where it's UB to pass in an
/// index not a multiple of 4).
///
/// # Safety
/// Using these functions allows reading a different type from the buffer than
/// was originally written (by a previous `store()` or the host API), allowing
/// all sorts of safety guarantees to be bypassed, making it effectively a
/// transmute.
#[repr(transparent)]
pub struct ByteAddressableBuffer<'a> {
pub struct ByteAddressableBuffer<T> {
/// The underlying array of bytes, able to be directly accessed.
pub data: &'a mut [u32],
pub data: T,
}

impl<'a> ByteAddressableBuffer<'a> {
fn bounds_check<T>(data: &[u32], byte_index: u32) {
let sizeof = mem::size_of::<T>() as u32;
if byte_index % 4 != 0 {
panic!("`byte_index` should be a multiple of 4");
}
if byte_index + sizeof > data.len() as u32 {
let last_byte = byte_index + sizeof;
panic!(
"index out of bounds: the len is {} but loading {} bytes at `byte_index` {} reads until {} (exclusive)",
data.len(),
sizeof,
byte_index,
last_byte,
);
}
}

impl<'a> ByteAddressableBuffer<&'a [u32]> {
/// Creates a `ByteAddressableBuffer` from the untyped blob of data.
#[inline]
pub fn new(data: &'a mut [u32]) -> Self {
pub fn from_slice(data: &'a [u32]) -> Self {
Self { data }
}

/// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4.
/// Loads an arbitrary type from the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute)
/// See [`Self`].
pub unsafe fn load<T>(&self, byte_index: u32) -> T {
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
panic!("Index out of range");
}
bounds_check::<T>(self.data, byte_index);
buffer_load_intrinsic(self.data, byte_index)
}

/// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4. Bounds checking is not
/// performed.
/// Loads an arbitrary type from the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute). Additionally, bounds checking is not performed.
/// See [`Self`]. Additionally, bounds or alignment checking is not performed.
pub unsafe fn load_unchecked<T>(&self, byte_index: u32) -> T {
buffer_load_intrinsic(self.data, byte_index)
}
}

impl<'a> ByteAddressableBuffer<&'a mut [u32]> {
/// Creates a `ByteAddressableBuffer` from the untyped blob of data.
#[inline]
pub fn from_mut_slice(data: &'a mut [u32]) -> Self {
Self { data }
}

/// Create a non-mutable `ByteAddressableBuffer` from this mutable one.
#[inline]
pub fn as_ref(&self) -> ByteAddressableBuffer<&[u32]> {
ByteAddressableBuffer { data: self.data }
}

/// Loads an arbitrary type from the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// See [`Self`].
#[inline]
pub unsafe fn load<T>(&self, byte_index: u32) -> T {
self.as_ref().load(byte_index)
}

/// Loads an arbitrary type from the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// See [`Self`]. Additionally, bounds or alignment checking is not performed.
#[inline]
pub unsafe fn load_unchecked<T>(&self, byte_index: u32) -> T {
self.as_ref().load_unchecked(byte_index)
}

/// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4.
/// Stores an arbitrary type into the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute)
/// See [`Self`].
pub unsafe fn store<T>(&mut self, byte_index: u32, value: T) {
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
panic!("Index out of range");
}
bounds_check::<T>(self.data, byte_index);
buffer_store_intrinsic(self.data, byte_index, value);
}

/// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4. Bounds checking is not
/// performed.
/// Stores an arbitrary type into the buffer. `byte_index` must be a
/// multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute). Additionally, bounds checking is not performed.
/// See [`Self`]. Additionally, bounds or alignment checking is not performed.
pub unsafe fn store_unchecked<T>(&mut self, byte_index: u32, value: T) {
buffer_store_intrinsic(self.data, byte_index, value);
}
Expand Down
15 changes: 13 additions & 2 deletions tests/ui/byte_addressable_buffer/arr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer};

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
out: &mut [i32; 4],
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut [i32; 4],
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}
Expand All @@ -20,7 +31,7 @@ pub fn store(
#[spirv(flat)] val: [i32; 4],
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
15 changes: 13 additions & 2 deletions tests/ui/byte_addressable_buffer/big_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@ pub struct BigStruct {

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}
Expand All @@ -29,7 +40,7 @@ pub fn store(
#[spirv(flat)] val: BigStruct,
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
15 changes: 13 additions & 2 deletions tests/ui/byte_addressable_buffer/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,22 @@ pub struct Nesty {

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
out: &mut Nesty,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut Nesty,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}
Expand All @@ -35,7 +46,7 @@ pub fn store(
val: Nesty,
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
15 changes: 13 additions & 2 deletions tests/ui/byte_addressable_buffer/empty_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@ pub struct EmptyStruct {}

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
out: &mut EmptyStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut EmptyStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}
Expand All @@ -20,7 +31,7 @@ pub fn load(
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32]) {
let val = EmptyStruct {};
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
14 changes: 11 additions & 3 deletions tests/ui/byte_addressable_buffer/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@ use spirv_std::spirv;
use spirv_std::ByteAddressableBuffer;

#[spirv(fragment)]
pub fn load(
pub fn load(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], out: &mut f32) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut f32,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: f32) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
15 changes: 13 additions & 2 deletions tests/ui/byte_addressable_buffer/small_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,22 @@ pub struct SmallStruct {

#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32],
out: &mut SmallStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::from_slice(buf);
*out = buf.load(5);
}
}

#[spirv(fragment)]
pub fn load_mut(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut SmallStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
let buf = ByteAddressableBuffer::from_mut_slice(buf);
*out = buf.load(5);
}
}
Expand All @@ -27,7 +38,7 @@ pub fn store(
) {
let val = SmallStruct { a, b };
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
let mut buf = ByteAddressableBuffer::from_mut_slice(buf);
buf.store(5, val);
}
}
Loading