Skip to content

Commit

Permalink
optimise some of the bounds checks (#34)
Browse files Browse the repository at this point in the history
* optimise bounds checks

This brings a performance improvement of 40-100%,
making this implementation as fast as the C++ alternative in kagome.

Where possible, compiler is aided to optimise away the bounds checks without
any unsafe code. However, a fair amount of unsafe code was needed,
but it doesn't lower the security posture as the needed assertions
were already being made.

Signed-off-by: alindima <alin@parity.io>

* fix clippy

* switch to using safe optimisations

* revert some changes

---------

Signed-off-by: alindima <alin@parity.io>
  • Loading branch information
alindima authored Jan 11, 2024
1 parent be37510 commit 886be0e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 36 deletions.
48 changes: 30 additions & 18 deletions reed-solomon-novelpoly/src/field/inc_afft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ pub struct AdditiveFFT {
}

/// Formal derivative of polynomial in the new?? basis
pub fn formal_derivative(cos: &mut [Additive], size: usize) {
for i in 1..size {
pub fn formal_derivative(cos: &mut [Additive]) {
for i in 1..cos.len() {
let length = ((i ^ (i - 1)) + 1) >> 1;
for j in (i - length)..i {
cos[j] ^= cos.get(j + length).copied().unwrap_or(Additive::ZERO);
}
}
let mut i = size;
let mut i = cos.len();
while i < FIELD_SIZE && i < cos.len() {
for j in 0..size {
for j in 0..cos.len() {
cos[j] ^= cos.get(j + i).copied().unwrap_or(Additive::ZERO);
}
i <<= 1;
Expand All @@ -32,9 +32,11 @@ pub fn formal_derivative(cos: &mut [Additive], size: usize) {

/// Formal derivative of polynomial in tweaked?? basis
#[allow(non_snake_case)]
pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
pub fn tweaked_formal_derivative(codeword: &mut [Additive]) {
#[cfg(b_is_not_one)]
let B = unsafe { &AFFT.B };
#[cfg(b_is_not_one)]
let n = codeword.len();

// We change nothing when multiplying by b from B.
#[cfg(b_is_not_one)]
Expand All @@ -44,7 +46,7 @@ pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
codeword[i + 1] = codeword[i + 1].mul(b);
}

formal_derivative(codeword, n);
formal_derivative(codeword);

// Again changes nothing by multiplying by b although b differs here.
#[cfg(b_is_not_one)]
Expand Down Expand Up @@ -86,21 +88,25 @@ fn b_is_one() {
// We're hunting for the differences and trying to undersrtand the algorithm.

/// Inverse additive FFT in the "novel polynomial basis"
#[inline(always)]
pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.inverse_afft(data, size, index)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[inline(always)]
pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.inverse_afft_faster8(data, size, index)
}

/// Additive FFT in the "novel polynomial basis"
#[inline(always)]
pub fn afft(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.afft(data, size, index)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[inline(always)]
/// Additive FFT in the "novel polynomial basis"
pub fn afft_faster8(data: &mut [Additive], size: usize, index: usize) {
unsafe { &AFFT }.afft_faster8(data, size, index)
Expand Down Expand Up @@ -141,6 +147,8 @@ impl AdditiveFFT {
// After this, we start at depth (i of Algorithm 2) = (k of Algorithm 2) - 1
// and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0).
let mut depart_no = 1_usize;
assert!(data.len() >= size);

while depart_no < size {
// if depart_no >= 8 {
// println!("\n\n\nplain/Round depart_no={depart_no}");
Expand All @@ -167,20 +175,16 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i + depart_no] ^= dbg!(data[dbg!(i)]);
// } else {

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i + depart_no] ^= data[i];
// }
}

// Algorithm 2 indexs the skew factor in line 5 page 6288
// by i and \omega_{j 2^{i+1}}, but not by r explicitly.
// We further explore this confusion below. (TODO)
let skew =
// if depart_no >= 8 && false {
// dbg!(self.skews[j + index - 1])
// } else {
self.skews[j + index - 1]
// }
;
let skew = self.skews[j + index - 1];

// It's reasonale to skip the loop if skew is zero, but doing so with
// all bits set requires justification. (TODO)
if skew.0 != ONEMASK {
Expand All @@ -191,8 +195,9 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew));
// } else {

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i] ^= data[i + depart_no].mul(skew);
// }
}
}

Expand Down Expand Up @@ -270,6 +275,8 @@ impl AdditiveFFT {
// After this, we start at depth (i of Algorithm 1) = (k of Algorithm 1) - 1
// and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0).
let mut depart_no = size >> 1_usize;
assert!(data.len() >= size);

while depart_no > 0 {
// Agrees with for loop (j of Algorithm 1) in (0..2^{k-i-1}) from line 5,
// except we've j in (depart_no..size).step_by(2*depart_no), meaning
Expand All @@ -291,6 +298,7 @@ impl AdditiveFFT {
// we think r actually appears but the skew factor repeats itself
// like in (19) in the proof of Lemma 4. (TODO)
// We should understand the rest of this basis story, like (8) too. (TODO)

let skew = self.skews[j + index - 1];

// It's reasonale to skip the loop if skew is zero, but doing so with
Expand All @@ -300,6 +308,8 @@ impl AdditiveFFT {
for i in (j - depart_no)..j {
// Line 6, explained by (28) page 6287, but
// adding depart_no acts like the r+2^i superscript.

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i] ^= data[i + depart_no].mul(skew);
}
}
Expand All @@ -308,6 +318,8 @@ impl AdditiveFFT {
for i in (j - depart_no)..j {
// Line 7, explained by (31) page 6287, but
// adding depart_no acts like the r+2^i superscript.

// TODO: Optimising bounds checks on this line will yield a great performance improvement.
data[i + depart_no] ^= data[i];
}

Expand Down Expand Up @@ -484,7 +496,7 @@ pub mod test_utils {
let data = gen_plain::<R>(size);
gen_faster8_from_plain(data)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
pub fn assert_plain_eq_faster8(plain: impl AsRef<[Additive]>, faster8: impl AsRef<[Additive]>) {
let plain = plain.as_ref();
Expand All @@ -502,7 +514,7 @@ mod afft_tests {
use super::super::*;
use super::super::test_utils::*;
use rand::rngs::SmallRng;

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_size_16() {
Expand Down Expand Up @@ -544,7 +556,7 @@ mod afft_tests {
println!(">>>>");
assert_plain_eq_faster8(data_plain, data_faster8);
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_impulse_data() {
Expand Down
23 changes: 10 additions & 13 deletions reed-solomon-novelpoly/src/field/inc_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn encode_low(data: &[Additive], k: usize, codeword: &mut [Additive], n: usi
encode_low_plain(data, k, codeword, n);
}

#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
encode_low_plain(data, k, codeword, n);
}

Expand Down Expand Up @@ -37,12 +37,10 @@ pub fn encode_low_plain(data: &[Additive], k: usize, codeword: &mut [Additive],

for shift in (k..n).step_by(k) {
let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift];

// copy `M_topdash` to the position we are currently at, the n transform
codeword_at_shift.copy_from_slice(codeword_first_k);
// dbg!(&codeword_at_shift);
afft(codeword_at_shift, k, shift);
// let post = &codeword_at_shift;
// dbg!(post);
}

// restore `M` from the derived ones
Expand Down Expand Up @@ -79,11 +77,10 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive]

for shift in (k..n).step_by(k) {
let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift];

// copy `M_topdash` to the position we are currently at, the n transform
codeword_at_shift.copy_from_slice(codeword_first_k);

afft_faster8(codeword_at_shift, k, shift);
// let post = &codeword8x_at_shift;
}

// restore `M` from the derived ones
Expand All @@ -108,6 +105,8 @@ pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &m
//data: message array. parity: parity array. mem: buffer(size>= n-k)
//Encoding alg for k/n>0.5: parity is a power of two.
pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) {
assert!(is_power_of_2(n));

let t: usize = n - k;

// mem_zero(&mut parity[0..t]);
Expand Down Expand Up @@ -158,7 +157,7 @@ pub fn encode_sub(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additive>> {
} else {
encode_sub_plain(bytes, n, k)
}
#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
encode_sub_plain(bytes, n, k)
}

Expand Down Expand Up @@ -194,13 +193,11 @@ pub fn encode_sub_plain(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additive
elm_data[i] = Additive(Elt::from_be_bytes([
bytes.get(2 * i).copied().unwrap_or_default(),
bytes.get(2 * i + 1).copied().unwrap_or_default(),
]))
]));
}

// update new data bytes with zero padded bytes
// `l` is now `GF(2^16)` symbols
let elm_len = elm_data.len();
assert_eq!(elm_len, n);

let mut codeword = elm_data.clone();
assert_eq!(codeword.len(), n);
Expand Down Expand Up @@ -243,9 +240,9 @@ pub fn encode_sub_faster8(bytes: &[u8], n: usize, k: usize) -> Result<Vec<Additi

for i in 0..((bytes_len + 1) / 2) {
elm_data[i] = Additive(Elt::from_be_bytes([
bytes.get(2 * i).map(|x| *x).unwrap_or_default(),
bytes.get(2 * i + 1).map(|x| *x).unwrap_or_default(),
]))
bytes.get(2 * i).copied().unwrap_or_default(),
bytes.get(2 * i + 1).copied().unwrap_or_default(),
]));
}

// update new data bytes with zero padded bytes
Expand Down
8 changes: 6 additions & 2 deletions reed-solomon-novelpoly/src/field/inc_log_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl Additive {

/// Multiplicaiton friendly LOG form of f2e16
#[derive(Clone, Debug, Copy, Add, AddAssign, Sub, SubAssign, PartialEq, Eq)] // Default, PartialOrd,Ord
#[repr(transparent)]
pub struct Multiplier(pub Elt);

impl Multiplier {
Expand All @@ -81,13 +82,16 @@ impl std::fmt::Display for Multiplier {
/// Fast Walsh–Hadamard transform over modulo `ONEMASK`
#[inline(always)]
pub fn walsh(data: &mut [Multiplier], size: usize) {
#[cfg(all(target_feature = "avx", table_bootstrap_complete))]
#[cfg(all(target_feature = "avx", table_bootstrap_complete, feature = "avx"))]
walsh_faster8(data, size);
#[cfg(not(all(target_feature = "avx", table_bootstrap_complete)))]
#[cfg(not(all(target_feature = "avx", table_bootstrap_complete, feature = "avx")))]
walsh_plain(data, size);
}

#[inline(always)]
pub fn walsh_plain(data: &mut [Multiplier], size: usize) {
assert!(data.len() >= size);

let mask = ONEMASK as Wide;
let mut depart_no = 1_usize;
while depart_no < size {
Expand Down
8 changes: 6 additions & 2 deletions reed-solomon-novelpoly/src/field/inc_reconstruct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ pub(crate) fn decode_main(
assert!(n >= recover_up_to);
assert_eq!(erasure.len(), n);

for i in 0..n {
for i in 0..codeword.len() {
codeword[i] = if erasure[i] { Additive(0) } else { codeword[i].mul(log_walsh2[i]) };
}

inverse_afft(codeword, n, 0);

tweaked_formal_derivative(codeword, n);
tweaked_formal_derivative(codeword);

afft(codeword, n, 0);

Expand All @@ -89,6 +89,10 @@ pub(crate) fn decode_main(
// since this has only to be called once per reconstruction
pub fn eval_error_polynomial(erasure: &[bool], log_walsh2: &mut [Multiplier], n: usize) {
let z = std::cmp::min(n, erasure.len());
assert!(z <= erasure.len());
assert!(n <= log_walsh2.len());
assert!(z <= log_walsh2.len());

for i in 0..z {
log_walsh2[i] = Multiplier(erasure[i] as Elt);
}
Expand Down
2 changes: 1 addition & 1 deletion reed-solomon-novelpoly/src/novel_poly_basis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl CodeParams {
{
self.k >= (Additive8x::LANE << 1) && self.n % Additive8x::LANE == 0
}
#[cfg(not(target_feature = "avx"))]
#[cfg(not(all(target_feature = "avx", feature = "avx")))]
false
}

Expand Down

0 comments on commit 886be0e

Please sign in to comment.