Skip to content

Commit

Permalink
toy sin and dot working
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Feb 4, 2025
1 parent 6f2345c commit c504147
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 120 deletions.
204 changes: 144 additions & 60 deletions naga/src/common/wgsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,146 @@ impl StandardFilterableTriggeringRule {
}
}

impl crate::MathFunction {
pub fn to_wgsl(self) -> &'static str {
use crate::MathFunction as Mf;

match self {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
Mf::Saturate => "saturate",
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
Mf::Asinh => "asinh",
Mf::Acosh => "acosh",
Mf::Atanh => "atanh",
Mf::Radians => "radians",
Mf::Degrees => "degrees",
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
Mf::Dot => "dot",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
Mf::Refract => "refract",
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "quantizeToF16",
Mf::CountTrailingZeros => "countTrailingZeros",
Mf::CountLeadingZeros => "countLeadingZeros",
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
Mf::ExtractBits => "extractBits",
Mf::InsertBits => "insertBits",
Mf::FirstTrailingBit => "firstTrailingBit",
Mf::FirstLeadingBit => "firstLeadingBit",
Mf::Pack4x8snorm => "pack4x8snorm",
Mf::Pack4x8unorm => "pack4x8unorm",
Mf::Pack2x16snorm => "pack2x16snorm",
Mf::Pack2x16unorm => "pack2x16unorm",
Mf::Pack2x16float => "pack2x16float",
Mf::Pack4xI8 => "pack4xI8",
Mf::Pack4xU8 => "pack4xU8",
Mf::Unpack4x8snorm => "unpack4x8snorm",
Mf::Unpack4x8unorm => "unpack4x8unorm",
Mf::Unpack2x16snorm => "unpack2x16snorm",
Mf::Unpack2x16unorm => "unpack2x16unorm",
Mf::Unpack2x16float => "unpack2x16float",
Mf::Unpack4xI8 => "unpack4xI8",
Mf::Unpack4xU8 => "unpack4xU8",
Mf::Inverse => "{matrix inverse}",
Mf::Outer => "{vector outer product}",
}
}
}

impl crate::BuiltIn {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::BuiltIn::Position { invariant: true } => "@position @invariant",
crate::BuiltIn::Position { invariant: false } => "@position",
crate::BuiltIn::ViewIndex => "view_index",
crate::BuiltIn::BaseInstance => "{BaseInstance}",
crate::BuiltIn::BaseVertex => "{BaseVertex}",
crate::BuiltIn::ClipDistance => "{ClipDistance}",
crate::BuiltIn::CullDistance => "{CullDistance}",
crate::BuiltIn::InstanceIndex => "instance_index",
crate::BuiltIn::PointSize => "{PointSize}",
crate::BuiltIn::VertexIndex => "vertex_index",
crate::BuiltIn::DrawID => "{DrawId}",
crate::BuiltIn::FragDepth => "frag_depth",
crate::BuiltIn::PointCoord => "{PointCoord}",
crate::BuiltIn::FrontFacing => "front_facing",
crate::BuiltIn::PrimitiveIndex => "primitive_index",
crate::BuiltIn::SampleIndex => "sample_index",
crate::BuiltIn::SampleMask => "sample_mask",
crate::BuiltIn::GlobalInvocationId => "global_invocation_id",
crate::BuiltIn::LocalInvocationId => "local_invocation_id",
crate::BuiltIn::LocalInvocationIndex => "local_invocation_index",
crate::BuiltIn::WorkGroupId => "workgroup_id",
crate::BuiltIn::WorkGroupSize => "{WorkGroupSize}",
crate::BuiltIn::NumWorkGroups => "num_workgroups",
crate::BuiltIn::NumSubgroups => "num_subgroups",
crate::BuiltIn::SubgroupId => "{SubgroupId}",
crate::BuiltIn::SubgroupSize => "subgroup_size",
crate::BuiltIn::SubgroupInvocationId => "subgroup_invocation_id",
}
}
}

impl crate::Interpolation {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::Interpolation::Perspective => "perspective",
crate::Interpolation::Linear => "linear",
crate::Interpolation::Flat => "flat",
}
}
}

impl crate::Sampling {
pub fn to_wgsl(self) -> &'static str {
match self {
crate::Sampling::Center => "center",
crate::Sampling::Centroid => "centroid",
crate::Sampling::Sample => "sample",
crate::Sampling::First => "first",
crate::Sampling::Either => "either",
}
}
}

pub struct Wgslish<T>(pub T);

