0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071 function [estimate, varargout] = gmmb_fj(data, varargin);
0072
0073 [N, D] = size(data);
0074
0075
0076 conf = struct(...
0077 'maxloops', 500, ...
0078 'Cmax', ceil(min(50, N/(D*D)/3)), ...
0079 'Cmin', 1, ...
0080 'verbose', 0, ...
0081 'thr', 1e-6, ...
0082 'animate', 0, ...
0083 'covtype', 0, ...
0084 'broken', 1, ...
0085 'logging', 0 ...
0086 );
0087
0088 if nargout>1
0089 conf.logging = 1;
0090 varargout{1} = [];
0091 end
0092
0093 conf = getargs(conf, varargin);
0094
0095 C = conf.Cmax;
0096
0097 if nargout<2
0098 conf.logging=0;
0099 end
0100
0101
0102 log_covfixer2 = {};
0103 log_loglikes = {};
0104 log_costs = {};
0105 log_annih = {};
0106 log_initialmix = {};
0107 log_mixtures = {};
0108
0109
0110 if (C<1) | (C>N)
0111 C = N;
0112 mu = data.';
0113 else
0114
0115 permi = randperm(N);
0116 mu = data(permi(1:C),:).';
0117 end
0118
0119
0120
0121 s2 = max(diag(gmmb_covfixer(cov(data,1))/10));
0122 sigma = repmat(s2*eye(D), [1 1 C]);
0123
0124
0125 alpha = ones(1,C) * (1/C);
0126
0127
0128 log_initialmix = struct(...
0129 'weight', alpha, ...
0130 'mu', mu, ...
0131 'sigma', sigma);
0132
0133
0134
0135 if isreal(data) | (conf.broken ~= 0)
0136 if conf.covtype == 1
0137 Nparc = D+D;
0138 else
0139 Nparc = D+D*(D+1)/2;
0140 end
0141 else
0142
0143 if conf.covtype == 1
0144 Nparc = 2*D + D;
0145 else
0146 Nparc = 2*D + D*D;
0147 end
0148 end
0149 Nparc2 = Nparc/2;
0150
0151 N_limit = (Nparc+1)*3*conf.Cmin;
0152 if N < N_limit
0153 warning_wrap('gmmb_fj:data_amount', ...
0154 ['Training data may be insufficient for selected ' ...
0155 'minimum number of components. ' ...
0156 'Have: ' num2str(N) ', recommended: >' num2str(N_limit) ...
0157 ' points.']);
0158 end
0159
0160
0161 if conf.animate ~= 0
0162 aniH = my_plot_init;
0163 my_plot_ellipses(aniH, data, mu, sigma, alpha);
0164 end
0165
0166
0167 t = 0;
0168 Cnz = C;
0169 Lmin = NaN;
0170
0171 u = zeros(N,C);
0172 for c = 1:C
0173 u(:,c) = gmmb_cmvnpdf(data, mu(:,c).', sigma(:,:,c));
0174 end
0175 indic = u .* repmat(alpha, N,1);
0176
0177 old_loglike = sum(log(sum(realmin+indic, 2)));
0178 old_L = Nparc2*sum(log(alpha)) + (Nparc2+0.5)*Cnz*log(N) - old_loglike;
0179
0180
0181 while Cnz >= conf.Cmin
0182 repeating = 1;
0183
0184 fixing_cycles = 0;
0185 loops = 0;
0186 while repeating
0187 t = t+1;
0188 loops = loops +1;
0189
0190 fixed_on_this_round = 0;
0191 log_covfixer2{t,1} = 0;
0192 c = 1;
0193 while c <= C
0194 indic = u .* repmat(alpha, N,1);
0195 normindic = indic ./ (realmin + repmat(sum(indic,2), 1,C));
0196
0197 normf = 1/sum(normindic(:,c));
0198 aux = repmat(normindic(:,c), 1,D) .* data;
0199
0200 nmu = normf * sum(aux,1);
0201 mu(:,c) = nmu.';
0202
0203 if conf.covtype == 1
0204 nsigma = normf*diag(sum(aux .* conj(data), 1)) - diag(nmu.*conj(nmu));
0205 else
0206 nsigma = normf*(aux' * data) - nmu'*nmu;
0207 end
0208 [sigma(:,:,c) log_fixcount] = gmmb_covfixer(nsigma);
0209
0210
0211
0212
0213
0214 if conf.logging>0
0215
0216
0217 log_covfixer2{t,1} = ...
0218 log_covfixer2{t,1} + log_fixcount;
0219 end
0220
0221 alpha(c) = max(0, sum(normindic(:,c))-Nparc2) / N;
0222 alpha = alpha / sum(alpha);
0223
0224 if ~all( isfinite(alpha(:)) )
0225
0226
0227
0228 warning_wrap('gmmb_fj:weight_finity', 'Mixture weights are no longer finite, aborting estimation.');
0229 alpha(:) = 0;
0230 Cnz = 0;
0231 repeating = 0;
0232 end
0233
0234 if alpha(c) == 0
0235 Cnz = Cnz -1;
0236 else
0237 if log_fixcount ~= 0
0238
0239
0240 fixed_on_this_round = 1;
0241 end
0242 try
0243 u(:,c) = gmmb_cmvnpdf( data, ...
0244 mu(:,c).', sigma(:,:,c) );
0245 catch
0246 disp('covariance went bzrk !!!');
0247 sigma(:,:,c)
0248
0249 Cnz = 0;
0250 end
0251 end
0252 c=c+1;
0253
0254 if Cnz <= 0
0255
0256
0257 error('Estimation failed, number of components fell to zero. Not enough training data?');
0258 end
0259
0260 end
0261
0262
0263 annihilated_count = length(find(alpha==0));
0264 if annihilated_count > 0
0265 nz = find(alpha>0);
0266 alpha = alpha(nz);
0267 mu = mu(:,nz);
0268 sigma = sigma(:,:,nz);
0269 u = u(:,nz);
0270 C = length(nz);
0271 end
0272
0273 if conf.animate ~= 0
0274 my_plot_ellipses(aniH, data, mu, sigma, alpha);
0275 end
0276
0277 u = zeros(N,C);
0278 for c = 1:C
0279 u(:,c) = gmmb_cmvnpdf(data, mu(:,c).', sigma(:,:,c));
0280 end
0281 indic = u .* repmat(alpha, N,1);
0282
0283 loglike = sum(log(realmin+sum(indic, 2)));
0284 L = Nparc2*sum(log(alpha)) + (Nparc2+0.5)*Cnz*log(N) - loglike;
0285
0286
0287 if conf.verbose ~= 0
0288 disp(['Cnz=' num2str(Cnz) ' t=' num2str(t) ' '...
0289 num2str(abs(loglike - old_loglike)) ...
0290 ' <? ' num2str(conf.thr*abs(old_loglike))]);
0291 disp(['t=' num2str(t) ' L= ' num2str(L)]);
0292 end
0293
0294 if conf.logging>0
0295 log_loglikes{t} = loglike;
0296 log_costs{t} = L;
0297 log_annih{t} = [annihilated_count, 0];
0298 end
0299 if conf.logging>1
0300 log_mixtures{t} = struct(...
0301 'weight', alpha, ...
0302 'mu', mu, ...
0303 'sigma', sigma);
0304 end
0305
0306 if fixed_on_this_round ~= 0
0307
0308
0309 fixing_cycles = fixing_cycles +1;
0310 if conf.verbose ~= 0
0311 disp(['fix cycle ' num2str(fixing_cycles)]);
0312 end
0313 else
0314
0315
0316 fixing_cycles = 0;
0317 if (abs(loglike/old_loglike -1) < conf.thr)
0318 repeating = 0;
0319 end
0320 end
0321
0322 old_L = L;
0323 old_loglike = loglike;
0324
0325 if fixing_cycles > 20
0326 repeating = 0;
0327 end
0328 if loops > conf.maxloops
0329 repeating = 0;
0330 end
0331 end
0332
0333 if isnan(Lmin) | (L <= Lmin)
0334 Lmin = L;
0335 estimate = struct('mu', mu,...
0336 'sigma', sigma,...
0337 'weight', alpha.');
0338 end
0339 if conf.verbose ~= 0
0340 disp(['Cnz = ' num2str(Cnz)]);
0341 end
0342
0343
0344 m = find(alpha == min(alpha(alpha>0)));
0345 alpha(m(1)) = 0;
0346 Cnz = Cnz -1;
0347
0348
0349 if conf.logging > 0
0350 log_annih{t}(2) = 1;
0351 end
0352
0353 if Cnz > 0
0354 alpha = alpha / sum(alpha);
0355
0356
0357 if length(find(alpha==0)) > 0
0358 nz = find(alpha>0);
0359 alpha = alpha(nz);
0360 mu = mu(:,nz);
0361 sigma = sigma(:,:,nz);
0362 u = u(:,nz);
0363 C = length(nz);
0364 end
0365
0366 u = zeros(N,C);
0367 for c = 1:C
0368 u(:,c) = gmmb_cmvnpdf(data, mu(:,c).', sigma(:,:,c));
0369 end
0370 indic = u .* repmat(alpha, N,1);
0371
0372 old_loglike = sum(log(realmin+sum(indic, 2)));
0373 old_L = Nparc2*sum(log(alpha)) + (Nparc2+0.5)*Cnz*log(N) - old_loglike;
0374 end
0375 end
0376
0377
0378
0379 if conf.logging>1
0380 varargout{1} = struct(...
0381 'iterations', {t}, ...
0382 'costs', {cat(1,log_costs{:})}, ...
0383 'annihilations', {sparse(cat(1,log_annih{:}))}, ...
0384 'covfixer2', {cat(1,log_covfixer2{:})}, ...
0385 'loglikes', {cat(1,log_loglikes{:})}, ...
0386 'initialmix', {log_initialmix}, ...
0387 'mixtures', {log_mixtures});
0388 end
0389 if conf.logging == 1
0390 varargout{1} = struct(...
0391 'iterations', {t}, ...
0392 'costs', {cat(1,log_costs{:})}, ...
0393 'annihilations', {sparse(cat(1,log_annih{:}))}, ...
0394 'covfixer2', {cat(1,log_covfixer2{:})}, ...
0395 'loglikes', {cat(1,log_loglikes{:})} ...
0396 );
0397 end
0398
0399
0400
0401 e = estimate;
0402 inds = find(e.weight>0);
0403 estimate.mu = e.mu(:,inds);
0404 estimate.sigma = e.sigma(:,:,inds);
0405 estimate.weight = e.weight(inds);
0406
0407 if conf.animate ~= 0
0408 my_plot_ellipses(aniH, data, estimate.mu, estimate.sigma, estimate.weight);
0409 end
0410
0411
0412
0413
0414
0415 function h = my_plot_init;
0416 h = figure;
0417 figure(h);
0418 title('Distribution of x_1 and x_2 values','FontSize',14);
0419 xlabel('x_1 value','FontSize',14);
0420 ylabel('x_2 value','FontSize',14);
0421 zlabel('weight','FontSize',14);
0422 view(2)
0423 tic;
0424
0425 function my_plot_ellipses(h, data, mu, sigma, weight);
0426 dtime = 0.3;
0427
0428 D = size(mu, 1);
0429
0430 if D ~= 2
0431 error('Can plot only 2D objects.');
0432 end
0433
0434 [x,y,z] = cylinder([2 2], 40);
0435 xy = [ x(1,:) ; y(1,:) ];
0436
0437 figure(h);
0438
0439 plot(data(:,1), data(:,2), 'rx');
0440
0441 hold on
0442 C = size(mu, 2);
0443 for c = 1:C
0444 mxy = chol(sigma(:,:,c))' * xy;
0445 x = mxy(1,:) + mu(1,c);
0446 y = mxy(2,:) + mu(2,c);
0447 z = ones(size(x))*weight(c);
0448 plot3(x,y,z, 'k-');
0449 end
0450 drawnow;
0451 hold off
0452
0453 t = toc;
0454 if t+0.01<dtime
0455 pause(dtime-t);
0456 end
0457 tic
0458