Skip to content

Commit

Permalink
[naga msl-out] Avoid UB by making all loops bounded.
Browse files Browse the repository at this point in the history
In MSL output, avoid undefined behavior due to unbounded loops by
adding an unpredictable, never-actually-taken `break` to the bottom of
each loop body, rather than adding an unpredictable,
never-actually-taken branch over each loop.

This will probably have more of a performance impact, because it
affects each iteration of the loop, but unlike branching over the
loop, which leaves infinite loops (and thus undefined behavior) in the
output, this actually ensures that no loop presented to Metal is
unbounded, so that there is no undefined behavior present that the
optimizer could use to make unwelcome inferences.

Fixes gfx-rs#6528.
  • Loading branch information
jimblandy committed Nov 18, 2024
1 parent e59f003 commit 0b82776
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 79 deletions.
96 changes: 47 additions & 49 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,10 @@ pub struct Writer<W> {
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the loop reachability macro.
/// Name of the force-bounded-loop macro.
///
/// See `emit_loop_reachable_macro` for details.
loop_reachable_macro_name: String,
/// See `emit_force_bounded_loop_macro` for details.
force_bounded_loop_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -682,7 +682,7 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
loop_reachable_macro_name: String::default(),
force_bounded_loop_macro_name: String::default(),
}
}

