Skip to content

Commit

Permalink
chacha20: Remove mutable borrows from AVX2 backend (#268)
Browse files Browse the repository at this point in the history
The use of `&mut StateWord` everywhere caused a `vmovdqa` to be inserted
after almost every operation, and also caused the diagonalization to use
`vpermilps` instead of seeing the optimisation to `vpshufd`.

The new `State` struct helps to manage the passing-around of owned
`StateWord`s.
  • Loading branch information
str4d authored Aug 29, 2021
1 parent 818c4ac commit 7c86911
Showing 1 changed file with 146 additions and 118 deletions.
264 changes: 146 additions & 118 deletions chacha20/src/backend/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,90 +35,115 @@ union StateWord {

impl StateWord {
#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn add_assign_epi32(&mut self, rhs: &Self) {
self.avx = [
_mm256_add_epi32(self.avx[0], rhs.avx[0]),
_mm256_add_epi32(self.avx[1], rhs.avx[1]),
];
unsafe fn add_epi32(self, rhs: Self) -> Self {
StateWord {
avx: [
_mm256_add_epi32(self.avx[0], rhs.avx[0]),
_mm256_add_epi32(self.avx[1], rhs.avx[1]),
],
}
}

#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn xor_assign(&mut self, rhs: &Self) {
self.avx = [
_mm256_xor_si256(self.avx[0], rhs.avx[0]),
_mm256_xor_si256(self.avx[1], rhs.avx[1]),
];
unsafe fn xor(self, rhs: Self) -> Self {
StateWord {
avx: [
_mm256_xor_si256(self.avx[0], rhs.avx[0]),
_mm256_xor_si256(self.avx[1], rhs.avx[1]),
],
}
}

#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn shuffle_epi32<const MASK: i32>(&mut self) {
self.avx = [
_mm256_shuffle_epi32(self.avx[0], MASK),
_mm256_shuffle_epi32(self.avx[1], MASK),
];
unsafe fn shuffle_epi32<const MASK: i32>(self) -> Self {
StateWord {
avx: [
_mm256_shuffle_epi32(self.avx[0], MASK),
_mm256_shuffle_epi32(self.avx[1], MASK),
],
}
}

#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn rol<const BY: i32, const REST: i32>(&mut self) {
self.avx = [
_mm256_xor_si256(
_mm256_slli_epi32(self.avx[0], BY),
_mm256_srli_epi32(self.avx[0], REST),
),
_mm256_xor_si256(
_mm256_slli_epi32(self.avx[1], BY),
_mm256_srli_epi32(self.avx[1], REST),
),
];
unsafe fn rol<const BY: i32, const REST: i32>(self) -> Self {
StateWord {
avx: [
_mm256_xor_si256(
_mm256_slli_epi32(self.avx[0], BY),
_mm256_srli_epi32(self.avx[0], REST),
),
_mm256_xor_si256(
_mm256_slli_epi32(self.avx[1], BY),
_mm256_srli_epi32(self.avx[1], REST),
),
],
}
}

#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn rol_8(&mut self) {
self.avx = [
_mm256_shuffle_epi8(
self.avx[0],
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8,
11, 6, 5, 4, 7, 2, 1, 0, 3,
unsafe fn rol_8(self) -> Self {
StateWord {
avx: [
_mm256_shuffle_epi8(
self.avx[0],
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10,
9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3,
),
),
),
_mm256_shuffle_epi8(
self.avx[1],
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8,
11, 6, 5, 4, 7, 2, 1, 0, 3,
_mm256_shuffle_epi8(
self.avx[1],
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10,
9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3,
),
),
),
];
],
}
}

#[inline]
#[must_use]
#[target_feature(enable = "avx2")]
unsafe fn rol_16(&mut self) {
self.avx = [
_mm256_shuffle_epi8(
self.avx[0],
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11,
10, 5, 4, 7, 6, 1, 0, 3, 2,
unsafe fn rol_16(self) -> Self {
StateWord {
avx: [
_mm256_shuffle_epi8(
self.avx[0],
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8,
11, 10, 5, 4, 7, 6, 1, 0, 3, 2,
),
),
),
_mm256_shuffle_epi8(
self.avx[1],
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11,
10, 5, 4, 7, 6, 1, 0, 3, 2,
_mm256_shuffle_epi8(
self.avx[1],
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8,
11, 10, 5, 4, 7, 6, 1, 0, 3, 2,
),
),
),
];
],
}
}
}

struct State {
a: StateWord,
b: StateWord,
c: StateWord,
d: StateWord,
}

/// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64)
// TODO(tarcieri): zeroize?
#[derive(Clone)]
Expand Down Expand Up @@ -152,10 +177,14 @@ impl<R: Rounds> Core<R> {
#[inline]
pub fn generate(&self, counter: u64, output: &mut [u8]) {
unsafe {
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
let mut v3 = iv_setup(self.iv, counter);
self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
store(v0, v1, v2, v3, output);
let state = State {
a: self.v0,
b: self.v1,
c: self.v2,
d: iv_setup(self.iv, counter),
};
let state = self.rounds(state);
store(state.a, state.b, state.c, state.d, output);
}
}

Expand All @@ -166,14 +195,22 @@ impl<R: Rounds> Core<R> {
debug_assert_eq!(output.len(), BUFFER_SIZE);

unsafe {
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
let mut v3 = iv_setup(self.iv, counter);
self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
let state = State {
a: self.v0,
b: self.v1,
c: self.v2,
d: iv_setup(self.iv, counter),
};
let state = self.rounds(state);

for i in 0..BLOCKS {
for (chunk, a) in output[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE]
.chunks_mut(0x10)
.zip([v0, v1, v2, v3].iter().map(|s| s.blocks[i]))
.zip(
[state.a, state.b, state.c, state.d]
.iter()
.map(|s| s.blocks[i]),
)
{
let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
let out = _mm_xor_si128(a, b);
Expand All @@ -185,23 +222,19 @@ impl<R: Rounds> Core<R> {

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rounds(
&self,
v0: &mut StateWord,
v1: &mut StateWord,
v2: &mut StateWord,
v3: &mut StateWord,
) {
let v3_orig = *v3;
unsafe fn rounds(&self, mut state: State) -> State {
let d_orig = state.d;

for _ in 0..(R::COUNT / 2) {
double_quarter_round(v0, v1, v2, v3);
state = double_quarter_round(state);
}

v0.add_assign_epi32(&self.v0);
v1.add_assign_epi32(&self.v1);
v2.add_assign_epi32(&self.v2);
v3.add_assign_epi32(&v3_orig);
State {
a: state.a.add_epi32(self.v0),
b: state.b.add_epi32(self.v1),
c: state.c.add_epi32(self.v2),
d: state.d.add_epi32(d_orig),
}
}
}

Expand Down Expand Up @@ -264,16 +297,9 @@ unsafe fn store(v0: StateWord, v1: StateWord, v2: StateWord, v3: StateWord, outp

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn double_quarter_round(
a: &mut StateWord,
b: &mut StateWord,
c: &mut StateWord,
d: &mut StateWord,
) {
add_xor_rot(a, b, c, d);
rows_to_cols(a, b, c, d);
add_xor_rot(a, b, c, d);
cols_to_rows(a, b, c, d);
unsafe fn double_quarter_round(state: State) -> State {
let state = add_xor_rot(state);
cols_to_rows(add_xor_rot(rows_to_cols(state)))
}

/// The goal of this function is to transform the state words from:
Expand Down Expand Up @@ -313,16 +339,18 @@ unsafe fn double_quarter_round(
/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rows_to_cols(
a: &mut StateWord,
_b: &mut StateWord,
c: &mut StateWord,
d: &mut StateWord,
) {
unsafe fn rows_to_cols(state: State) -> State {
// c = ROR256_B(c); d = ROR256_C(d); a = ROR256_D(a);
c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
let c = state.c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
let a = state.a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)

State {
a,
b: state.b,
c,
d,
}
}

/// The goal of this function is to transform the state words from:
Expand All @@ -344,38 +372,38 @@ unsafe fn rows_to_cols(
/// reversing the transformation of [`rows_to_cols`].
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn cols_to_rows(
a: &mut StateWord,
_b: &mut StateWord,
c: &mut StateWord,
d: &mut StateWord,
) {
unsafe fn cols_to_rows(state: State) -> State {
// c = ROR256_D(c); d = ROR256_C(d); a = ROR256_B(a);
c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
let c = state.c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
let a = state.a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)

State {
a,
b: state.b,
c,
d,
}
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn add_xor_rot(a: &mut StateWord, b: &mut StateWord, c: &mut StateWord, d: &mut StateWord) {
unsafe fn add_xor_rot(state: State) -> State {
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d);
a.add_assign_epi32(b);
d.xor_assign(a);
d.rol_16();
let a = state.a.add_epi32(state.b);
let d = state.d.xor(a).rol_16();

// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b);
c.add_assign_epi32(d);
b.xor_assign(c);
b.rol::<12, 20>();
let c = state.c.add_epi32(d);
let b = state.b.xor(c).rol::<12, 20>();

// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d);
a.add_assign_epi32(b);
d.xor_assign(a);
d.rol_8();
let a = a.add_epi32(b);
let d = d.xor(a).rol_8();

// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b);
c.add_assign_epi32(d);
b.xor_assign(c);
b.rol::<7, 25>();
let c = c.add_epi32(d);
let b = b.xor(c).rol::<7, 25>();

State { a, b, c, d }
}

0 comments on commit 7c86911

Please sign in to comment.