diff --git a/src/bn/relic_bn_smb.c b/src/bn/relic_bn_smb.c index 54441cd20..7e0e610f9 100644 --- a/src/bn/relic_bn_smb.c +++ b/src/bn/relic_bn_smb.c @@ -79,19 +79,18 @@ int bn_smb_leg(const bn_t a, const bn_t b) { int bn_smb_jac(const bn_t a, const bn_t b) { dis_t ai, bi, ci, di; dig_t n, d, t; - bn_t r, t0, t1, t2, t3; + bn_t t0, t1, t2, t3; uint_t z, i, s = (RLC_DIG >> 1) - 2; bn_null(t0); bn_null(t1); bn_null(t2); bn_null(t3); - bn_null(r); /* Optimized Pornin's Algorithm by Aleksei Vambol from * https://github.com/privacy-scaling-explorations/halo2curves/pull/95 */ - /* Argument b must be odd. */ + /* Argument b must be odd for Jacobi symbol. */ if (bn_is_even(b) || bn_sign(b) == RLC_NEG) { RLC_THROW(ERR_NO_VALID); return 0; @@ -102,7 +101,6 @@ int bn_smb_jac(const bn_t a, const bn_t b) { bn_new(t1); bn_new(t2); bn_new(t3); - bn_new(r); bn_mod(t0, a, b); bn_copy(t1, b); @@ -111,8 +109,10 @@ int bn_smb_jac(const bn_t a, const bn_t b) { while (1) { ai = di = 1; bi = ci = 0; - + i = RLC_MAX(t0->used, t1->used); + dv_zero(t0->dp + t0->used, i - t0->used); + dv_zero(t1->dp + t1->used, i - t1->used); if (i == 1) { n = t0->dp[0]; d = t1->dp[0]; @@ -132,6 +132,7 @@ int bn_smb_jac(const bn_t a, const bn_t b) { } return (d == 1 ? 1 - (t & 2) : 0); } + z = RLC_MIN(arch_lzcnt(t0->dp[i - 1]), arch_lzcnt(t1->dp[i - 1])); n = t0->dp[i - 1] << z; d = t1->dp[i - 1] << z; @@ -167,6 +168,7 @@ int bn_smb_jac(const bn_t a, const bn_t b) { i -= z; } } + if (ai < 0) { bn_mul_dig(t2, t0, -ai); bn_neg(t2, t2); @@ -179,8 +181,7 @@ int bn_smb_jac(const bn_t a, const bn_t b) { } else { bn_mul_dig(t3, t1, bi); } - bn_add(r, t2, t3); - bn_rsh(r, r, s); + bn_add(t3, t3, t2); if (ci < 0) { bn_mul_dig(t2, t0, -ci); @@ -189,14 +190,14 @@ int bn_smb_jac(const bn_t a, const bn_t b) { bn_mul_dig(t2, t0, ci); } if (di < 0) { - bn_mul_dig(t3, t1, -di); - bn_neg(t3, t3); + bn_mul_dig(t1, t1, -di); + bn_neg(t1, t1); } else { - bn_mul_dig(t3, t1, di); + bn_mul_dig(t1, t1, di); } - bn_add(t1, t2, t3); + bn_add(t1, t1, t2); bn_rsh(t1, t1, s); - bn_copy(t0, r); + bn_rsh(t0, t3, s); if (bn_is_zero(t0)) { return (bn_cmp_dig(t1, 1) == RLC_EQ ? 1 - (t & 2) : 0); @@ -218,7 +219,6 @@ int bn_smb_jac(const bn_t a, const bn_t b) { bn_free(t1); bn_free(t2); bn_free(t3); - bn_free(r); } return t;