Skip to content

Commit

Permalink
Tag vartime the bithacks that are not constant-time
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Feb 6, 2022
1 parent 404a966 commit c02e6bd
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion constantine/arithmetic/bigints.nim
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func bit*[bits: static int](a: BigInt[bits], index: int): Ct[uint8] =
## (b7, b6, b5, b4, b3, b2, b1, b0)
## for a 256-bit big-integer
## (b255, b254, ..., b1, b0)
const SlotShift = log2(WordBitWidth.uint32)
const SlotShift = log2_vartime(WordBitWidth.uint32)
const SelectMask = WordBitWidth - 1
const BitMask = One

Expand Down
6 changes: 3 additions & 3 deletions constantine/arithmetic/limbs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func setUint*(a: var Limbs, n: SomeUnsignedInt) =
"in ", a.len, " limb of size ", sizeof(SecretWord), "."

a[0] = SecretWord(n) # Truncate the upper part
a[1] = SecretWord(n shr log2(sizeof(SecretWord)))
a[1] = SecretWord(n shr static(log2_vartime(sizeof(SecretWord))))
when a.len > 2:
zeroMem(a[2].addr, (a.len - 2) * sizeof(SecretWord))

Expand Down Expand Up @@ -340,8 +340,8 @@ func div10*(a: var Limbs): SecretWord =
## TODO constant-time
result = Zero

let clz = WordBitWidth - 1 - log2(10)
let norm10 = SecretWord(10) shl clz
const clz = WordBitWidth - 1 - log2_vartime(10)
const norm10 = SecretWord(10) shl clz

for i in countdown(a.len-1, 0):
# dividend = 2^64 * remainder + a[i]
Expand Down
2 changes: 1 addition & 1 deletion constantine/arithmetic/limbs_modular.nim
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func csub(a: LimbsViewMut, b: LimbsViewAny, ctl: SecretBool, len: int): Borrow =
# ------------------------------------------------------------

func numWordsFromBits(bits: int): int {.inline.} =
const divShiftor = log2(uint32(WordBitWidth))
const divShiftor = log2_vartime(uint32(WordBitWidth))
result = (bits + WordBitWidth - 1) shr divShiftor

