diff --git a/crates/wasmi/Cargo.toml b/crates/wasmi/Cargo.toml index 76deaa9bc8..9a7f656095 100644 --- a/crates/wasmi/Cargo.toml +++ b/crates/wasmi/Cargo.toml @@ -23,6 +23,8 @@ spin = { version = "0.9", default-features = false, features = [ ] } smallvec = { version = "1.10.0", features = ["union"] } multi-stash = { version = "0.2.0" } +num-traits = { version = "0.2", default-features = false } +num-derive = "0.4" [dev-dependencies] wat = "1" @@ -33,7 +35,7 @@ criterion = { version = "0.5", default-features = false } [features] default = ["std"] -std = ["wasmi_core/std", "wasmi_arena/std", "wasmparser/std", "spin/std"] +std = ["wasmi_core/std", "wasmi_arena/std", "wasmparser/std", "spin/std", "num-traits/std"] [[bench]] name = "benches" diff --git a/crates/wasmi/src/engine/bytecode/construct.rs b/crates/wasmi/src/engine/bytecode/construct.rs index 05e6334af5..4867c61017 100644 --- a/crates/wasmi/src/engine/bytecode/construct.rs +++ b/crates/wasmi/src/engine/bytecode/construct.rs @@ -167,6 +167,13 @@ macro_rules! constructor_for { }; } +impl Instruction { + /// Creates a new [`Instruction::BranchCmpFallback`]. + pub fn branch_cmp_fallback(lhs: Register, rhs: Register, params: Register) -> Self { + Self::BranchCmpFallback { lhs, rhs, params } + } +} + macro_rules! constructor_for_branch_binop { ( $( fn $name:ident() -> Self::$op_code:ident; )* ) => { impl Instruction { diff --git a/crates/wasmi/src/engine/bytecode/mod.rs b/crates/wasmi/src/engine/bytecode/mod.rs index 5413aa526e..a9f8e279e7 100644 --- a/crates/wasmi/src/engine/bytecode/mod.rs +++ b/crates/wasmi/src/engine/bytecode/mod.rs @@ -17,9 +17,11 @@ pub(crate) use self::{ BranchBinOpInstr, BranchBinOpInstrImm, BranchBinOpInstrImm16, + BranchComparator, BranchOffset, BranchOffset16, CallIndirectParams, + ComparatorOffsetParam, DataSegmentIdx, ElementSegmentIdx, FuncIdx, @@ -364,6 +366,32 @@ pub enum Instruction { offset: BranchOffset, }, + /// A fallback instruction for cmp+branch instructions with branch offsets that cannot be 16-bit encoded. + /// + /// # Note + /// + /// This instruction fits in a single instruction word but arguably executes slower than + /// cmp+branch instructions with a 16-bit encoded branch offset. It only ever gets encoded + /// and used whenever a branch offset of a cmp+branch instruction cannot be 16-bit encoded. + BranchCmpFallback { + /// The left-hand side value for the comparison. + lhs: Register, + /// The right-hand side value for the comparison. + /// + /// # Note + /// + /// We allocate constant values as function local constant values and use + /// their register to only require a single fallback instruction variant. + rhs: Register, + /// The register that stores the [`ComparatorOffsetParam`] of this instruction. + /// + /// # Note + /// + /// The [`ComparatorOffsetParam`] is loaded from register as `u64` value and + /// decoded into a [`ComparatorOffsetParam`] before access its comparator + /// and 32-bit branch offset fields. + params: Register, + }, /// A fused [`Instruction::I32And`] and Wasm branch instruction. BranchI32And(BranchBinOpInstr), /// A fused [`Instruction::I32And`] and Wasm branch instruction. diff --git a/crates/wasmi/src/engine/bytecode/utils.rs b/crates/wasmi/src/engine/bytecode/utils.rs index 7d9852c5aa..dfaaba9b9b 100644 --- a/crates/wasmi/src/engine/bytecode/utils.rs +++ b/crates/wasmi/src/engine/bytecode/utils.rs @@ -3,6 +3,8 @@ use crate::{ engine::{Instr, TranslationError}, Error, }; +use num_derive::FromPrimitive; +use wasmi_core::UntypedValue; #[cfg(doc)] use super::Instruction; @@ -524,12 +526,6 @@ impl From for BranchOffset { } impl BranchOffset16 { - /// Creates a 16-bit [`BranchOffset16`] from a 32-bit [`BranchOffset`] if possible. - pub fn new(offset: BranchOffset) -> Option { - let offset16 = i16::try_from(offset.to_i32()).ok()?; - Some(Self(offset16)) - } - /// Returns `true` if the [`BranchOffset16`] has been initialized. pub fn is_init(self) -> bool { self.to_i16() != 0 @@ -541,12 +537,14 @@ impl BranchOffset16 { /// /// - If the [`BranchOffset`] have already been initialized. /// - If the given [`BranchOffset`] is not properly initialized. + /// + /// # Errors + /// + /// If `valid_offset` cannot be encoded as 16-bit [`BranchOffset16`]. pub fn init(&mut self, valid_offset: BranchOffset) -> Result<(), Error> { assert!(valid_offset.is_init()); assert!(!self.is_init()); - let Some(valid_offset16) = Self::new(valid_offset) else { - return Err(Error::from(TranslationError::BranchOffsetOutOfBounds)); - }; + let valid_offset16 = Self::try_from(valid_offset)?; *self = valid_offset16; Ok(()) } @@ -751,13 +749,17 @@ impl BranchOffset { /// /// If the resulting [`BranchOffset`] is uninitialized, aka equal to 0. pub fn from_src_to_dst(src: Instr, dst: Instr) -> Result { - fn make_err() -> Error { - Error::from(TranslationError::BranchOffsetOutOfBounds) - } let src = i64::from(src.into_u32()); let dst = i64::from(dst.into_u32()); - let offset = dst.checked_sub(src).ok_or_else(make_err)?; - let offset = i32::try_from(offset).map_err(|_| make_err())?; + let Some(offset) = dst.checked_sub(src) else { + // Note: This never needs to be called on backwards branches since they are immediated resolved. + unreachable!( + "offset for forward branches must have `src` be smaller than or equal to `dst`" + ); + }; + let Ok(offset) = i32::try_from(offset) else { + return Err(Error::from(TranslationError::BranchOffsetOutOfBounds)); + }; Ok(Self(offset)) } @@ -822,3 +824,103 @@ impl BlockFuel { u64::from(self.0) } } + +/// Encodes the conditional branch comparator. +#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)] +#[repr(u32)] +pub enum BranchComparator { + I32Eq = 0, + I32Ne = 1, + I32LtS = 2, + I32LtU = 3, + I32LeS = 4, + I32LeU = 5, + I32GtS = 6, + I32GtU = 7, + I32GeS = 8, + I32GeU = 9, + + I32And = 10, + I32Or = 11, + I32Xor = 12, + I32AndEqz = 13, + I32OrEqz = 14, + I32XorEqz = 15, + + I64Eq = 16, + I64Ne = 17, + I64LtS = 18, + I64LtU = 19, + I64LeS = 20, + I64LeU = 21, + I64GtS = 22, + I64GtU = 23, + I64GeS = 24, + I64GeU = 25, + + F32Eq = 26, + F32Ne = 27, + F32Lt = 28, + F32Le = 29, + F32Gt = 30, + F32Ge = 31, + + F64Eq = 32, + F64Ne = 33, + F64Lt = 34, + F64Le = 35, + F64Gt = 36, + F64Ge = 37, +} + +/// Encodes the conditional branch comparator and 32-bit offset of the [`Instruction::BranchCmpFallback`]. +/// +/// # Note +/// +/// This type can be converted from and to a `u64` value. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ComparatorOffsetParam { + /// Encodes the actual binary operator for the conditional branch. + pub cmp: BranchComparator, + //// Encodes the 32-bit branching offset. + pub offset: BranchOffset, +} + +impl ComparatorOffsetParam { + /// Create a new [`ComparatorOffsetParam`]. + pub fn new(cmp: BranchComparator, offset: BranchOffset) -> Self { + Self { cmp, offset } + } + + /// Creates a new [`ComparatorOffsetParam`] from the given `u64` value. + /// + /// Returns `None` if the `u64` has an invalid encoding. + pub fn from_u64(value: u64) -> Option { + use num_traits::FromPrimitive as _; + let hi = (value >> 32) as u32; + let lo = (value & 0xFFFF_FFFF) as u32; + let cmp = BranchComparator::from_u32(hi)?; + let offset = BranchOffset::from(lo as i32); + Some(Self { cmp, offset }) + } + + /// Creates a new [`ComparatorOffsetParam`] from the given [`UntypedValue`]. + /// + /// Returns `None` if the [`UntypedValue`] has an invalid encoding. + pub fn from_untyped(value: UntypedValue) -> Option { + Self::from_u64(u64::from(value)) + } + + /// Converts the [`ComparatorOffsetParam`] into an `u64` value. + pub fn as_u64(&self) -> u64 { + let hi = self.cmp as u64; + let lo = self.offset.to_i32() as u64; + hi << 32 & lo + } +} + +impl From for UntypedValue { + fn from(params: ComparatorOffsetParam) -> Self { + Self::from(params.as_u64()) + } +} diff --git a/crates/wasmi/src/engine/executor/instrs.rs b/crates/wasmi/src/engine/executor/instrs.rs index 15306f6933..9650b10c14 100644 --- a/crates/wasmi/src/engine/executor/instrs.rs +++ b/crates/wasmi/src/engine/executor/instrs.rs @@ -260,6 +260,9 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { Instr::BranchTable { index, len_targets } => { self.execute_branch_table(index, len_targets) } + Instr::BranchCmpFallback { lhs, rhs, params } => { + self.execute_branch_cmp_fallback(lhs, rhs, params) + } Instr::BranchI32And(instr) => self.execute_branch_i32_and(instr), Instr::BranchI32AndImm(instr) => self.execute_branch_i32_and_imm(instr), Instr::BranchI32Or(instr) => self.execute_branch_i32_or(instr), diff --git a/crates/wasmi/src/engine/executor/instrs/branch.rs b/crates/wasmi/src/engine/executor/instrs/branch.rs index 8a256f5d35..50cdf53d1b 100644 --- a/crates/wasmi/src/engine/executor/instrs/branch.rs +++ b/crates/wasmi/src/engine/executor/instrs/branch.rs @@ -2,8 +2,10 @@ use super::Executor; use crate::engine::bytecode::{ BranchBinOpInstr, BranchBinOpInstrImm16, + BranchComparator, BranchOffset, BranchOffset16, + ComparatorOffsetParam, Const16, Const32, Register, @@ -56,10 +58,23 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { where T: From, { - let lhs: T = self.get_register_as(instr.lhs); - let rhs: T = self.get_register_as(instr.rhs); + self.execute_branch_binop_raw::(instr.lhs, instr.rhs, instr.offset, f) + } + + /// Executes a generic fused compare and branch instruction with raw inputs. + fn execute_branch_binop_raw( + &mut self, + lhs: Register, + rhs: Register, + offset: impl Into, + f: fn(T, T) -> bool, + ) where + T: From, + { + let lhs: T = self.get_register_as(lhs); + let rhs: T = self.get_register_as(rhs); if f(lhs, rhs) { - return self.branch_to(instr.offset.into()); + return self.branch_to(offset.into()); } self.next_instr() } @@ -78,6 +93,72 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { } } +fn cmp_eq(a: T, b: T) -> bool +where + T: PartialEq, +{ + a == b +} + +fn cmp_ne(a: T, b: T) -> bool +where + T: PartialEq, +{ + a != b +} + +fn cmp_lt(a: T, b: T) -> bool +where + T: PartialOrd, +{ + a < b +} + +fn cmp_le(a: T, b: T) -> bool +where + T: PartialOrd, +{ + a <= b +} + +fn cmp_gt(a: T, b: T) -> bool +where + T: PartialOrd, +{ + a > b +} + +fn cmp_ge(a: T, b: T) -> bool +where + T: PartialOrd, +{ + a >= b +} + +fn cmp_i32_and(a: i32, b: i32) -> bool { + (a & b) != 0 +} + +fn cmp_i32_or(a: i32, b: i32) -> bool { + (a | b) != 0 +} + +fn cmp_i32_xor(a: i32, b: i32) -> bool { + (a ^ b) != 0 +} + +fn cmp_i32_and_eqz(a: i32, b: i32) -> bool { + !cmp_i32_and(a, b) +} + +fn cmp_i32_or_eqz(a: i32, b: i32) -> bool { + !cmp_i32_or(a, b) +} + +fn cmp_i32_xor_eqz(a: i32, b: i32) -> bool { + !cmp_i32_xor(a, b) +} + macro_rules! impl_execute_branch_binop { ( $( ($ty:ty, Instruction::$op_name:ident, $fn_name:ident, $op:expr) ),* $(,)? ) => { impl<'ctx, 'engine> Executor<'ctx, 'engine> { @@ -92,47 +173,47 @@ macro_rules! impl_execute_branch_binop { } } impl_execute_branch_binop! { - (i32, Instruction::BranchI32And, execute_branch_i32_and, |a, b| (a & b) != 0), - (i32, Instruction::BranchI32Or, execute_branch_i32_or, |a, b| (a | b) != 0), - (i32, Instruction::BranchI32Xor, execute_branch_i32_xor, |a, b| (a ^ b) != 0), - (i32, Instruction::BranchI32AndEqz, execute_branch_i32_and_eqz, |a, b| (a & b) == 0), - (i32, Instruction::BranchI32OrEqz, execute_branch_i32_or_eqz, |a, b| (a | b) == 0), - (i32, Instruction::BranchI32XorEqz, execute_branch_i32_xor_eqz, |a, b| (a ^ b) == 0), - (i32, Instruction::BranchI32Eq, execute_branch_i32_eq, |a, b| a == b), - (i32, Instruction::BranchI32Ne, execute_branch_i32_ne, |a, b| a != b), - (i32, Instruction::BranchI32LtS, execute_branch_i32_lt_s, |a, b| a < b), - (u32, Instruction::BranchI32LtU, execute_branch_i32_lt_u, |a, b| a < b), - (i32, Instruction::BranchI32LeS, execute_branch_i32_le_s, |a, b| a <= b), - (u32, Instruction::BranchI32LeU, execute_branch_i32_le_u, |a, b| a <= b), - (i32, Instruction::BranchI32GtS, execute_branch_i32_gt_s, |a, b| a > b), - (u32, Instruction::BranchI32GtU, execute_branch_i32_gt_u, |a, b| a > b), - (i32, Instruction::BranchI32GeS, execute_branch_i32_ge_s, |a, b| a >= b), - (u32, Instruction::BranchI32GeU, execute_branch_i32_ge_u, |a, b| a >= b), - - (i64, Instruction::BranchI64Eq, execute_branch_i64_eq, |a, b| a == b), - (i64, Instruction::BranchI64Ne, execute_branch_i64_ne, |a, b| a != b), - (i64, Instruction::BranchI64LtS, execute_branch_i64_lt_s, |a, b| a < b), - (u64, Instruction::BranchI64LtU, execute_branch_i64_lt_u, |a, b| a < b), - (i64, Instruction::BranchI64LeS, execute_branch_i64_le_s, |a, b| a <= b), - (u64, Instruction::BranchI64LeU, execute_branch_i64_le_u, |a, b| a <= b), - (i64, Instruction::BranchI64GtS, execute_branch_i64_gt_s, |a, b| a > b), - (u64, Instruction::BranchI64GtU, execute_branch_i64_gt_u, |a, b| a > b), - (i64, Instruction::BranchI64GeS, execute_branch_i64_ge_s, |a, b| a >= b), - (u64, Instruction::BranchI64GeU, execute_branch_i64_ge_u, |a, b| a >= b), - - (f32, Instruction::BranchF32Eq, execute_branch_f32_eq, |a, b| a == b), - (f32, Instruction::BranchF32Ne, execute_branch_f32_ne, |a, b| a != b), - (f32, Instruction::BranchF32Lt, execute_branch_f32_lt, |a, b| a < b), - (f32, Instruction::BranchF32Le, execute_branch_f32_le, |a, b| a <= b), - (f32, Instruction::BranchF32Gt, execute_branch_f32_gt, |a, b| a > b), - (f32, Instruction::BranchF32Ge, execute_branch_f32_ge, |a, b| a >= b), - - (f64, Instruction::BranchF64Eq, execute_branch_f64_eq, |a, b| a == b), - (f64, Instruction::BranchF64Ne, execute_branch_f64_ne, |a, b| a != b), - (f64, Instruction::BranchF64Lt, execute_branch_f64_lt, |a, b| a < b), - (f64, Instruction::BranchF64Le, execute_branch_f64_le, |a, b| a <= b), - (f64, Instruction::BranchF64Gt, execute_branch_f64_gt, |a, b| a > b), - (f64, Instruction::BranchF64Ge, execute_branch_f64_ge, |a, b| a >= b), + (i32, Instruction::BranchI32And, execute_branch_i32_and, cmp_i32_and), + (i32, Instruction::BranchI32Or, execute_branch_i32_or, cmp_i32_or), + (i32, Instruction::BranchI32Xor, execute_branch_i32_xor, cmp_i32_xor), + (i32, Instruction::BranchI32AndEqz, execute_branch_i32_and_eqz, cmp_i32_and_eqz), + (i32, Instruction::BranchI32OrEqz, execute_branch_i32_or_eqz, cmp_i32_or_eqz), + (i32, Instruction::BranchI32XorEqz, execute_branch_i32_xor_eqz, cmp_i32_xor_eqz), + (i32, Instruction::BranchI32Eq, execute_branch_i32_eq, cmp_eq), + (i32, Instruction::BranchI32Ne, execute_branch_i32_ne, cmp_ne), + (i32, Instruction::BranchI32LtS, execute_branch_i32_lt_s, cmp_lt), + (u32, Instruction::BranchI32LtU, execute_branch_i32_lt_u, cmp_lt), + (i32, Instruction::BranchI32LeS, execute_branch_i32_le_s, cmp_le), + (u32, Instruction::BranchI32LeU, execute_branch_i32_le_u, cmp_le), + (i32, Instruction::BranchI32GtS, execute_branch_i32_gt_s, cmp_gt), + (u32, Instruction::BranchI32GtU, execute_branch_i32_gt_u, cmp_gt), + (i32, Instruction::BranchI32GeS, execute_branch_i32_ge_s, cmp_ge), + (u32, Instruction::BranchI32GeU, execute_branch_i32_ge_u, cmp_ge), + + (i64, Instruction::BranchI64Eq, execute_branch_i64_eq, cmp_eq), + (i64, Instruction::BranchI64Ne, execute_branch_i64_ne, cmp_ne), + (i64, Instruction::BranchI64LtS, execute_branch_i64_lt_s, cmp_lt), + (u64, Instruction::BranchI64LtU, execute_branch_i64_lt_u, cmp_lt), + (i64, Instruction::BranchI64LeS, execute_branch_i64_le_s, cmp_le), + (u64, Instruction::BranchI64LeU, execute_branch_i64_le_u, cmp_le), + (i64, Instruction::BranchI64GtS, execute_branch_i64_gt_s, cmp_gt), + (u64, Instruction::BranchI64GtU, execute_branch_i64_gt_u, cmp_gt), + (i64, Instruction::BranchI64GeS, execute_branch_i64_ge_s, cmp_ge), + (u64, Instruction::BranchI64GeU, execute_branch_i64_ge_u, cmp_ge), + + (f32, Instruction::BranchF32Eq, execute_branch_f32_eq, cmp_eq), + (f32, Instruction::BranchF32Ne, execute_branch_f32_ne, cmp_ne), + (f32, Instruction::BranchF32Lt, execute_branch_f32_lt, cmp_lt), + (f32, Instruction::BranchF32Le, execute_branch_f32_le, cmp_le), + (f32, Instruction::BranchF32Gt, execute_branch_f32_gt, cmp_gt), + (f32, Instruction::BranchF32Ge, execute_branch_f32_ge, cmp_ge), + + (f64, Instruction::BranchF64Eq, execute_branch_f64_eq, cmp_eq), + (f64, Instruction::BranchF64Ne, execute_branch_f64_ne, cmp_ne), + (f64, Instruction::BranchF64Lt, execute_branch_f64_lt, cmp_lt), + (f64, Instruction::BranchF64Le, execute_branch_f64_le, cmp_le), + (f64, Instruction::BranchF64Gt, execute_branch_f64_gt, cmp_gt), + (f64, Instruction::BranchF64Ge, execute_branch_f64_ge, cmp_ge), } macro_rules! impl_execute_branch_binop_imm { @@ -149,31 +230,83 @@ macro_rules! impl_execute_branch_binop_imm { } } impl_execute_branch_binop_imm! { - (i32, Instruction::BranchI32AndImm, execute_branch_i32_and_imm, |a, b| (a & b) != 0), - (i32, Instruction::BranchI32OrImm, execute_branch_i32_or_imm, |a, b| (a | b) != 0), - (i32, Instruction::BranchI32XorImm, execute_branch_i32_xor_imm, |a, b| (a ^ b) != 0), - (i32, Instruction::BranchI32AndEqzImm, execute_branch_i32_and_eqz_imm, |a, b| (a & b) == 0), - (i32, Instruction::BranchI32OrEqzImm, execute_branch_i32_or_eqz_imm, |a, b| (a | b) == 0), - (i32, Instruction::BranchI32XorEqzImm, execute_branch_i32_xor_eqz_imm, |a, b| (a ^ b) == 0), - (i32, Instruction::BranchI32EqImm, execute_branch_i32_eq_imm, |a, b| a == b), - (i32, Instruction::BranchI32NeImm, execute_branch_i32_ne_imm, |a, b| a != b), - (i32, Instruction::BranchI32LtSImm, execute_branch_i32_lt_s_imm, |a, b| a < b), - (u32, Instruction::BranchI32LtUImm, execute_branch_i32_lt_u_imm, |a, b| a < b), - (i32, Instruction::BranchI32LeSImm, execute_branch_i32_le_s_imm, |a, b| a <= b), - (u32, Instruction::BranchI32LeUImm, execute_branch_i32_le_u_imm, |a, b| a <= b), - (i32, Instruction::BranchI32GtSImm, execute_branch_i32_gt_s_imm, |a, b| a > b), - (u32, Instruction::BranchI32GtUImm, execute_branch_i32_gt_u_imm, |a, b| a > b), - (i32, Instruction::BranchI32GeSImm, execute_branch_i32_ge_s_imm, |a, b| a >= b), - (u32, Instruction::BranchI32GeUImm, execute_branch_i32_ge_u_imm, |a, b| a >= b), - - (i64, Instruction::BranchI64EqImm, execute_branch_i64_eq_imm, |a, b| a == b), - (i64, Instruction::BranchI64NeImm, execute_branch_i64_ne_imm, |a, b| a != b), - (i64, Instruction::BranchI64LtSImm, execute_branch_i64_lt_s_imm, |a, b| a < b), - (u64, Instruction::BranchI64LtUImm, execute_branch_i64_lt_u_imm, |a, b| a < b), - (i64, Instruction::BranchI64LeSImm, execute_branch_i64_le_s_imm, |a, b| a <= b), - (u64, Instruction::BranchI64LeUImm, execute_branch_i64_le_u_imm, |a, b| a <= b), - (i64, Instruction::BranchI64GtSImm, execute_branch_i64_gt_s_imm, |a, b| a > b), - (u64, Instruction::BranchI64GtUImm, execute_branch_i64_gt_u_imm, |a, b| a > b), - (i64, Instruction::BranchI64GeSImm, execute_branch_i64_ge_s_imm, |a, b| a >= b), - (u64, Instruction::BranchI64GeUImm, execute_branch_i64_ge_u_imm, |a, b| a >= b), + (i32, Instruction::BranchI32AndImm, execute_branch_i32_and_imm, cmp_i32_and), + (i32, Instruction::BranchI32OrImm, execute_branch_i32_or_imm, cmp_i32_or), + (i32, Instruction::BranchI32XorImm, execute_branch_i32_xor_imm, cmp_i32_xor), + (i32, Instruction::BranchI32AndEqzImm, execute_branch_i32_and_eqz_imm, cmp_i32_and_eqz), + (i32, Instruction::BranchI32OrEqzImm, execute_branch_i32_or_eqz_imm, cmp_i32_or_eqz), + (i32, Instruction::BranchI32XorEqzImm, execute_branch_i32_xor_eqz_imm, cmp_i32_xor_eqz), + (i32, Instruction::BranchI32EqImm, execute_branch_i32_eq_imm, cmp_eq), + (i32, Instruction::BranchI32NeImm, execute_branch_i32_ne_imm, cmp_ne), + (i32, Instruction::BranchI32LtSImm, execute_branch_i32_lt_s_imm, cmp_lt), + (u32, Instruction::BranchI32LtUImm, execute_branch_i32_lt_u_imm, cmp_lt), + (i32, Instruction::BranchI32LeSImm, execute_branch_i32_le_s_imm, cmp_le), + (u32, Instruction::BranchI32LeUImm, execute_branch_i32_le_u_imm, cmp_le), + (i32, Instruction::BranchI32GtSImm, execute_branch_i32_gt_s_imm, cmp_gt), + (u32, Instruction::BranchI32GtUImm, execute_branch_i32_gt_u_imm, cmp_gt), + (i32, Instruction::BranchI32GeSImm, execute_branch_i32_ge_s_imm, cmp_ge), + (u32, Instruction::BranchI32GeUImm, execute_branch_i32_ge_u_imm, cmp_ge), + + (i64, Instruction::BranchI64EqImm, execute_branch_i64_eq_imm, cmp_eq), + (i64, Instruction::BranchI64NeImm, execute_branch_i64_ne_imm, cmp_ne), + (i64, Instruction::BranchI64LtSImm, execute_branch_i64_lt_s_imm, cmp_lt), + (u64, Instruction::BranchI64LtUImm, execute_branch_i64_lt_u_imm, cmp_lt), + (i64, Instruction::BranchI64LeSImm, execute_branch_i64_le_s_imm, cmp_le), + (u64, Instruction::BranchI64LeUImm, execute_branch_i64_le_u_imm, cmp_le), + (i64, Instruction::BranchI64GtSImm, execute_branch_i64_gt_s_imm, cmp_gt), + (u64, Instruction::BranchI64GtUImm, execute_branch_i64_gt_u_imm, cmp_gt), + (i64, Instruction::BranchI64GeSImm, execute_branch_i64_ge_s_imm, cmp_ge), + (u64, Instruction::BranchI64GeUImm, execute_branch_i64_ge_u_imm, cmp_ge), +} + +impl<'ctx, 'engine> Executor<'ctx, 'engine> { + /// Executes an [`Instruction::BranchCmpFallback`]. + pub fn execute_branch_cmp_fallback(&mut self, lhs: Register, rhs: Register, params: Register) { + use BranchComparator as C; + let params = self.get_register(params); + let Some(params) = ComparatorOffsetParam::from_untyped(params) else { + panic!("encountered invalidaly encoded ComparatorOffsetParam: {params:?}") + }; + let offset = params.offset; + match params.cmp { + C::I32Eq => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_eq), + C::I32Ne => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ne), + C::I32LtS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::I32LtU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::I32LeS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::I32LeU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::I32GtS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::I32GtU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::I32GeS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + C::I32GeU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + C::I32And => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_and), + C::I32Or => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_or), + C::I32Xor => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_xor), + C::I32AndEqz => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_and_eqz), + C::I32OrEqz => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_or_eqz), + C::I32XorEqz => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_i32_xor_eqz), + C::I64Eq => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_eq), + C::I64Ne => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ne), + C::I64LtS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::I64LtU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::I64LeS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::I64LeU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::I64GtS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::I64GtU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::I64GeS => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + C::I64GeU => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + C::F32Eq => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_eq), + C::F32Ne => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ne), + C::F32Lt => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::F32Le => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::F32Gt => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::F32Ge => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + C::F64Eq => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_eq), + C::F64Ne => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ne), + C::F64Lt => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_lt), + C::F64Le => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_le), + C::F64Gt => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_gt), + C::F64Ge => self.execute_branch_binop_raw::(lhs, rhs, offset, cmp_ge), + }; + } } diff --git a/crates/wasmi/src/engine/translator/instr_encoder.rs b/crates/wasmi/src/engine/translator/instr_encoder.rs index 5b9aa5ab71..7c41376a91 100644 --- a/crates/wasmi/src/engine/translator/instr_encoder.rs +++ b/crates/wasmi/src/engine/translator/instr_encoder.rs @@ -10,8 +10,10 @@ use crate::{ bytecode::{ BinInstr, BinInstrImm16, + BranchComparator, BranchOffset, BranchOffset16, + ComparatorOffsetParam, Const16, Const32, Instruction, @@ -277,6 +279,10 @@ impl InstrEncoder { /// /// Returns an uninitialized [`BranchOffset`] if the `label` cannot yet /// be resolved and defers resolution to later. + /// + /// # Errors + /// + /// If the [`BranchOffset`] cannot be encoded in 32 bits. pub fn try_resolve_label_for( &mut self, label: LabelRef, @@ -290,9 +296,11 @@ impl InstrEncoder { /// # Panics /// /// If this is used before all branching labels have been pinned. - pub fn update_branch_offsets(&mut self) -> Result<(), Error> { + pub fn update_branch_offsets(&mut self, stack: &mut ValueStack) -> Result<(), Error> { for (user, offset) in self.labels.resolved_users() { - self.instrs.get_mut(user).update_branch_offset(offset?)?; + self.instrs + .get_mut(user) + .update_branch_offset(stack, offset?)?; } Ok(()) } @@ -884,18 +892,42 @@ impl InstrEncoder { type BranchCmpConstructor = fn(Register, Register, BranchOffset16) -> Instruction; type BranchCmpImmConstructor = fn(Register, Const16, BranchOffset16) -> Instruction; + /// Create an [`Instruction::BranchCmpFallback`]. + fn make_branch_cmp_fallback( + stack: &mut ValueStack, + cmp: BranchComparator, + lhs: Register, + rhs: Register, + offset: BranchOffset, + ) -> Result { + let params = stack.alloc_const(ComparatorOffsetParam::new(cmp, offset))?; + Ok(Instruction::branch_cmp_fallback(lhs, rhs, params)) + } + /// Encode an unoptimized `branch_eqz` instruction. /// /// This is used as fallback whenever fusing compare and branch instructions is not possible. fn encode_branch_eqz_fallback( this: &mut InstrEncoder, + stack: &mut ValueStack, condition: Register, label: LabelRef, ) -> Result<(), Error> { - let offset = this - .try_resolve_label(label) - .and_then(BranchOffset16::try_from)?; - this.push_instr(Instruction::branch_i32_eqz(condition, offset))?; + let offset = this.try_resolve_label(label)?; + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => Instruction::branch_i32_eqz(condition, offset), + Err(_) => { + let zero = stack.alloc_const(0_i32)?; + make_branch_cmp_fallback( + stack, + BranchComparator::I32Eq, + condition, + zero, + offset, + )? + } + }; + this.push_instr(instr)?; Ok(()) } @@ -908,6 +940,7 @@ impl InstrEncoder { last_instr: Instr, instr: BinInstr, label: LabelRef, + cmp: BranchComparator, make_instr: BranchCmpConstructor, ) -> Result, Error> { if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { @@ -917,9 +950,11 @@ impl InstrEncoder { return Ok(None); } let offset = this.try_resolve_label_for(label, last_instr)?; - let instr = BranchOffset16::new(offset) - .map(|offset16| make_instr(instr.lhs, instr.rhs, offset16)); - Ok(instr) + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => make_instr(instr.lhs, instr.rhs, offset), + Err(_) => make_branch_cmp_fallback(stack, cmp, instr.lhs, instr.rhs, offset)?, + }; + Ok(Some(instr)) } /// Create a fused cmp+branch instruction with a 16-bit immediate and wrap it in a `Some`. @@ -931,8 +966,12 @@ impl InstrEncoder { last_instr: Instr, instr: BinInstrImm16, label: LabelRef, + cmp: BranchComparator, make_instr: BranchCmpImmConstructor, - ) -> Result, Error> { + ) -> Result, Error> + where + T: From> + Into, + { if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { // We need to filter out instructions that store their result // into a local register slot because they introduce observable behavior @@ -940,120 +979,86 @@ impl InstrEncoder { return Ok(None); } let offset = this.try_resolve_label_for(label, last_instr)?; - let instr = BranchOffset16::new(offset) - .map(|offset16| make_instr(instr.reg_in, instr.imm_in, offset16)); - Ok(instr) + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => make_instr(instr.reg_in, instr.imm_in, offset), + Err(_) => { + let rhs = stack.alloc_const(T::from(instr.imm_in))?; + make_branch_cmp_fallback(stack, cmp, instr.reg_in, rhs, offset)? + } + }; + Ok(Some(instr)) } + use BranchComparator as Cmp; use Instruction as I; let Some(last_instr) = self.last_instr else { - return encode_branch_eqz_fallback(self, condition, label); + return encode_branch_eqz_fallback(self, stack, condition, label); }; #[rustfmt::skip] let fused_instr = match *self.instrs.get(last_instr) { - I::I32EqImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i32_nez(instr.reg_in, offset16)) - } - } - } - I::I64EqImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i64_nez(instr.reg_in, offset16)) - } - } - } - I::I32NeImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i32_eqz(instr.reg_in, offset16)) - } - } - } - I::I64NeImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i64_eqz(instr.reg_in, offset16)) - } - } - } - I::I32And(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_and_eqz as _)?, - I::I32Or(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_or_eqz as _)?, - I::I32Xor(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_xor_eqz as _)?, - I::I32AndEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_and as _)?, - I::I32OrEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_or as _)?, - I::I32XorEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_xor as _)?, - I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ne as _)?, - I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_eq as _)?, - I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_s as _)?, - I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_u as _)?, - I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_s as _)?, - I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_u as _)?, - I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_s as _)?, - I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_u as _)?, - I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_s as _)?, - I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_u as _)?, - I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ne as _)?, - I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_eq as _)?, - I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_s as _)?, - I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_u as _)?, - I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_s as _)?, - I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_u as _)?, - I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_s as _)?, - I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_u as _)?, - I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_s as _)?, - I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_u as _)?, - I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ne as _)?, - I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_eq as _)?, + I::I32And(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32AndEqz, I::branch_i32_and_eqz as _)?, + I::I32Or(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32OrEqz, I::branch_i32_or_eqz as _)?, + I::I32Xor(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32XorEqz, I::branch_i32_xor_eqz as _)?, + I::I32AndEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32And, I::branch_i32_and as _)?, + I::I32OrEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Or, I::branch_i32_or as _)?, + I::I32XorEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Xor, I::branch_i32_xor as _)?, + I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Ne, I::branch_i32_ne as _)?, + I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Eq, I::branch_i32_eq as _)?, + I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GeS, I::branch_i32_ge_s as _)?, + I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GeU, I::branch_i32_ge_u as _)?, + I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GtS, I::branch_i32_gt_s as _)?, + I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GtU, I::branch_i32_gt_u as _)?, + I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LeS, I::branch_i32_le_s as _)?, + I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LeU, I::branch_i32_le_u as _)?, + I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LtS, I::branch_i32_lt_s as _)?, + I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LtU, I::branch_i32_lt_u as _)?, + I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64Ne, I::branch_i64_ne as _)?, + I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64Eq, I::branch_i64_eq as _)?, + I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GeS, I::branch_i64_ge_s as _)?, + I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GeU, I::branch_i64_ge_u as _)?, + I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GtS, I::branch_i64_gt_s as _)?, + I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GtU, I::branch_i64_gt_u as _)?, + I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LeS, I::branch_i64_le_s as _)?, + I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LeU, I::branch_i64_le_u as _)?, + I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LtS, I::branch_i64_lt_s as _)?, + I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LtU, I::branch_i64_lt_u as _)?, + I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Ne, I::branch_f32_ne as _)?, + I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Eq, I::branch_f32_eq as _)?, // Note: We cannot fuse cmp+branch for float comparison operators due to how NaN values are treated. - I::I32AndImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_and_eqz_imm as _)?, - I::I32OrImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_or_eqz_imm as _)?, - I::I32XorImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_xor_eqz_imm as _)?, - I::I32AndEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_and_imm as _)?, - I::I32OrEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_or_imm as _)?, - I::I32XorEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_xor_imm as _)?, - I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ne_imm as _)?, - I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_eq_imm as _)?, - I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_s_imm as _)?, - I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_u_imm as _)?, - I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_s_imm as _)?, - I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_u_imm as _)?, - I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_s_imm as _)?, - I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_u_imm as _)?, - I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_s_imm as _)?, - I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_u_imm as _)?, - I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ne_imm as _)?, - I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_eq_imm as _)?, - I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_s_imm as _)?, - I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_u_imm as _)?, - I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_s_imm as _)?, - I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_u_imm as _)?, - I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_s_imm as _)?, - I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_u_imm as _)?, - I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_s_imm as _)?, - I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_u_imm as _)?, + I::I32AndImm16(instr) => fuse_imm::(self, stack, last_instr, instr, label, Cmp::I32AndEqz, I::branch_i32_and_eqz_imm as _)?, + I::I32OrImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32OrEqz, I::branch_i32_or_eqz_imm as _)?, + I::I32XorImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32XorEqz, I::branch_i32_xor_eqz_imm as _)?, + I::I32AndEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32And, I::branch_i32_and_imm as _)?, + I::I32OrEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Or, I::branch_i32_or_imm as _)?, + I::I32XorEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Xor, I::branch_i32_xor_imm as _)?, + I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Ne, I::branch_i32_ne_imm as _)?, + I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Eq, I::branch_i32_eq_imm as _)?, + I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GeS, I::branch_i32_ge_s_imm as _)?, + I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GeU, I::branch_i32_ge_u_imm as _)?, + I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GtS, I::branch_i32_gt_s_imm as _)?, + I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GtU, I::branch_i32_gt_u_imm as _)?, + I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LeS, I::branch_i32_le_s_imm as _)?, + I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LeU, I::branch_i32_le_u_imm as _)?, + I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LtS, I::branch_i32_lt_s_imm as _)?, + I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LtU, I::branch_i32_lt_u_imm as _)?, + I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64Ne, I::branch_i64_ne_imm as _)?, + I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64Eq, I::branch_i64_eq_imm as _)?, + I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GeS, I::branch_i64_ge_s_imm as _)?, + I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GeU, I::branch_i64_ge_u_imm as _)?, + I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GtS, I::branch_i64_gt_s_imm as _)?, + I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GtU, I::branch_i64_gt_u_imm as _)?, + I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LeS, I::branch_i64_le_s_imm as _)?, + I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LeU, I::branch_i64_le_u_imm as _)?, + I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LtS, I::branch_i64_lt_s_imm as _)?, + I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LtU, I::branch_i64_lt_u_imm as _)?, _ => None, }; if let Some(fused_instr) = fused_instr { _ = mem::replace(self.instrs.get_mut(last_instr), fused_instr); return Ok(()); } - encode_branch_eqz_fallback(self, condition, label) + encode_branch_eqz_fallback(self, stack, condition, label) } /// Encodes a `branch_nez` instruction and tries to fuse it with a previous comparison instruction. @@ -1066,18 +1071,42 @@ impl InstrEncoder { type BranchCmpConstructor = fn(Register, Register, BranchOffset16) -> Instruction; type BranchCmpImmConstructor = fn(Register, Const16, BranchOffset16) -> Instruction; + /// Create an [`Instruction::BranchCmpFallback`]. + fn make_branch_cmp_fallback( + stack: &mut ValueStack, + cmp: BranchComparator, + lhs: Register, + rhs: Register, + offset: BranchOffset, + ) -> Result { + let params = stack.alloc_const(ComparatorOffsetParam::new(cmp, offset))?; + Ok(Instruction::branch_cmp_fallback(lhs, rhs, params)) + } + /// Encode an unoptimized `branch_nez` instruction. /// /// This is used as fallback whenever fusing compare and branch instructions is not possible. fn encode_branch_nez_fallback( this: &mut InstrEncoder, + stack: &mut ValueStack, condition: Register, label: LabelRef, ) -> Result<(), Error> { - let offset = this - .try_resolve_label(label) - .and_then(BranchOffset16::try_from)?; - this.push_instr(Instruction::branch_i32_nez(condition, offset))?; + let offset = this.try_resolve_label(label)?; + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => Instruction::branch_i32_nez(condition, offset), + Err(_) => { + let zero = stack.alloc_const(0_i32)?; + make_branch_cmp_fallback( + stack, + BranchComparator::I32Ne, + condition, + zero, + offset, + )? + } + }; + this.push_instr(instr)?; Ok(()) } @@ -1090,6 +1119,7 @@ impl InstrEncoder { last_instr: Instr, instr: BinInstr, label: LabelRef, + cmp: BranchComparator, make_instr: BranchCmpConstructor, ) -> Result, Error> { if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { @@ -1099,9 +1129,11 @@ impl InstrEncoder { return Ok(None); } let offset = this.try_resolve_label_for(label, last_instr)?; - let instr = BranchOffset16::new(offset) - .map(|offset16| make_instr(instr.lhs, instr.rhs, offset16)); - Ok(instr) + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => make_instr(instr.lhs, instr.rhs, offset), + Err(_) => make_branch_cmp_fallback(stack, cmp, instr.lhs, instr.rhs, offset)?, + }; + Ok(Some(instr)) } /// Create a fused cmp+branch instruction with a 16-bit immediate and wrap it in a `Some`. @@ -1113,8 +1145,12 @@ impl InstrEncoder { last_instr: Instr, instr: BinInstrImm16, label: LabelRef, + cmp: BranchComparator, make_instr: BranchCmpImmConstructor, - ) -> Result, Error> { + ) -> Result, Error> + where + T: From> + Into, + { if matches!(stack.get_register_space(instr.result), RegisterSpace::Local) { // We need to filter out instructions that store their result // into a local register slot because they introduce observable behavior @@ -1122,129 +1158,95 @@ impl InstrEncoder { return Ok(None); } let offset = this.try_resolve_label_for(label, last_instr)?; - let instr = BranchOffset16::new(offset) - .map(|offset16| make_instr(instr.reg_in, instr.imm_in, offset16)); - Ok(instr) + let instr = match BranchOffset16::try_from(offset) { + Ok(offset) => make_instr(instr.reg_in, instr.imm_in, offset), + Err(_) => { + let rhs = stack.alloc_const(T::from(instr.imm_in))?; + make_branch_cmp_fallback(stack, cmp, instr.reg_in, rhs, offset)? + } + }; + Ok(Some(instr)) } + use BranchComparator as Cmp; use Instruction as I; let Some(last_instr) = self.last_instr else { - return encode_branch_nez_fallback(self, condition, label); + return encode_branch_nez_fallback(self, stack, condition, label); }; #[rustfmt::skip] let fused_instr = match *self.instrs.get(last_instr) { - I::I32EqImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i32_eqz(instr.reg_in, offset16)) - } - } - } - I::I64EqImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i64_eqz(instr.reg_in, offset16)) - } - } - } - I::I32NeImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i32_nez(instr.reg_in, offset16)) - } - } - } - I::I64NeImm16(instr) if instr.imm_in.is_zero() => { - match stack.get_register_space(instr.result) { - RegisterSpace::Local => None, - _ => { - let offset16 = self.try_resolve_label_for(label, last_instr) - .and_then(BranchOffset16::try_from)?; - Some(Instruction::branch_i64_nez(instr.reg_in, offset16)) - } - } - } - I::I32And(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_and as _)?, - I::I32Or(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_or as _)?, - I::I32Xor(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_xor as _)?, - I::I32AndEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_and_eqz as _)?, - I::I32OrEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_or_eqz as _)?, - I::I32XorEqz(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_xor_eqz as _)?, - I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_eq as _)?, - I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ne as _)?, - I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_s as _)?, - I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_lt_u as _)?, - I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_s as _)?, - I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_le_u as _)?, - I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_s as _)?, - I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_gt_u as _)?, - I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_s as _)?, - I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i32_ge_u as _)?, - I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_eq as _)?, - I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ne as _)?, - I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_s as _)?, - I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_lt_u as _)?, - I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_s as _)?, - I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_le_u as _)?, - I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_s as _)?, - I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_gt_u as _)?, - I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_s as _)?, - I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, I::branch_i64_ge_u as _)?, - I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_eq as _)?, - I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ne as _)?, - I::F32Lt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_lt as _)?, - I::F32Le(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_le as _)?, - I::F32Gt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_gt as _)?, - I::F32Ge(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f32_ge as _)?, - I::F64Eq(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_eq as _)?, - I::F64Ne(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_ne as _)?, - I::F64Lt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_lt as _)?, - I::F64Le(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_le as _)?, - I::F64Gt(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_gt as _)?, - I::F64Ge(instr) => fuse(self, stack, last_instr, instr, label, I::branch_f64_ge as _)?, - I::I32AndImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_and_imm as _)?, - I::I32OrImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_or_imm as _)?, - I::I32XorImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_xor_imm as _)?, - I::I32AndEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_and_eqz_imm as _)?, - I::I32OrEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_or_eqz_imm as _)?, - I::I32XorEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_xor_eqz_imm as _)?, - I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_eq_imm as _)?, - I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ne_imm as _)?, - I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_s_imm as _)?, - I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_lt_u_imm as _)?, - I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_s_imm as _)?, - I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_le_u_imm as _)?, - I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_s_imm as _)?, - I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_gt_u_imm as _)?, - I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_s_imm as _)?, - I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i32_ge_u_imm as _)?, - I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_eq_imm as _)?, - I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ne_imm as _)?, - I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_s_imm as _)?, - I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_lt_u_imm as _)?, - I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_s_imm as _)?, - I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_le_u_imm as _)?, - I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_s_imm as _)?, - I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_gt_u_imm as _)?, - I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_s_imm as _)?, - I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, I::branch_i64_ge_u_imm as _)?, + I::I32And(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32And, I::branch_i32_and as _)?, + I::I32Or(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Or, I::branch_i32_or as _)?, + I::I32Xor(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Xor, I::branch_i32_xor as _)?, + I::I32AndEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32AndEqz, I::branch_i32_and_eqz as _)?, + I::I32OrEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32OrEqz, I::branch_i32_or_eqz as _)?, + I::I32XorEqz(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32XorEqz, I::branch_i32_xor_eqz as _)?, + I::I32Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Eq, I::branch_i32_eq as _)?, + I::I32Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32Ne, I::branch_i32_ne as _)?, + I::I32LtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LtS, I::branch_i32_lt_s as _)?, + I::I32LtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LtU, I::branch_i32_lt_u as _)?, + I::I32LeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LeS, I::branch_i32_le_s as _)?, + I::I32LeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32LeU, I::branch_i32_le_u as _)?, + I::I32GtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GtS, I::branch_i32_gt_s as _)?, + I::I32GtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GtU, I::branch_i32_gt_u as _)?, + I::I32GeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GeS, I::branch_i32_ge_s as _)?, + I::I32GeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I32GeU, I::branch_i32_ge_u as _)?, + I::I64Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64Eq, I::branch_i64_eq as _)?, + I::I64Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64Ne, I::branch_i64_ne as _)?, + I::I64LtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LtS, I::branch_i64_lt_s as _)?, + I::I64LtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LtU, I::branch_i64_lt_u as _)?, + I::I64LeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LeS, I::branch_i64_le_s as _)?, + I::I64LeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64LeU, I::branch_i64_le_u as _)?, + I::I64GtS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GtS, I::branch_i64_gt_s as _)?, + I::I64GtU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GtU, I::branch_i64_gt_u as _)?, + I::I64GeS(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GeS, I::branch_i64_ge_s as _)?, + I::I64GeU(instr) => fuse(self, stack, last_instr, instr, label, Cmp::I64GeU, I::branch_i64_ge_u as _)?, + I::F32Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Eq, I::branch_f32_eq as _)?, + I::F32Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Ne, I::branch_f32_ne as _)?, + I::F32Lt(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Lt, I::branch_f32_lt as _)?, + I::F32Le(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Le, I::branch_f32_le as _)?, + I::F32Gt(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Gt, I::branch_f32_gt as _)?, + I::F32Ge(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F32Ge, I::branch_f32_ge as _)?, + I::F64Eq(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Eq, I::branch_f64_eq as _)?, + I::F64Ne(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Ne, I::branch_f64_ne as _)?, + I::F64Lt(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Lt, I::branch_f64_lt as _)?, + I::F64Le(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Le, I::branch_f64_le as _)?, + I::F64Gt(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Gt, I::branch_f64_gt as _)?, + I::F64Ge(instr) => fuse(self, stack, last_instr, instr, label, Cmp::F64Ge, I::branch_f64_ge as _)?, + I::I32AndImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32And, I::branch_i32_and_imm as _)?, + I::I32OrImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Or, I::branch_i32_or_imm as _)?, + I::I32XorImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Xor, I::branch_i32_xor_imm as _)?, + I::I32AndEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32AndEqz, I::branch_i32_and_eqz_imm as _)?, + I::I32OrEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32OrEqz, I::branch_i32_or_eqz_imm as _)?, + I::I32XorEqzImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32XorEqz, I::branch_i32_xor_eqz_imm as _)?, + I::I32EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Eq, I::branch_i32_eq_imm as _)?, + I::I32NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32Ne, I::branch_i32_ne_imm as _)?, + I::I32LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LtS, I::branch_i32_lt_s_imm as _)?, + I::I32LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LtU, I::branch_i32_lt_u_imm as _)?, + I::I32LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LeS, I::branch_i32_le_s_imm as _)?, + I::I32LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32LeU, I::branch_i32_le_u_imm as _)?, + I::I32GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GtS, I::branch_i32_gt_s_imm as _)?, + I::I32GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GtU, I::branch_i32_gt_u_imm as _)?, + I::I32GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GeS, I::branch_i32_ge_s_imm as _)?, + I::I32GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I32GeU, I::branch_i32_ge_u_imm as _)?, + I::I64EqImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64Eq, I::branch_i64_eq_imm as _)?, + I::I64NeImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64Ne, I::branch_i64_ne_imm as _)?, + I::I64LtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LtS, I::branch_i64_lt_s_imm as _)?, + I::I64LtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LtU, I::branch_i64_lt_u_imm as _)?, + I::I64LeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LeS, I::branch_i64_le_s_imm as _)?, + I::I64LeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64LeU, I::branch_i64_le_u_imm as _)?, + I::I64GtSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GtS, I::branch_i64_gt_s_imm as _)?, + I::I64GtUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GtU, I::branch_i64_gt_u_imm as _)?, + I::I64GeSImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GeS, I::branch_i64_ge_s_imm as _)?, + I::I64GeUImm16(instr) => fuse_imm(self, stack, last_instr, instr, label, Cmp::I64GeU, I::branch_i64_ge_u_imm as _)?, _ => None, }; if let Some(fused_instr) = fused_instr { _ = mem::replace(self.instrs.get_mut(last_instr), fused_instr); return Ok(()); } - encode_branch_nez_fallback(self, condition, label) + encode_branch_nez_fallback(self, stack, condition, label) } } @@ -1254,76 +1256,103 @@ impl Instruction { /// # Panics /// /// If `self` is not a branch [`Instruction`]. - pub fn update_branch_offset(&mut self, new_offset: BranchOffset) -> Result<(), Error> { + #[rustfmt::skip] + pub fn update_branch_offset(&mut self, stack: &mut ValueStack, new_offset: BranchOffset) -> Result<(), Error> { + /// Initializes the 16-bit offset of `instr` if possible. + /// + /// If `new_offset` cannot be encoded as 16-bit offset `self` is replaced with a fallback instruction. + macro_rules! init_offset { + ($instr:expr, $new_offset:expr, $cmp:expr) => {{ + if let Err(_) = $instr.offset.init($new_offset) { + let params = stack.alloc_const(ComparatorOffsetParam::new($cmp, $new_offset))?; + *self = Instruction::branch_cmp_fallback($instr.lhs, $instr.rhs, params); + } + Ok(()) + }} + } + + macro_rules! init_offset_imm { + ($ty:ty, $instr:expr, $new_offset:expr, $cmp:expr) => {{ + if let Err(_) = $instr.offset.init($new_offset) { + let rhs = stack.alloc_const(<$ty>::from($instr.rhs))?; + let params = stack.alloc_const(ComparatorOffsetParam::new($cmp, $new_offset))?; + *self = Instruction::branch_cmp_fallback($instr.lhs, rhs, params); + } + Ok(()) + }}; + } + + use Instruction as I; + use BranchComparator as Cmp; match self { Instruction::Branch { offset } => { offset.init(new_offset); Ok(()) } - Instruction::BranchI32And(instr) - | Instruction::BranchI32Or(instr) - | Instruction::BranchI32Xor(instr) - | Instruction::BranchI32AndEqz(instr) - | Instruction::BranchI32OrEqz(instr) - | Instruction::BranchI32XorEqz(instr) - | Instruction::BranchI32Eq(instr) - | Instruction::BranchI32Ne(instr) - | Instruction::BranchI32LtS(instr) - | Instruction::BranchI32LtU(instr) - | Instruction::BranchI32LeS(instr) - | Instruction::BranchI32LeU(instr) - | Instruction::BranchI32GtS(instr) - | Instruction::BranchI32GtU(instr) - | Instruction::BranchI32GeS(instr) - | Instruction::BranchI32GeU(instr) - | Instruction::BranchI64Eq(instr) - | Instruction::BranchI64Ne(instr) - | Instruction::BranchI64LtS(instr) - | Instruction::BranchI64LtU(instr) - | Instruction::BranchI64LeS(instr) - | Instruction::BranchI64LeU(instr) - | Instruction::BranchI64GtS(instr) - | Instruction::BranchI64GtU(instr) - | Instruction::BranchI64GeS(instr) - | Instruction::BranchI64GeU(instr) - | Instruction::BranchF32Eq(instr) - | Instruction::BranchF32Ne(instr) - | Instruction::BranchF32Lt(instr) - | Instruction::BranchF32Le(instr) - | Instruction::BranchF32Gt(instr) - | Instruction::BranchF32Ge(instr) - | Instruction::BranchF64Eq(instr) - | Instruction::BranchF64Ne(instr) - | Instruction::BranchF64Lt(instr) - | Instruction::BranchF64Le(instr) - | Instruction::BranchF64Gt(instr) - | Instruction::BranchF64Ge(instr) => instr.offset.init(new_offset), - Instruction::BranchI32AndImm(instr) - | Instruction::BranchI32OrImm(instr) - | Instruction::BranchI32XorImm(instr) - | Instruction::BranchI32AndEqzImm(instr) - | Instruction::BranchI32OrEqzImm(instr) - | Instruction::BranchI32XorEqzImm(instr) - | Instruction::BranchI32EqImm(instr) - | Instruction::BranchI32NeImm(instr) - | Instruction::BranchI32LtSImm(instr) - | Instruction::BranchI32LeSImm(instr) - | Instruction::BranchI32GtSImm(instr) - | Instruction::BranchI32GeSImm(instr) => instr.offset.init(new_offset), - Instruction::BranchI32LtUImm(instr) - | Instruction::BranchI32LeUImm(instr) - | Instruction::BranchI32GtUImm(instr) - | Instruction::BranchI32GeUImm(instr) => instr.offset.init(new_offset), - Instruction::BranchI64EqImm(instr) - | Instruction::BranchI64NeImm(instr) - | Instruction::BranchI64LtSImm(instr) - | Instruction::BranchI64LeSImm(instr) - | Instruction::BranchI64GtSImm(instr) - | Instruction::BranchI64GeSImm(instr) => instr.offset.init(new_offset), - Instruction::BranchI64LtUImm(instr) - | Instruction::BranchI64LeUImm(instr) - | Instruction::BranchI64GtUImm(instr) - | Instruction::BranchI64GeUImm(instr) => instr.offset.init(new_offset), + I::BranchI32And(instr) => init_offset!(instr, new_offset, Cmp::I32And), + I::BranchI32Or(instr) => init_offset!(instr, new_offset, Cmp::I32Or), + I::BranchI32Xor(instr) => init_offset!(instr, new_offset, Cmp::I32Xor), + I::BranchI32AndEqz(instr) => init_offset!(instr, new_offset, Cmp::I32AndEqz), + I::BranchI32OrEqz(instr) => init_offset!(instr, new_offset, Cmp::I32OrEqz), + I::BranchI32XorEqz(instr) => init_offset!(instr, new_offset, Cmp::I32XorEqz), + I::BranchI32Eq(instr) => init_offset!(instr, new_offset, Cmp::I32Eq), + I::BranchI32Ne(instr) => init_offset!(instr, new_offset, Cmp::I32Ne), + I::BranchI32LtS(instr) => init_offset!(instr, new_offset, Cmp::I32LtS), + I::BranchI32LtU(instr) => init_offset!(instr, new_offset, Cmp::I32LtU), + I::BranchI32LeS(instr) => init_offset!(instr, new_offset, Cmp::I32LeS), + I::BranchI32LeU(instr) => init_offset!(instr, new_offset, Cmp::I32LeU), + I::BranchI32GtS(instr) => init_offset!(instr, new_offset, Cmp::I32GtS), + I::BranchI32GtU(instr) => init_offset!(instr, new_offset, Cmp::I32GtU), + I::BranchI32GeS(instr) => init_offset!(instr, new_offset, Cmp::I32GeS), + I::BranchI32GeU(instr) => init_offset!(instr, new_offset, Cmp::I32GeU), + I::BranchI64Eq(instr) => init_offset!(instr, new_offset, Cmp::I64Eq), + I::BranchI64Ne(instr) => init_offset!(instr, new_offset, Cmp::I64Ne), + I::BranchI64LtS(instr) => init_offset!(instr, new_offset, Cmp::I64LtS), + I::BranchI64LtU(instr) => init_offset!(instr, new_offset, Cmp::I64LtU), + I::BranchI64LeS(instr) => init_offset!(instr, new_offset, Cmp::I64LeS), + I::BranchI64LeU(instr) => init_offset!(instr, new_offset, Cmp::I64LeU), + I::BranchI64GtS(instr) => init_offset!(instr, new_offset, Cmp::I64GtS), + I::BranchI64GtU(instr) => init_offset!(instr, new_offset, Cmp::I64GtU), + I::BranchI64GeS(instr) => init_offset!(instr, new_offset, Cmp::I64GeS), + I::BranchI64GeU(instr) => init_offset!(instr, new_offset, Cmp::I64GeU), + I::BranchF32Eq(instr) => init_offset!(instr, new_offset, Cmp::F32Eq), + I::BranchF32Ne(instr) => init_offset!(instr, new_offset, Cmp::F32Ne), + I::BranchF32Lt(instr) => init_offset!(instr, new_offset, Cmp::F32Lt), + I::BranchF32Le(instr) => init_offset!(instr, new_offset, Cmp::F32Le), + I::BranchF32Gt(instr) => init_offset!(instr, new_offset, Cmp::F32Gt), + I::BranchF32Ge(instr) => init_offset!(instr, new_offset, Cmp::F32Ge), + I::BranchF64Eq(instr) => init_offset!(instr, new_offset, Cmp::F64Eq), + I::BranchF64Ne(instr) => init_offset!(instr, new_offset, Cmp::F64Ne), + I::BranchF64Lt(instr) => init_offset!(instr, new_offset, Cmp::F64Lt), + I::BranchF64Le(instr) => init_offset!(instr, new_offset, Cmp::F64Le), + I::BranchF64Gt(instr) => init_offset!(instr, new_offset, Cmp::F64Gt), + I::BranchF64Ge(instr) => init_offset!(instr, new_offset, Cmp::F64Ge), + I::BranchI32AndImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32And), + I::BranchI32OrImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32Or), + I::BranchI32XorImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32Xor), + I::BranchI32AndEqzImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32AndEqz), + I::BranchI32OrEqzImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32OrEqz), + I::BranchI32XorEqzImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32XorEqz), + I::BranchI32EqImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32Eq), + I::BranchI32NeImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32Ne), + I::BranchI32LtSImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32LtS), + I::BranchI32LeSImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32LeS), + I::BranchI32GtSImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32GtS), + I::BranchI32GeSImm(instr) => init_offset_imm!(i32, instr, new_offset, Cmp::I32GeS), + I::BranchI32LtUImm(instr) => init_offset_imm!(u32, instr, new_offset, Cmp::I32LtU), + I::BranchI32LeUImm(instr) => init_offset_imm!(u32, instr, new_offset, Cmp::I32LeU), + I::BranchI32GtUImm(instr) => init_offset_imm!(u32, instr, new_offset, Cmp::I32GtU), + I::BranchI32GeUImm(instr) => init_offset_imm!(u32, instr, new_offset, Cmp::I32GeU), + I::BranchI64EqImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64Eq), + I::BranchI64NeImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64Ne), + I::BranchI64LtSImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64LtS), + I::BranchI64LeSImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64LeS), + I::BranchI64GtSImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64GtS), + I::BranchI64GeSImm(instr) => init_offset_imm!(i64, instr, new_offset, Cmp::I64GeS), + I::BranchI64LtUImm(instr) => init_offset_imm!(u64, instr, new_offset, Cmp::I64LtU), + I::BranchI64LeUImm(instr) => init_offset_imm!(u64, instr, new_offset, Cmp::I64LeU), + I::BranchI64GtUImm(instr) => init_offset_imm!(u64, instr, new_offset, Cmp::I64GtU), + I::BranchI64GeUImm(instr) => init_offset_imm!(u64, instr, new_offset, Cmp::I64GeU), _ => panic!("tried to update branch offset of a non-branch instruction: {self:?}"), } } diff --git a/crates/wasmi/src/engine/translator/mod.rs b/crates/wasmi/src/engine/translator/mod.rs index 3fe55dcad9..7d73f61eed 100644 --- a/crates/wasmi/src/engine/translator/mod.rs +++ b/crates/wasmi/src/engine/translator/mod.rs @@ -492,7 +492,9 @@ impl<'parser> WasmTranslator<'parser> for FuncTranslator { self.alloc .instr_encoder .defrag_registers(&mut self.alloc.stack)?; - self.alloc.instr_encoder.update_branch_offsets()?; + self.alloc + .instr_encoder + .update_branch_offsets(&mut self.alloc.stack)?; let len_registers = self.alloc.stack.len_registers(); if let Some(fuel_costs) = self.fuel_costs() { // Note: Fuel metering is enabled so we need to bump the fuel diff --git a/crates/wasmi/src/engine/translator/relink_result.rs b/crates/wasmi/src/engine/translator/relink_result.rs index 70eed472ea..0d15f58155 100644 --- a/crates/wasmi/src/engine/translator/relink_result.rs +++ b/crates/wasmi/src/engine/translator/relink_result.rs @@ -63,6 +63,7 @@ impl Instruction { | I::ReturnNezSpan { .. } | I::ReturnNezMany { .. } | I::Branch { .. } + | I::BranchCmpFallback { .. } | I::BranchI32And(_) | I::BranchI32AndImm(_) | I::BranchI32Or(_) diff --git a/crates/wasmi/src/engine/translator/visit_register.rs b/crates/wasmi/src/engine/translator/visit_register.rs index f8963d6b76..4a75a86534 100644 --- a/crates/wasmi/src/engine/translator/visit_register.rs +++ b/crates/wasmi/src/engine/translator/visit_register.rs @@ -80,6 +80,7 @@ impl VisitInputRegisters for Instruction { Instruction::Branch { .. } => {}, Instruction::BranchTable { index, .. } => f(index), + Instruction::BranchCmpFallback { lhs, rhs, .. } => visit_registers!(f, lhs, rhs), Instruction::BranchI32And(instr) => instr.visit_input_registers(f), Instruction::BranchI32AndImm(instr) => instr.visit_input_registers(f), Instruction::BranchI32Or(instr) => instr.visit_input_registers(f),