Skip to content

Commit

Permalink
pool: Fix unsafe randomizers (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
cthulhu-rider authored Dec 5, 2024
2 parents 591dd25 + fd1e188 commit 335d9fe
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
9 changes: 3 additions & 6 deletions pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -769,8 +768,7 @@ func (p *Pool) Dial(ctx context.Context) error {

atLeastOneHealthy = true
}
source := rand.NewSource(time.Now().UnixNano())
sampl := newSampler(params.weights, source)
sampl := newSampler(params.weights, safeRand{})

inner[i] = &innerPool{
sampler: sampl,
Expand Down Expand Up @@ -940,9 +938,8 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights

if healthyChanged.Load() {
probabilities := adjustWeights(bufferWeights)
source := rand.NewSource(time.Now().UnixNano())
pool.lock.Lock()
pool.sampler = newSampler(probabilities, source)
pool.sampler = newSampler(probabilities, safeRand{})
pool.lock.Unlock()
}
}
Expand Down Expand Up @@ -985,7 +982,7 @@ func (p *innerPool) connection() (internalClient, error) {
}
attempts := 3 * len(p.clients)
for range attempts {
i := p.sampler.Next()
i := p.sampler.next()
if cp := p.clients[i]; cp.isHealthy() {
return cp, nil
}
Expand Down
22 changes: 17 additions & 5 deletions pool/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@ package pool

import "math/rand"

// [rand.Rand] interface.
type rander interface {
Intn(n int) int
Float64() float64
}

// replacement of [rand.Rand] safe for concurrent use.
type safeRand struct{}

func (safeRand) Intn(n int) int { return rand.Intn(n) }
func (safeRand) Float64() float64 { return rand.Float64() }

// sampler implements weighted random number generation using Vose's Alias
// Method (https://www.keithschwarz.com/darts-dice-coins/).
type sampler struct {
randomGenerator *rand.Rand
randomGenerator rander
probabilities []float64
alias []int
}

// newSampler creates new sampler with a given set of probabilities using
// given source of randomness. Created sampler will produce numbers from
// 0 to len(probabilities).
func newSampler(probabilities []float64, source rand.Source) *sampler {
func newSampler(probabilities []float64, r rander) *sampler {
sampler := &sampler{}
var (
small workList
large workList
)
n := len(probabilities)
sampler.randomGenerator = rand.New(source)
sampler.randomGenerator = r
sampler.probabilities = make([]float64, n)
sampler.alias = make([]int, n)
// Compute scaled probabilities.
Expand Down Expand Up @@ -57,8 +69,8 @@ func newSampler(probabilities []float64, source rand.Source) *sampler {
return sampler
}

// Next returns the next (not so) random number from sampler.
func (g *sampler) Next() int {
// returns the next (not so) random number from sampler.
func (g *sampler) next() int {
n := len(g.alias)
i := g.randomGenerator.Intn(n)
if g.randomGenerator.Float64() < g.probabilities[i] {
Expand Down
57 changes: 52 additions & 5 deletions pool/sampler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package pool

import (
"context"
"fmt"
"math/rand"
"runtime/debug"
"sync"
"sync/atomic"
"testing"

neofscryptotest "github.com/nspcc-dev/neofs-sdk-go/crypto/test"
Expand Down Expand Up @@ -32,10 +36,10 @@ func TestSamplerStability(t *testing.T) {
}

for _, tc := range cases {
sampler := newSampler(tc.probabilities, rand.NewSource(0))
sampler := newSampler(tc.probabilities, rand.New(rand.NewSource(0)))
res := make([]int, len(tc.probabilities))
for range COUNT {
res[sampler.Next()]++
res[sampler.next()]++
}

require.Equal(t, tc.expected, res, "probabilities: %v", tc.probabilities)
Expand All @@ -58,7 +62,7 @@ func TestHealthyReweight(t *testing.T) {
client2 := newMockClient(names[1], neofscryptotest.Signer())

inner := &innerPool{
sampler: newSampler(weights, rand.NewSource(0)),
sampler: newSampler(weights, rand.New(rand.NewSource(0))),
clients: []internalClient{client1, client2},
}
p := &Pool{
Expand Down Expand Up @@ -87,7 +91,7 @@ func TestHealthyReweight(t *testing.T) {
inner.lock.Unlock()

p.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.sampler = newSampler(weights, rand.NewSource(0))
inner.sampler = newSampler(weights, rand.New(rand.NewSource(0)))

connection0, err = p.connection()
require.NoError(t, err)
Expand All @@ -102,7 +106,7 @@ func TestHealthyNoReweight(t *testing.T) {
buffer = make([]float64, len(weights))
)

sampl := newSampler(weights, rand.NewSource(0))
sampl := newSampler(weights, rand.New(rand.NewSource(0)))
inner := &innerPool{
sampler: sampl,
clients: []internalClient{
Expand All @@ -121,3 +125,46 @@ func TestHealthyNoReweight(t *testing.T) {
defer inner.lock.RUnlock()
require.Equal(t, inner.sampler, sampl)
}

func TestSamplerSafety(t *testing.T) {
// https://github.com/nspcc-dev/neofs-sdk-go/issues/631
// Note that this test is not 100% consistent so it may PASS, but it FAILs more
// often when bugs.
type panicInfo = struct {
cause any
stack []byte
}
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{})
var pr atomic.Value // panicInfo
var wg sync.WaitGroup
for range 1000 {
const rn = 1000
wg.Add(rn)
for range rn {
go func() {
defer func() {
wg.Done()
// [require.NotPanics] should be called in a test func routine, so we simulate it
if r := recover(); r != nil {
// in theory, various causes may happen. With this, only the "last" one is
// caught. In practice, we are chasing the exact one.
pr.Store(panicInfo{r, debug.Stack()})
}
}()
s.next()
}()
}
wg.Wait()
if v := pr.Load(); v != nil {
p := v.(panicInfo)
require.Fail(t, fmt.Sprintf("should not panic\n\tPanic value:\t%v\n\tPanic stack:\t%s", p.cause, p.stack))
}
}
}

func BenchmarkSampler(b *testing.B) {
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{})
for range b.N {
s.next()
}
}

0 comments on commit 335d9fe

Please sign in to comment.