From be1962878cca94326f16f2a74f81c156a6d84bc1 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 10 Apr 2024 17:42:31 -0700 Subject: [PATCH] Clean up `computeJitterUpperBoundMs` (#427) --- lib/jitter/sleep.go | 21 +++++++++++++++------ lib/jitter/sleep_test.go | 12 ++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/lib/jitter/sleep.go b/lib/jitter/sleep.go index 58224ef88..e1f603c16 100644 --- a/lib/jitter/sleep.go +++ b/lib/jitter/sleep.go @@ -1,25 +1,34 @@ package jitter import ( + "math" "math/rand" "time" ) const DefaultMaxMs = 3500 +// safePowerOfTwo calculates 2 ** n without panicking for values of n below 0 or above 62. +func safePowerOfTwo(n int64) int64 { + if n < 0 { + return 0 + } else if n > 62 { + return math.MaxInt64 // 2 ** n will overflow + } + return 1 << n // equal to 2 ** n +} + // computeJitterUpperBoundMs calculates min(maxMs, baseMs * 2 ** attempt). func computeJitterUpperBoundMs(baseMs, maxMs, attempts int64) int64 { if maxMs <= 0 { return 0 } - // Check for overflows when computing base * 2 ** attempts. - // 2 ** x == 1 << x - if attemptsMaxMs := baseMs * (1 << attempts); attemptsMaxMs > 0 { - maxMs = min(maxMs, attemptsMaxMs) + powerOfTwo := safePowerOfTwo(attempts) + if powerOfTwo > math.MaxInt64/baseMs { // check for overflow + return maxMs } - - return maxMs + return min(maxMs, baseMs*powerOfTwo) } // Jitter implements exponential backoff + jitter. diff --git a/lib/jitter/sleep_test.go b/lib/jitter/sleep_test.go index f62201df3..d97c5e462 100644 --- a/lib/jitter/sleep_test.go +++ b/lib/jitter/sleep_test.go @@ -8,6 +8,18 @@ import ( "github.com/stretchr/testify/assert" ) +func TestSafePowerOfTwo(t *testing.T) { + assert.Equal(t, int64(0), safePowerOfTwo(-2)) + assert.Equal(t, int64(0), safePowerOfTwo(-1)) + assert.Equal(t, int64(1), safePowerOfTwo(0)) + assert.Equal(t, int64(2), safePowerOfTwo(1)) + assert.Equal(t, int64(4), safePowerOfTwo(2)) + assert.Equal(t, int64(4611686018427387904), safePowerOfTwo(62)) + assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(63)) + assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(64)) + assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(100)) +} + func TestComputeJitterUpperBoundMs(t *testing.T) { // A maxMs that is <= 0 returns 0. assert.Equal(t, int64(0), computeJitterUpperBoundMs(0, 0, 0))