Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl. const. eval. for yet more trivial component-wise numeric built-ins in WGSL #5098

Merged
merged 20 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ Bottom level categories:
- `step`
- `tan`
- `tanh`
- [#5098](https://github.com/gfx-rs/wgpu/pull/5098) by @ErichDonGubler:
- `ceil`
- `countLeadingZeros`
- `countOneBits`
- `countTrailingZeros`
- `degrees`
- `exp`
- `exp2`
- `floor`
- `fract`
- `fma`
- `inverseSqrt`
- `log`
- `log2`
- `max`
- `min`
- `radians`
- `reverseBits`
- `sign`
- `trunc`
- Eager release of GPU resources comes from device.trackers. By @bradwerth in [#5075](https://github.com/gfx-rs/wgpu/pull/5075)


Expand Down
193 changes: 161 additions & 32 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,34 @@ gen_component_wise_extractor! {
],
}

gen_component_wise_extractor! {
component_wise_concrete_int -> ConcreteInt,
literals: [
U32 => U32: u32,
I32 => I32: i32,
],
scalar_kinds: [
Sint,
Uint,
],
}

gen_component_wise_extractor! {
component_wise_signed -> Signed,
literals: [
AbstractFloat => AbstractFloat: f64,
AbstractInt => AbstractInt: i64,
F32 => F32: f32,
I32 => I32: i32,
],
scalar_kinds: [
Sint,
AbstractInt,
Float,
AbstractFloat,
],
}

#[derive(Debug)]
enum Behavior {
Wgsl,
Expand Down Expand Up @@ -809,7 +837,9 @@ impl<'a> ConstantEvaluator<'a> {
));
}

// NOTE: We try to match the declaration order of `MathFunction` here.
match fun {
// comparison
crate::MathFunction::Abs => {
component_wise_scalar(self, span, [arg], |args| match args {
Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
Expand All @@ -819,27 +849,14 @@ impl<'a> ConstantEvaluator<'a> {
Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
})
}
crate::MathFunction::Acos => {
component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
}
crate::MathFunction::Acosh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
}
crate::MathFunction::Asin => {
component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
}
crate::MathFunction::Asinh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
}
crate::MathFunction::Atan => {
component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
}
crate::MathFunction::Atanh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
crate::MathFunction::Min => {
component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
Ok([e1.min(e2)])
})
}
crate::MathFunction::Pow => {
component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
Ok([e1.powf(e2)])
crate::MathFunction::Max => {
component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
Ok([e1.max(e2)])
})
}
crate::MathFunction::Clamp => {
Expand All @@ -856,12 +873,61 @@ impl<'a> ConstantEvaluator<'a> {
}
)
}
crate::MathFunction::Saturate => {
component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
}

// trigonometry
crate::MathFunction::Cos => {
component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
}
crate::MathFunction::Cosh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
}
crate::MathFunction::Sin => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
}
crate::MathFunction::Sinh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
}
crate::MathFunction::Tan => {
component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
}
crate::MathFunction::Tanh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
}
crate::MathFunction::Acos => {
component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
}
crate::MathFunction::Asin => {
component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
}
crate::MathFunction::Atan => {
component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
}
crate::MathFunction::Asinh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
}
crate::MathFunction::Acosh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
}
crate::MathFunction::Atanh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
}
crate::MathFunction::Radians => {
component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
}
crate::MathFunction::Degrees => {
component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
}

// decomposition
crate::MathFunction::Ceil => {
component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
}
crate::MathFunction::Floor => {
component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
}
crate::MathFunction::Round => {
// TODO: Use `f{32,64}.round_ties_even()` when available on stable. This polyfill
// is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
Expand All @@ -888,29 +954,92 @@ impl<'a> ConstantEvaluator<'a> {
Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
})
}
crate::MathFunction::Saturate => {
component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
crate::MathFunction::Fract => {
component_wise_float!(self, span, [arg], |e| {
// N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
// here.
Ok([e - e.floor()])
})
}
crate::MathFunction::Sin => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
crate::MathFunction::Trunc => {
component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
}
crate::MathFunction::Sinh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })

// exponent
crate::MathFunction::Exp => {
component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
}
crate::MathFunction::Tan => {
component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
crate::MathFunction::Exp2 => {
component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
}
crate::MathFunction::Tanh => {
component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
crate::MathFunction::Log => {
component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
}
crate::MathFunction::Sqrt => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
crate::MathFunction::Log2 => {
component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
}
crate::MathFunction::Pow => {
component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
Ok([e1.powf(e2)])
})
}

// computational
crate::MathFunction::Sign => {
component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
}
crate::MathFunction::Fma => {
component_wise_float!(
self,
span,
[arg, arg1.unwrap(), arg2.unwrap()],
|e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
)
}
crate::MathFunction::Step => {
component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
Ok([if edge <= x { 1.0 } else { 0.0 }])
})
}
crate::MathFunction::Sqrt => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
}
crate::MathFunction::InverseSqrt => {
component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
}

