0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 function [] = gmmb_demo01;
0021
0022
0023 disp('Generating data from three classes with 3, 1 and 2 Gaussian components...');
0024
0025
0026 alldata = [ ...
0027 mvnrnd([2 1], covrot(1, 0.7, 1), 200) ;...
0028 mvnrnd([-2 1], covrot(0.4, 1.2, pi/3), 200) ;...
0029 mvnrnd([0 1.5], covrot(0.5, 0.5, 0), 150) ;...
0030 mvnrnd([-3 -1.5], covrot(0.5, 0.5, 0), 150) ;...
0031 mvnrnd([3 -1.5], covrot(0.5, 0.5, 0), 150) ;...
0032 mvnrnd([0 -2.5], covrot(2.5, 1.5, 0), 200) ;...
0033 ];
0034
0035 alltype = [ ...
0036 1*ones(200,1); ...
0037 1*ones(200,1); ...
0038 2*ones(150,1); ...
0039 3*ones(150,1); ...
0040 3*ones(150,1); ...
0041 1*ones(200,1); ...
0042 ];
0043
0044 disp('Separating test set (30%) and training set (70%)...');
0045 [Ptrain Ttrain Ptest Ttest] = subset(alldata, alltype, round(size(alltype, 1)*0.70));
0046 figH = figure;
0047 plot_data(Ptrain, Ttrain, ['xr'; 'xb'; 'xg']);
0048 disp('Now we have this kind of training set, three classes.');
0049 disp('Next we will use the FJ algorithm to learn those classes.');
0050 input('<press enter>');
0051
0052
0053 FJ_params = { 'Cmax', 25, 'thr', 1e-3, 'animate', 1 }
0054 disp('Running FJ...');
0055
0056 bayesS = gmmb_create(Ptrain, Ttrain, 'FJ', FJ_params{:});
0057
0058 disp('Training complete.');
0059 disp('There are now 3 more figures open, in those you can see how the FJ learned the distributions.');
0060 input('<press enter>');
0061
0062
0063
0064 figure(figH);
0065 disp('This is our test set. Let''s forget the class labels and classify the samples.');
0066 plot_data(Ptest, Ttest, ['xr'; 'xb'; 'xg']);
0067 input('<press enter>');
0068
0069
0070
0071 pdfmat = gmmb_pdf(Ptest, bayesS);
0072 postprob = gmmb_normalize( gmmb_weightprior(pdfmat, bayesS) );
0073 result = gmmb_decide(postprob);
0074
0075 disp('Done classifying. We used the Bayesian classifier.');
0076
0077 plot_data(Ptest, result, ['xr'; 'xb'; 'xg']);
0078 rat = sum(result == Ttest) / length(Ttest);
0079 disp(['We got ' num2str(rat*100) ' percent correct.']);
0080 disp('The misclassified points are circled.');
0081 miss = Ptest(result ~= Ttest, :);
0082 hold on
0083 plot(miss(:,1), miss(:,2), 'ok');
0084 input('<press enter>');
0085
0086
0087 figure
0088
0089
0090 histS = gmmb_generatehist(bayesS, 1000);
0091 outlier_mask = gmmb_fracthresh(pdfmat, histS, 0.9);
0092 postprob(outlier_mask) = 0;
0093 result = gmmb_decide(postprob);
0094
0095
0096
0097
0098
0099
0100
0101 plot_data(Ptest, result+1, ['.k'; 'xr'; 'xb'; 'xg']);
0102 miss = Ptest((result ~= Ttest)&(result~=0), :);
0103 hold on
0104 plot(miss(:,1), miss(:,2), 'ok');
0105 disp('Here we classified the test data again using threshold of density quantile=0.9.');
0106 disp('The points classified as outliers are black dots.');
0107 disp('The misclassified points, that are not outliers, are circled.');
0108 input('<press enter>');
0109
0110 disp('The End.');
0111
0112
0113
0114
0115
0116
0117
0118 function [tdata, ttype, left_data, left_type] = subset(data, type, n);
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128 tdata = zeros(n, size(data,2));
0129 ttype = zeros(n, 1);
0130 left_data = [];
0131 left_type = [];
0132
0133 N = size(data,1);
0134 if n>N
0135 tdata = data;
0136 ttype = type;
0137 return;
0138 end
0139
0140 left_data = zeros(N-n, size(data,2));
0141 left_type = zeros(N-n, 1);
0142
0143 done=0;
0144 over=0;
0145 e=0;
0146 unkst = unique(type)';
0147 for k = unkst
0148 cdata = data(type==k, :);
0149 cN = size(cdata,1);
0150 sn = min(round(n*cN/N), n-done);
0151 e = e + sn - n*cN/N;
0152 if e >= 1
0153 e = e-1;
0154 sn = sn -1;
0155 end
0156 if e <= -1
0157 e = e+1;
0158 sn = sn +1;
0159 end
0160 perm = randperm(cN);
0161 tdata((done+1):(done+sn), :) = cdata(perm(1:sn), :);
0162 left_data((over+1):(over+cN-sn), :) = cdata(perm((sn+1):cN), :);
0163 ttype((done+1):(done+sn), 1) = k;
0164 left_type((over+1):(over+cN-sn), :) = k;
0165 done = done + sn;
0166 over = over + cN - sn;
0167 end
0168
0169
0170
0171 function C = covrot(x, y, th);
0172
0173
0174
0175
0176
0177 O = [x 0; 0 y];
0178 R = [cos(th) -sin(th); sin(th) cos(th)];
0179 M = R * O;
0180 C = M * M';
0181
0182
0183 function plot_data(data, type, colors);
0184
0185 for k = 1:max(type)
0186 x = data(type==k,1);
0187 y = data(type==k,2);
0188 if ~isempty(x)
0189 h = plot(x, y, colors(mod(k-1,size(colors,1))+1,:));
0190 end
0191
0192 hold on
0193 end
0194 hold off