From 7c86911857e2568a1cada95ace5e300c89fae129 Mon Sep 17 00:00:00 2001 From: str4d Date: Sun, 29 Aug 2021 04:13:58 +0100 Subject: [PATCH] chacha20: Remove mutable borrows from AVX2 backend (#268) The use of `&mut StateWord` everywhere caused a `vmovdqa` to be inserted after almost every operation, and also caused the diagonalization to use `vpermilps` instead of seeing the optimisation to `vpshufd`. The new `State` struct helps to manage the passing-around of owned `StateWord`s. --- chacha20/src/backend/avx2.rs | 264 +++++++++++++++++++---------------- 1 file changed, 146 insertions(+), 118 deletions(-) diff --git a/chacha20/src/backend/avx2.rs b/chacha20/src/backend/avx2.rs index 94b6908f..d31daad2 100644 --- a/chacha20/src/backend/avx2.rs +++ b/chacha20/src/backend/avx2.rs @@ -35,90 +35,115 @@ union StateWord { impl StateWord { #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn add_assign_epi32(&mut self, rhs: &Self) { - self.avx = [ - _mm256_add_epi32(self.avx[0], rhs.avx[0]), - _mm256_add_epi32(self.avx[1], rhs.avx[1]), - ]; + unsafe fn add_epi32(self, rhs: Self) -> Self { + StateWord { + avx: [ + _mm256_add_epi32(self.avx[0], rhs.avx[0]), + _mm256_add_epi32(self.avx[1], rhs.avx[1]), + ], + } } #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn xor_assign(&mut self, rhs: &Self) { - self.avx = [ - _mm256_xor_si256(self.avx[0], rhs.avx[0]), - _mm256_xor_si256(self.avx[1], rhs.avx[1]), - ]; + unsafe fn xor(self, rhs: Self) -> Self { + StateWord { + avx: [ + _mm256_xor_si256(self.avx[0], rhs.avx[0]), + _mm256_xor_si256(self.avx[1], rhs.avx[1]), + ], + } } #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn shuffle_epi32(&mut self) { - self.avx = [ - _mm256_shuffle_epi32(self.avx[0], MASK), - _mm256_shuffle_epi32(self.avx[1], MASK), - ]; + unsafe fn shuffle_epi32(self) -> Self { + StateWord { + avx: [ + _mm256_shuffle_epi32(self.avx[0], MASK), + _mm256_shuffle_epi32(self.avx[1], MASK), + ], + } } #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn rol(&mut self) { - self.avx = [ - _mm256_xor_si256( - _mm256_slli_epi32(self.avx[0], BY), - _mm256_srli_epi32(self.avx[0], REST), - ), - _mm256_xor_si256( - _mm256_slli_epi32(self.avx[1], BY), - _mm256_srli_epi32(self.avx[1], REST), - ), - ]; + unsafe fn rol(self) -> Self { + StateWord { + avx: [ + _mm256_xor_si256( + _mm256_slli_epi32(self.avx[0], BY), + _mm256_srli_epi32(self.avx[0], REST), + ), + _mm256_xor_si256( + _mm256_slli_epi32(self.avx[1], BY), + _mm256_srli_epi32(self.avx[1], REST), + ), + ], + } } #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn rol_8(&mut self) { - self.avx = [ - _mm256_shuffle_epi8( - self.avx[0], - _mm256_set_epi8( - 14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8, - 11, 6, 5, 4, 7, 2, 1, 0, 3, + unsafe fn rol_8(self) -> Self { + StateWord { + avx: [ + _mm256_shuffle_epi8( + self.avx[0], + _mm256_set_epi8( + 14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, + 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, + ), ), - ), - _mm256_shuffle_epi8( - self.avx[1], - _mm256_set_epi8( - 14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8, - 11, 6, 5, 4, 7, 2, 1, 0, 3, + _mm256_shuffle_epi8( + self.avx[1], + _mm256_set_epi8( + 14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, + 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, + ), ), - ), - ]; + ], + } } #[inline] + #[must_use] #[target_feature(enable = "avx2")] - unsafe fn rol_16(&mut self) { - self.avx = [ - _mm256_shuffle_epi8( - self.avx[0], - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, - 10, 5, 4, 7, 6, 1, 0, 3, 2, + unsafe fn rol_16(self) -> Self { + StateWord { + avx: [ + _mm256_shuffle_epi8( + self.avx[0], + _mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, + 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, + ), ), - ), - _mm256_shuffle_epi8( - self.avx[1], - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, - 10, 5, 4, 7, 6, 1, 0, 3, 2, + _mm256_shuffle_epi8( + self.avx[1], + _mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, + 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, + ), ), - ), - ]; + ], + } } } +struct State { + a: StateWord, + b: StateWord, + c: StateWord, + d: StateWord, +} + /// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64) // TODO(tarcieri): zeroize? #[derive(Clone)] @@ -152,10 +177,14 @@ impl Core { #[inline] pub fn generate(&self, counter: u64, output: &mut [u8]) { unsafe { - let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2); - let mut v3 = iv_setup(self.iv, counter); - self.rounds(&mut v0, &mut v1, &mut v2, &mut v3); - store(v0, v1, v2, v3, output); + let state = State { + a: self.v0, + b: self.v1, + c: self.v2, + d: iv_setup(self.iv, counter), + }; + let state = self.rounds(state); + store(state.a, state.b, state.c, state.d, output); } } @@ -166,14 +195,22 @@ impl Core { debug_assert_eq!(output.len(), BUFFER_SIZE); unsafe { - let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2); - let mut v3 = iv_setup(self.iv, counter); - self.rounds(&mut v0, &mut v1, &mut v2, &mut v3); + let state = State { + a: self.v0, + b: self.v1, + c: self.v2, + d: iv_setup(self.iv, counter), + }; + let state = self.rounds(state); for i in 0..BLOCKS { for (chunk, a) in output[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE] .chunks_mut(0x10) - .zip([v0, v1, v2, v3].iter().map(|s| s.blocks[i])) + .zip( + [state.a, state.b, state.c, state.d] + .iter() + .map(|s| s.blocks[i]), + ) { let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i); let out = _mm_xor_si128(a, b); @@ -185,23 +222,19 @@ impl Core { #[inline] #[target_feature(enable = "avx2")] - unsafe fn rounds( - &self, - v0: &mut StateWord, - v1: &mut StateWord, - v2: &mut StateWord, - v3: &mut StateWord, - ) { - let v3_orig = *v3; + unsafe fn rounds(&self, mut state: State) -> State { + let d_orig = state.d; for _ in 0..(R::COUNT / 2) { - double_quarter_round(v0, v1, v2, v3); + state = double_quarter_round(state); } - v0.add_assign_epi32(&self.v0); - v1.add_assign_epi32(&self.v1); - v2.add_assign_epi32(&self.v2); - v3.add_assign_epi32(&v3_orig); + State { + a: state.a.add_epi32(self.v0), + b: state.b.add_epi32(self.v1), + c: state.c.add_epi32(self.v2), + d: state.d.add_epi32(d_orig), + } } } @@ -264,16 +297,9 @@ unsafe fn store(v0: StateWord, v1: StateWord, v2: StateWord, v3: StateWord, outp #[inline] #[target_feature(enable = "avx2")] -unsafe fn double_quarter_round( - a: &mut StateWord, - b: &mut StateWord, - c: &mut StateWord, - d: &mut StateWord, -) { - add_xor_rot(a, b, c, d); - rows_to_cols(a, b, c, d); - add_xor_rot(a, b, c, d); - cols_to_rows(a, b, c, d); +unsafe fn double_quarter_round(state: State) -> State { + let state = add_xor_rot(state); + cols_to_rows(add_xor_rot(rows_to_cols(state))) } /// The goal of this function is to transform the state words from: @@ -313,16 +339,18 @@ unsafe fn double_quarter_round( /// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643 #[inline] #[target_feature(enable = "avx2")] -unsafe fn rows_to_cols( - a: &mut StateWord, - _b: &mut StateWord, - c: &mut StateWord, - d: &mut StateWord, -) { +unsafe fn rows_to_cols(state: State) -> State { // c = ROR256_B(c); d = ROR256_C(d); a = ROR256_D(a); - c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1) - d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2) - a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3) + let c = state.c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1) + let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2) + let a = state.a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3) + + State { + a, + b: state.b, + c, + d, + } } /// The goal of this function is to transform the state words from: @@ -344,38 +372,38 @@ unsafe fn rows_to_cols( /// reversing the transformation of [`rows_to_cols`]. #[inline] #[target_feature(enable = "avx2")] -unsafe fn cols_to_rows( - a: &mut StateWord, - _b: &mut StateWord, - c: &mut StateWord, - d: &mut StateWord, -) { +unsafe fn cols_to_rows(state: State) -> State { // c = ROR256_D(c); d = ROR256_C(d); a = ROR256_B(a); - c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3) - d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2) - a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1) + let c = state.c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3) + let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2) + let a = state.a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1) + + State { + a, + b: state.b, + c, + d, + } } #[inline] #[target_feature(enable = "avx2")] -unsafe fn add_xor_rot(a: &mut StateWord, b: &mut StateWord, c: &mut StateWord, d: &mut StateWord) { +unsafe fn add_xor_rot(state: State) -> State { // a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d); - a.add_assign_epi32(b); - d.xor_assign(a); - d.rol_16(); + let a = state.a.add_epi32(state.b); + let d = state.d.xor(a).rol_16(); // c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b); - c.add_assign_epi32(d); - b.xor_assign(c); - b.rol::<12, 20>(); + let c = state.c.add_epi32(d); + let b = state.b.xor(c).rol::<12, 20>(); // a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d); - a.add_assign_epi32(b); - d.xor_assign(a); - d.rol_8(); + let a = a.add_epi32(b); + let d = d.xor(a).rol_8(); // c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b); - c.add_assign_epi32(d); - b.xor_assign(c); - b.rol::<7, 25>(); + let c = c.add_epi32(d); + let b = b.xor(c).rol::<7, 25>(); + + State { a, b, c, d } }