Skip to content

Commit

Permalink
cmov: implement constant-time equality comparisons (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
brxken128 authored Apr 1, 2023
1 parent 83c7093 commit 734c63b
Show file tree
Hide file tree
Showing 5 changed files with 475 additions and 116 deletions.
114 changes: 68 additions & 46 deletions cmov/src/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,90 +1,112 @@
use crate::{Cmov, Condition};
use crate::{Cmov, CmovEq, Condition};
use core::arch::asm;

macro_rules! csel {
($cmp:expr, $csel:expr, $dst:expr, $src:expr, $condition:expr) => {
($csel:expr, $dst:expr, $src:expr, $condition:expr) => {
unsafe {
asm! {
$cmp,
"cmp {0:w}, 0",
$csel,
in(reg) $condition,
inlateout(reg) *$dst,
in(reg) $src,
in(reg) *$src,
in(reg) *$dst,
options(pure, nomem, nostack),
};
}
};
}

macro_rules! csel_eq {
($instruction:expr, $lhs:expr, $rhs:expr, $condition:expr, $dst:expr) => {
let mut tmp = *$dst as u16;
unsafe {
asm! {
"eor {0:w}, {1:w}, {2:w}",
"cmp {0:w}, 0",
$instruction,
out(reg) _,
in(reg) *$lhs,
in(reg) *$rhs,
inlateout(reg) tmp,
in(reg) $condition as u16,
in(reg) tmp,
options(pure, nomem, nostack),
};
};

*$dst = tmp as u8;
};
}

impl Cmov for u16 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:w}, 0",
"csel {1:w}, {2:w}, {3:w}, NE",
self,
*value,
condition
);
csel!("csel {1:w}, {2:w}, {3:w}, NE", self, value, condition);
}

#[inline(always)]
fn cmovz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:w}, 0",
"csel {1:w}, {2:w}, {3:w}, EQ",
self,
*value,
condition
);
csel!("csel {1:w}, {2:w}, {3:w}, EQ", self, value, condition);
}
}

impl CmovEq for u16 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, NE", self, rhs, input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, EQ", self, rhs, input, output);
}
}

impl Cmov for u32 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:w}, 0",
"csel {1:w}, {2:w}, {3:w}, NE",
self,
*value,
condition
);
csel!("csel {1:w}, {2:w}, {3:w}, NE", self, value, condition);
}

#[inline(always)]
fn cmovz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:w}, 0",
"csel {1:w}, {2:w}, {3:w}, EQ",
self,
*value,
condition
);
csel!("csel {1:w}, {2:w}, {3:w}, EQ", self, value, condition);
}
}

impl CmovEq for u32 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, NE", self, rhs, input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, EQ", self, rhs, input, output);
}
}

impl Cmov for u64 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:x}, 0",
"csel {1:x}, {2:x}, {3:x}, NE",
self,
*value,
condition
);
csel!("csel {1:x}, {2:x}, {3:x}, NE", self, value, condition);
}

#[inline(always)]
fn cmovz(&mut self, value: &Self, condition: Condition) {
csel!(
"cmp {0:x}, 0",
"csel {1:x}, {2:x}, {3:x}, EQ",
self,
*value,
condition
);
csel!("csel {1:x}, {2:x}, {3:x}, EQ", self, value, condition);
}
}

impl CmovEq for u64 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, NE", self, rhs, input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
csel_eq!("csel {3:w}, {4:w}, {5:w}, EQ", self, rhs, input, output);
}
}
56 changes: 55 additions & 1 deletion cmov/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ mod x86;
pub type Condition = u8;

/// Conditional move
// TODO(tarcieri): make one of `cmovz`/`cmovnz` a provided method which calls the other?
pub trait Cmov {
/// Move if non-zero.
///
Expand All @@ -41,6 +40,25 @@ pub trait Cmov {
}
}

/// Conditional move with equality comparison
pub trait CmovEq {
/// Move if both inputs are equal.
///
/// Uses a `xor` instruction to compare the two values, and
/// conditionally moves `input` to `output` when they are equal.
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition);

/// Move if both inputs are not equal.
///
/// Uses a `xor` instruction to compare the two values, and
/// conditionally moves `input` to `output` when they are not equal.
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let mut tmp = 1u8;
self.cmoveq(rhs, 0u8, &mut tmp);
tmp.cmoveq(&1u8, input, output);
}
}

impl Cmov for u8 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
Expand All @@ -57,6 +75,18 @@ impl Cmov for u8 {
}
}

impl CmovEq for u8 {
#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u16).cmoveq(&(*rhs as u16), input, output);
}

#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u16).cmovne(&(*rhs as u16), input, output);
}
}

impl Cmov for u128 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
Expand All @@ -80,3 +110,27 @@ impl Cmov for u128 {
*self = (lo as u128) | (hi as u128) << 64;
}
}

impl CmovEq for u128 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let lo = (*self & u64::MAX as u128) as u64;
let hi = (*self >> 64) as u64;

let mut tmp = 1u8;
lo.cmovne(&((*rhs & u64::MAX as u128) as u64), 0, &mut tmp);
hi.cmovne(&((*rhs >> 64) as u64), 0, &mut tmp);
tmp.cmoveq(&0, input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
let lo = (*self & u64::MAX as u128) as u64;
let hi = (*self >> 64) as u64;

let mut tmp = 1u8;
lo.cmovne(&((*rhs & u64::MAX as u128) as u64), 0, &mut tmp);
hi.cmovne(&((*rhs >> 64) as u64), 0, &mut tmp);
tmp.cmoveq(&1, input, output);
}
}
38 changes: 37 additions & 1 deletion cmov/src/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// TODO(tarcieri): more optimized implementation for small integers

use crate::{Cmov, Condition};
use crate::{Cmov, CmovEq, Condition};

impl Cmov for u16 {
#[inline(always)]
Expand All @@ -24,6 +24,18 @@ impl Cmov for u16 {
}
}

impl CmovEq for u16 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u64).cmovne(&(*rhs as u64), input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u64).cmoveq(&(*rhs as u64), input, output);
}
}

impl Cmov for u32 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
Expand All @@ -40,6 +52,18 @@ impl Cmov for u32 {
}
}

impl CmovEq for u32 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u64).cmovne(&(*rhs as u64), input, output);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
(*self as u64).cmoveq(&(*rhs as u64), input, output);
}
}

impl Cmov for u64 {
#[inline(always)]
fn cmovnz(&mut self, value: &Self, condition: Condition) {
Expand All @@ -54,6 +78,18 @@ impl Cmov for u64 {
}
}

impl CmovEq for u64 {
#[inline(always)]
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
output.cmovnz(&input, (self ^ rhs) as u8);
}

#[inline(always)]
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
output.cmovz(&input, (self ^ rhs) as u8);
}
}

/// Check if the given condition value is non-zero
///
/// # Returns
Expand Down
Loading

0 comments on commit 734c63b

Please sign in to comment.