0001 function [W,M,R,varargout] = gmmbvl_em(X,kmax,nr_of_cand,plo,dia, logging)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 THRESHOLD = 1e-5;
0026
0027 [n,d] = size(X);
0028 n1=ones(n,1);
0029 d1=ones(1,d);
0030
0031
0032 if d > 2
0033 plo = 0;
0034 end
0035
0036 if plo
0037 figure(1);
0038 set(1,'Double','on');
0039 end
0040
0041 if nr_of_cand
0042 k = 1;
0043 if dia
0044 fprintf('Greedy ');
0045 end
0046 else
0047 k = kmax;
0048 if dia
0049 fprintf('Non-greedy ');
0050 end
0051 end
0052
0053 if dia
0054 fprintf('EM initialization\n');
0055 end
0056
0057
0058 if nargout == 4
0059 varargout{1} = [];
0060 end
0061 log_loglikes = {};
0062 full_em_loops = 0;
0063 part_em_loops = 0;
0064
0065
0066 [W,M,R,P,sigma] = gmmbvl_em_init_km(X,k,0);
0067 sigma=sigma^2;
0068
0069 oldlogl = -realmax;
0070
0071 while 1
0072
0073 if dia
0074 fprintf('EM steps');
0075 end
0076
0077 while 1
0078 [W,M,R] = gmmbvl_em_step(X,W,M,R,P,plo);
0079 full_em_loops = full_em_loops+1;
0080
0081 if dia
0082 fprintf('.');
0083 end
0084
0085
0086 L = gmmbvl_em_gauss(X,M,R);
0087
0088 F = L * W;
0089 F(find(F < realmin)) = realmin;
0090 logl = mean(log(F));
0091
0092
0093 P = L .* (ones(n,1)*W') ./ (F*ones(1,k));
0094
0095 if logging > 0
0096 log_loglikes{full_em_loops+part_em_loops} = ...
0097 sum(log(F));
0098 end
0099
0100 if abs(logl/oldlogl-1) < THRESHOLD
0101 if dia
0102 fprintf('\n');
0103 fprintf('Logl = %g\n', logl);
0104 end
0105 break;
0106 end
0107 oldlogl = logl;
0108 end
0109
0110 if k == kmax;
0111 break;
0112 end
0113
0114 if dia
0115 fprintf('Trying component allocation');
0116 end
0117
0118 [Mnew,Rnew,alpha] = gmmbvl_rand_split(P,X,M,R,sigma,F,W,nr_of_cand);
0119 if alpha==0
0120 break;
0121 end
0122
0123
0124 K = gmmbvl_em_gauss(X,Mnew,Rnew);
0125 PP = F*(1-alpha)+K*alpha;
0126 LOGL = mean(log(PP));
0127
0128
0129 veryoldlogl = logl;
0130 oldlogl = LOGL;
0131 done_here=0;
0132
0133 Pnew = (K.*(ones(n,1)*alpha))./PP;
0134
0135 while ~done_here
0136 if dia
0137 fprintf('*');
0138 end
0139
0140 [alpha,Mnew,Rnew] = gmmbvl_em_step(X,alpha,Mnew,Rnew,Pnew,0);
0141 part_em_loops = part_em_loops+1;
0142
0143 K = gmmbvl_em_gauss(X,Mnew,Rnew);
0144
0145 Fnew = F*(1-alpha)+K*alpha;
0146 Pnew = K*alpha./Fnew;
0147 logl = mean(log(Fnew));
0148
0149 if logging > 0
0150 log_loglikes{full_em_loops+part_em_loops} = ...
0151 sum(log(Fnew));
0152 end
0153
0154 if abs(logl/oldlogl-1)<THRESHOLD
0155 done_here=1;
0156 end
0157
0158 oldlogl=logl;
0159 end
0160
0161 if logl <= veryoldlogl
0162 if dia
0163 fprintf('Mixture uses only %d components\n', k);
0164 end
0165
0166 break;
0167 end
0168
0169
0170 M = [M; Mnew];
0171 R = [R; Rnew];
0172 W = [(1-alpha)*W; alpha];
0173 k = k + 1;
0174
0175 if dia
0176 fprintf(' k = %d\n', k);
0177 fprintf('LogL = %g\n', logl);
0178 end
0179
0180
0181 L = gmmbvl_em_gauss(X,M,R);
0182 F = L * W;
0183 F(find(F<realmin))=realmin;
0184 P = L .* (ones(n,1)*W') ./ (F*ones(1,k));
0185 end
0186
0187
0188 if logging > 1
0189 varargout{1} = struct( ...
0190 'iterations', {full_em_loops + part_em_loops}, ...
0191 'loglikes', {cat(1,log_loglikes{:})} ...
0192 );
0193 end
0194 if logging == 1
0195 varargout{1} = struct( ...
0196 'iterations', {full_em_loops + part_em_loops}, ...
0197 'loglikes', {cat(1,log_loglikes{:})} ...
0198 );
0199 end
0200