-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathleaf.go
137 lines (123 loc) · 3.93 KB
/
leaf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package confeito
import "fmt"
// Feature ID for a terminal leaf.
const _FEATURE_ID_TERMINAL_LEAF = _FEATURE_ID_ILLEGAL
// Leaf is an element in a tree.
// It is either of non-terminal or terminal.
// If it is non-terminal, then it has left and right leaf, otherwise it has a value which can be any object (interface{}).
//
// In predicting the value of the given feature, if feature[featureID] <= threshold, then the left leaf is taken, else the right one is taken.
// This process is repeated until the cursor points a terminal leaf, and returns the value of it.
//
// Leaf is slow, because it is designed to use manipulating tree structure in training-phase or testing its correctness.
type Leaf struct {
featureID FeatureID
threshold float32
value interface{}
left, right *Leaf
}
// NewLeaf returns a new non-terminal leaf with feature ID and threshold.
// Also, the function sets the default value of the left and right leaf.
//
// This function returns an error if featureID is FEATURE_ID_TERMINAL_LEAF.
func NewLeaf(featureID FeatureID, threshold float32, leftValue, rightValue interface{}) (*Leaf, error) {
if featureID == _FEATURE_ID_ILLEGAL {
return nil, fmt.Errorf("featureID must be valid")
}
left, _ := NewTerminalLeaf(leftValue)
right, _ := NewTerminalLeaf(rightValue)
return &Leaf{
featureID: featureID,
threshold: threshold,
left: left,
right: right,
}, nil
}
// NewTerminalLeaf returns a new terminal leaf with value.
//
// This function returns no error currently.
func NewTerminalLeaf(value interface{}) (*Leaf, error) {
return &Leaf{
featureID: _FEATURE_ID_TERMINAL_LEAF,
value: value,
}, nil
}
// IsTerminal returns true if l is terminal, otherwise false.
func (l *Leaf) IsTerminal() bool {
return l.featureID == _FEATURE_ID_TERMINAL_LEAF
}
// Left returns the left leaf.
// If l is terminal, then this returns nil.
func (l *Leaf) Left() *Leaf {
return l.left
}
// Predict returns the predicted value of the given feature.
//
// This function returns an errors at getting feature values of x.
func (l *Leaf) Predict(x FeatureVector) (value interface{}, err error) {
if l.IsTerminal() {
return l.value, nil
}
if fvalue, _ := x.Get(l.featureID); fvalue > l.threshold {
return l.right.Predict(x)
}
return l.left.Predict(x)
}
// Right returns the right leaf.
// If l is terminal, then this returns nil.
func (l *Leaf) Right() *Leaf {
return l.right
}
// SetLeft sets the left leaf.
//
// This function returns an error if l is terminal, or the new leaf is nil.
func (l *Leaf) SetLeft(left *Leaf) error {
if l.IsTerminal() {
return fmt.Errorf("terminal leaf cannot have left leaf")
}
if left == nil {
return fmt.Errorf("left leaf of non-terminal leaf must not be nil")
}
l.left = left
return nil
}
// SetRight sets the right leaf.
//
// This function returns an error if l is terminal, or the new leaf is nil.
func (l *Leaf) SetRight(right *Leaf) error {
if l.IsTerminal() {
return fmt.Errorf("terminal leaf cannot have right leaf")
}
if right == nil {
return fmt.Errorf("right leaf of non-terminal leaf must not be nil")
}
l.right = right
return nil
}
// String returns the human-readable string representation of l.
func (l *Leaf) String() string {
if l.IsTerminal() {
return fmt.Sprintf("%g", l.value)
}
return fmt.Sprintf("(feature[%d] <= %g ? %s : %s)", l.featureID, l.threshold, l.left, l.right)
}
// Threshold returns the threshold with feature ID of l.
//
// This function returns an error if l is terminal.
func (l *Leaf) Threshold() (featureID FeatureID, threshold float32, err error) {
if l.IsTerminal() {
err = fmt.Errorf("terminal leaf does not have threshold")
return
}
return l.featureID, l.threshold, nil
}
// Value returns the value of the terminal leaf l.
//
// This function returns an error if l is not terminal.
func (l *Leaf) Value() (value interface{}, err error) {
if !l.IsTerminal() {
err = fmt.Errorf("non-terminal leaf does not have value")
return
}
return l.value, nil
}