-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdirichlet_fit.m
55 lines (52 loc) · 1.33 KB
/
dirichlet_fit.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
function [a,run] = dirichlet_fit(data,a,bar_p)
% DIRICHLET_FIT Maximum-likelihood Dirichlet distribution.
%
% DIRICHLET_FIT(data) returns the MLE (a) for the matrix DATA.
% Each row of DATA is a probability vector.
% DIRICHLET_FIT(data,a) provides an initial guess A to speed up the search.
%
% The Dirichlet distribution is parameterized as
% p(p) = (Gamma(sum_k a_k)/prod_k Gamma(a_k)) prod_k p_k^(a_k-1)
%
% The algorithm is an alternating optimization for m and for s, described in
% "Estimating a Dirichlet distribution" by T. Minka.
% Written by Tom Minka
%[N,K] = size(data);
if nargin < 3
bar_p = mean(log(data));
addflops(numel(data)*(flops_exp + 1));
end
K = length(bar_p);
if nargin < 2
a = dirichlet_moment_match(data);
%s = dirichlet_initial_s(a,bar_p);
%a = s*a/sum(a);
end
s = sum(a);
if s <= 0
% bad initial guess; fix it
disp('fixing initial guess')
if s == 0
a = ones(size(a))/length(a);
else
a = a/s;
end
s = 1;
end
for iter = 1:100
old_a = a;
% time for fit_s is negligible compared to fit_m
a = dirichlet_fit_s(data, a, bar_p);
s = sum(a);
a = dirichlet_fit_m(data, a, bar_p, 1);
m = a/s;
addflops(2*K-1);
if nargout > 1
run.e(iter) = dirichlet_logProb_fast(a, bar_p);
run.flops(iter) = flops;
end
if abs(a - old_a) < 1e-4
break
end
end
%flops(flops + iter*(2*K-1));