-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathcounters.go
64 lines (52 loc) · 1.41 KB
/
counters.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package bandit
import (
"fmt"
"math/rand"
"sync"
"time"
)
// NewCounters constructs counters for given arms
func NewCounters(arms int) Counters {
return Counters{
arms: arms,
counts: make([]int, arms),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
values: make([]float64, arms),
}
}
// Counters maintain internal strategy state
type Counters struct {
sync.Mutex
arms int // number of arms present in this strategy
counts []int // number of pulls. len(counts) == arms.
rand *rand.Rand // seeded random number generator
values []float64 // running average reward per arm. len(values) == arms.
}
// Update the running average, where arm is the 1 indexed arm
func (c *Counters) Update(arm int, reward float64) {
c.Lock()
defer c.Unlock()
arm--
count := c.counts[arm]
c.values[arm] = ((c.values[arm] * float64(count-1)) + reward) / float64(count)
}
// Init the strategy to a new counter state.
func (c *Counters) Init(snapshot *Counters) error {
if c.arms != snapshot.arms {
return fmt.Errorf("cannot %d arms with %d arms", c.arms, snapshot.arms)
}
if snapshot.arms == 0 {
return fmt.Errorf("need at least 1 arm")
}
c.Lock()
defer c.Unlock()
c.counts = snapshot.counts
c.rand = snapshot.rand
c.values = snapshot.values
return nil
}
// Reset the strategy to initial state.
func (c *Counters) Reset() {
c.counts = make([]int, c.arms)
c.values = make([]float64, c.arms)
}