Expand All @@ -693,12 +693,13 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke before loops, to defeat MSL infinite loop
/// reasoning.
/// Define a macro to invoke at the bottom of each loop body, to
/// defeat MSL infinite loop reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked before each loop in the generated MSL, to ensure
/// that the MSL compiler's optimizations do not remove bounds checks.
/// macro to be invoked at the end of each loop body in the generated MSL,
/// to ensure that the MSL compiler's optimizations do not remove bounds
/// checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
Expand Down Expand Up @@ -764,52 +765,51 @@ impl<W: Write> Writer<W> {
/// nicely, after having stolen data from elsewhere in the GPU address
/// space.
///
/// Ideally, Naga would prevent UB entirely via some means that persuades
/// the MSL compiler that no loop Naga generates is infinite. One approach
/// would be to add inline assembly to each loop that is annotated as
/// potentially branching out of the loop, but which in fact generates no
/// instructions. Unfortunately, inline assembly is not handled correctly by
/// some Metal device drivers. Further experimentation hasn't produced a
/// satisfactory approach.
/// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
/// generates is infinite. One approach would be to add inline assembly to
/// each loop that is annotated as potentially branching out of the loop,
/// but which in fact generates no instructions. Unfortunately, inline
/// assembly is not handled correctly by some Metal device drivers.
///
/// Instead, we accept that the MSL compiler may determine that some loops
/// are infinite, and focus instead on preventing the range analysis from
/// being affected. We transform *every* loop into something like this:
/// Instead, we add the following code to the bottom of every loop:
///
/// ```ignore
/// if (volatile bool unpredictable = true; unpredictable)
/// while (true) { }
/// if (volatile bool unpredictable = false; unpredictable)
/// break;
/// ```
///
/// Since the `volatile` qualifier prevents the compiler from assuming that
/// the `if` condition is true, it cannot be sure the infinite loop is
/// reached, and thus it cannot assume the entire structure is unreachable.
/// This prevents the range analysis impact described above.
/// Although the `if` condition will always be false in any real execution,
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
/// it must assume that the `break` might be reached, and hence that the
/// loop is not unbounded. This prevents the range analysis impact described
/// above.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, before each loop. There's no part of the system that has a
/// global enough view to be sure that `unpredictable` is true, and remove
/// it from the code.
/// runtime, in every iteration of the loop. There's no part of the system
/// that has a global enough view to be sure that `unpredictable` is true,
/// and remove it from the code. Adding the branch also affects
/// optimization: for example, it's impossible to unroll this loop. This
/// transformation has been observed to significantly hurt performance.
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
///
/// This approach is also used by Chromium WebGPU's Dawn shader compiler, as of
/// <https://github.com/google/dawn/commit/ffd485c685040edb1e678165dcbf0e841cfa0298>.
fn emit_loop_reachable_macro(&mut self) -> BackendResult {
if !self.loop_reachable_macro_name.is_empty() {
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
if !self.force_bounded_loop_macro_name.is_empty() {
return Ok(());
}

self.loop_reachable_macro_name = self.namer.call("LOOP_IS_REACHABLE");
let loop_reachable_volatile_name = self.namer.call("unpredictable_jump_over_loop");
self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
writeln!(
self.out,
"#define {} if (volatile bool {} = true; {})",
self.loop_reachable_macro_name,
loop_reachable_volatile_name,
loop_reachable_volatile_name,
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
self.force_bounded_loop_macro_name,
loop_bounded_volatile_name,
loop_bounded_volatile_name,
)?;

Ok(())
Expand Down Expand Up @@ -3045,15 +3045,10 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
self.emit_loop_reachable_macro()?;
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
writeln!(self.out, "{level}while(true) {{",)?;
let lif = level.next();
let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
Expand All @@ -3068,13 +3063,16 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?;
} else {
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
writeln!(self.out, "{level}while(true) {{",)?;
}
self.put_block(level.next(), body, context)?;
self.emit_force_bounded_loop_macro()?;
writeln!(
self.out,
"{}{}",
level.next(),
self.force_bounded_loop_macro_name
)?;
writeln!(self.out, "{level}}}")?;
}
crate::Statement::Break => {
Expand Down Expand Up @@ -3553,7 +3551,7 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.loop_reachable_macro_name.clear();
self.force_bounded_loop_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
14 changes: 9 additions & 5 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ kernel void test_atomic_compare_exchange_i32_(
uint i = 0u;
int old = {};
bool exchanged = {};
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
Expand All @@ -94,7 +93,7 @@ kernel void test_atomic_compare_exchange_i32_(
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
LOOP_IS_REACHABLE while(true) {
while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
Expand All @@ -109,8 +108,11 @@ kernel void test_atomic_compare_exchange_i32_(
old = _e23.old_value;
exchanged = _e23.exchanged;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
}
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -123,7 +125,7 @@ kernel void test_atomic_compare_exchange_u32_(
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
Expand All @@ -139,7 +141,7 @@ kernel void test_atomic_compare_exchange_u32_(
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
LOOP_IS_REACHABLE while(true) {
while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
Expand All @@ -154,8 +156,10 @@ kernel void test_atomic_compare_exchange_u32_(
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
LOOP_IS_BOUNDED
}
}
LOOP_IS_BOUNDED
}
return;
}
5 changes: 3 additions & 2 deletions naga/tests/out/msl/boids.msl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ kernel void main_(
vPos = _e8;
metal::float2 _e14 = particlesSrc.particles[index].vel;
vVel = _e14;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
uint _e91 = i;
i = _e91 + 1u;
Expand Down Expand Up @@ -106,6 +105,8 @@ kernel void main_(
int _e88 = cVelCount;
cVelCount = _e88 + 1;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
int _e94 = cMassCount;
if (_e94 > 0) {
Expand Down
14 changes: 9 additions & 5 deletions naga/tests/out/msl/break-if.msl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ using metal::uint;

void breakIfEmpty(
) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
if (true) {
break;
}
}
loop_init = false;
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -26,7 +27,7 @@ void breakIfEmptyBody(
bool b = {};
bool c = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_1) {
b = a;
bool _e2 = b;
Expand All @@ -37,6 +38,7 @@ void breakIfEmptyBody(
}
}
loop_init_1 = false;
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -47,7 +49,7 @@ void breakIf(
bool d = {};
bool e = {};
bool loop_init_2 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_2) {
bool _e5 = e;
if (a_1 == e) {
Expand All @@ -58,6 +60,7 @@ void breakIf(
d = a_1;
bool _e2 = d;
e = a_1 != _e2;
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -66,7 +69,7 @@ void breakIfSeparateVariable(
) {
uint counter = 0u;
bool loop_init_3 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_3) {
uint _e5 = counter;
if (counter == 5u) {
Expand All @@ -76,6 +79,7 @@ void breakIfSeparateVariable(
loop_init_3 = false;
uint _e3 = counter;
counter = _e3 + 1u;
LOOP_IS_BOUNDED
}
return;
}
Expand Down
5 changes: 3 additions & 2 deletions naga/tests/out/msl/collatz.msl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ uint collatz_iterations(
uint n = {};
uint i = 0u;
n = n_base;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
while(true) {
uint _e4 = n;
if (_e4 > 1u) {
} else {
Expand All @@ -38,6 +37,8 @@ uint collatz_iterations(
uint _e20 = i;
i = _e20 + 1u;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
uint _e23 = i;
return _e23;
Expand Down
Loading

0 comments on commit 0b82776

Please sign in to comment.