impl Display for Wgslish<&crate::TypeInner> {
Expand Down Expand Up @@ -191,7 +331,7 @@ impl Display for Wgslish<crate::AddressSpace> {
crate::AddressSpace::WorkGroup => "workgroup",
crate::AddressSpace::Uniform => "uniform",
crate::AddressSpace::Storage { access } => {
return write!(f, "{access:?}");
return write!(f, "storage, {access:?}");
}
crate::AddressSpace::Handle => "handle",
crate::AddressSpace::PushConstant => "push_constant",
Expand All @@ -203,7 +343,7 @@ impl Display for Wgslish<crate::AddressSpace> {
impl Display for Wgslish<&crate::Binding> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self.0 {
crate::Binding::BuiltIn(built_in) => Wgslish(built_in).fmt(f),
crate::Binding::BuiltIn(built_in) => f.write_str(built_in.to_wgsl()),
crate::Binding::Location {
location,
second_blend_source,
Expand All @@ -215,69 +355,13 @@ impl Display for Wgslish<&crate::Binding> {
f.write_str(" @second_blend_source")?;
}
if let Some(interpolation) = interpolation {
write!(f, " {}", Wgslish(interpolation))?;
write!(f, " {}", interpolation.to_wgsl())?;
}
if let Some(sampling) = sampling {
write!(f, " {}", Wgslish(sampling))?;
write!(f, " {}", sampling.to_wgsl())?;
}
Ok(())
}
}
}
}

impl Display for Wgslish<crate::BuiltIn> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(match self.0 {
crate::BuiltIn::Position { invariant: true } => "@position @invariant",
crate::BuiltIn::Position { invariant: false } => "@position",
crate::BuiltIn::ViewIndex => "view_index",
crate::BuiltIn::BaseInstance => "{BaseInstance}",
crate::BuiltIn::BaseVertex => "{BaseVertex}",
crate::BuiltIn::ClipDistance => "{ClipDistance}",
crate::BuiltIn::CullDistance => "{CullDistance}",
crate::BuiltIn::InstanceIndex => "instance_index",
crate::BuiltIn::PointSize => "{PointSize}",
crate::BuiltIn::VertexIndex => "vertex_index",
crate::BuiltIn::DrawID => "{DrawId}",
crate::BuiltIn::FragDepth => "frag_depth",
crate::BuiltIn::PointCoord => "{PointCoord}",
crate::BuiltIn::FrontFacing => "front_facing",
crate::BuiltIn::PrimitiveIndex => "primitive_index",
crate::BuiltIn::SampleIndex => "sample_index",
crate::BuiltIn::SampleMask => "sample_mask",
crate::BuiltIn::GlobalInvocationId => "global_invocation_id",
crate::BuiltIn::LocalInvocationId => "local_invocation_id",
crate::BuiltIn::LocalInvocationIndex => "local_invocation_index",
crate::BuiltIn::WorkGroupId => "workgroup_id",
crate::BuiltIn::WorkGroupSize => "{WorkGroupSize}",
crate::BuiltIn::NumWorkGroups => "num_workgroups",
crate::BuiltIn::NumSubgroups => "num_subgroups",
crate::BuiltIn::SubgroupId => "{SubgroupId}",
crate::BuiltIn::SubgroupSize => "subgroup_size",
crate::BuiltIn::SubgroupInvocationId => "subgroup_invocation_id",
})
}
}

impl Display for Wgslish<crate::Interpolation> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(match self.0 {
crate::Interpolation::Perspective => "perspective",
crate::Interpolation::Linear => "linear",
crate::Interpolation::Flat => "flat",
})
}
}

impl Display for Wgslish<crate::Sampling> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(match self.0 {
crate::Sampling::Center => "center",
crate::Sampling::Centroid => "centroid",
crate::Sampling::Sample => "sample",
crate::Sampling::First => "first",
crate::Sampling::Either => "either",
})
}
}
53 changes: 53 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,18 @@ pub(crate) enum Error<'a> {
expected: Range<u32>,
found: u32,
},
WrongArgumentType {
function: String,
call_span: Span,
arg_span: Span,
arg_index: u32,
found: String,
allowed: Vec<String>,
},
AmbiguousCall {
call_span: Span,
alternatives: Vec<String>,
},
FunctionReturnsVoid(Span),
FunctionMustUseUnused(Span),
FunctionMustUseReturnsVoid(Span, Span),
Expand Down Expand Up @@ -809,6 +821,47 @@ impl<'a> Error<'a> {
labels: vec![(span, "wrong number of arguments".into())],
notes: vec![],
},
Error::WrongArgumentType {
ref function,
call_span,
arg_span,
arg_index,
ref found,
ref allowed,
} => {
let message = format!(
"This call to `{function}` cannot accept a value of type `{found}` for argument #{}",
arg_index + 1,
);
let labels = vec![
(call_span, "The arguments to this function call have incorrect types".into()),
(arg_span, format!(
"This argument has type `{found}`",
).into())
];

let mut notes = vec![];
if arg_index > 0 {
notes.push("Given the types of the preceding arguments,".into());
notes.push(format!("the following types are allowed for argument #{}:", arg_index + 1));
} else {
notes.push("The following types are allowed for the first argument:".to_string());
};
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));

ParseError { message, labels, notes }
},
Error::AmbiguousCall { call_span, ref alternatives } => {
let message = "Function call is ambiguous: more than one overload could apply".into();
let labels = vec![
(call_span, "More than one overload of this function could apply to these arguments".into()),
];
let mut notes = vec![
"All of the following overloads could apply, but no one overload is clearly preferable:".into()
];
notes.extend(alternatives.iter().map(|alt| format!("possible overload: {alt}")));
ParseError { message, labels, notes }
},
Error::FunctionReturnsVoid(span) => ParseError {
message: "function does not return any value".to_string(),
labels: vec![(span, "".into())],
Expand Down
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
self.convert_leaf_scalar(expr, expr_span, goal_scalar)
}

/// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions.
/// Try to convert `expr`'s leaf scalar to `goal_scalar` using automatic conversions.
///
/// If no conversions are necessary, return `expr` unchanged.
///
Expand Down
Loading

0 comments on commit c504147

Please sign in to comment.