// bits
crate::MathFunction::CountTrailingZeros => {
component_wise_concrete_int!(self, span, [arg], |e| {
#[allow(clippy::useless_conversion)]
Ok([e
.trailing_zeros()
.try_into()
.expect("bit count overflowed 32 bits, somehow!?")])
})
}
crate::MathFunction::CountLeadingZeros => {
component_wise_concrete_int!(self, span, [arg], |e| {
#[allow(clippy::useless_conversion)]
Ok([e
.leading_zeros()
.try_into()
.expect("bit count overflowed 32 bits, somehow!?")])
})
}
crate::MathFunction::CountOneBits => {
component_wise_concrete_int!(self, span, [arg], |e| {
#[allow(clippy::useless_conversion)]
Ok([e
.count_ones()
.try_into()
.expect("bit count overflowed 32 bits, somehow!?")])
})
}
crate::MathFunction::ReverseBits => {
component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
}

fun => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
))),
Expand Down
25 changes: 8 additions & 17 deletions naga/tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ void main() {
vec4 d = radians(v);
vec4 e = clamp(v, vec4(0.0), vec4(1.0));
vec4 g = refract(v, v, 1.0);
int sign_a = sign(-1);
ivec4 sign_b = sign(ivec4(-1));
float sign_c = sign(-1.0);
vec4 sign_d = sign(vec4(-1.0));
ivec4 sign_b = ivec4(-1, -1, -1, -1);
vec4 sign_d = vec4(-1.0, -1.0, -1.0, -1.0);
int const_dot = ( + ivec2(0).x * ivec2(0).x + ivec2(0).y * ivec2(0).y);
uint first_leading_bit_abs = uint(findMSB(0u));
int flb_a = findMSB(-1);
Expand All @@ -75,19 +73,12 @@ void main() {
uint ftb_b = uint(findLSB(1u));
ivec2 ftb_c = findLSB(ivec2(-1));
uvec2 ftb_d = uvec2(findLSB(uvec2(1u)));
uint ctz_a = min(uint(findLSB(0u)), 32u);
int ctz_b = int(min(uint(findLSB(0)), 32u));
uint ctz_c = min(uint(findLSB(4294967295u)), 32u);
int ctz_d = int(min(uint(findLSB(-1)), 32u));
uvec2 ctz_e = min(uvec2(findLSB(uvec2(0u))), uvec2(32u));
ivec2 ctz_f = ivec2(min(uvec2(findLSB(ivec2(0))), uvec2(32u)));
uvec2 ctz_g = min(uvec2(findLSB(uvec2(1u))), uvec2(32u));
ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u)));
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e67 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e67), ivec2(0), lessThan(_e67, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
uvec2 ctz_e = uvec2(32u, 32u);
ivec2 ctz_f = ivec2(32, 32);
uvec2 ctz_g = uvec2(0u, 0u);
ivec2 ctz_h = ivec2(0, 0);
ivec2 clz_c = ivec2(0, 0);
uvec2 clz_d = uvec2(31u, 31u);
float lde_a = ldexp(1.0, 2);
vec2 lde_b = ldexp(vec2(1.0, 2.0), ivec2(3, 4));
_modf_result_f32_ modf_a = naga_modf(1.5);
Expand Down
3 changes: 1 addition & 2 deletions naga/tests/out/glsl/variations.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ uniform highp samplerCube _group_0_binding_0_fs;

void main_1() {
ivec2 sizeCube = ivec2(0);
float a = 0.0;
float a = 1.0;
sizeCube = ivec2(uvec2(textureSize(_group_0_binding_0_fs, 0).xy));
a = ceil(1.0);
return;
}

Expand Down
25 changes: 8 additions & 17 deletions naga/tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ void main()
float4 d = radians(v);
float4 e = saturate(v);
float4 g = refract(v, v, 1.0);
int sign_a = sign(-1);
int4 sign_b = sign((-1).xxxx);
float sign_c = sign(-1.0);
float4 sign_d = sign((-1.0).xxxx);
int4 sign_b = int4(-1, -1, -1, -1);
float4 sign_d = float4(-1.0, -1.0, -1.0, -1.0);
int const_dot = dot((int2)0, (int2)0);
uint first_leading_bit_abs = firstbithigh(0u);
int flb_a = asint(firstbithigh(-1));
Expand All @@ -85,19 +83,12 @@ void main()
uint ftb_b = firstbitlow(1u);
int2 ftb_c = asint(firstbitlow((-1).xx));
uint2 ftb_d = firstbitlow((1u).xx);
uint ctz_a = min(32u, firstbitlow(0u));
int ctz_b = asint(min(32u, firstbitlow(0)));
uint ctz_c = min(32u, firstbitlow(4294967295u));
int ctz_d = asint(min(32u, firstbitlow(-1)));
uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx));
int2 ctz_f = asint(min((32u).xx, firstbitlow((0).xx)));
uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx));
int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx)));
int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1)));
uint clz_b = (31u - firstbithigh(1u));
int2 _expr67 = (-1).xx;
int2 clz_c = (_expr67 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr67)));
uint2 clz_d = ((31u).xx - firstbithigh((1u).xx));
uint2 ctz_e = uint2(32u, 32u);
int2 ctz_f = int2(32, 32);
uint2 ctz_g = uint2(0u, 0u);
int2 ctz_h = int2(0, 0);
int2 clz_c = int2(0, 0);
uint2 clz_d = uint2(31u, 31u);
float lde_a = ldexp(1.0, 2);
float2 lde_b = ldexp(float2(1.0, 2.0), int2(3, 4));
_modf_result_f32_ modf_a = naga_modf(1.5);
Expand Down
Loading