Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Validate image expressions #607

Merged
merged 4 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ pub enum ShaderStage {
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
pub enum StorageClass {
/// Function locals.
Function,
Expand Down
6 changes: 5 additions & 1 deletion src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Figures out the following properties:
- expression reference counts
!*/

use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ValidationFlags};
use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypeResolution},
Expand Down Expand Up @@ -164,6 +164,8 @@ impl ExpressionInfo {
pub struct FunctionInfo {
/// Validation flags.
flags: ValidationFlags,
/// Set of shader stages where calling this function is valid.
pub available_stages: ShaderStages,
/// Uniformity characteristics.
pub uniformity: Uniformity,
/// Function may kill the invocation.
Expand Down Expand Up @@ -676,6 +678,7 @@ impl ModuleInfo {
) -> Result<FunctionInfo, FunctionError> {
let mut info = FunctionInfo {
flags,
available_stages: ShaderStages::all(),
uniformity: Uniformity::new(),
may_kill: false,
sampling_set: crate::FastHashSet::default(),
Expand Down Expand Up @@ -779,6 +782,7 @@ fn uniform_control_flow() {

let mut info = FunctionInfo {
flags: ValidationFlags::all(),
available_stages: ShaderStages::all(),
uniformity: Uniformity::new(),
may_kill: false,
sampling_set: crate::FastHashSet::default(),
Expand Down
375 changes: 349 additions & 26 deletions src/valid/expression.rs

Large diffs are not rendered by default.

111 changes: 87 additions & 24 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ pub enum FunctionError {
pointer: Handle<crate::Expression>,
value: Handle<crate::Expression>,
},
#[error("The image array can't be indexed by {0:?}")]
InvalidArrayIndex(Handle<crate::Expression>),
#[error("The expression {0:?} is currupted")]
InvalidExpression(Handle<crate::Expression>),
#[error("The expression {0:?} is not an image")]
InvalidImage(Handle<crate::Expression>),
#[error("Image store parameters are invalid")]
InvalidImageStore(#[source] ExpressionError),
#[error("Call to {function:?} is invalid")]
InvalidCall {
function: Handle<crate::Function>,
Expand Down Expand Up @@ -120,6 +118,7 @@ struct BlockContext<'a> {
info: &'a FunctionInfo,
expressions: &'a Arena<crate::Expression>,
types: &'a Arena<crate::Type>,
global_vars: &'a Arena<crate::GlobalVariable>,
functions: &'a Arena<crate::Function>,
return_type: Option<Handle<crate::Type>>,
}
Expand All @@ -131,6 +130,7 @@ impl<'a> BlockContext<'a> {
info,
expressions: &fun.expressions,
types: &module.types,
global_vars: &module.global_variables,
functions: &module.functions,
return_type: fun.result.as_ref().map(|fr| fr.ty),
}
Expand All @@ -142,6 +142,7 @@ impl<'a> BlockContext<'a> {
info: self.info,
expressions: self.expressions,
types: self.types,
global_vars: self.global_vars,
functions: self.functions,
return_type: self.return_type,
}
Expand Down Expand Up @@ -329,7 +330,7 @@ impl super::Validator {
let mut current = pointer;
loop {
let _ = context.resolve_type(current)?;
match context.expressions[current] {
match *context.get_expression(current)? {
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => current = base,
crate::Expression::LocalVariable(_)
Expand Down Expand Up @@ -368,28 +369,82 @@ impl super::Validator {
}
S::ImageStore {
image,
coordinate: _,
coordinate,
array_index,
value,
} => {
let _expected_coordinate_ty = match *context.get_expression(image)? {
crate::Expression::GlobalVariable(_var_handle) => (), //TODO
_ => return Err(FunctionError::InvalidImage(image)),
};
match *context.resolve_type(value)? {
Ti::Scalar { .. } | Ti::Vector { .. } => {}
//Note: this code uses a lot of `FunctionError::InvalidImageStore`,
// and could probably be refactored.
let var = match *context.get_expression(image)? {
crate::Expression::GlobalVariable(var_handle) => {
&context.global_vars[var_handle]
}
_ => {
return Err(FunctionError::InvalidStoreValue(value));
return Err(FunctionError::InvalidImageStore(
ExpressionError::ExpectedGlobalVariable,
))
}
}
if let Some(expr) = array_index {
match *context.resolve_type(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
} => (),
_ => return Err(FunctionError::InvalidArrayIndex(expr)),
};

let value_ty = match context.types[var.ty].inner {
Ti::Image {
class,
arrayed,
dim,
} => {
match context
.resolve_type(coordinate)?
.image_storage_coordinates()
{
Some(coord_dim) if coord_dim == dim => {}
_ => {
return Err(FunctionError::InvalidImageStore(
ExpressionError::InvalidImageCoordinateType(
dim, coordinate,
),
))
}
};
if arrayed != array_index.is_some() {
return Err(FunctionError::InvalidImageStore(
ExpressionError::InvalidImageArrayIndex,
));
}
if let Some(expr) = array_index {
match *context.resolve_type(expr)? {
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
} => {}
_ => {
return Err(FunctionError::InvalidImageStore(
ExpressionError::InvalidImageArrayIndexType(expr),
))
}
}
}
match class {
crate::ImageClass::Storage(format) => crate::TypeInner::Vector {
kind: format.into(),
size: crate::VectorSize::Quad,
width: 4,
},
_ => {
return Err(FunctionError::InvalidImageStore(
ExpressionError::InvalidImageClass(class),
))
}
}
}
_ => {
return Err(FunctionError::InvalidImageStore(
ExpressionError::ExpectedImageType(var.ty),
))
}
};

if *context.resolve_type(value)? != value_ty {
return Err(FunctionError::InvalidStoreValue(value));
}
}
S::Call {
Expand Down Expand Up @@ -453,7 +508,7 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, FunctionError> {
let info = mod_info.process_function(fun, module, self.flags)?;
let mut info = mod_info.process_function(fun, module, self.flags)?;

for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, &module.types, &module.constants)
Expand Down Expand Up @@ -482,8 +537,16 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if !self.flags.contains(ValidationFlags::EXPRESSIONS) {
if let Err(error) = self.validate_expression(handle, expr, fun, module, &info) {
return Err(FunctionError::Expression { handle, error });
match self.validate_expression(
handle,
expr,
fun,
module,
&info,
&mod_info.functions,
) {
Ok(stages) => info.available_stages &= stages,
Err(error) => return Err(FunctionError::Expression { handle, error }),
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion src/valid/interface.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
analyzer::{FunctionInfo, GlobalUse},
Disalignment, FunctionError, ModuleInfo, TypeFlags,
Disalignment, FunctionError, ModuleInfo, ShaderStages, TypeFlags,
};
use crate::arena::{Arena, Handle};

Expand Down Expand Up @@ -56,6 +56,8 @@ pub enum EntryPointError {
UnexpectedWorkgroupSize,
#[error("Workgroup size is out of range")]
OutOfRangeWorkgroupSize,
#[error("Uses operations forbidden at this stage")]
ForbiddenStageOperations,
#[error("Global variable {0:?} is used incorrectly as {1:?}")]
InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
#[error("Bindings for {0:?} conflict with other resource")]
Expand Down Expand Up @@ -370,8 +372,18 @@ impl super::Validator {
return Err(EntryPointError::UnexpectedWorkgroupSize);
}

let stage_bit = match ep.stage {
crate::ShaderStage::Vertex => ShaderStages::VERTEX,
crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
crate::ShaderStage::Compute => ShaderStages::COMPUTE,
};

let info = self.validate_function(&ep.function, module, &mod_info)?;

if !info.available_stages.contains(stage_bit) {
return Err(EntryPointError::ForbiddenStageOperations);
}

self.location_mask.clear();
for (index, fa) in ep.function.arguments.iter().enumerate() {
let ctx = VaryingContext {
Expand Down
31 changes: 31 additions & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ bitflags::bitflags! {
}
}

bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ShaderStages: u8 {
const VERTEX = 0x1;
const FRAGMENT = 0x2;
const COMPUTE = 0x4;
}
}

#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
Expand Down Expand Up @@ -125,6 +136,26 @@ impl crate::TypeInner {
Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false,
}
}

fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
match *self {
Self::Scalar {
kind: crate::ScalarKind::Sint,
..
} => Some(crate::ImageDimension::D1),
Self::Vector {
size: crate::VectorSize::Bi,
kind: crate::ScalarKind::Sint,
..
} => Some(crate::ImageDimension::D2),
Self::Vector {
size: crate::VectorSize::Tri,
kind: crate::ScalarKind::Sint,
..
} => Some(crate::ImageDimension::D3),
_ => None,
}
}
}

impl Validator {
Expand Down
6 changes: 6 additions & 0 deletions tests/out/collatz.info.ron.snap
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(5),
requirements: (
Expand Down Expand Up @@ -350,6 +353,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(5),
requirements: (
Expand Down
9 changes: 9 additions & 0 deletions tests/out/shadow.info.ron.snap
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (
Expand Down Expand Up @@ -1006,6 +1009,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (
Expand Down Expand Up @@ -2634,6 +2640,9 @@ expression: output
flags: (
bits: 7,
),
available_stages: (
bits: 7,
),
uniformity: (
non_uniform_result: Some(44),
requirements: (
Expand Down