Skip to content

Commit

Permalink
Merge pull request #4844 from onflow/tarak/improve-shuffle-test
Browse files Browse the repository at this point in the history
[Crypto] improve shuffle test
  • Loading branch information
tarakby authored Nov 8, 2023
2 parents 417ec4a + 9ce23e0 commit fe6714e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 52 deletions.
2 changes: 1 addition & 1 deletion crypto/random/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (p *genericPRG) Shuffle(n int, swap func(i, j int)) error {
return p.Samples(n, n, swap)
}

// Samples picks randomly m elements out of n elemnts and places them
// Samples picks randomly m elements out of n elements and places them
// in random order at indices [0,m-1], the swapping being implemented in place.
//
// It implements the first (m) elements of Fisher-Yates Shuffle using `p` as a source of randoms.
Expand Down
79 changes: 36 additions & 43 deletions crypto/random/rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ func TestUintN(t *testing.T) {
require.NoError(t, err)

t.Run("basic uniformity", func(t *testing.T) {

maxN := uint64(1000)
mod := mrand.Uint64()
var n, classWidth uint64
Expand All @@ -123,7 +122,6 @@ func TestUintN(t *testing.T) {
return uint64(rng.UintN(mod)), nil
}
BasicDistributionTest(t, n, classWidth, uintNf)

})

t.Run("zero n", func(t *testing.T) {
Expand Down Expand Up @@ -228,52 +226,47 @@ func TestShuffle(t *testing.T) {
require.NoError(t, err)

t.Run("basic uniformity", func(t *testing.T) {
listSize := 100
sampleSize := 80000
// the distribution of a particular element of the list, testElement
distribution := make([]float64, listSize)
testElement := rand.Intn(listSize)
// Slice to shuffle
list := make([]int, 0, listSize)
for i := 0; i < listSize; i++ {
list = append(list, i)
}

shuffleAndCount := func(t *testing.T) {
err = rng.Shuffle(listSize, func(i, j int) {
list[i], list[j] = list[j], list[i]
})
require.NoError(t, err)
has := make(map[int]struct{})
for j, e := range list {
// check for repetition
_, ok := has[e]
require.False(t, ok, "duplicated item")
has[e] = struct{}{}
// fill the distribution
if e == testElement {
distribution[j] += 1.0
}
// compute n!
fact := func(n int) int {
f := 1
for i := 1; i <= n; i++ {
f *= i
}
return f
}

t.Run("shuffle a random permutation", func(t *testing.T) {
for k := 0; k < sampleSize; k++ {
shuffleAndCount(t)
}
EvaluateDistributionUniformity(t, distribution)
})

t.Run("shuffle a same permutation", func(t *testing.T) {
for k := 0; k < sampleSize; k++ {
// reinit the permutation to the same value
for listSize := 2; listSize <= 6; listSize++ {
factN := uint64(fact(listSize))
t.Logf("permutation size is %d (factorial is %d)", listSize, factN)
t.Run("shuffle a random permutation", func(t *testing.T) {
list := make([]int, 0, listSize)
for i := 0; i < listSize; i++ {
list[i] = i
list = append(list, i)
}
shuffleAndCount(t)
}
EvaluateDistributionUniformity(t, distribution)
})
permEncoding := func() (uint64, error) {
err = rng.Shuffle(listSize, func(i, j int) {
list[i], list[j] = list[j], list[i]
})
return uint64(EncodePermutation(list)), err
}
BasicDistributionTest(t, factN, 1, permEncoding)
})

t.Run("shuffle a same permutation", func(t *testing.T) {
list := make([]int, listSize)
permEncoding := func() (uint64, error) {
// reinit the permutation to the same value
for i := 0; i < listSize; i++ {
list[i] = i
}
err = rng.Shuffle(listSize, func(i, j int) {
list[i], list[j] = list[j], list[i]
})
return uint64(EncodePermutation(list)), err
}
BasicDistributionTest(t, factN, 1, permEncoding)
})
}
})

t.Run("empty slice", func(t *testing.T) {
Expand Down
46 changes: 38 additions & 8 deletions crypto/random/rand_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,26 @@ import (
"gonum.org/v1/gonum/stat"
)

// this constant should be increased if tests are flakey, but it the higher the constant
// the slower the test
const sampleSizeConstant = 85000
const sampleCoefficient = sampleSizeConstant / 85

// BasicDistributionTest is a test function to run a basic statistic test on `randf` output.
// `randf` is a function that outputs random integers.
// It partitions all outputs into `n` continuous classes and computes the distribution
// over the partition. Each class has a width of `classWidth`: first class is [0..classWidth-1],
// secons class is [classWidth..2*classWidth-1], etc..
// second class is [classWidth..2*classWidth-1], etc..
// It computes the frequency of outputs in the `n` classes and computes the
// standard deviation of frequencies. A small standard deviation is a necessary
// condition for a uniform distribution of `randf` (though is not a guarantee of
// condition for a uniform distribution of `randf` (but is not a guarantee of
// uniformity)
func BasicDistributionTest(t *testing.T, n uint64, classWidth uint64, randf func() (uint64, error)) {
// sample size should ideally be a high number multiple of `n`
// but if `n` is too small, we could use a small sample size so that the test
// isn't too slow
sampleSize := 1000 * n
if n < 100 {
sampleSize = (80000 / n) * n // highest multiple of n less than 80000
sampleSize := sampleCoefficient * n
if n < 80 {
// but if `n` is too small, we use a "high enough" sample size
sampleSize = ((sampleSizeConstant) / n) * n // highest multiple of n less than 80000
}
distribution := make([]float64, n)
// populate the distribution
Expand All @@ -39,7 +43,7 @@ func BasicDistributionTest(t *testing.T, n uint64, classWidth uint64, randf func
EvaluateDistributionUniformity(t, distribution)
}

// EvaluateDistributionUniformity evaluates if the input distribution is close to uinform
// EvaluateDistributionUniformity evaluates if the input distribution is close to uniform
// through a basic quick test.
// The test computes the standard deviation and checks it is small enough compared
// to the distribution mean.
Expand All @@ -49,3 +53,29 @@ func EvaluateDistributionUniformity(t *testing.T, distribution []float64) {
mean := stat.Mean(distribution, nil)
assert.Greater(t, tolerance*mean, stdev, fmt.Sprintf("basic randomness test failed: n: %d, stdev: %v, mean: %v", len(distribution), stdev, mean))
}

// computes a bijection from the set of all permutations
// into the the set [0, n!-1] (where `n` is the size of input `perm`).
// input `perm` is assumed to be a correct permutation of the set [0,n-1]
// (not checked in this function).
func EncodePermutation(perm []int) int {
r := make([]int, len(perm))
// generate Lehmer code
// (for details https://en.wikipedia.org/wiki/Lehmer_code)
for i, x := range perm {
for _, y := range perm[i+1:] {
if y < x {
r[i]++
}
}
}
// Convert to an integer following the factorial number system
// (for details https://en.wikipedia.org/wiki/Factorial_number_system)
m := 0
fact := 1
for i := len(perm) - 1; i >= 0; i-- {
m += r[i] * fact
fact *= len(perm) - i
}
return m
}

0 comments on commit fe6714e

Please sign in to comment.