func shlAddMod_estimate(a: LimbsViewMut, aLen: int,
Expand Down
6 changes: 3 additions & 3 deletions constantine/config/precompute.nim
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func checkOdd(M: BigInt) =

func checkValidModulus(M: BigInt) =
const expectedMsb = M.bits-1 - WordBitWidth * (M.limbs.len - 1)
let msb = log2(BaseType(M.limbs[^1]))
let msb = log2_vartime(BaseType(M.limbs[^1]))

doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" &
" Modulus '" & M.toHex() & "' is declared with " & $M.bits &
Expand All @@ -252,7 +252,7 @@ func countSpareBits*(M: BigInt): int =
## - [0, 8p) if 3 bits are available
## - ...
checkValidModulus(M)
let msb = log2(BaseType(M.limbs[^1]))
let msb = log2_vartime(BaseType(M.limbs[^1]))
result = WordBitWidth - 1 - msb.int

func invModBitwidth[T: SomeUnsignedInt](a: T): T =
Expand Down Expand Up @@ -280,7 +280,7 @@ func invModBitwidth[T: SomeUnsignedInt](a: T): T =
# which grows in O(log(log(a)))
checkOdd(a)

let k = log2(T.sizeof() * 8)
let k = log2_vartime(T.sizeof() * 8)
result = a # Start from an inverse of M0 modulo 2, M0 is odd and it's own inverse
for _ in 0 ..< k: # at each iteration we get the inverse mod(2^2k)
result *= 2 - a * result # x' = x(2 - ax)
Expand Down
6 changes: 3 additions & 3 deletions constantine/elliptic/ec_endomorphism_accel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func buildLookupTable[M: static int, EC](
lut[0] = P
for u in 1'u32 ..< 1 shl (M-1):
# The recoding allows usage of 2^(n-1) table instead of the usual 2^n with NAF
let msb = u.log2() # No undefined, u != 0
let msb = u.log2_vartime() # No undefined, u != 0
lut[u].sum(lut[u.clearBit(msb)], endomorphisms[msb])

func tableIndex(glv: GLV_SAC, bit: int): SecretWord =
Expand Down Expand Up @@ -599,7 +599,7 @@ when isMainModule:
## per new entries
lut[0] = P
for u in 1'u32 ..< 1 shl (M-1):
let msb = u.log2() # No undefined, u != 0
let msb = u.log2_vartime() # No undefined, u != 0
lut[u] = lut[u.clearBit(msb)] & " + " & endomorphisms[msb]

proc main_lut() =
Expand Down Expand Up @@ -716,7 +716,7 @@ when isMainModule:
## per new entries
lut[0].incl P
for u in 1'u32 ..< 1 shl (M-1):
let msb = u.log2() # No undefined, u != 0
let msb = u.log2_vartime() # No undefined, u != 0
lut[u] = lut[u.clearBit(msb)] + {endomorphisms[msb]}


Expand Down
42 changes: 29 additions & 13 deletions constantine/primitives/bithacks.nim
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@
# and https://graphics.stanford.edu/%7Eseander/bithacks.html
# for compendiums of bit manipulation

proc clearMask[T: SomeInteger](v: T, mask: T): T {.inline.} =
func clearMask[T: SomeInteger](v: T, mask: T): T {.inline.} =
## Returns ``v``, with all the ``1`` bits from ``mask`` set to 0
v and not mask

proc clearBit*[T: SomeInteger](v: T, bit: T): T {.inline.} =
func clearBit*[T: SomeInteger](v: T, bit: T): T {.inline.} =
## Returns ``v``, with the bit at position ``bit`` set to 0
v.clearMask(1.T shl bit)

func log2Impl(x: uint32): uint32 =
func log2impl_vartime(x: uint32): uint32 =
## Find the log base 2 of a 32-bit or less integer.
## using De Bruijn multiplication
## Works at compile-time, guaranteed constant-time.
## Works at compile-time.
## ⚠️ not constant-time, table accesses are not uniform.
## TODO: at runtime BitScanReverse or CountLeadingZero are more efficient
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
const lookup: array[32, uint8] = [0'u8, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18,
Expand All @@ -57,10 +58,11 @@ func log2Impl(x: uint32): uint32 =
v = v or v shr 16
lookup[(v * 0x07C4ACDD'u32) shr 27]

func log2Impl(x: uint64): uint64 {.inline, noSideEffect.} =
func log2impl_vartime(x: uint64): uint64 {.inline.} =
## Find the log base 2 of a 32-bit or less integer.
## using De Bruijn multiplication
## Works at compile-time, guaranteed constant-time.
## Works at compile-time.
## ⚠️ not constant-time, table accesses are not uniform.
## TODO: at runtime BitScanReverse or CountLeadingZero are more efficient
# https://graphics.stanford.edu/%7Eseander/bithacks.html#IntegerLogDeBruijn
const lookup: array[64, uint8] = [0'u8, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54,
Expand All @@ -76,27 +78,41 @@ func log2Impl(x: uint64): uint64 {.inline, noSideEffect.} =
v = v or v shr 32
lookup[(v * 0x03F6EAF2CD271461'u64) shr 58]

func log2*[T: SomeUnsignedInt](n: T): T =
func log2_vartime*[T: SomeUnsignedInt](n: T): T {.inline.} =
## Find the log base 2 of an integer
when sizeof(T) == sizeof(uint64):
T(log2Impl(uint64(n)))
T(log2impl_vartime(uint64(n)))
else:
static: doAssert sizeof(T) <= sizeof(uint32)
T(log2Impl(uint32(n)))
T(log2impl_vartime(uint32(n)))

func hammingWeight*(x: uint32): int {.inline.} =
func hammingWeight*(x: uint32): uint {.inline.} =
## Counts the set bits in integer.
# https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
var v = x
v = v - ((v shr 1) and 0x55555555)
v = (v and 0x33333333) + ((v shr 2) and 0x33333333)
cast[int](((v + (v shr 4) and 0xF0F0F0F) * 0x1010101) shr 24)
uint(((v + (v shr 4) and 0xF0F0F0F) * 0x1010101) shr 24)

func hammingWeight*(x: uint64): int {.inline.} =
func hammingWeight*(x: uint64): uint {.inline.} =
## Counts the set bits in integer.
# https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
var v = x
v = v - ((v shr 1'u64) and 0x5555555555555555'u64)
v = (v and 0x3333333333333333'u64) + ((v shr 2'u64) and 0x3333333333333333'u64)
v = (v + (v shr 4'u64) and 0x0F0F0F0F0F0F0F0F'u64)
cast[int]((v * 0x0101010101010101'u64) shr 56'u64)
uint((v * 0x0101010101010101'u64) shr 56'u64)

func countLeadingZeros_vartime*[T: SomeUnsignedInt](x: T): T {.inline.} =
(8*sizeof(T)) - 1 - log2_vartime(x)

func isPowerOf2_vartime*(n: SomeUnsignedInt): bool {.inline.} =
## Returns true if n is a power of 2
## ⚠️ Result is bool instead of Secretbool,
## for compile-time or explicit vartime proc only.
(n and (n - 1)) == 0

func nextPowerOf2_vartime*(n: uint64): uint64 {.inline.} =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2_vartime(n-1) + 1)
2 changes: 1 addition & 1 deletion docs/optimizations.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n
- Inversion (constant-time baseline, Little-Fermat inversion via a^(p-2))
- [x] Constant-time binary GCD algorithm by Möller, algorithm 5 in https://link.springer.com/content/pdf/10.1007%2F978-3-642-40588-4_10.pdf
- [x] Addition-chain for a^(p-2)
- [ ] Constant-time binary GCD algorithm by Bernstein-Young, https://eprint.iacr.org/2019/266
- [ ] Constant-time binary GCD algorithm by Bernstein-Yang, https://eprint.iacr.org/2019/266
- [ ] Constant-time binary GCD algorithm by Pornin, https://eprint.iacr.org/2020/972
- [ ] Simultaneous inversion

Expand Down
12 changes: 2 additions & 10 deletions research/kzg_poly_commit/fft_fr.nim
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ type
expandedRootsOfUnity: seq[F]
## domain, starting and ending with 1

func isPowerOf2(n: SomeUnsignedInt): bool =
(n and (n - 1)) == 0

func nextPowerOf2(n: uint64): uint64 =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2(n-1) + 1)

func expandRootOfUnity[F](rootOfUnity: F): seq[F] =
## From a generator root of unity
## expand to width + 1 values.
Expand Down Expand Up @@ -145,7 +137,7 @@ func fft*[F](
vals: openarray[F]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo

let rootz = desc.expandedRootsOfUnity
Expand All @@ -163,7 +155,7 @@ func ifft*[F](
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo

let rootz = desc.expandedRootsOfUnity
Expand Down
12 changes: 2 additions & 10 deletions research/kzg_poly_commit/fft_g1.nim
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ type
expandedRootsOfUnity: seq[matchingOrderBigInt(EC.F.C)]
## domain, starting and ending with 1

func isPowerOf2(n: SomeUnsignedInt): bool =
(n and (n - 1)) == 0

func nextPowerOf2(n: uint64): uint64 =
## Returns x if x is a power of 2
## or the next biggest power of 2
1'u64 shl (log2(n-1) + 1)

func expandRootOfUnity[F](rootOfUnity: F): auto {.noInit.} =
## From a generator root of unity
## expand to width + 1 values.
Expand Down Expand Up @@ -157,7 +149,7 @@ func fft*[EC](
vals: openarray[EC]): FFT_Status =
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo

let rootz = desc.expandedRootsOfUnity
Expand All @@ -175,7 +167,7 @@ func ifft*[EC](
## Inverse FFT
if vals.len > desc.maxWidth:
return FFTS_TooManyValues
if not vals.len.uint64.isPowerOf2():
if not vals.len.uint64.isPowerOf2_vartime():
return FFTS_SizeNotPowerOfTwo

let rootz = desc.expandedRootsOfUnity
Expand Down

0 comments on commit c02e6bd

Please sign in to comment.