├── CalcMetrics.m ├── MultiNMF_incomplete_original_l21.m ├── NMF.m ├── PerViewNMF_incomplete_original_l21.m ├── README.md ├── bestMap.m ├── hungarian.m ├── litekmeans.m ├── nmi.m └── printResult.m /CalcMetrics.m: -------------------------------------------------------------------------------- 1 | function [AC, nmi_value, error_cnt] = CalcMetrics(label, result) 2 | % Written by Jialu Liu 3 | result = bestMap(label, result); 4 | error_cnt = sum(label ~= result); 5 | AC = length(find(label == result))/length(label); 6 | 7 | nmi_value = nmi(label, result); 8 | -------------------------------------------------------------------------------- /MultiNMF_incomplete_original_l21.m: -------------------------------------------------------------------------------- 1 | function [U, V, centroidU, log, ac] = MultiNMF_incomplete_original_l21(X, C, K, label, options) 2 | % This is a module of Multi-View Non-negative Matrix Factorization(MultiNMF) 3 | % 4 | % Notation: 5 | % X ... a cell array containing all views for the data. Each 6 | % K ... number of hidden factors 7 | % label ... ground truth labels 8 | % options ... a cell containing the parameters 9 | % Depend on the dataset and the normalization of the dataset, the two parameters alpha and beta need to be tuned. 10 | % implemented by Weixiang Shao (wshao4@uic.edu) 11 | 12 | viewNum = length(X); 13 | Rounds = options.rounds; 14 | 15 | U_ = []; 16 | V_ = []; 17 | 18 | U = cell(1, viewNum); 19 | V = cell(1, viewNum); 20 | 21 | j = 0; 22 | log = 0; 23 | ac = 0; 24 | 25 | while j < 3 26 | j = j + 1; 27 | if j == 1 28 | [V{1}, U{1}] = NMF(X{1}', K, options, V_, U_); 29 | else 30 | [V{1}, U{1}] = NMF(X{1}', K, options, V_, U{viewNum}); 31 | end 32 | printResult(U{1}, label, K, options.kmeans); 33 | for i = 2:viewNum 34 | [V{i}, U{i}] = NMF(X{i}', K, options, V_, U{i-1}); 35 | printResult(U{i}, label, K, options.kmeans); 36 | end 37 | end 38 | 39 | optionsForPerViewNMF = options; 40 | oldL = 10000000; 41 | oldU = U; 42 | oldV = V; 43 | tic 44 | j = 0; 45 | oldcentroidU = zeros(size(C{1},2), K); 46 | converge = 0; 47 | while j < Rounds 48 | j = j + 1; 49 | 50 | CU = options.alpha(1)*(C{1}.^2)*U{1}; 51 | CC = options.alpha(1)*(C{1}.^2); 52 | 53 | for i = 2:viewNum 54 | CU = CU + options.alpha(i)*(C{i}.^2)*U{i}; 55 | CC = CC + options.alpha(i)*(C{i}.^2); 56 | end 57 | CC_inv = diag(1./diag(CC)); 58 | centroidU = CC_inv*CU; 59 | logL = 0; 60 | for i = 1:viewNum 61 | tmp1 = C{i}*(X{i} - U{i}*V{i}'); 62 | tmp2 = C{i}*(U{i} - centroidU); 63 | tmp3 = 0; 64 | for k =1:size(U{i},2); 65 | tmp3 = tmp3 + norm(U{i}(:,k)); 66 | end 67 | logL = logL + sum(sum(tmp1.^2)) + options.alpha(i) * sum(sum(tmp2.^2)) + options.beta(i)*tmp3; 68 | end 69 | log(end+1) = logL; 70 | logL; 71 | if(oldL < logL) 72 | j = j; 73 | disp('objective function value increasing'); 74 | ac(end+1) = printResult(centroidU, label, K, options.kmeans); 75 | else 76 | ac(end+1) = printResult(centroidU, label, K, options.kmeans); 77 | end 78 | avg_diff = sum(sum(abs(oldcentroidU - centroidU).^2)); 79 | fprintf('The average diff is %d for iteration %f\n', avg_diff, j); 80 | if avg_diff< 1e-12 81 | converge = 1; 82 | fprintf('converge at iteration %f\n', j); 83 | end 84 | oldU = U; 85 | oldV = V; 86 | oldL = logL; 87 | oldcentroidU = centroidU; 88 | for i = 1:viewNum 89 | optionsForPerViewNMF.alpha = options.alpha(i); 90 | optionsForPerViewNMF.beta = options.beta(i); 91 | [U{i}, V{i}] = PerViewNMF_incomplete_original_l21(X{i}, K, centroidU, optionsForPerViewNMF, U{i}, V{i}, C{i}); 92 | end 93 | end 94 | toc 95 | -------------------------------------------------------------------------------- /NMF.m: -------------------------------------------------------------------------------- 1 | function [U_final, V_final, nIter_final, elapse_final, bSuccess, objhistory_final] = NMF(X, k, options, U_, V_) 2 | % Non-negative Matrix Factorization (NMF) with multiplicative update 3 | % 4 | % Notation: 5 | % X ... (mFea x nSmp) data matrix 6 | % mFea ... number of words (vocabulary size) 7 | % nSmp ... number of documents 8 | % k ... number of hidden factors 9 | % 10 | % options ... Structure holding all settings 11 | % 12 | % U_ ... initialization for basis matrix 13 | % V_ ... initialization for coefficient matrix 14 | % 15 | % 16 | % Written by Deng Cai (dengcai AT gmail.com) 17 | % Modified by Jialu Liu (jliu64 AT illinois.edu) 18 | 19 | differror = options.error; 20 | maxIter = options.maxIter; 21 | nRepeat = options.nRepeat; 22 | minIterOrig = options.minIter; 23 | minIter = minIterOrig-1; 24 | meanFitRatio = options.meanFitRatio; 25 | 26 | Norm = 1; 27 | NormV = 0; 28 | 29 | [mFea,nSmp]=size(X); 30 | 31 | bSuccess.bSuccess = 1; 32 | 33 | selectInit = 1; 34 | if isempty(U_) 35 | U = abs(rand(mFea,k)); 36 | norms = sqrt(sum(U.^2,1)); 37 | norms = max(norms,1e-10); 38 | U = U./repmat(norms,mFea,1); 39 | if isempty(V_) 40 | V = abs(rand(nSmp,k)); 41 | V = V/sum(sum(V)); 42 | else 43 | V = V_; 44 | end 45 | else 46 | U = U_; 47 | if isempty(V_) 48 | V = abs(rand(nSmp,k)); 49 | V = V/sum(sum(V)); 50 | else 51 | V = V_; 52 | end 53 | end 54 | 55 | [U,V] = NormalizeUV(U, V, NormV, Norm); 56 | if nRepeat == 1 57 | selectInit = 0; 58 | minIterOrig = 0; 59 | minIter = 0; 60 | if isempty(maxIter) 61 | objhistory = CalculateObj(X, U, V); 62 | meanFit = objhistory*10; 63 | else 64 | if isfield(options,'Converge') && options.Converge 65 | objhistory = CalculateObj(X, U, V); 66 | end 67 | end 68 | else 69 | if isfield(options,'Converge') && options.Converge 70 | error('Not implemented!'); 71 | end 72 | end 73 | 74 | 75 | 76 | tryNo = 0; 77 | while tryNo < nRepeat 78 | tmp_T = cputime; 79 | tryNo = tryNo+1; 80 | nIter = 0; 81 | maxErr = 1; 82 | nStepTrial = 0; 83 | while(maxErr > differror) 84 | % ===================== update V ======================== 85 | XU = X'*U; % mnk or pk (p< minIter 100 | if selectInit 101 | objhistory = CalculateObj(X, U, V); 102 | maxErr = 0; 103 | else 104 | if isempty(maxIter) 105 | newobj = CalculateObj(X, U, V); 106 | objhistory = [objhistory newobj]; %#ok 107 | meanFit = meanFitRatio*meanFit + (1-meanFitRatio)*newobj; 108 | maxErr = (meanFit-newobj)/meanFit; 109 | else 110 | if isfield(options,'Converge') && options.Converge 111 | newobj = CalculateObj(X, U, V); 112 | objhistory = [objhistory newobj]; %#ok 113 | end 114 | maxErr = 1; 115 | if nIter >= maxIter 116 | maxErr = 0; 117 | if isfield(options,'Converge') && options.Converge 118 | else 119 | objhistory = 0; 120 | end 121 | end 122 | end 123 | end 124 | end 125 | end 126 | 127 | elapse = cputime - tmp_T; 128 | 129 | if tryNo == 1 130 | U_final = U; 131 | V_final = V; 132 | nIter_final = nIter; 133 | elapse_final = elapse; 134 | objhistory_final = objhistory; 135 | bSuccess.nStepTrial = nStepTrial; 136 | else 137 | if objhistory(end) < objhistory_final(end) 138 | U_final = U; 139 | V_final = V; 140 | nIter_final = nIter; 141 | objhistory_final = objhistory; 142 | bSuccess.nStepTrial = nStepTrial; 143 | if selectInit 144 | elapse_final = elapse; 145 | else 146 | elapse_final = elapse_final+elapse; 147 | end 148 | end 149 | end 150 | 151 | if selectInit 152 | if tryNo < nRepeat 153 | %re-start 154 | if isempty(U_) 155 | U = abs(rand(mFea,k)); 156 | norms = sqrt(sum(U.^2,1)); 157 | norms = max(norms,1e-10); 158 | U = U./repmat(norms,mFea,1); 159 | if isempty(V_) 160 | V = abs(rand(nSmp,k)); 161 | V = V/sum(sum(V)); 162 | else 163 | V = V_; 164 | end 165 | else 166 | U = U_; 167 | if isempty(V_) 168 | V = abs(rand(nSmp,k)); 169 | V = V/sum(sum(V)); 170 | else 171 | V = V_; 172 | end 173 | end 174 | 175 | [U,V] = NormalizeUV(U, V, NormV, Norm); 176 | else 177 | tryNo = tryNo - 1; 178 | minIter = 0; 179 | selectInit = 0; 180 | U = U_final; 181 | V = V_final; 182 | objhistory = objhistory_final; 183 | meanFit = objhistory*10; 184 | 185 | end 186 | end 187 | end 188 | 189 | nIter_final = nIter_final + minIterOrig; 190 | 191 | [U_final, V_final] = Normalize(U_final, V_final); 192 | 193 | 194 | %========================================================================== 195 | 196 | function [obj, dV] = CalculateObj(X, U, V, deltaVU, dVordU) 197 | if ~exist('deltaVU','var') 198 | deltaVU = 0; 199 | end 200 | if ~exist('dVordU','var') 201 | dVordU = 1; 202 | end 203 | dV = []; 204 | maxM = 62500000; 205 | [mFea, nSmp] = size(X); 206 | mn = numel(X); 207 | nBlock = floor(mn*3/maxM); 208 | 209 | if mn < maxM 210 | dX = U*V'-X; 211 | obj_NMF = sum(sum(dX.^2)); 212 | if deltaVU 213 | if dVordU 214 | dV = dX'*U; 215 | else 216 | dV = dX*V; 217 | end 218 | end 219 | else 220 | obj_NMF = 0; 221 | if deltaVU 222 | if dVordU 223 | dV = zeros(size(V)); 224 | else 225 | dV = zeros(size(U)); 226 | end 227 | end 228 | for i = 1:ceil(nSmp/nBlock) 229 | if i == ceil(nSmp/nBlock) 230 | smpIdx = (i-1)*nBlock+1:nSmp; 231 | else 232 | smpIdx = (i-1)*nBlock+1:i*nBlock; 233 | end 234 | dX = U*V(smpIdx,:)'-X(:,smpIdx); 235 | obj_NMF = obj_NMF + sum(sum(dX.^2)); 236 | if deltaVU 237 | if dVordU 238 | dV(smpIdx,:) = dX'*U; 239 | else 240 | dV = dU+dX*V(smpIdx,:); 241 | end 242 | end 243 | end 244 | if deltaVU 245 | if dVordU 246 | dV = dV ; 247 | end 248 | end 249 | end 250 | %obj_Lap = alpha*sum(sum((L*V).*V)); 251 | 252 | obj = obj_NMF; 253 | 254 | 255 | function [U, V] = Normalize(U, V) 256 | [U,V] = NormalizeUV(U, V, 0, 1); 257 | 258 | 259 | function [U, V] = NormalizeUV(U, V, NormV, Norm) 260 | nSmp = size(V,1); 261 | mFea = size(U,1); 262 | if Norm == 2 263 | if NormV 264 | norms = sqrt(sum(V.^2,1)); 265 | norms = max(norms,1e-10); 266 | V = V./repmat(norms,nSmp,1); 267 | U = U.*repmat(norms,mFea,1); 268 | else 269 | norms = sqrt(sum(U.^2,1)); 270 | norms = max(norms,1e-10); 271 | U = U./repmat(norms,mFea,1); 272 | V = V.*repmat(norms,nSmp,1); 273 | end 274 | else 275 | if NormV 276 | norms = sum(abs(V),1); 277 | norms = max(norms,1e-10); 278 | V = V./repmat(norms,nSmp,1); 279 | U = U.*repmat(norms,mFea,1); 280 | else 281 | norms = sum(abs(U),1); 282 | %norms = max(norms,1e-10); 283 | U = U./repmat(norms,mFea,1); 284 | V = V.*repmat(norms,nSmp,1); 285 | end 286 | end 287 | 288 | -------------------------------------------------------------------------------- /PerViewNMF_incomplete_original_l21.m: -------------------------------------------------------------------------------- 1 | function [U_final, V_final, nIter_final, elapse_final, bSuccess, objhistory_final] = PerViewNMF_incomplete_original_l21(X, k, Uo, options, U, V, C) 2 | % 3 | % Notation: 4 | % X ... (nSmp x mFea) data matrix of one view 5 | % mFea ... number of features 6 | % nSmp ... number of samples 7 | % k ... number of hidden factors 8 | % Uo... consunsus 9 | % options ... Structure holding all settings 10 | % U ... initialization for coefficient matrix 11 | % V ... initialization for basis matrix 12 | % 13 | % Originally written by Deng Cai (dengcai AT gmail.com) for GNMF 14 | % Modified by Weixiang Shao (wshao4@uic.edu) 15 | 16 | differror = options.error; 17 | maxIter = options.maxIter; 18 | nRepeat = options.nRepeat; 19 | minIterOrig = options.minIter; 20 | minIter = minIterOrig-1; 21 | meanFitRatio = options.meanFitRatio; 22 | 23 | alpha = options.alpha; 24 | beta = options.beta; 25 | 26 | [nSmp, mFea]=size(X); 27 | 28 | bSuccess.bSuccess = 1; 29 | 30 | selectInit = 1; 31 | if isempty(U) 32 | U = abs(rand(nSmp,k)); 33 | V = abs(rand(mFea,k)); 34 | else 35 | nRepeat = 1; 36 | end 37 | 38 | [U,V] = Normalize(U, V); 39 | if nRepeat == 1 40 | selectInit = 0; 41 | minIterOrig = 0; 42 | minIter = 0; 43 | if isempty(maxIter) 44 | objhistory = CalculateObj(X, U, V, Uo, C, alpha, beta); 45 | meanFit = objhistory*10; 46 | else 47 | if isfield(options,'Converge') && options.Converge 48 | objhistory = CalculateObj(X, U, V, Uo, C, alpha, beta); 49 | end 50 | end 51 | else 52 | if isfield(options,'Converge') && options.Converge 53 | error('Not implemented!'); 54 | end 55 | end 56 | 57 | 58 | 59 | tryNo = 0; 60 | while tryNo < nRepeat 61 | tmp_T = cputime; 62 | tryNo = tryNo+1; 63 | nIter = 0; 64 | maxErr = 1; 65 | nStepTrial = 0; 66 | %disp a 67 | while(maxErr > differror) 68 | % ===================== update U ======================== 69 | 70 | XV = (C.^2)*X*V; % mnk or pk (p< minIter 93 | if selectInit 94 | objhistory = CalculateObj(X, U, V, Uo, C, alpha, beta); 95 | maxErr = 0; 96 | else 97 | if isempty(maxIter) 98 | newobj = CalculateObj(X, U, V, Uo, C, alpha, beta); 99 | objhistory = [objhistory newobj]; 100 | meanFit = meanFitRatio*meanFit + (1-meanFitRatio)*newobj; 101 | maxErr = (meanFit-newobj)/meanFit; 102 | else 103 | if isfield(options,'Converge') && options.Converge 104 | newobj = CalculateObj(X, U, V, Uo, C, alpha, beta); 105 | objhistory = [objhistory newobj]; 106 | end 107 | maxErr = 1; 108 | if nIter >= maxIter 109 | maxErr = 0; 110 | if isfield(options,'Converge') && options.Converge 111 | else 112 | objhistory = 0; 113 | end 114 | end 115 | end 116 | end 117 | end 118 | end 119 | 120 | elapse = cputime - tmp_T; 121 | 122 | if tryNo == 1 123 | U_final = U; 124 | V_final = V; 125 | nIter_final = nIter; 126 | elapse_final = elapse; 127 | objhistory_final = objhistory; 128 | bSuccess.nStepTrial = nStepTrial; 129 | else 130 | if objhistory(end) < objhistory_final(end) 131 | U_final = U; 132 | V_final = V; 133 | nIter_final = nIter; 134 | objhistory_final = objhistory; 135 | bSuccess.nStepTrial = nStepTrial; 136 | if selectInit 137 | elapse_final = elapse; 138 | else 139 | elapse_final = elapse_final+elapse; 140 | end 141 | end 142 | end 143 | 144 | if selectInit 145 | if tryNo < nRepeat 146 | %re-start 147 | U = abs(rand(mFea,k)); 148 | V = abs(rand(nSmp,k)); 149 | [U,V] = Normalize(U, V); 150 | else 151 | tryNo = tryNo - 1; 152 | minIter = 0; 153 | selectInit = 0; 154 | U = U_final; 155 | V = V_final; 156 | objhistory = objhistory_final; 157 | meanFit = objhistory*10; 158 | 159 | end 160 | end 161 | end 162 | 163 | nIter_final = nIter_final + minIterOrig; 164 | [U_final, V_final] = Normalize(U_final, V_final); 165 | end 166 | 167 | %========================================================================== 168 | 169 | function [obj, dV] = CalculateObj(X, U, V, Uo, C, alpha, beta) 170 | tmp = C*(U-Uo); 171 | obj_Lap = sum(sum(tmp.^2)); 172 | dX = C*(U*V'-X); 173 | obj_NMF = sum(sum(dX.^2)); 174 | obj_L1 = sum(sum(abs(U))); 175 | tmp3 = 0; 176 | for k =1:size(U,2); 177 | tmp3 = tmp3 + norm(U(:,k)); 178 | end 179 | obj = obj_NMF+ alpha * obj_Lap + beta * tmp3; 180 | end 181 | 182 | function [U, V] = Normalize(U, V) 183 | nSmp = size(U,1); 184 | mFea = size(V,1); 185 | norms = sum(abs(V),1); 186 | norms = max(norms,1e-30); 187 | V = V./repmat(norms,mFea,1); 188 | U = bsxfun(@times, U, norms); 189 | end 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Incomplete-view-Clustering 2 | The MATLAB implementation for Multi-Incomplete-view Clustering (MIC) method proposed in Multiple Incomplete Views Clustering via Weighted Nonnegative Matrix Factorization with L2, 1 Regularization, ECML-PKDD 2015. 3 | http://link.springer.com/chapter/10.1007%2F978-3-319-23528-8_20 4 | 5 | Some of the code was originally written by Deng Cai ((dengcai AT gmail.com)), Jialu Liu (jliu64 AT illinois.edu) and Mo Chen (mochen AT ie.cuhk.edu.hk). 6 | -------------------------------------------------------------------------------- /bestMap.m: -------------------------------------------------------------------------------- 1 | function [newL2] = bestMap(L1,L2) 2 | %bestmap: permute labels of L2 match L1 as good as possible 3 | % [newL2] = bestMap(L1,L2); 4 | % Written by Deng Cai (dengcai AT gmail.com) 5 | 6 | %=========== 7 | L1 = L1(:); 8 | L2 = L2(:); 9 | if size(L1) ~= size(L2) 10 | error('size(L1) must == size(L2)'); 11 | end 12 | 13 | Label1 = unique(L1); 14 | nClass1 = length(Label1); 15 | Label2 = unique(L2); 16 | nClass2 = length(Label2); 17 | 18 | nClass = max(nClass1,nClass2); 19 | G = zeros(nClass); 20 | for i=1:nClass1 21 | for j=1:nClass2 22 | G(i,j) = length(find(L1 == Label1(i) & L2 == Label2(j))); 23 | end 24 | end 25 | [c,t] = hungarian(-G); 26 | newL2 = zeros(size(L2)); 27 | for i=1:nClass2 28 | newL2(L2 == Label2(i)) = Label1(c(i)); 29 | end 30 | 31 | 32 | return; 33 | 34 | %=======backup old=========== 35 | 36 | L1 = L1 - min(L1) + 1; % min (L1) <- 1; 37 | L2 = L2 - min(L2) + 1; % min (L2) <- 1; 38 | %=========== make bipartition graph ============ 39 | nClass = max(max(L1), max(L2)); 40 | G = zeros(nClass); 41 | for i=1:nClass 42 | for j=1:nClass 43 | G(i,j) = length(find(L1 == i & L2 == j)); 44 | end 45 | end 46 | %=========== assign with hungarian method ====== 47 | [c,t] = hungarian(-G); 48 | newL2 = zeros(nClass,1); 49 | for i=1:nClass 50 | newL2(L2 == i) = c(i); 51 | end 52 | -------------------------------------------------------------------------------- /hungarian.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/software-shao/Multi-Incomplete-view-Clustering/7562b88831e03fbfd304c02335fdda5c7eebae0c/hungarian.m -------------------------------------------------------------------------------- /litekmeans.m: -------------------------------------------------------------------------------- 1 | function [label, center, bCon, sumD, D] = litekmeans(X, k, varargin) 2 | %LITEKMEANS K-means clustering, accelerated by matlab matrix operations. 3 | % 4 | % label = LITEKMEANS(X, K) partitions the points in the N-by-P data matrix 5 | % X into K clusters. This partition minimizes the sum, over all 6 | % clusters, of the within-cluster sums of point-to-cluster-centroid 7 | % distances. Rows of X correspond to points, columns correspond to 8 | % variables. KMEANS returns an N-by-1 vector label containing the 9 | % cluster indices of each point. 10 | % 11 | % [label, center] = LITEKMEANS(X, K) returns the K cluster centroid 12 | % locations in the K-by-P matrix center. 13 | % 14 | % [label, center, bCon] = LITEKMEANS(X, K) returns the bool value bCon to 15 | % indicate whether the iteration is converged. 16 | % 17 | % [label, center, bCon, SUMD] = LITEKMEANS(X, K) returns the 18 | % within-cluster sums of point-to-centroid distances in the 1-by-K vector 19 | % sumD. 20 | % 21 | % [label, center, bCon, SUMD, D] = LITEKMEANS(X, K) returns 22 | % distances from each point to every centroid in the N-by-K matrix D. 23 | % 24 | % [ ... ] = LITEKMEANS(..., 'PARAM1',val1, 'PARAM2',val2, ...) specifies 25 | % optional parameter name/value pairs to control the iterative algorithm 26 | % used by KMEANS. Parameters are: 27 | % 28 | % 'Distance' - Distance measure, in P-dimensional space, that KMEANS 29 | % should minimize with respect to. Choices are: 30 | % {'sqEuclidean'} - Squared Euclidean distance (the default) 31 | % 'cosine' - One minus the cosine of the included angle 32 | % between points (treated as vectors). Each 33 | % row of X SHOULD be normalized to unit. If 34 | % the intial center matrix is provided, it 35 | % SHOULD also be normalized. 36 | % 37 | % 'Start' - Method used to choose initial cluster centroid positions, 38 | % sometimes known as "seeds". Choices are: 39 | % {'sample'} - Select K observations from X at random (the default) 40 | % matrix - A K-by-P matrix of starting locations; or a K-by-1 41 | % indicate vector indicating which K points in X 42 | % should be used as the initial center. In this case, 43 | % you can pass in [] for K, and KMEANS infers K from 44 | % the first dimension of the matrix. 45 | % 46 | % 'MaxIter' - Maximum number of iterations allowed. Default is 100. 47 | % 48 | % 'Replicates' - Number of times to repeat the clustering, each with a 49 | % new set of initial centroids. Default is 1. If the 50 | % initial centroids are provided, the replicate will be 51 | % automatically set to be 1. 52 | % 53 | % 54 | % 55 | % Examples: 56 | % 57 | % fea = rand(500,10); 58 | % [label, center] = litekmeans(fea, 5, 'MaxIter', 50); 59 | % 60 | % fea = rand(500,10); 61 | % [label, center] = litekmeans(fea, 5, 'MaxIter', 50, 'Replicates', 10); 62 | % 63 | % fea = rand(500,10); 64 | % [label, center, bCon, sumD, D] = litekmeans(fea, 5, 'MaxIter', 50); 65 | % TSD = sum(sumD); 66 | % 67 | % fea = rand(500,10); 68 | % initcenter = rand(5,10); 69 | % [label, center] = litekmeans(fea, 5, 'MaxIter', 50, 'Start', initcenter); 70 | % 71 | % fea = rand(500,10); 72 | % idx=randperm(500); 73 | % [label, center] = litekmeans(fea, 5, 'MaxIter', 50, 'Start', idx(1:5)); 74 | % 75 | % 76 | % See also KMEANS 77 | % 78 | % version 2.0 --December/2011 79 | % version 1.0 --November/2011 80 | % 81 | % Written by Deng Cai (dengcai AT gmail.com) 82 | 83 | 84 | if nargin < 2 85 | error('litekmeans:TooFewInputs','At least two input arguments required.'); 86 | end 87 | 88 | [n, p] = size(X); 89 | 90 | 91 | pnames = { 'distance' 'start' 'maxiter' 'replicates' 'onlinephase'}; 92 | dflts = {'sqeuclidean' 'sample' [] [] 'off' }; 93 | [eid,errmsg,distance,start,maxit,reps,online] = getargs(pnames, dflts, varargin{:}); 94 | if ~isempty(eid) 95 | error(sprintf('litekmeans:%s',eid),errmsg); 96 | end 97 | 98 | if ischar(distance) 99 | distNames = {'sqeuclidean','cosine'}; 100 | j = strcmpi(distance, distNames); 101 | j = find(j); 102 | if length(j) > 1 103 | error('litekmeans:AmbiguousDistance', ... 104 | 'Ambiguous ''Distance'' parameter value: %s.', distance); 105 | elseif isempty(j) 106 | error('litekmeans:UnknownDistance', ... 107 | 'Unknown ''Distance'' parameter value: %s.', distance); 108 | end 109 | distance = distNames{j}; 110 | else 111 | error('litekmeans:InvalidDistance', ... 112 | 'The ''Distance'' parameter value must be a string.'); 113 | end 114 | 115 | 116 | center = []; 117 | if ischar(start) 118 | startNames = {'sample','cluster'}; 119 | j = find(strncmpi(start,startNames,length(start))); 120 | if length(j) > 1 121 | error(message('litekmeans:AmbiguousStart', start)); 122 | elseif isempty(j) 123 | error(message('litekmeans:UnknownStart', start)); 124 | elseif isempty(k) 125 | error('litekmeans:MissingK', ... 126 | 'You must specify the number of clusters, K.'); 127 | end 128 | if j == 2 129 | if floor(.1*n) < 5*k 130 | j = 1; 131 | end 132 | end 133 | start = startNames{j}; 134 | elseif isnumeric(start) 135 | if size(start,2) == p 136 | center = start; 137 | elseif (size(start,2) == 1 || size(start,1) == 1) 138 | center = X(start,:); 139 | else 140 | error('litekmeans:MisshapedStart', ... 141 | 'The ''Start'' matrix must have the same number of columns as X.'); 142 | end 143 | if isempty(k) 144 | k = size(center,1); 145 | elseif (k ~= size(center,1)) 146 | error('litekmeans:MisshapedStart', ... 147 | 'The ''Start'' matrix must have K rows.'); 148 | end 149 | start = 'numeric'; 150 | else 151 | error('litekmeans:InvalidStart', ... 152 | 'The ''Start'' parameter value must be a string or a numeric matrix or array.'); 153 | end 154 | 155 | % The maximum iteration number is default 100 156 | if isempty(maxit) 157 | maxit = 100; 158 | end 159 | 160 | % Assume one replicate 161 | if isempty(reps) || ~isempty(center) 162 | reps = 1; 163 | end 164 | 165 | if ~(isscalar(k) && isnumeric(k) && isreal(k) && k > 0 && (round(k)==k)) 166 | error('litekmeans:InvalidK', ... 167 | 'X must be a positive integer value.'); 168 | elseif n < k 169 | error('litekmeans:TooManyClusters', ... 170 | 'X must have more rows than the number of clusters.'); 171 | end 172 | 173 | 174 | bestlabel = []; 175 | sumD = zeros(1,k); 176 | bCon = false; 177 | 178 | for t=1:reps 179 | switch start 180 | case 'sample' 181 | center = X(randsample(n,k),:); 182 | case 'cluster' 183 | Xsubset = X(randsample(n,floor(.1*n)),:); 184 | [dump, center] = litekmeans(Xsubset, k, varargin{:}, 'start','sample', 'replicates',1); 185 | case 'numeric' 186 | end 187 | 188 | last = 0;label=1; 189 | it=0; 190 | 191 | switch distance 192 | case 'sqeuclidean' 193 | while any(label ~= last) && it1 224 | if it>=maxit 225 | aa = full(sum(X.*X,2)); 226 | bb = full(sum(center.*center,2)); 227 | ab = full(X*center'); 228 | D = bsxfun(@plus,aa,bb') - 2*ab; 229 | D(D<0) = 0; 230 | else 231 | aa = full(sum(X.*X,2)); 232 | D = aa(:,ones(1,k)) + D; 233 | D(D<0) = 0; 234 | end 235 | D = sqrt(D); 236 | for j = 1:k 237 | sumD(j) = sum(D(label==j,j)); 238 | end 239 | bestsumD = sumD; 240 | bestD = D; 241 | end 242 | else 243 | if it>=maxit 244 | aa = full(sum(X.*X,2)); 245 | bb = full(sum(center.*center,2)); 246 | ab = full(X*center'); 247 | D = bsxfun(@plus,aa,bb') - 2*ab; 248 | D(D<0) = 0; 249 | else 250 | aa = full(sum(X.*X,2)); 251 | D = aa(:,ones(1,k)) + D; 252 | D(D<0) = 0; 253 | end 254 | D = sqrt(D); 255 | for j = 1:k 256 | sumD(j) = sum(D(label==j,j)); 257 | end 258 | if sum(sumD) < sum(bestsumD) 259 | bestlabel = label; 260 | bestcenter = center; 261 | bestsumD = sumD; 262 | bestD = D; 263 | end 264 | end 265 | case 'cosine' 266 | while any(label ~= last) && it1 291 | if any(label ~= last) 292 | W=full(X*center'); 293 | end 294 | D = 1-W; 295 | for j = 1:k 296 | sumD(j) = sum(D(label==j,j)); 297 | end 298 | bestsumD = sumD; 299 | bestD = D; 300 | end 301 | else 302 | if any(label ~= last) 303 | W=full(X*center'); 304 | end 305 | D = 1-W; 306 | for j = 1:k 307 | sumD(j) = sum(D(label==j,j)); 308 | end 309 | if sum(sumD) < sum(bestsumD) 310 | bestlabel = label; 311 | bestcenter = center; 312 | bestsumD = sumD; 313 | bestD = D; 314 | end 315 | end 316 | end 317 | end 318 | 319 | label = bestlabel; 320 | center = bestcenter; 321 | if reps>1 322 | sumD = bestsumD; 323 | D = bestD; 324 | elseif nargout > 3 325 | switch distance 326 | case 'sqeuclidean' 327 | if it>=maxit 328 | aa = full(sum(X.*X,2)); 329 | bb = full(sum(center.*center,2)); 330 | ab = full(X*center'); 331 | D = bsxfun(@plus,aa,bb') - 2*ab; 332 | D(D<0) = 0; 333 | else 334 | aa = full(sum(X.*X,2)); 335 | D = aa(:,ones(1,k)) + D; 336 | D(D<0) = 0; 337 | end 338 | D = sqrt(D); 339 | case 'cosine' 340 | if it>=maxit 341 | W=full(X*center'); 342 | end 343 | D = 1-W; 344 | end 345 | for j = 1:k 346 | sumD(j) = sum(D(label==j,j)); 347 | end 348 | end 349 | 350 | 351 | 352 | 353 | function [eid,emsg,varargout]=getargs(pnames,dflts,varargin) 354 | %GETARGS Process parameter name/value pairs 355 | % [EID,EMSG,A,B,...]=GETARGS(PNAMES,DFLTS,'NAME1',VAL1,'NAME2',VAL2,...) 356 | % accepts a cell array PNAMES of valid parameter names, a cell array 357 | % DFLTS of default values for the parameters named in PNAMES, and 358 | % additional parameter name/value pairs. Returns parameter values A,B,... 359 | % in the same order as the names in PNAMES. Outputs corresponding to 360 | % entries in PNAMES that are not specified in the name/value pairs are 361 | % set to the corresponding value from DFLTS. If nargout is equal to 362 | % length(PNAMES)+1, then unrecognized name/value pairs are an error. If 363 | % nargout is equal to length(PNAMES)+2, then all unrecognized name/value 364 | % pairs are returned in a single cell array following any other outputs. 365 | % 366 | % EID and EMSG are empty if the arguments are valid. If an error occurs, 367 | % EMSG is the text of an error message and EID is the final component 368 | % of an error message id. GETARGS does not actually throw any errors, 369 | % but rather returns EID and EMSG so that the caller may throw the error. 370 | % Outputs will be partially processed after an error occurs. 371 | % 372 | % This utility can be used for processing name/value pair arguments. 373 | % 374 | % Example: 375 | % pnames = {'color' 'linestyle', 'linewidth'} 376 | % dflts = { 'r' '_' '1'} 377 | % varargin = {{'linew' 2 'nonesuch' [1 2 3] 'linestyle' ':'} 378 | % [eid,emsg,c,ls,lw] = statgetargs(pnames,dflts,varargin{:}) % error 379 | % [eid,emsg,c,ls,lw,ur] = statgetargs(pnames,dflts,varargin{:}) % ok 380 | 381 | % We always create (nparams+2) outputs: 382 | % one each for emsg and eid 383 | % nparams varargs for values corresponding to names in pnames 384 | % If they ask for one more (nargout == nparams+3), it's for unrecognized 385 | % names/values 386 | 387 | % Original Copyright 1993-2008 The MathWorks, Inc. 388 | % Modified by Deng Cai (dengcai@gmail.com) 2011.11.27 389 | 390 | 391 | 392 | 393 | % Initialize some variables 394 | emsg = ''; 395 | eid = ''; 396 | nparams = length(pnames); 397 | varargout = dflts; 398 | unrecog = {}; 399 | nargs = length(varargin); 400 | 401 | % Must have name/value pairs 402 | if mod(nargs,2)~=0 403 | eid = 'WrongNumberArgs'; 404 | emsg = 'Wrong number of arguments.'; 405 | else 406 | % Process name/value pairs 407 | for j=1:2:nargs 408 | pname = varargin{j}; 409 | if ~ischar(pname) 410 | eid = 'BadParamName'; 411 | emsg = 'Parameter name must be text.'; 412 | break; 413 | end 414 | i = strcmpi(pname,pnames); 415 | i = find(i); 416 | if isempty(i) 417 | % if they've asked to get back unrecognized names/values, add this 418 | % one to the list 419 | if nargout > nparams+2 420 | unrecog((end+1):(end+2)) = {varargin{j} varargin{j+1}}; 421 | % otherwise, it's an error 422 | else 423 | eid = 'BadParamName'; 424 | emsg = sprintf('Invalid parameter name: %s.',pname); 425 | break; 426 | end 427 | elseif length(i)>1 428 | eid = 'BadParamName'; 429 | emsg = sprintf('Ambiguous parameter name: %s.',pname); 430 | break; 431 | else 432 | varargout{i} = varargin{j+1}; 433 | end 434 | end 435 | end 436 | 437 | varargout{nparams+1} = unrecog; 438 | -------------------------------------------------------------------------------- /nmi.m: -------------------------------------------------------------------------------- 1 | function v = nmi(label, result) 2 | % Nomalized mutual information 3 | % Written by Mo Chen (mochen@ie.cuhk.edu.hk). March 2009. 4 | %assert(length(label) == length(result)); 5 | 6 | label = label(:); 7 | result = result(:); 8 | 9 | n = length(label); 10 | 11 | label_unique = unique(label); 12 | result_unique = unique(result); 13 | 14 | % check the integrity of result 15 | %if length(label_unique) ~= length(result_unique) 16 | % error('The clustering result is not consistent with label.'); 17 | %end; 18 | 19 | c = length(label_unique); 20 | 21 | % distribution of result and label 22 | Ml = double(repmat(label,1,c) == repmat(label_unique',n,1)); 23 | %Mr = double(repmat(result,1,c) == repmat(result_unique',n,1)); 24 | Mr = double(repmat(result,1,c) == repmat(label_unique',n,1)); 25 | Pl = sum(Ml)/n; 26 | Pr = sum(Mr)/n; 27 | 28 | % entropy of Pr and Pl 29 | Hl = -sum( Pl .* log2( Pl + eps ) ); 30 | Hr = -sum( Pr .* log2( Pr + eps ) ); 31 | 32 | 33 | % joint entropy of Pr and Pl 34 | % M = zeros(c); 35 | % for I = 1:c 36 | % for J = 1:c 37 | % M(I,J) = sum(result==result_unique(I)&label==label_unique(J)); 38 | % end; 39 | % end; 40 | % M = M / n; 41 | M = Ml'*Mr/n; 42 | Hlr = -sum( M(:) .* log2( M(:) + eps ) ); 43 | 44 | % mutual information 45 | MI = Hl + Hr - Hlr; 46 | 47 | % normalized mutual information 48 | v = sqrt((MI/Hl)*(MI/Hr)) ; -------------------------------------------------------------------------------- /printResult.m: -------------------------------------------------------------------------------- 1 | function [ac] = printResult(X, label, K, kmeansFlag) 2 | 3 | if kmeansFlag == 1 4 | indic = litekmeans(X, K, 'Replicates',20); 5 | else 6 | [~, indic] = max(X, [] ,2); 7 | end 8 | result = bestMap(label, indic); 9 | [ac, nmi_value, cnt] = CalcMetrics(label, indic); 10 | disp(sprintf('ac: %0.4f\t%d/%d\tnmi:%0.4f\t', ac, cnt, length(label), nmi_value)); --------------------------------------------------------------------------------