-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogadd_bug.jl
52 lines (44 loc) · 1.71 KB
/
logadd_bug.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
using Zygote
############ Recurrence relation for state of the system
A = [1 0 0;
1 1 0;
1 1 1]
################## Loss of the dynamical system
function loss(x) # Input x > 0
state = x[:, 1] # state_0 = [1, 0, 0, ...] => state_1 = x_1
for i in 2:size(x)[2]
state = A*state # Update state depending on previous state and
state = state .* x[:, i] # Current input
end
state[end] # loss is a simple function of state
end
################################# Generate some simple data
X = reshape((1:3*7).%5, 3, 7) .+ 1
display(X)
@show loss(X)
display(gradient(loss, X)[1])
################################# Do the same operation in the log-domain
logadd1(a, b) = log(exp(a)+exp(b))
logadd2(a, b) = b + log(1 + exp(a - b))
logadd3(a, b) = b > a ? b + log(1 + exp(a - b)) : a + log(1 + exp(b - a))
logadd4(a, b) = a == -Inf ? b :
b == -Inf ? a :
b > a ? b + log(1 + exp(a - b)) : a + log(1 + exp(b - a))
logadd = logadd3 # Pick one of the four implementations
function logloss(x)
logstate = log.(x[:, 1])
for i in 2:size(x)[2]
logstate = [logstate[1],
logadd(logstate[1], logstate[2]),
logadd(logstate[1], logadd(logstate[2], logstate[3]))]
logstate += log.(x[:, i])
end
exp(logstate[end])
end
@show logloss(X)
display(gradient(logloss, X)[1])
for i in 1:10
a, b = randn(2)
println((logadd1(a, b), logadd2(a, b), logadd3(a, b), logadd4(a, b)))
println((gradient(logadd1, a, b), gradient(logadd2, a, b), gradient(logadd3, a, b), gradient(logadd4, a, b)))
end