├── Dataset ├── HW.mat ├── BBC.mat ├── Mfeat.mat ├── NGs.mat ├── WebKB.mat ├── Hdigit.mat ├── TwoMoon.mat ├── 100leaves.mat ├── 3sources.mat ├── BBCSport.mat ├── HW2sources.mat └── ThreeRing.mat ├── CalcMeasures ├── hungarian.m ├── CalcMeasures.m ├── bestMap.m ├── compute_nmi.m └── valid_RandIndex.m ├── README.md ├── funs ├── L2_distance_1.m ├── eig1.m ├── InitializeSIGs.m ├── SloutionToP19.m └── EuDist2.m ├── Run_GMC.m ├── GMC.m └── Run_GMC_ToyExamples.m /Dataset/HW.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/HW.mat -------------------------------------------------------------------------------- /Dataset/BBC.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/BBC.mat -------------------------------------------------------------------------------- /Dataset/Mfeat.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/Mfeat.mat -------------------------------------------------------------------------------- /Dataset/NGs.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/NGs.mat -------------------------------------------------------------------------------- /Dataset/WebKB.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/WebKB.mat -------------------------------------------------------------------------------- /Dataset/Hdigit.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/Hdigit.mat -------------------------------------------------------------------------------- /Dataset/TwoMoon.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/TwoMoon.mat -------------------------------------------------------------------------------- /Dataset/100leaves.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/100leaves.mat -------------------------------------------------------------------------------- /Dataset/3sources.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/3sources.mat -------------------------------------------------------------------------------- /Dataset/BBCSport.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/BBCSport.mat -------------------------------------------------------------------------------- /Dataset/HW2sources.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/HW2sources.mat -------------------------------------------------------------------------------- /Dataset/ThreeRing.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/Dataset/ThreeRing.mat -------------------------------------------------------------------------------- /CalcMeasures/hungarian.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshaowang/gmc/HEAD/CalcMeasures/hungarian.m -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GMC: Graph-based Multi-view Clustering 2 | 3 | This repo hosts the code for paper "GMC: Graph-based Multi-view Clustering", IEEE TKDE, 2019. 4 | -------------------------------------------------------------------------------- /funs/L2_distance_1.m: -------------------------------------------------------------------------------- 1 | % compute squared Euclidean distance 2 | % ||A-B||^2 = ||A||^2 + ||B||^2 - 2*A'*B 3 | function d = L2_distance_1(a,b) 4 | % a,b: two matrices. each column is a data 5 | % d: distance matrix of a and b 6 | 7 | if (size(a,1) == 1) 8 | a = [a; zeros(1,size(a,2))]; 9 | b = [b; zeros(1,size(b,2))]; 10 | end; 11 | 12 | aa=sum(a.*a); bb=sum(b.*b); ab=a'*b; 13 | d = repmat(aa',[1 size(bb,2)]) + repmat(bb,[size(aa,2) 1]) - 2*ab; 14 | 15 | d = real(d); 16 | d = max(d,0); 17 | d = d - diag(diag(d)); % force 0 on the diagonal 18 | -------------------------------------------------------------------------------- /CalcMeasures/CalcMeasures.m: -------------------------------------------------------------------------------- 1 | function result = CalcMeasures(Y, predY) 2 | % result = [ACC, NMI, ARI, error_cnt]; 3 | if size(Y,2) ~= 1 4 | Y = Y'; 5 | end; 6 | if size(predY,2) ~= 1 7 | predY = predY'; 8 | end; 9 | 10 | % bestMap 11 | predY = bestMap(Y, predY); 12 | if size(Y)~=size(predY) 13 | predY=predY'; 14 | end 15 | 16 | error_cnt = sum(Y ~= predY); 17 | AC = length(find(Y == predY))/length(Y); 18 | [~,nmi_value,~] = compute_nmi(Y', predY'); 19 | [ARI,~,~,~] = valid_RandIndex(Y', predY'); 20 | 21 | result = [AC, nmi_value, ARI, error_cnt]; 22 | -------------------------------------------------------------------------------- /funs/eig1.m: -------------------------------------------------------------------------------- 1 | function [eigvec, eigval, eigval_full] = eig1(A, c, isMax, isSym) 2 | 3 | if nargin < 2 4 | c = size(A,1); 5 | isMax = 1; 6 | isSym = 1; 7 | elseif c > size(A,1) 8 | c = size(A,1); 9 | end; 10 | 11 | if nargin < 3 12 | isMax = 1; 13 | isSym = 1; 14 | end; 15 | 16 | if nargin < 4 17 | isSym = 1; 18 | end; 19 | 20 | if isSym == 1 21 | A = max(A,A'); 22 | end; 23 | [v, d] = eig(A); 24 | d = diag(d); 25 | %d = real(d); 26 | if isMax == 0 27 | [d1, idx] = sort(d); 28 | else 29 | [d1, idx] = sort(d,'descend'); 30 | end; 31 | 32 | idx1 = idx(1:c); 33 | eigval = d(idx1); 34 | eigvec = v(:,idx1); 35 | 36 | eigval_full = d(idx); -------------------------------------------------------------------------------- /funs/InitializeSIGs.m: -------------------------------------------------------------------------------- 1 | function [S, D] = InitializeSIGs(X, k, issymmetric) 2 | % X: each column is a data point 3 | % k: number of neighbors 4 | % issymmetric: set S = (S+S')/2 if issymmetric=1 5 | % S: similarity matrix, each row is a data point 6 | % Ref: F. Nie, X. Wang, M. I. Jordan, and H. Huang, The constrained 7 | % Laplacian rank algorithm for graph-based clustering, in AAAI, 2016. 8 | 9 | if nargin < 3 10 | issymmetric = 1; 11 | end; 12 | if nargin < 2 13 | k = 5; 14 | end; 15 | 16 | [~, n] = size(X); 17 | D = L2_distance_1(X, X); 18 | [~, idx] = sort(D, 2); % sort each row 19 | 20 | S = zeros(n); 21 | for i = 1:n 22 | id = idx(i,2:k+2); 23 | di = D(i, id); 24 | S(i,id) = (di(k+1)-di)/(k*di(k+1)-sum(di(1:k))+eps); 25 | end; 26 | 27 | if issymmetric == 1 28 | S = (S+S')/2; 29 | end; -------------------------------------------------------------------------------- /funs/SloutionToP19.m: -------------------------------------------------------------------------------- 1 | % 2 | % min 1/2 sum_v|| s - qv||^2 3 | % s.t. s>=0, 1's=1 4 | function [x, ft] = SloutionToP19(q0, m) 5 | 6 | if nargin < 2 7 | m = 1; 8 | end; 9 | ft=1; 10 | n = length(q0); 11 | p0 = sum(q0,1)/m-mean(sum(q0,1))/m + 1/n; 12 | vmin = min(p0); 13 | if vmin < 0 14 | f = 1; 15 | lambda_m = 0; 16 | while abs(f) > 10^-10 17 | v1 = lambda_m-p0; 18 | posidx = v1>0; 19 | npos = sum(posidx); 20 | g = npos/n-1; 21 | if 0 == g 22 | g = eps; 23 | end; 24 | f = sum(v1(posidx))/n - lambda_m; 25 | lambda_m = lambda_m - f/g; 26 | ft=ft+1; 27 | if ft > 100 28 | x = max(-v1,0); 29 | break; 30 | end; 31 | end; 32 | x = max(-v1,0); 33 | else 34 | x = p0; 35 | end; 36 | -------------------------------------------------------------------------------- /CalcMeasures/bestMap.m: -------------------------------------------------------------------------------- 1 | function [newL2] = bestMap(L1,L2) 2 | %bestmap: permute labels of L2 match L1 as good as possible 3 | % [newL2] = bestMap(L1,L2); 4 | 5 | %=========== 6 | L1 = L1(:); 7 | L2 = L2(:); 8 | if size(L1) ~= size(L2) 9 | error('size(L1) must == size(L2)'); 10 | end 11 | 12 | Label1 = unique(L1); 13 | nClass1 = length(Label1); 14 | Label2 = unique(L2); 15 | nClass2 = length(Label2); 16 | 17 | nClass = max(nClass1,nClass2); 18 | G = zeros(nClass); 19 | for i=1:nClass1 20 | for j=1:nClass2 21 | G(i,j) = length(find(L1 == Label1(i) & L2 == Label2(j))); 22 | end 23 | end 24 | [c,t] = hungarian(-G); 25 | newL2 = zeros(size(L2)); 26 | for i=1:nClass2 27 | newL2(L2 == Label2(i)) = Label1(c(i)); 28 | end 29 | 30 | 31 | return; 32 | 33 | %=======backup old=========== 34 | 35 | L1 = L1 - min(L1) + 1; % min (L1) <- 1; 36 | L2 = L2 - min(L2) + 1; % min (L2) <- 1; 37 | %=========== make bipartition graph ============ 38 | nClass = max(max(L1), max(L2)); 39 | G = zeros(nClass); 40 | for i=1:nClass 41 | for j=1:nClass 42 | G(i,j) = length(find(L1 == i & L2 == j)); 43 | end 44 | end 45 | %=========== assign with hungarian method ====== 46 | [c,t] = hungarian(-G); 47 | newL2 = zeros(nClass,1); 48 | for i=1:nClass 49 | newL2(L2 == i) = c(i); 50 | end -------------------------------------------------------------------------------- /funs/EuDist2.m: -------------------------------------------------------------------------------- 1 | function D = EuDist2(fea_a,fea_b,bSqrt) 2 | %EUDIST2 Efficiently Compute the Euclidean Distance Matrix by Exploring the 3 | %Matlab matrix operations. 4 | % 5 | % D = EuDist(fea_a,fea_b) 6 | % fea_a: nSample_a * nFeature 7 | % fea_b: nSample_b * nFeature 8 | % D: nSample_a * nSample_a 9 | % or nSample_a * nSample_b 10 | % 11 | % Examples: 12 | % 13 | % a = rand(500,10); 14 | % b = rand(1000,10); 15 | % 16 | % A = EuDist2(a); % A: 500*500 17 | % D = EuDist2(a,b); % D: 500*1000 18 | % 19 | % version 2.1 --November/2011 20 | % version 2.0 --May/2009 21 | % version 1.0 --November/2005 22 | % 23 | % Written by Deng Cai (dengcai AT gmail.com) 24 | 25 | 26 | if ~exist('bSqrt','var') 27 | bSqrt = 1; 28 | end 29 | 30 | if (~exist('fea_b','var')) || isempty(fea_b) 31 | aa = sum(fea_a.*fea_a,2); 32 | ab = fea_a*fea_a'; 33 | 34 | if issparse(aa) 35 | aa = full(aa); 36 | end 37 | 38 | D = bsxfun(@plus,aa,aa') - 2*ab; 39 | D(D<0) = 0; 40 | if bSqrt 41 | D = sqrt(D); 42 | end 43 | D = max(D,D'); 44 | else 45 | aa = sum(fea_a.*fea_a,2); 46 | bb = sum(fea_b.*fea_b,2); 47 | ab = fea_a*fea_b'; 48 | 49 | if issparse(aa) 50 | aa = full(aa); 51 | bb = full(bb); 52 | end 53 | 54 | D = bsxfun(@plus,aa,bb') - 2*ab; 55 | D(D<0) = 0; 56 | if bSqrt 57 | D = sqrt(D); 58 | end 59 | end 60 | 61 | -------------------------------------------------------------------------------- /Run_GMC.m: -------------------------------------------------------------------------------- 1 | %% Experiments on real-world data sets 2 | % Graph-based Multi-view Clustering (GMC) 3 | % 4 | %% 5 | clc; close all; clear all; 6 | currentFolder = pwd; 7 | addpath(genpath(currentFolder)); 8 | resultdir = 'Results/'; 9 | if(~exist('Results','file')) 10 | mkdir('Results'); 11 | addpath(genpath('Results/')); 12 | end 13 | dataname = {'100leaves','3sources','BBC','BBCSport','HW','HW2sources','NGs','WebKB','Hdigit','Mfeat'}; 14 | runtimes = 1; % run-times on each dataset, default: 1 15 | numdata = length(dataname); 16 | 17 | for cdata = 1:numdata 18 | %% read dataset 19 | idata = cdata; 20 | datadir = 'Dataset/'; 21 | dataf = [datadir, cell2mat(dataname(idata))]; 22 | load(dataf); 23 | 24 | X = data; 25 | y0 = truelabel{1}; 26 | c = length(unique(truelabel{1})); 27 | %% iteration ... 28 | for rtimes = 1:runtimes 29 | [y, U, S0, S0_initial, F, evs] = GMC(X, c); % c: the # of clusters 30 | metric = CalcMeasures(y0, y); 31 | ACC(rtimes) = metric(1); 32 | NMI(rtimes) = metric(2); 33 | ARI(rtimes) = metric(3); 34 | error_cnt(rtimes) = metric(4); 35 | disp(char(dataname(idata))); 36 | fprintf('=====In iteration %d=====\nACC:%.4f\tNMI:%.4f\tARI:%.4f\terror_cnt:%d\n',rtimes,metric(1),metric(2),metric(3),metric(4)); 37 | end; 38 | Result(1,:) = ACC; 39 | Result(2,:) = NMI; 40 | Result(3,:) = ARI; 41 | Result(4,1) = mean(ACC); 42 | Result(4,2) = mean(NMI); 43 | Result(4,3) = mean(ARI); 44 | Result(5,1) = std(ACC); 45 | Result(5,2) = std(NMI); 46 | Result(5,3) = std(ARI); 47 | save([resultdir,char(dataname(idata)),'_result.mat'],'Result','U','y0','y'); 48 | clear ACC NMI ARI metric Result U y0 y; 49 | end; 50 | -------------------------------------------------------------------------------- /CalcMeasures/compute_nmi.m: -------------------------------------------------------------------------------- 1 | function [A nmi avgent] = compute_nmi (T, H) 2 | 3 | N = length(T); 4 | classes = unique(T); 5 | clusters = unique(H); 6 | num_class = length(classes); 7 | num_clust = length(clusters); 8 | 9 | %%compute number of points in each class 10 | for j=1:num_class 11 | index_class = (T(:)==classes(j)); 12 | D(j) = sum(index_class); 13 | end 14 | 15 | %%mutual information 16 | mi = 0; 17 | A = zeros(num_clust, num_class); 18 | avgent = 0; 19 | for i=1:num_clust 20 | %number of points in cluster 'i' 21 | index_clust = (H(:)==clusters(i)); 22 | B(i) = sum(index_clust); 23 | for j=1:num_class 24 | index_class = (T(:)==classes(j)); 25 | %%compute number of points in class 'j' that end up in cluster 'i' 26 | A(i,j) = sum(index_class.*index_clust); 27 | if (A(i,j) ~= 0) 28 | miarr(i,j) = A(i,j)/N * log2 (N*A(i,j)/(B(i)*D(j))); 29 | %%average entropy calculation 30 | avgent = avgent - (B(i)/N) * (A(i,j)/B(i)) * log2 (A(i,j)/B(i)); 31 | else 32 | miarr(i,j) = 0; 33 | end 34 | mi = mi + miarr(i,j); 35 | 36 | 37 | 38 | end 39 | end 40 | 41 | %%class entropy 42 | class_ent = 0; 43 | for i=1:num_class 44 | class_ent = class_ent + D(i)/N * log2(N/D(i)); 45 | end 46 | 47 | %%clustering entropy 48 | clust_ent = 0; 49 | for i=1:num_clust 50 | clust_ent = clust_ent + B(i)/N * log2(N/B(i)); 51 | end 52 | 53 | %%normalized mutual information 54 | nmi = 2*mi / (clust_ent + class_ent); -------------------------------------------------------------------------------- /CalcMeasures/valid_RandIndex.m: -------------------------------------------------------------------------------- 1 | function [AR,RI,MI,HI]=valid_RandIndex(c1,c2) 2 | % ========================================================================= 3 | % RANDINDEX - calculates Rand Indices to compare two partitions 4 | % ARI=RANDINDEX(c1,c2), where c1,c2 are vectors listing the 5 | % class membership, returns the "Hubert & Arabie adjusted Rand index". 6 | % [AR,RI,MI,HI]=RANDINDEX(c1,c2) returns the adjusted Rand index, 7 | % the unadjusted Rand index, "Mirkin's" index and "Hubert's" index. 8 | % 9 | % See L. Hubert and P. Arabie (1985) "Comparing Partitions" Journal of 10 | % Classification 2:193-218 11 | % ========================================================================= 12 | %(C) David Corney (2000) D.Corney@cs.ucl.ac.uk 13 | % ========================================================================= 14 | 15 | if nargin < 2 | min(size(c1)) > 1 | min(size(c2)) > 1 16 | error('RandIndex: Requires two vector arguments') 17 | return 18 | end 19 | 20 | C=Contingency(c1,c2); %form contingency matrix 21 | 22 | n=sum(sum(C)); 23 | nis=sum(sum(C,2).^2); %sum of squares of sums of rows 24 | njs=sum(sum(C,1).^2); %sum of squares of sums of columns 25 | 26 | t1=nchoosek(n,2); %total number of pairs of entities 27 | t2=sum(sum(C.^2)); %sum over rows & columnns of nij^2 28 | t3=.5*(nis+njs); 29 | 30 | %Expected index (for adjustment) 31 | nc=(n*(n^2+1)-(n+1)*nis-(n+1)*njs+2*(nis*njs)/n)/(2*(n-1)); 32 | 33 | A=t1+t2-t3; %no. agreements 34 | D= -t2+t3; %no. disagreements 35 | 36 | if t1==nc 37 | AR=0; %avoid division by zero; if k=1, define Rand = 0 38 | else 39 | AR=(A-nc)/(t1-nc); %adjusted Rand - Hubert & Arabie 1985 40 | end 41 | 42 | RI=A/t1; %Rand 1971 %Probability of agreement 43 | MI=D/t1; %Mirkin 1970 %p(disagreement) 44 | HI=(A-D)/t1; %Hubert 1977 %p(agree)-p(disagree) 45 | 46 | function Cont=Contingency(Mem1,Mem2) 47 | 48 | if nargin < 2 | min(size(Mem1)) > 1 | min(size(Mem2)) > 1 49 | error('Contingency: Requires two vector arguments') 50 | return 51 | end 52 | 53 | Cont=zeros(max(Mem1),max(Mem2)); 54 | 55 | for i = 1:length(Mem1); 56 | Cont(Mem1(i),Mem2(i))=Cont(Mem1(i),Mem2(i))+1; 57 | end 58 | -------------------------------------------------------------------------------- /GMC.m: -------------------------------------------------------------------------------- 1 | % 2 | %% min sum_v{sum_i{||x_i - x_j||^2*s_ij + alpha*||s_i||^2} + w_v||U - Sv||^2 + lambda*trace(F'*Lu*F)} 3 | % s.t Sv>=0, 1^T*Sv_i=1, U>=0, 1^T*Ui=1, F'*F=I 4 | % 5 | function [y, U, S0, S0_initial, F, evs] = GMC(X, c, lambda, normData) 6 | %% input: 7 | % X{}: multi-view dataset, each cell is a view, each column is a data point 8 | % c: cluster number 9 | % lambda: parameter (default 1) 10 | %% output: 11 | % S0: similarity-induced graph (SIG) matrix for each view 12 | % y: the final clustering result, i.e., cluster indicator vector 13 | % U: the learned unified matrix 14 | % F: the embedding representation 15 | % evs: eigenvalues of learned graph Laplacian matrix 16 | 17 | NITER = 20; 18 | zr = 10e-11; 19 | pn = 15; % number of neighbours for constructS_PNG 20 | islocal = 1; % only update the similarities of neighbors if islocal=1 21 | if nargin < 3 22 | lambda = 1; 23 | end; 24 | if nargin < 4 25 | normData = 1; 26 | end; 27 | 28 | num = size(X{1},2); % number of instances 29 | m = length(X); % number of views 30 | %% Normalization: Z-score 31 | if normData == 1 32 | for i = 1:m 33 | for j = 1:num 34 | normItem = std(X{i}(:,j)); 35 | if (0 == normItem) 36 | normItem = eps; 37 | end; 38 | X{i}(:,j) = (X{i}(:,j)-mean(X{i}(:,j)))/(normItem); 39 | end; 40 | end; 41 | end; 42 | 43 | %% initialize S0: Constructing the SIG matrices 44 | S0 = cell(1,m); 45 | for i = 1:m 46 | [S0{i}, ~] = InitializeSIGs(X{i}, pn, 0); 47 | end; 48 | S0_initial = S0; 49 | 50 | %% initialize U, F and w 51 | U = zeros(num); 52 | for i = 1:m 53 | U = U + S0{i}; 54 | end; 55 | U = U/m; 56 | for j = 1:num 57 | U(j,:) = U(j,:)/sum(U(j,:)); 58 | end; 59 | % % choose the top-k neighbors 60 | % [~, ids] = sort(U,2,'descend'); 61 | % ts = zeros(num); 62 | % for i =1:num 63 | % ts(i,ids(i,1:pn)) = U(i,ids(i,1:pn)); 64 | % end 65 | % for j = 1:num 66 | % ts(j,:) = ts(j,:)/sum(ts(j,:)); 67 | % end 68 | % U = ts; 69 | 70 | sU = (U+U')/2; 71 | D = diag(sum(sU)); 72 | L = D - sU; 73 | [F, ~, evs]=eig1(L, c, 0); 74 | 75 | w = ones(1,m)/m; 76 | 77 | idxx = cell(1,m); 78 | ed = cell(1,m); 79 | for v = 1:m 80 | ed{v} = L2_distance_1(X{v}, X{v}); 81 | [~, idxx{v}] = sort(ed{v}, 2); % sort each row 82 | end; 83 | 84 | %% update ... 85 | for iter = 1:NITER 86 | % update S^v 87 | for v = 1:m 88 | S0{v} = zeros(num); 89 | for i = 1:num 90 | id = idxx{v}(i,2:pn+2); 91 | di = ed{v}(i, id); 92 | numerator = di(pn+1)-di+2*w(v)*U(i,id(:))-2*w(v)*U(i,id(pn+1)); 93 | denominator1 = pn*di(pn+1)-sum(di(1:pn)); 94 | denominator2 = 2*w(v)*sum(U(i,id(1:pn)))-2*pn*w(v)*U(i,id(pn+1)); 95 | S0{v}(i,id) = max(numerator/(denominator1+denominator2+eps),0); 96 | end; 97 | % for j = 1:num 98 | % normItem = sum(S0{v}(j,:)); 99 | % if normItem == 0 100 | % normItem = eps; 101 | % end; 102 | % S0{v}(j,:) = S0{v}(j,:)/normItem; 103 | % end; 104 | end; 105 | % update w 106 | for v = 1:m 107 | US = U - S0{v}; 108 | distUS = norm(US, 'fro')^2; 109 | if distUS == 0 110 | distUS = eps; 111 | end; 112 | w(v) = 0.5/sqrt(distUS); 113 | end; 114 | % disp(['weights: ',num2str(w)]); 115 | % update U 116 | dist = L2_distance_1(F',F'); 117 | U = zeros(num); 118 | for i=1:num 119 | idx = zeros(); 120 | for v = 1:m 121 | s0 = S0{v}(i,:); 122 | idx = [idx,find(s0>0)]; 123 | end; 124 | idxs = unique(idx(2:end)); 125 | if islocal == 1 126 | idxs0 = idxs; 127 | else 128 | idxs0 = 1:num; 129 | end; 130 | for v = 1:m 131 | s1 = S0{v}(i,:); 132 | si = s1(idxs0); 133 | di = dist(i,idxs0); 134 | mw = m*w(v); 135 | lmw = lambda/mw; 136 | q(v,:) = si-0.5*lmw*di; 137 | end; 138 | U(i,idxs0) = SloutionToP19(q,m); 139 | clear q; 140 | end; 141 | % % choose the top-k neighbors 142 | % [~, ids] = sort(U,2,'descend'); 143 | % ts = zeros(num); 144 | % for i =1:num 145 | % ts(i,ids(i,1:pn)) = U(i,ids(i,1:pn)); 146 | % end 147 | % for j = 1:num 148 | % ts(j,:) = ts(j,:)/sum(ts(j,:)); 149 | % end 150 | % sU = ts; 151 | % update F 152 | sU = U; 153 | sU = (sU+sU')/2; 154 | D = diag(sum(sU)); 155 | L = D-sU; 156 | F_old = F; 157 | [F, ~, ev]=eig1(L, c, 0, 0); 158 | evs(:,iter+1) = ev; 159 | % update lambda and the stopping criterion 160 | fn1 = sum(ev(1:c)); 161 | fn2 = sum(ev(1:c+1)); 162 | if fn1 > zr 163 | lambda = 2*lambda; 164 | elseif fn2 < zr 165 | lambda = lambda/2; 166 | F = F_old; 167 | else 168 | disp(['iter = ',num2str(iter),' lambda:',num2str(lambda)]); 169 | break; 170 | end; 171 | end; 172 | %% generating the clustering result 173 | [clusternum, y]=graphconncomp(sparse(sU)); y = y'; 174 | if clusternum ~= c 175 | fprintf('Can not find the correct cluster number: %d\n', c) 176 | end; 177 | 178 | 179 | -------------------------------------------------------------------------------- /Run_GMC_ToyExamples.m: -------------------------------------------------------------------------------- 1 | %% Toy examples on Two-Moon data set and Three-Ring data set 2 | % Graph-based Multi-view Clustering (GMC) 3 | % 4 | %% 5 | clc; close all; clear all; 6 | currentFolder = pwd; 7 | addpath(genpath(currentFolder)); 8 | 9 | %% Load toy data set 10 | datadir = 'Dataset/'; 11 | m = 2; % number of views 12 | X = cell(1,m); 13 | dataname = input('Input the name of data set: (TwoMoon or ThreeRing)\n','s'); 14 | while(1) 15 | if strcmp(dataname,'TwoMoon') 16 | dataf = [datadir, dataname]; 17 | c = 2; 18 | load(dataf); 19 | break; 20 | elseif strcmp(dataname,'ThreeRing') 21 | dataf = [datadir, dataname]; 22 | c = 3; 23 | load(dataf); 24 | flag = 1; 25 | break; 26 | else 27 | dataType = input('Please only input TwoMoon or ThreeRing\n','s'); 28 | end; 29 | 30 | end 31 | 32 | %% Call GMC algorithm 33 | num = size(X{1},1); % the number of samples 34 | data = cell(1,m); 35 | for i = 1:m 36 | data{i} = X{i}'; 37 | end 38 | [predY, U, S0, S0_initial, F, evs] = GMC(data, c, 1, 0); 39 | metric = CalcMeasures(y0(:,1), predY); 40 | fprintf('Data set %s-> ACC:%.4f\tNMI:%.4f\tARI:%.4f\terror_cnt:%d\n',dataname,metric(1),metric(2),metric(3),metric(4)); 41 | 42 | markerSize = 20; 43 | %% Original data 44 | for v = 1:m 45 | lab = y0(:,v); 46 | cLab = unique(lab); 47 | figure; 48 | plot(X{v}(:,1),X{v}(:,2),'.k', 'MarkerSize', markerSize); hold on; 49 | plot(X{v}(lab==cLab(1),1),X{v}(lab==cLab(1),2),'.r', 'MarkerSize', 20); hold on; 50 | plot(X{v}(lab==cLab(2),1),X{v}(lab==cLab(2),2),'.', 'MarkerSize', markerSize); hold on; 51 | if flag 52 | plot(X{v}(lab==cLab(3),1),X{v}(lab==cLab(3),2),'.', 'Color', [79 79 79]/255, 'MarkerSize', markerSize); hold on; 53 | end; 54 | % set(gca,'xlim',[-1.7,1.7],'xtick',[-1.5:0.5:1.5]) % set x-axis 55 | % set(gca,'ylim',[-1.2,1.2],'ytick',[-1:0.5:1]) % set y-axix 56 | set(gca,'FontName','Times New Roman','FontSize',20,'LineWidth',1.2); 57 | % str = 'The orighnal data'; 58 | % titlename = sprintf('%s: View-%d',str,v); 59 | % t = title(titlename); 60 | % set(t,'FontName','Times New Roman','FontSize',20.0); 61 | axis equal; 62 | end; 63 | 64 | %% Original connected graph with probabilistic neighbors, line width denotes similarity 65 | S1 = cell(1,m); 66 | for v = 1:m 67 | S1{v} = S0_initial{v}; 68 | lab = y0(:,v); 69 | cLab = unique(lab); 70 | figure; 71 | plot(X{v}(:,1),X{v}(:,2),'.k', 'MarkerSize', markerSize); hold on; 72 | plot(X{v}(lab==cLab(1),1),X{v}(lab==cLab(1),2),'.r', 'MarkerSize', markerSize); hold on; 73 | plot(X{v}(lab==cLab(2),1),X{v}(lab==cLab(2),2),'.', 'MarkerSize', markerSize); hold on; 74 | if flag 75 | plot(X{v}(lab==cLab(3),1),X{v}(lab==cLab(3),2),'.', 'Color', [79 79 79]/255, 'MarkerSize', markerSize); hold on; 76 | end; 77 | for ii = 1 : num 78 | for jj = 1 : ii 79 | weight = S1{v}(ii, jj); 80 | if weight > 0 81 | plot([X{v}(ii, 1), X{v}(jj, 1)], [X{v}(ii, 2), X{v}(jj, 2)], '-', 'Color', [0 197 205]/255, 'LineWidth', 5*weight), hold on; 82 | end; 83 | end; 84 | end; 85 | % set(gca,'xlim',[-1.7,1.7],'xtick',[-1.5:0.5:1.5]) % set x-axis 86 | % set(gca,'ylim',[-1.2,1.2],'ytick',[-1:0.5:1]) %set y-axis 87 | set(gca,'FontName','Times New Roman','FontSize',20,'LineWidth',1.2); 88 | % str = 'Connected graph with probabilistic neighbors'; 89 | % titlename = sprintf('%s: View-%d',str,v); 90 | % t = title(titlename); 91 | % set(t,'FontName','Times New Roman','FontSize',20.0); 92 | axis equal; 93 | end; 94 | 95 | %% the learned graph of each view by GMC, line width denotes similarity 96 | S2 = cell(1,m); 97 | for v = 1:m 98 | S2{v} = S0{v}; 99 | lab = y0(:,v); 100 | cLab = unique(lab); 101 | figure; 102 | plot(X{v}(:,1),X{v}(:,2),'.k', 'MarkerSize', markerSize); hold on; 103 | plot(X{v}(lab==cLab(1),1),X{v}(lab==cLab(1),2),'.r', 'MarkerSize', markerSize); hold on; 104 | plot(X{v}(lab==cLab(2),1),X{v}(lab==cLab(2),2),'.', 'MarkerSize', markerSize); hold on; 105 | if flag 106 | plot(X{v}(lab==cLab(3),1),X{v}(lab==cLab(3),2),'.', 'Color', [79 79 79]/255, 'MarkerSize', markerSize); hold on; 107 | end; 108 | for ii = 1 : num 109 | for jj = 1 : ii 110 | weight = S2{v}(ii, jj); 111 | if weight > 0 112 | plot([X{v}(ii, 1), X{v}(jj, 1)], [X{v}(ii, 2), X{v}(jj, 2)], '-', 'Color', [0 197 205]/255, 'LineWidth', 5*weight), hold on; 113 | end; 114 | end; 115 | end; 116 | % set(gca,'xlim',[-1.7,1.7],'xtick',[-1.5:0.5:1.5]) % set x-axis 117 | % set(gca,'ylim',[-1.2,1.2],'ytick',[-1:0.5:1]) % set y-axis 118 | set(gca,'FontName','Times New Roman','FontSize',20,'LineWidth',1.2); 119 | % str = 'Connected graph with probabilistic neighbors'; 120 | % titlename = sprintf('%s: View-%d',str,v); 121 | % t = title(titlename); 122 | % set(t,'FontName','Times New Roman','FontSize',20.0); 123 | axis equal; 124 | end; 125 | 126 | %% the learned unified graph by GMC, line width denotes similarity 127 | U2 = U; 128 | for v = 1:m 129 | lab = y0(:,v); 130 | cLab = unique(lab); 131 | figure; 132 | plot(X{v}(:,1),X{v}(:,2),'.k', 'MarkerSize', markerSize); hold on; 133 | plot(X{v}(lab==cLab(1),1),X{v}(lab==cLab(1),2),'.r', 'MarkerSize', markerSize); hold on; 134 | plot(X{v}(lab==cLab(2),1),X{v}(lab==cLab(2),2),'.', 'MarkerSize', markerSize); hold on; 135 | if flag 136 | plot(X{v}(lab==cLab(3),1),X{v}(lab==cLab(3),2),'.', 'Color', [79 79 79]/255, 'MarkerSize', markerSize); hold on; 137 | end; 138 | for ii = 1 : num 139 | for jj = 1 : ii 140 | weight = U2(ii, jj); 141 | if weight > 0 142 | plot([X{v}(ii, 1), X{v}(jj, 1)], [X{v}(ii, 2), X{v}(jj, 2)], '-', 'Color', [0 197 205]/255, 'LineWidth', 5*weight), hold on; 143 | end; 144 | end; 145 | end; 146 | % set(gca,'xlim',[-1.7,1.7],'xtick',[-1.5:0.5:1.5]) % set x-axis 147 | % set(gca,'ylim',[-1.2,1.2],'ytick',[-1:0.5:1]) % set y-axix 148 | set(gca,'FontName','Times New Roman','FontSize',20,'LineWidth',1.2); 149 | % str = 'Learnt connected graph'; 150 | % titlename = sprintf('%s: View-%d',str,v); 151 | % t = title(titlename); 152 | % set(t,'FontName','Times New Roman','FontSize',20.0); 153 | axis equal; 154 | end; 155 | --------------------------------------------------------------------------------