From 6204a51a76a3a5ad3283dd102878ae000c92c4c8 Mon Sep 17 00:00:00 2001 From: DianQK Date: Sat, 3 Feb 2024 22:44:01 +0800 Subject: [PATCH 1/4] Update matches_reduce_branches.rs --- ...h_i128_u128.MatchBranchSimplification.diff | 42 +++++ ...atch_i16_i8.MatchBranchSimplification.diff | 37 ++++ ...atch_i8_i16.MatchBranchSimplification.diff | 37 ++++ ..._i16_failed.MatchBranchSimplification.diff | 37 ++++ ...atch_u8_i16.MatchBranchSimplification.diff | 32 ++++ ...ch_u8_i16_2.MatchBranchSimplification.diff | 44 +++++ ..._i16_failed.MatchBranchSimplification.diff | 32 ++++ ...16_fallback.MatchBranchSimplification.diff | 31 ++++ ...atch_u8_u16.MatchBranchSimplification.diff | 37 ++++ tests/mir-opt/matches_reduce_branches.rs | 161 +++++++++++++++++- 10 files changed, 486 insertions(+), 4 deletions(-) create mode 100644 tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..9fb74eb670fda --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -0,0 +1,42 @@ +- // MIR for `match_i128_u128` before MatchBranchSimplification ++ // MIR for `match_i128_u128` after MatchBranchSimplification + + fn match_i128_u128(_1: EnumAi128) -> u128 { + debug i => _1; + let mut _0: u128; + let mut _2: i128; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const _; + goto -> bb6; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const 1_u128; + goto -> bb6; + } + + bb4: { + _0 = const 2_u128; + goto -> bb6; + } + + bb5: { + _0 = const 3_u128; + goto -> bb6; + } + + bb6: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..4d069c1236b61 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i16_i8` before MatchBranchSimplification ++ // MIR for `match_i16_i8` after MatchBranchSimplification + + fn match_i16_i8(_1: EnumAi16) -> i8 { + debug i => _1; + let mut _0: i8; + let mut _2: i16; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const -3_i8; + goto -> bb5; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const -1_i8; + goto -> bb5; + } + + bb4: { + _0 = const 2_i8; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..d934be4adc29d --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i8_i16` before MatchBranchSimplification ++ // MIR for `match_i8_i16` after MatchBranchSimplification + + fn match_i8_i16(_1: EnumAi8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: i8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const -3_i16; + goto -> bb5; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const -1_i16; + goto -> bb5; + } + + bb4: { + _0 = const 2_i16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..2cdc11b9cd574 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i8_i16_failed` before MatchBranchSimplification ++ // MIR for `match_i8_i16_failed` after MatchBranchSimplification + + fn match_i8_i16_failed(_1: EnumAi8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: i8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 3_i16; + goto -> bb5; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const -1_i16; + goto -> bb5; + } + + bb4: { + _0 = const 2_i16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..e0328494be220 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_u8_i16` before MatchBranchSimplification ++ // MIR for `match_u8_i16` after MatchBranchSimplification + + fn match_u8_i16(_1: EnumAu8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 2_i16; + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const 1_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..69209b3c84faa --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff @@ -0,0 +1,44 @@ +- // MIR for `match_u8_i16_2` before MatchBranchSimplification ++ // MIR for `match_u8_i16_2` after MatchBranchSimplification + + fn match_u8_i16_2(_1: EnumAu8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: i16; + let _3: (); + let mut _4: u8; + scope 1 { + debug r => _2; + } + + bb0: { + StorageLive(_2); + _2 = const 0_i16; + StorageLive(_3); + _4 = discriminant(_1); + switchInt(move _4) -> [1: bb3, 2: bb1, otherwise: bb2]; + } + + bb1: { + _2 = const 2_i16; + _3 = const (); + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + _3 = const (); + goto -> bb4; + } + + bb4: { + StorageDead(_3); + _0 = _2; + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..d4f7e596e3b90 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_u8_i16_failed` before MatchBranchSimplification ++ // MIR for `match_u8_i16_failed` after MatchBranchSimplification + + fn match_u8_i16_failed(_1: EnumAu8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 3_i16; + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const 1_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..8fa497fe89002 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff @@ -0,0 +1,31 @@ +- // MIR for `match_u8_i16_fallback` before MatchBranchSimplification ++ // MIR for `match_u8_i16_fallback` after MatchBranchSimplification + + fn match_u8_i16_fallback(_1: u8) -> i16 { + debug i => _1; + let mut _0: i16; + + bb0: { + switchInt(_1) -> [1: bb2, 2: bb3, otherwise: bb1]; + } + + bb1: { + _0 = const 3_i16; + goto -> bb4; + } + + bb2: { + _0 = const 1_i16; + goto -> bb4; + } + + bb3: { + _0 = const 2_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..6e14eb8c28110 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_u8_u16` before MatchBranchSimplification ++ // MIR for `match_u8_u16` after MatchBranchSimplification + + fn match_u8_u16(_1: EnumBu8) -> u16 { + debug i => _1; + let mut _0: u16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 5_u16; + goto -> bb5; + } + + bb2: { + unreachable; + } + + bb3: { + _0 = const 1_u16; + goto -> bb5; + } + + bb4: { + _0 = const 2_u16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 13db797341435..505b6a7f6dc5d 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -1,18 +1,24 @@ -// skip-filecheck // unit-test: MatchBranchSimplification +#![feature(repr128)] // EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff -// EMIT_MIR matches_reduce_branches.bar.MatchBranchSimplification.diff -// EMIT_MIR matches_reduce_branches.match_nested_if.MatchBranchSimplification.diff - fn foo(bar: Option<()>) { + // CHECK-LABEL: fn foo( + // CHECK: = Eq( + // CHECK: switchInt + // CHECK-NOT: switchInt if matches!(bar, None) { () } } +// EMIT_MIR matches_reduce_branches.bar.MatchBranchSimplification.diff fn bar(i: i32) -> (bool, bool, bool, bool) { + // CHECK-LABEL: fn bar( + // CHECK: = Ne( + // CHECK: = Eq( + // CHECK-NOT: switchInt let a; let b; let c; @@ -38,7 +44,10 @@ fn bar(i: i32) -> (bool, bool, bool, bool) { (a, b, c, d) } +// EMIT_MIR matches_reduce_branches.match_nested_if.MatchBranchSimplification.diff fn match_nested_if() -> bool { + // CHECK-LABEL: fn match_nested_if( + // CHECK-NOT: switchInt let val = match () { () if if if if true { true } else { false } { true } else { false } { true @@ -53,9 +62,153 @@ fn match_nested_if() -> bool { val } +#[repr(u8)] +enum EnumAu8 { + A = 1, + B = 2, +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff +fn match_u8_i16(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16( + // CHECK: switchInt + match i { + EnumAu8::A => 1, + EnumAu8::B => 2, + } +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff +// FIXME: prepare tests with different instruction lengths +fn match_u8_i16_2(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_2( + // CHECK: switchInt + let mut r = 0; + match i { + EnumAu8::A => {}, + EnumAu8::B => { r = 2; }, + } + r +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff +fn match_u8_i16_failed(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_failed( + // CHECK: switchInt + match i { + EnumAu8::A => 1, + EnumAu8::B => 3, + } +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff +fn match_u8_i16_fallback(i: u8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_fallback( + // CHECK: switchInt + match i { + 1 => 1, + 2 => 2, + _ => 3, + } +} + +#[repr(u8)] +enum EnumBu8 { + A = 1, + B = 2, + C = 5, +} + +// EMIT_MIR matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff +fn match_u8_u16(i: EnumBu8) -> u16 { + // CHECK-LABEL: fn match_u8_u16( + // CHECK: switchInt + match i { + EnumBu8::A => 1, + EnumBu8::B => 2, + EnumBu8::C => 5, + } +} + +#[repr(i8)] +enum EnumAi8 { + A = -1, + B = 2, + C = -3, +} + +// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff +fn match_i8_i16(i: EnumAi8) -> i16 { + // CHECK-LABEL: fn match_i8_i16( + // CHECK: switchInt + match i { + EnumAi8::A => -1, + EnumAi8::B => 2, + EnumAi8::C => -3, + } +} + +// EMIT_MIR matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff +fn match_i8_i16_failed(i: EnumAi8) -> i16 { + // CHECK-LABEL: fn match_i8_i16_failed( + // CHECK: switchInt + match i { + EnumAi8::A => -1, + EnumAi8::B => 2, + EnumAi8::C => 3, + } +} + +#[repr(i16)] +enum EnumAi16 { + A = -1, + B = 2, + C = -3, +} + +// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff +fn match_i16_i8(i: EnumAi16) -> i8 { + // CHECK-LABEL: fn match_i16_i8( + // CHECK: switchInt + match i { + EnumAi16::A => -1, + EnumAi16::B => 2, + EnumAi16::C => -3, + } +} + +#[repr(i128)] +enum EnumAi128 { + A = 1, + B = 2, + C = 3, + D = -1, +} + +// EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +fn match_i128_u128(i: EnumAi128) -> u128 { + // CHECK-LABEL: fn match_i128_u128( + // CHECK: switchInt + match i { + EnumAi128::A => 1, + EnumAi128::B => 2, + EnumAi128::C => 3, + EnumAi128::D => u128::MAX, + } +} + fn main() { let _ = foo(None); let _ = foo(Some(())); let _ = bar(0); let _ = match_nested_if(); + let _ = match_u8_i16(EnumAu8::A); + let _ = match_u8_i16_failed(EnumAu8::A); + let _ = match_u8_i16_fallback(1); + let _ = match_u8_u16(EnumBu8::A); + let _ = match_i8_i16(EnumAi8::A); + let _ = match_i8_i16_failed(EnumAi8::A); + let _ = match_i8_i16(EnumAi8::A); + let _ = match_i16_i8(EnumAi16::A); + let _ = match_i128_u128(EnumAi128::A); } From beadbcb10b9d195d2e78d540f58706dfb5c2a715 Mon Sep 17 00:00:00 2001 From: DianQK Date: Sat, 3 Feb 2024 22:44:11 +0800 Subject: [PATCH 2/4] Refactor `MatchBranchSimplification` --- .../rustc_mir_transform/src/match_branches.rs | 338 +++++++++++------- 1 file changed, 205 insertions(+), 133 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 6d4332793af31..be1158683ac42 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,11 +1,116 @@ +use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; pub struct MatchBranchSimplification; +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let bbs = body.basic_blocks.as_mut(); + let mut should_cleanup = false; + for bb_idx in bbs.indices() { + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { + continue; + } + + match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), + ref targets, + .. + // We require that the possible target blocks don't contain this block. + } if !targets.all_targets().contains(&bb_idx) => {} + // Only optimize switch int statements + _ => continue, + }; + + if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { + should_cleanup = true; + continue; + } + } + + if should_cleanup { + simplify_cfg(body); + } + } +} + +trait SimplifyMatch<'tcx> { + fn simplify( + &self, + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + bbs: &mut IndexVec>, + switch_bb_idx: BasicBlock, + param_env: ParamEnv<'tcx>, + ) -> bool { + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), + _ => unreachable!(), + }; + + if !self.can_simplify(tcx, targets, param_env, bbs) { + return false; + } + + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + let discr_ty = discr.ty(local_decls, tcx); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[switch_bb_idx].terminator().source_info; + let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); + + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); + let (_, first) = targets.iter().next().unwrap(); + let (from, first) = bbs.pick2_mut(switch_bb_idx, first); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); + from.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), + }); + from.statements.extend(new_stmts); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); + from.terminator_mut().kind = first.terminator().kind.clone(); + true + } + + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool; + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec>; +} + +struct SimplifyToIf; + /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false /// to the same place), merge a target block statements into the source block, @@ -37,144 +142,111 @@ pub struct MatchBranchSimplification; /// goto -> bb3; /// } /// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool { + if targets.iter().len() != 1 { + return false; + } + // We require that the possible target blocks all be distinct. + let (_, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + if first == second { + return false; + } + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + return false; + } -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let def_id = body.source.def_id(); - let param_env = tcx.param_env_reveal_all_normalized(def_id); - - let bbs = body.basic_blocks.as_mut(); - let mut should_cleanup = false; - 'outer: for bb_idx in bbs.indices() { - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { - continue; - } - - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { - TerminatorKind::SwitchInt { - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), - ref targets, - .. - } if targets.iter().len() == 1 => { - let (value, target) = targets.iter().next().unwrap(); - // We require that this block and the two possible target blocks all be - // distinct. - if target == targets.otherwise() - || bb_idx == target - || bb_idx == targets.otherwise() - { - continue; - } - (discr, value, target, targets.otherwise()) - } - // Only optimize switch int statements - _ => continue, - }; - - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { - continue; + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let second_stmts = &bbs[second].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + for (f, s) in iter::zip(first_stmts, second_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} + + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty().is_bool() + && s_c.const_.ty().is_bool() + && f_c.const_.try_eval_bool(tcx, param_env).is_some() + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} + + // Otherwise we cannot optimize. Try another block. + _ => return false, } + } + true + } - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let scnd_stmts = &bbs[second].statements; - if first_stmts.len() != scnd_stmts.len() { - continue; - } - for (f, s) in iter::zip(first_stmts, scnd_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} - - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, param_env).is_some() - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} - - // Otherwise we cannot optimize. Try another block. - _ => continue 'outer, - } - } - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); - let discr_ty = discr.ty(&body.local_decls, tcx); - - // Introduce a temporary for the discriminant value. - let source_info = bbs[bb_idx].terminator().source_info; - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); - - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => (*f).clone(), - - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - (*f).clone() - } else { - // Different value between blocks. Make value conditional on switch condition. - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - Statement { - source_info: f.source_info, - kind: StatementKind::Assign(Box::new((*lhs, rhs))), - } + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec> { + let (val, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let first = &bbs[first]; + let second = &bbs[second]; + + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => (*f).clone(), + + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + (*f).clone() + } else { + // Different value between blocks. Make value conditional on switch condition. + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + Statement { + source_info: f.source_info, + kind: StatementKind::Assign(Box::new((*lhs, rhs))), } } - - _ => unreachable!(), } - }); - - from.statements - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); - from.statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(discr_local), - Rvalue::Use(discr), - ))), - }); - from.statements.extend(new_stmts); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); - from.terminator_mut().kind = first.terminator().kind.clone(); - should_cleanup = true; - } - if should_cleanup { - simplify_cfg(body); - } + _ => unreachable!(), + } + }); + new_stmts.collect() } } From eccc782e8a7aab92cc44b3751581fbaf9514b77f Mon Sep 17 00:00:00 2001 From: DianQK Date: Sat, 3 Feb 2024 22:44:20 +0800 Subject: [PATCH 3/4] Transforms match into an assignment statement --- compiler/rustc_middle/src/mir/terminator.rs | 6 + .../rustc_mir_transform/src/match_branches.rs | 222 +++++++++++++++++- tests/codegen/match-optimized.rs | 4 +- ...h_i128_u128.MatchBranchSimplification.diff | 61 ++--- ...atch_u8_i16.MatchBranchSimplification.diff | 41 ++-- ...atch_u8_u16.MatchBranchSimplification.diff | 51 ++-- tests/mir-opt/matches_reduce_branches.rs | 12 +- ...stive_match.MatchBranchSimplification.diff | 41 ++-- ...ve_match_i8.MatchBranchSimplification.diff | 41 ++-- 9 files changed, 364 insertions(+), 115 deletions(-) diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 0fe33e441f430..6acee5d76fe1c 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -74,6 +74,12 @@ impl SwitchTargets { pub fn target_for_value(&self, value: u128) -> BasicBlock { self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise()) } + + /// Returns true if all targets (including the fallback target) are distinct. + #[inline] + pub fn is_distinct(&self) -> bool { + self.targets.iter().collect::>().len() == self.targets.len() + } } pub struct SwitchTargetsIter<'a> { diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index be1158683ac42..ecf7d6ffd3ae4 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,6 +1,6 @@ use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; +use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; @@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { should_cleanup = true; continue; } + if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) + { + should_cleanup = true; + continue; + } } if should_cleanup { @@ -48,7 +53,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { trait SimplifyMatch<'tcx> { fn simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, local_decls: &mut IndexVec>, bbs: &mut IndexVec>, @@ -72,7 +77,7 @@ trait SimplifyMatch<'tcx> { let source_info = bbs[switch_bb_idx].terminator().source_info; let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - // We already checked that first and second are different blocks, + // We already checked that targets are different blocks, // and bb_idx has a different terminator from both of them. let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); let (_, first) = targets.iter().next().unwrap(); @@ -91,7 +96,7 @@ trait SimplifyMatch<'tcx> { } fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -144,7 +149,7 @@ struct SimplifyToIf; /// ``` impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -250,3 +255,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { new_stmts.collect() } } + +#[derive(Default)] +struct SimplifyToExp { + transfrom_types: Vec, +} + +#[derive(Clone, Copy)] +enum CompareType<'tcx, 'a> { + Same(&'a StatementKind<'tcx>), + Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), + Discr(&'a Place<'tcx>, Ty<'tcx>), +} + +enum TransfromType { + Same, + Eq, + Discr, +} + +impl From> for TransfromType { + fn from(compare_type: CompareType<'_, '_>) -> Self { + match compare_type { + CompareType::Same(_) => TransfromType::Same, + CompareType::Eq(_, _, _) => TransfromType::Eq, + CompareType::Discr(_, _) => TransfromType::Discr, + } + } +} + +/// If we find that the value of match is the same as the assignment, +/// merge a target block statements into the source block, +/// using cast to transform different integer types. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; +/// } +/// +/// bb1: { +/// unreachable; +/// } +/// +/// bb2: { +/// _0 = const 1_i16; +/// goto -> bb5; +/// } +/// +/// bb3: { +/// _0 = const 2_i16; +/// goto -> bb5; +/// } +/// +/// bb4: { +/// _0 = const 3_i16; +/// goto -> bb5; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _0 = _3 as i16 (IntToInt); +/// goto -> bb5; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool { + if targets.iter().len() < 2 || targets.iter().len() > 64 { + return false; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return false; + } + if !bbs[targets.otherwise()].is_empty_unreachable() { + return false; + } + let mut iter = targets.iter(); + let (first_val, first_target) = iter.next().unwrap(); + let first_terminator_kind = &bbs[first_target].terminator().kind; + // Check that destinations are identical, and if not, then don't optimize this block + if !targets + .iter() + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) + { + return false; + } + + let first_stmts = &bbs[first_target].statements; + let (second_val, second_target) = iter.next().unwrap(); + let second_stmts = &bbs[second_target].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + + let mut compare_types = Vec::new(); + for (f, s) in iter::zip(first_stmts, second_stmts) { + let compare_type = match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => CompareType::Same(f_s), + + // If two statements are assignments with the match values to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty() == s_c.const_.ty() + && f_c.const_.ty().is_integral() => + { + match ( + f_c.const_.try_eval_scalar_int(tcx, param_env), + s_c.const_.try_eval_scalar_int(tcx, param_env), + ) { + (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), + (Some(f), Some(s)) + if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + { + CompareType::Discr(lhs_f, f_c.const_.ty()) + } + _ => return false, + } + } + + // Otherwise we cannot optimize. Try another block. + _ => return false, + }; + compare_types.push(compare_type); + } + + for (other_val, other_target) in iter { + let other_stmts = &bbs[other_target].statements; + if compare_types.len() != other_stmts.len() { + return false; + } + for (f, s) in iter::zip(&compare_types, other_stmts) { + match (*f, &s.kind) { + (CompareType::Same(f_s), s_s) if f_s == s_s => {} + ( + CompareType::Eq(lhs_f, f_ty, val), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} + ( + CompareType::Discr(lhs_f, f_ty), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { + let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { + return false; + }; + if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { + return false; + } + } + _ => return false, + } + } + } + self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect(); + true + } + + fn new_stmts( + &self, + _tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + _param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec> { + let (_, first) = targets.iter().next().unwrap(); + let first = &bbs[first]; + + let new_stmts = + iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) { + (TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(), + ( + TransfromType::Discr, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + ) => { + let operand = Operand::Copy(Place::from(discr_local)); + let r_val = if f_c.const_.ty() == discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) + }; + Statement { + source_info: s.source_info, + kind: StatementKind::Assign(Box::new((*lhs, r_val))), + } + } + _ => unreachable!(), + }); + new_stmts.collect() + } +} diff --git a/tests/codegen/match-optimized.rs b/tests/codegen/match-optimized.rs index e32a5e5450427..51db4e825d212 100644 --- a/tests/codegen/match-optimized.rs +++ b/tests/codegen/match-optimized.rs @@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 { // CHECK-NEXT: store i8 1, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] // CHECK: [[C]]: -// CHECK-NEXT: store i8 2, ptr %_0, align 1 +// CHECK-NEXT: store i8 3, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] match e { E::A => 0, E::B => 1, - E::C => 2, + E::C => 3, } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff index 9fb74eb670fda..6b2af6f0188e9 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -5,37 +5,42 @@ debug i => _1; let mut _0: u128; let mut _2: i128; ++ let mut _3: i128; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const _; - goto -> bb6; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const 1_u128; - goto -> bb6; - } - - bb4: { - _0 = const 2_u128; - goto -> bb6; - } - - bb5: { - _0 = const 3_u128; - goto -> bb6; - } - - bb6: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const _; +- goto -> bb6; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const 1_u128; +- goto -> bb6; +- } +- +- bb4: { +- _0 = const 2_u128; +- goto -> bb6; +- } +- +- bb5: { +- _0 = const 3_u128; +- goto -> bb6; +- } +- +- bb6: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u128 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff index e0328494be220..f96ec0ed2beb0 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug i => _1; let mut _0: i16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const 2_i16; - goto -> bb4; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const 1_i16; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [1: bb3, 2: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const 2_i16; +- goto -> bb4; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const 1_i16; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff index 6e14eb8c28110..45a3460a6b021 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: u16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const 5_u16; - goto -> bb5; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const 1_u16; - goto -> bb5; - } - - bb4: { - _0 = const 2_u16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const 5_u16; +- goto -> bb5; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const 1_u16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_u16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 505b6a7f6dc5d..d2bc565eb36fd 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -71,7 +71,9 @@ enum EnumAu8 { // EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff fn match_u8_i16(i: EnumAu8) -> i16 { // CHECK-LABEL: fn match_u8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAu8::A => 1, EnumAu8::B => 2, @@ -122,7 +124,9 @@ enum EnumBu8 { // EMIT_MIR matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff fn match_u8_u16(i: EnumBu8) -> u16 { // CHECK-LABEL: fn match_u8_u16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u16 (IntToInt); + // CHECH: return match i { EnumBu8::A => 1, EnumBu8::B => 2, @@ -188,7 +192,9 @@ enum EnumAi128 { // EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff fn match_i128_u128(i: EnumAi128) -> u128 { // CHECK-LABEL: fn match_i128_u128( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u128 (IntToInt); + // CHECH: return match i { EnumAi128::A => 1, EnumAi128::B => 2, diff --git a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff index fec5855636605..5adfcf8207617 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: u8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const 1_u8; - goto -> bb4; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const 0_u8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const 1_u8; +- goto -> bb4; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const 0_u8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff index 94d3ce6c97158..71d92e28f17d9 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: i8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const 1_i8; - goto -> bb4; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const 0_i8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const 1_i8; +- goto -> bb4; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const 0_i8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } } From 7a476356b92be052317d920dfa6302653cd72f8d Mon Sep 17 00:00:00 2001 From: DianQK Date: Sat, 3 Feb 2024 22:44:23 +0800 Subject: [PATCH 4/4] Transforms a match containing negative numbers into an assignment statement as well --- .../rustc_mir_transform/src/match_branches.rs | 55 +++++++++++++++---- ...atch_i16_i8.MatchBranchSimplification.diff | 51 +++++++++-------- ...atch_i8_i16.MatchBranchSimplification.diff | 51 +++++++++-------- tests/mir-opt/matches_reduce_branches.rs | 8 ++- 4 files changed, 106 insertions(+), 59 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index ecf7d6ffd3ae4..22080d67ff3d6 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> { _ => unreachable!(), }; - if !self.can_simplify(tcx, targets, param_env, bbs) { + let discr_ty = discr.ty(local_decls, tcx); + if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) { return false; } // Take ownership of items now that we know we can optimize. let discr = discr.clone(); - let discr_ty = discr.ty(local_decls, tcx); // Introduce a temporary for the discriminant value. let source_info = bbs[switch_bb_idx].terminator().source_info; @@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + discr_ty: Ty<'tcx>, ) -> bool; fn new_stmts( @@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + _discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() != 1 { return false; @@ -265,7 +267,7 @@ struct SimplifyToExp { enum CompareType<'tcx, 'a> { Same(&'a StatementKind<'tcx>), Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), - Discr(&'a Place<'tcx>, Ty<'tcx>), + Discr(&'a Place<'tcx>, Ty<'tcx>, bool), } enum TransfromType { @@ -279,7 +281,7 @@ impl From> for TransfromType { match compare_type { CompareType::Same(_) => TransfromType::Same, CompareType::Eq(_, _, _) => TransfromType::Eq, - CompareType::Discr(_, _) => TransfromType::Discr, + CompareType::Discr(_, _, _) => TransfromType::Discr, } } } @@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() < 2 || targets.iter().len() > 64 { return false; @@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { return false; } + let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; let first_stmts = &bbs[first_target].statements; let (second_val, second_target) = iter.next().unwrap(); let second_stmts = &bbs[second_target].statements; @@ -379,12 +383,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { ) { (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), (Some(f), Some(s)) - if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) - && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + if ((f_c.const_.ty().is_signed() || discr_ty.is_signed()) + && f.try_to_int(f.size()).unwrap() + == ScalarInt::try_from_uint(first_val, discr_size) + .unwrap() + .try_to_int(discr_size) + .unwrap() + && s.try_to_int(s.size()).unwrap() + == ScalarInt::try_from_uint(second_val, discr_size) + .unwrap() + .try_to_int(discr_size) + .unwrap()) + || (Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) + == ScalarInt::try_from_uint(second_val, s.size())) => { - CompareType::Discr(lhs_f, f_c.const_.ty()) + CompareType::Discr( + lhs_f, + f_c.const_.ty(), + f_c.const_.ty().is_signed() || discr_ty.is_signed(), + ) + } + _ => { + return false; } - _ => return false, } } @@ -409,15 +431,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { && s_c.const_.ty() == f_ty && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} ( - CompareType::Discr(lhs_f, f_ty), + CompareType::Discr(lhs_f, f_ty, is_signed), StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { return false; }; - if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { - return false; + if is_signed + && s_c.const_.ty().is_signed() + && f.try_to_int(f.size()).unwrap() + == ScalarInt::try_from_uint(other_val, discr_size) + .unwrap() + .try_to_int(discr_size) + .unwrap() + { + continue; + } + if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) { + continue; } + return false; } _ => return false, } diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff index 4d069c1236b61..9c616fa6349f3 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i8; let mut _2: i16; ++ let mut _3: i16; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const -3_i8; - goto -> bb5; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const -1_i8; - goto -> bb5; - } - - bb4: { - _0 = const 2_i8; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const -3_i8; +- goto -> bb5; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const -1_i8; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i8; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff index d934be4adc29d..404dc75a207e6 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i16; let mut _2: i8; ++ let mut _3: i8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2]; - } - - bb1: { - _0 = const -3_i16; - goto -> bb5; - } - - bb2: { - unreachable; - } - - bb3: { - _0 = const -1_i16; - goto -> bb5; - } - - bb4: { - _0 = const 2_i16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2]; +- } +- +- bb1: { +- _0 = const -3_i16; +- goto -> bb5; +- } +- +- bb2: { +- unreachable; +- } +- +- bb3: { +- _0 = const -1_i16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index d2bc565eb36fd..f5b6e22999e20 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -144,7 +144,9 @@ enum EnumAi8 { // EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff fn match_i8_i16(i: EnumAi8) -> i16 { // CHECK-LABEL: fn match_i8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAi8::A => -1, EnumAi8::B => 2, @@ -173,7 +175,9 @@ enum EnumAi16 { // EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff fn match_i16_i8(i: EnumAi16) -> i8 { // CHECK-LABEL: fn match_i16_i8( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i8 (IntToInt); + // CHECH: return match i { EnumAi16::A => -1, EnumAi16::B => 2,