This repository has been archived by the owner on Nov 3, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathBidirectionalFixedLengthRunningTotal.kt
66 lines (56 loc) · 2.01 KB
/
BidirectionalFixedLengthRunningTotal.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package com.komputation.cpu.demos.runningtotal.bidirectional
import com.komputation.cpu.network.network
import com.komputation.demos.runningtotal.RunningTotalData
import com.komputation.initialization.uniformInitialization
import com.komputation.instructions.continuation.activation.RecurrentActivation
import com.komputation.instructions.continuation.projection.weighting
import com.komputation.instructions.entry.input
import com.komputation.instructions.loss.squaredLoss
import com.komputation.instructions.recurrent.ResultExtraction
import com.komputation.instructions.recurrent.bidirectionalRecurrent
import com.komputation.loss.printLoss
import com.komputation.optimization.stochasticGradientDescent
import java.util.*
/*
Input 1 2 3
Forward 1 3 6
Backward 6 5 3
Sum 7 8 9
*/
fun main(args: Array<String>) {
val random = Random(1)
val initialization = uniformInitialization(random, -0.01f, 0.01f)
val optimization = stochasticGradientDescent(0.001f)
val steps = 2
val numberExamples = 10_000
val input = RunningTotalData.generateFixedLengthInput(random, steps, 0, 10, numberExamples)
val forwardTargets = RunningTotalData.generateTargets(input)
val backwardTargets = RunningTotalData.generateReversedTargets(input)
val sumTargets = Array(numberExamples) { index ->
val forwardTarget = forwardTargets[index]
val backwardTarget = backwardTargets[index]
forwardTarget
.zip(backwardTarget).map { (a, b) -> a+b }
.toFloatArray()
}
network(
1,
input(1, steps),
bidirectionalRecurrent(
1,
RecurrentActivation.Identity,
ResultExtraction.AllSteps,
initialization,
optimization
),
weighting(1, initialization, optimization)
)
.training(
input,
sumTargets,
2,
squaredLoss(),
printLoss
)
.run()
}