-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlhd_lin_sing_tr_gauss_clicks.m
54 lines (50 loc) · 2.05 KB
/
lhd_lin_sing_tr_gauss_clicks.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
function lklh = lhd_lin_sing_tr_gauss_clicks(dec, noise_stdev, kappa,...
T, left_clicks, right_clicks, gammas)
% computes the log-likelihood of choosing dec (+1 or -1), given the click
% trains, the discounting rate gamma, and the stdev of the Gaussian noise
% applied to each click. Code based on appendix D.2 from file
% clicks_ZK3.pdf
%
% ARGS:
% dec: 1 or -1; synthetic decision
% noise_stdev: stdev of Gaussian jump height at clicks
% kappa: mean of jump height for right click
% T: trial duration in sec
% left_clicks: must be row vector (one of the click trains at most may be empty)
% right_clicks: see above
% gammas: col vector of discounting rates to use
% RETURNS:
% NaN if both click trains are empty, otherwise, returns column
% vector with log probabilities, one entry per gamma value.
% NOTES:
% Called by: fit_linear_model()
% Somehow, without the following two lines, log-likelihood is constant for
% high values of gamma (above 22 roughly)
right_clicks=double(right_clicks);
left_clicks=double(left_clicks);
if size(gammas,2) > 1
error('gammas should be a column vector')
elseif isempty([right_clicks,left_clicks])
error('both click trains are empty')
elseif isempty(right_clicks)
pos1 = sum(exp(gammas*reshape(right_clicks,1,[])),2);
pos2 = sum(exp(2*gammas*reshape(right_clicks,1,[])),2);
neg1 = sum(exp(gammas*left_clicks),2);
neg2 = sum(exp(2*gammas*left_clicks),2);
elseif isempty(left_clicks)
pos1 = sum(exp(gammas*right_clicks),2);
pos2 = sum(exp(2*gammas*right_clicks),2);
neg1 = sum(exp(gammas*reshape(left_clicks,1,[])),2);
neg2 = sum(exp(2*gammas*reshape(left_clicks,1,[])),2);
else
pos1 = sum(exp(gammas*right_clicks),2);
pos2 = sum(exp(2*gammas*right_clicks),2);
neg1 = sum(exp(gammas*left_clicks),2);
neg2 = sum(exp(2*gammas*left_clicks),2);
end
meanTerm = single(dec) * kappa * exp(-gammas*T) .* (pos1 - neg1);
varTerm = noise_stdev^2 * exp(-2*gammas*T) .* (pos2 + neg2);
prob=normcdf(meanTerm ./ sqrt(varTerm));
prob(prob<eps) = eps;
lklh = log(prob);
end