Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typed buffer #16

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,18 +936,48 @@ fn trans_intrinsic_type<'tcx>(
.err("#[spirv(runtime_array)] type must have size 4"));
}

// We use a generic to indicate the underlying element type.
// The spirv type of it will be generated by querying the type of the first generic.
// We use a generic param to indicate the underlying element type.
// The SPIR-V element type will be generated from the first generic param.
if let Some(elem_ty) = args.types().next() {
let element = cx.layout_of(elem_ty).spirv_type(span, cx);
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
Ok(SpirvType::RuntimeArray {
element: cx.layout_of(elem_ty).spirv_type(span, cx),
}
.def(span, cx))
} else {
Err(cx
.tcx
.dcx()
.err("#[spirv(runtime_array)] type must have a generic element type"))
}
}
IntrinsicType::TypedBuffer => {
if ty.size != Size::from_bytes(4) {
return Err(cx
.tcx
.sess
.dcx()
.err("#[spirv(typed_buffer)] type must have size 4"));
}

// We use a generic param to indicate the underlying data type.
// The SPIR-V data type will be generated from the first generic param.
if let Some(data_ty) = args.types().next() {
// HACK(eddyb) this should be a *pointer* to an "interface block",
// but SPIR-V screwed up and used no explicit indirection for the
// descriptor indexing case, and instead made a `RuntimeArray` of
// `InterfaceBlock`s be an "array of typed buffer resources".
Ok(SpirvType::InterfaceBlock {
inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
}
.def(span, cx))
} else {
Err(cx
.tcx
.sess
.dcx()
.err("#[spirv(typed_buffer)] type must have a generic data type"))
}
}
IntrinsicType::Matrix => {
let span = def_id_for_spirv_type_adt(ty)
.map(|did| cx.tcx.def_span(did))
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub enum IntrinsicType {
SampledImage,
RayQueryKhr,
RuntimeArray,
TypedBuffer,
Matrix,
}

Expand Down
6 changes: 5 additions & 1 deletion crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
};
ty = match cx.lookup_type(ty) {
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element } => element,
| SpirvType::RuntimeArray { element }
// HACK(eddyb) this is pretty bad because it's not
// checking that the index is an `OpConstant 0`, but
// there's no other valid choice anyway.
| SpirvType::InterfaceBlock { inner_type: element } => element,

SpirvType::Adt { field_types, .. } => *index_to_usize()
.and_then(|i| field_types.get(i))
Expand Down
171 changes: 98 additions & 73 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,96 +495,114 @@ impl<'tcx> CodegenCx<'tcx> {
.dcx()
.span_fatal(hir_param.ty_span, "pair type not supported yet")
}
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"?
// FIXME(eddyb) should we talk about "descriptor indexing" or
// actually use more reasonable terms like "resource arrays"?
let needs_interface_block_and_supports_descriptor_indexing = matches!(
storage_class,
Ok(StorageClass::Uniform | StorageClass::StorageBuffer)
);
let needs_interface_block = needs_interface_block_and_supports_descriptor_indexing
|| storage_class == Ok(StorageClass::PushConstant);
// NOTE(eddyb) `#[spirv(typed_buffer)]` adds `SpirvType::InterfaceBlock`s
// which must bypass the automated ones (i.e. the user is taking control).
let has_explicit_interface_block = needs_interface_block_and_supports_descriptor_indexing
&& {
// Peel off arrays first (used for "descriptor indexing").
let outermost_or_array_element = match self.lookup_type(value_spirv_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
element
}
_ => value_spirv_type,
};
matches!(
self.lookup_type(outermost_or_array_element),
SpirvType::InterfaceBlock { .. }
)
};
let var_ptr_spirv_type;
let (value_ptr, value_len) = match storage_class {
Ok(
StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer,
) => {
let var_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);

let zero_u32 = self.constant_u32(hir_param.span, 0).def_cx(self);
let value_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
let value_ptr = bx
.emit()
.in_bounds_access_chain(
value_ptr_spirv_type,
None,
var_id.unwrap(),
[zero_u32].iter().cloned(),
)
.unwrap()
.with_type(value_ptr_spirv_type);
let (value_ptr, value_len) = if needs_interface_block && !has_explicit_interface_block {
let var_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);

let zero_u32 = self.constant_u32(hir_param.span, 0).def_cx(self);
let value_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
let value_ptr = bx
.emit()
.in_bounds_access_chain(
value_ptr_spirv_type,
None,
var_id.unwrap(),
[zero_u32].iter().cloned(),
)
.unwrap()
.with_type(value_ptr_spirv_type);

let value_len = if is_unsized_with_len {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {}
_ => {
self.tcx.dcx().span_err(
hir_param.ty_span,
"only plain slices are supported as unsized types",
);
}
let value_len = if is_unsized_with_len {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {}
_ => {
self.tcx.dcx().span_err(
hir_param.ty_span,
"only plain slices are supported as unsized types",
);
}
}

// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
.unwrap();

Some(len.with_type(len_spirv_type))
} else {
if is_unsized {
// It's OK to use a RuntimeArray<u32> and not have a length parameter, but
// it's just nicer ergonomics to use a slice.
self.tcx
.dcx()
.span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray<T>");
}
None
};
// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
.unwrap();

(Ok(value_ptr), value_len)
}
Ok(StorageClass::UniformConstant) => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
Some(len.with_type(len_spirv_type))
} else {
if is_unsized {
// It's OK to use a RuntimeArray<u32> and not have a length parameter, but
// it's just nicer ergonomics to use a slice.
self.tcx
.dcx()
.span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray<T>");
}
None
};

