-
Notifications
You must be signed in to change notification settings - Fork 2
/
coster_test.go
50 lines (43 loc) · 1.37 KB
/
coster_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package automata
import (
"testing"
)
func TestMeanSquaredErrorCost(t *testing.T) {
mse := MeanSquaredErrorCost{}
output := mse.Cost([]float64{0, 0.5, 1}, []float64{1, 0.5, 0})
want := 0.66666666666666666666
if output != want {
t.Errorf("Cost([0, 0.5, 1], [1, 0.5, 0]) : want %f, got %f", want, output)
}
output = mse.Cost([]float64{0, 0.5, 1}, []float64{0, 0.5, 1})
want = 0
if output != want {
t.Errorf("Cost([0, 0.5, 1], [0, 0.5, 1]) : want %f, got %f", want, output)
}
}
func TestCrossEntropyCost(t *testing.T) {
ce := CrossEntropyCost{}
output := ce.Cost([]float64{0, 0.5, 1}, []float64{1, 0.5, 0})
want := 69.66613905215368
if output != want {
t.Errorf("Cost([0, 0.5, 1], [1, 0.5, 0]) : want %f, got %f", want, output)
}
output = ce.Cost([]float64{0, 0.5, 1}, []float64{0, 0.5, 1})
want = 0.693147180559941
if output != want {
t.Errorf("Cost([0, 0.5, 1], [0, 0.5, 1]) : want %f, got %f", want, output)
}
}
func TestBinaryCost(t *testing.T) {
bc := BinaryCost{}
output := bc.Cost([]float64{0, 0.5, 1}, []float64{1, 0.5, 0})
want := 2.0
if output != want {
t.Errorf("Cost([0, 0.5, 1], [1, 0.5, 0]) : want %f, got %f", want, output)
}
output = bc.Cost([]float64{0, 0.5, 1}, []float64{0, 0.5, 1})
want = 0.0
if output != want {
t.Errorf("Cost([0, 0.5, 1], [0, 0.5, 1]) : want %f, got %f", want, output)
}
}