(Ok(value_ptr), value_len)
} else {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);

// FIXME(eddyb) should we talk about "descriptor indexing" or
// actually use more reasonable terms like "resource arrays"?
let unsized_is_descriptor_indexing =
needs_interface_block_and_supports_descriptor_indexing
|| storage_class == Ok(StorageClass::UniformConstant);
if unsized_is_descriptor_indexing {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {
if is_unsized_with_len {
self.tcx.dcx().span_err(
hir_param.ty_span,
"uniform_constant must use &RuntimeArray<T>, not &[T]",
"descriptor indexing must use &RuntimeArray<T>, not &[T]",
);
}
}
_ => {
if is_unsized {
self.tcx.dcx().span_err(
hir_param.ty_span,
"only plain slices are supported as unsized types",
"only RuntimeArray is supported, not other unsized types",
);
}
}
}

let value_len = if is_pair {
// We've already emitted an error, fill in a placeholder value
Some(bx.undef(self.type_isize()))
} else {
None
};

(Ok(var_id.unwrap().with_type(var_ptr_spirv_type)), value_len)
}
_ => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);

} else {
// FIXME(eddyb) determine, based on the type, what kind of type
// this is, to narrow it further to e.g. "buffer in a non-buffer
// storage class" or "storage class expects fixed data sizes".
if is_unsized {
self.tcx.dcx().span_fatal(
hir_param.ty_span,
Expand All @@ -597,12 +615,19 @@ impl<'tcx> CodegenCx<'tcx> {
),
);
}

(
var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)),
None,
)
}

let value_len = if is_pair {
// We've already emitted an error, fill in a placeholder value
Some(bx.undef(self.type_isize()))
} else {
None
};

(
var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)),
value_len,
)
};

// Compute call argument(s) to match what the Rust entry `fn` expects,
Expand Down
10 changes: 4 additions & 6 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,8 @@ impl SpirvType<'_> {
| Self::AccelerationStructureKhr
| Self::RayQueryKhr
| Self::Sampler
| Self::SampledImage { .. } => Size::from_bytes(4),

Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).sizeof(cx)?,
| Self::SampledImage { .. }
| Self::InterfaceBlock { .. } => Size::from_bytes(4),
};
Some(result)
}
Expand Down Expand Up @@ -377,9 +376,8 @@ impl SpirvType<'_> {
| Self::AccelerationStructureKhr
| Self::RayQueryKhr
| Self::Sampler
| Self::SampledImage { .. } => Align::from_bytes(4).unwrap(),

Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).alignof(cx),
| Self::SampledImage { .. }
| Self::InterfaceBlock { .. } => Align::from_bytes(4).unwrap(),
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ impl Symbols {
"runtime_array",
SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
),
(
"typed_buffer",
SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
),
(
"matrix",
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
Expand Down
2 changes: 2 additions & 0 deletions crates/spirv-std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ mod runtime_array;
mod sampler;
pub mod scalar;
pub(crate) mod sealed;
mod typed_buffer;
pub mod vector;

pub use self::sampler::Sampler;
pub use crate::macros::Image;
pub use byte_addressable_buffer::ByteAddressableBuffer;
pub use num_traits;
pub use runtime_array::*;
pub use typed_buffer::*;

pub use glam;

Expand Down
Loading