├── util ├── README.txt └── data2batch.m ├── clustering ├── getclusterlabel.m ├── netrecons.m ├── netcomput.m ├── savetxtdata.m ├── updategroup.m ├── pur.m ├── acc.m ├── randinitial.m ├── nmi.m ├── runclustering.m ├── CG_CLUSTER.m ├── minimize.m └── hungarian.m ├── data ├── loadmnist.m ├── batch2data.m ├── README.txt ├── converter.m └── makebatches.m ├── show ├── showresults.m ├── outcomput.m └── mnistdisp.m ├── initialize ├── pretrainnet.m ├── randweights.m ├── finetuning.m ├── CG_DATA.m ├── rbm.m └── minimize.m ├── demo.m └── README.txt /util/README.txt: -------------------------------------------------------------------------------- 1 | You need download VL_feat and save the tool file under this util fold, we have tested the version 0.9.16. 2 | Then edit the VL_path in Demo.m. 3 | 4 | data2batch.m can help you dealing with other database to train with our code. This function can convert normal data into batch-like one. -------------------------------------------------------------------------------- /clustering/getclusterlabel.m: -------------------------------------------------------------------------------- 1 | function CL = getclusterlabel(clusterdata,centro,w1,w2,w3,w4) 2 | %get clustering label from feature through the encoder nets 3 | targetout = netcomput(clusterdata,w1,w2,w3,w4);%get hidden layer value 4 | value = vl_alldist(targetout',centro'); %here use the VL_Feat fun is used to compute distence 5 | [~, CL] = min(value(:,1:10),[],2); 6 | -------------------------------------------------------------------------------- /clustering/netrecons.m: -------------------------------------------------------------------------------- 1 | function out = netrecons(feature,w5,w6,w7,w8) 2 | %this func can get reconstructed data with input feature. 3 | N =size(feature,1); 4 | w4probs = [feature ones(N,1)]; 5 | w5probs = 1./(1 + exp(-w4probs*w5)); w5probs = [w5probs ones(N,1)]; 6 | w6probs = 1./(1 + exp(-w5probs*w6)); w6probs = [w6probs ones(N,1)]; 7 | w7probs = 1./(1 + exp(-w6probs*w7)); w7probs = [w7probs ones(N,1)]; 8 | out = 1./(1 + exp(-w7probs*w8)); 9 | 10 | end -------------------------------------------------------------------------------- /clustering/netcomput.m: -------------------------------------------------------------------------------- 1 | function out = netcomput(data,w1,w2,w3,w4) 2 | %this function can get top(or hidden)layer value with autoencode network, with sigmoid function. 3 | N =size(data,1); 4 | data = [data ones(N,1)]; 5 | w1probs = 1./(1 + exp(-data*w1)); w1probs = [w1probs ones(N,1)]; 6 | w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)]; 7 | w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs ones(N,1)]; 8 | out = 1./(1 + exp(-w3probs*w4)); 9 | end -------------------------------------------------------------------------------- /data/loadmnist.m: -------------------------------------------------------------------------------- 1 | % Load MNIST digit dataset and convert into the format for clustering code. 2 | data_path = strcat(cur_path,'\data\mnistdata.mat'); 3 | batch_path = strcat(cur_path,'\data\mnistbatch.mat'); 4 | if ~exist(batch_path, 'file') 5 | converter; 6 | makebatches; 7 | save(batch_path,'batchdata','batchtargets'); 8 | [clusterdata,clustertargets]=batch2data(batchdata,batchtargets); 9 | save(data_path, 'clusterdata', 'clustertargets'); 10 | else 11 | load(batch_path); 12 | load(data_path); 13 | end 14 | -------------------------------------------------------------------------------- /data/batch2data.m: -------------------------------------------------------------------------------- 1 | function [clusterdata,clustertargets]=batch2data(batchdata,batchtargets) 2 | %this function can convert batch-like data(3 dims) into data-like(2 dims) one. 3 | [numcases numdims numbatches]=size(batchdata); 4 | numcluster = size(batchtargets,2); 5 | clusterdata = zeros(numcases*numbatches,numdims); 6 | clustertargets = zeros(numcases*numbatches,numcluster); 7 | for i = 1:numbatches 8 | clusterdata((i-1)*numcases+1:i*numcases,:) = batchdata(:,:,i); 9 | clustertargets((i-1)*numcases+1:i*numcases,:) = batchtargets(:,:,i); 10 | end 11 | -------------------------------------------------------------------------------- /data/README.txt: -------------------------------------------------------------------------------- 1 | The MNIST data is downloaded from: http://yann.lecun.com/exdb/mnist/. This database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. 2 | 3 | You can download the datasets and unzip them here. In our clustering code, we only use the training set with 60k obiects. 4 | 5 | Other database is easy to convert with the data2batch.m function in ../util/data2batch.m. -------------------------------------------------------------------------------- /util/data2batch.m: -------------------------------------------------------------------------------- 1 | function [batchdata,batchtargets]=batch2data(clusterdata,clustertargets) 2 | %this function can convert data-like(2 dims) into batch-like data(3 dims) one. 3 | [num,dims]= size(clusterdata); 4 | clusternum = size(clustertargets,2); 5 | num_case=100;%you need to set num of batch here. 6 | batchdata = zeros(num_case,dims,floor(num/num_case)); 7 | batchtargets = zeros(num_case,clusternum,floor(num/num_case)); 8 | for i = 1:floor(num/num_case) 9 | batchdata(:,:,i) = clusterdata((i-1)*num_case+1:i*num_case,:); 10 | batchtargets(:,:,i) = clustertargets((i-1)*num_case+1:i*num_case,:); 11 | end -------------------------------------------------------------------------------- /clustering/savetxtdata.m: -------------------------------------------------------------------------------- 1 | function [] = savetxtdata(epoch,RL,CL,cur_path) 2 | %quickly get pur , acc and nmi value,then save them in '.txt' format 3 | cluster_nmi = nmi(CL',RL');%get NMI 4 | cluster_pur = pur(CL',RL');%get PUR 5 | cluster_acc = acc(RL,CL);%get ACC 6 | cluster_results_path=strcat(cur_path,'\tmp\ClusterResult.txt'); 7 | if epoch==1 8 | fid=fopen(cluster_results_path,'wt'); 9 | fprintf(fid,'%s\t','epoch'); 10 | fprintf(fid,'%s\t','NMI'); 11 | fprintf(fid,'%s\t','PUR'); 12 | fprintf(fid,'%s\n','ACC'); 13 | fclose(fid); 14 | end 15 | fid=fopen(cluster_results_path,'at+'); 16 | fprintf(fid,'%g\t',epoch); 17 | fprintf(fid,'%g\t',cluster_nmi); 18 | fprintf(fid,'%g\t',cluster_pur); 19 | fprintf(fid,'%g\n',cluster_acc); 20 | fclose(fid); 21 | end -------------------------------------------------------------------------------- /clustering/updategroup.m: -------------------------------------------------------------------------------- 1 | function [batchBasis,C] = updategroup(clusterdata,CL,clustertargets,w1,w2,w3,w4) 2 | %update the cluster assignment as well as renew the basises for cluster restrains. 3 | N =size(clusterdata,1); 4 | K = size(clustertargets,2); 5 | C_dims = size(w4,2); 6 | batchBasis = zeros(N,C_dims); 7 | targetout = netcomput(clusterdata,w1,w2,w3,w4);%get hidden layer value 8 | C = zeros(K,C_dims); 9 | counter = zeros(1,K); 10 | %directly get K cluster centre 'C' in hiden layer(or code layer) 11 | for i = 1:N 12 | C(CL(i),:) = C(CL(i),:)+ targetout(i,:); 13 | counter(CL(i)) = counter(CL(i))+1; 14 | end 15 | for i = 1:K 16 | C(i,:)= C(i,:)./ counter(i); 17 | end 18 | 19 | batchBasis(:, :) = C(CL, :);%copy centro by CL series for finetuning 20 | end 21 | 22 | 23 | -------------------------------------------------------------------------------- /clustering/pur.m: -------------------------------------------------------------------------------- 1 | function purity = pur( A, B ) 2 | %get purity of clustering 3 | %A for cluster label,B for real label 4 | if length( A ) ~= length( B) 5 | error('length( A ) must == length( B)'); 6 | end 7 | 8 | %get cluser num 9 | BL = unique(B); 10 | clusternum = length(BL); 11 | totnum = length(A); 12 | pure_train=zeros(clusternum,clusternum); 13 | pcounter_train=zeros(1,clusternum); 14 | accnum = 0; 15 | 16 | for i=1:totnum 17 | pure_train(A(i),B(i))=pure_train((A(i)),B(i))+1; 18 | pcounter_train(A(i))=pcounter_train(A(i))+1; 19 | end 20 | 21 | for i=1:clusternum 22 | findmax=0; 23 | for j=1:clusternum 24 | if findmax<=pure_train(i,j) 25 | findmax=pure_train(i,j); 26 | end 27 | end 28 | accnum = accnum +findmax; 29 | end 30 | purity = accnum/totnum; 31 | 32 | -------------------------------------------------------------------------------- /show/showresults.m: -------------------------------------------------------------------------------- 1 | % show clustering result of mnist 2 | data_path = strcat(cur_path,'\data\mnistdata.mat'); 3 | cluster_path = strcat(cur_path,'\tmp\cluster_status.mat'); 4 | load(data_path); 5 | load(cluster_path); 6 | 7 | [C,S,output] = outcomput(clusterdata,CL,RL,w1,w2,w3,w4,w5,w6,w7,w8); 8 | %show 10 clustering centre 9 | figure('Position',[100,600,1000,200]); 10 | mnistdisp(output'); 11 | hold on; 12 | 13 | %show distribution of the clustering results 14 | figure(2) 15 | X = 0:9; 16 | for i = 1:5 17 | subplot(4,5,i); 18 | bar(X,S(i,:)); 19 | end 20 | for i = 6:10 21 | subplot(4,5,i); 22 | mnistdisp(output(i-5,:)'); 23 | end 24 | for i = 11:15 25 | subplot(4,5,i); 26 | bar(X,S(i-5,:)); 27 | end 28 | for i = 16:20 29 | subplot(4,5,i); 30 | mnistdisp(output(i-10,:)'); 31 | end 32 | hold on; 33 | 34 | -------------------------------------------------------------------------------- /initialize/pretrainnet.m: -------------------------------------------------------------------------------- 1 | %pretrain the nets with RBM 2 | weiths_path=strcat(save_path, 'rbm_weights.mat'); 3 | if exist(weiths_path, 'file') 4 | load rbm_weights; 5 | return; 6 | end 7 | epochnow = 1; 8 | maxepoch=10;%Set the max epoch for training RBM weights here 9 | Layernum = length(netStru); 10 | fprintf(1,'Pretraining a deep autoencoder. \n'); 11 | 12 | [numcases numdims numbatches]=size(batchdata); 13 | for i = 1:Layernum 14 | fprintf(1,'Pretraining Layer %d with RBM: %d-%d \n',i, numdims,netStru(i)); 15 | restart=1; 16 | numhid=netStru(i); 17 | rbm; 18 | eval(['w',num2str(i), ' = ', '[vishid ;hidbiases ];']); 19 | vishid_T= vishid'; 20 | eval(['w',num2str(Layernum*2 - i + 1), ' = ', '[vishid_T;visbiases ];']); 21 | batchdata=batchposhidprobs; 22 | end 23 | 24 | save(weiths_path, 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8'); 25 | 26 | -------------------------------------------------------------------------------- /clustering/acc.m: -------------------------------------------------------------------------------- 1 | function score = acc(true_labels, cluster_labels) 2 | % ACCURACY Compute clustering accuracy using the true and cluster labels and 3 | % return the value in 'score'. 4 | % 5 | % Input : true_labels : N-by-1 vector containing true labels 6 | % cluster_labels : N-by-1 vector containing cluster labels 7 | % 8 | % Output : score : clustering accuracy 9 | % 10 | % Author : Wen-Yen Chen (wychen@alumni.cs.ucsb.edu) 11 | % Chih-Jen Lin (cjlin@csie.ntu.edu.tw) 12 | 13 | % Compute the confusion matrix 'cmat', where 14 | % col index is for true label (CAT), 15 | % row index is for cluster label (CLS). 16 | n = length(true_labels); 17 | cat = spconvert([(1:n)' true_labels ones(n,1)]); 18 | cls = spconvert([(1:n)' cluster_labels ones(n,1)]); 19 | cls = cls'; 20 | cmat = full(cls * cat); 21 | 22 | % 23 | % Calculate accuracy 24 | % 25 | [match, cost] = hungarian(-cmat); 26 | score = (-cost/n); 27 | -------------------------------------------------------------------------------- /clustering/randinitial.m: -------------------------------------------------------------------------------- 1 | function [batchBasis, centro, RL] = randinitial(clusterdata,clustertargets,w1,w2,w3,w4) 2 | %Get rand label and centers for clustering as well as the real label for evaluating. 3 | N = size(clusterdata,1); 4 | K = size(clustertargets,2); 5 | D_dims = size(clusterdata,2); 6 | C_dims = size(w4,2); 7 | batchBasis = zeros(N,C_dims); 8 | [~, RL] = max(clustertargets, [], 2);%get real label from clustertargets, only used for computing pur, acc & nmi. 9 | CLK = randperm(N); %compute rand label 10 | CLK = mod(CLK,K)+1; 11 | CLK=CLK'; 12 | 13 | %compute C across all dataset 14 | %get K cluster centre 'C' in visual layer(or bottom layer) 15 | C = zeros(K,D_dims); 16 | counter = zeros(1,K); 17 | for i = 1:N 18 | C(CLK(i),:) = C(CLK(i),:)+clusterdata(i,:); 19 | counter(CLK(i)) = counter(CLK(i))+1; 20 | end 21 | for i = 1:K 22 | C(i,:)= C(i,:)./ counter(i); 23 | end 24 | 25 | centro = netcomput(C,w1,w2,w3,w4);%get centre from 'C' through the encoder nets 26 | batchBasis(:, :) = centro(CLK, :);%copy centre by CLK series for rand initial training. 27 | end 28 | 29 | 30 | -------------------------------------------------------------------------------- /initialize/randweights.m: -------------------------------------------------------------------------------- 1 | %rand weights to achieve a quick start of clustering 2 | weiths_path=strcat(save_path, 'rand_weights.mat'); 3 | if exist(weiths_path, 'file') 4 | load rand_weights; 5 | return; 6 | end 7 | 8 | l1=numdims; 9 | l2=netStru(1); 10 | l3=netStru(2); 11 | l4=netStru(3); 12 | l5=netStru(4); 13 | 14 | bound = 4*sqrt(6/(l1+l2)); 15 | w1 = -bound + 2*bound*rand(l1, l2); 16 | w8 = w1'; 17 | bound = 4*sqrt(6/(l2+l3)); 18 | w2 = -bound + 2*bound*rand(l2, l3); 19 | w7 = w2'; 20 | bound = 4*sqrt(6/(l3+l4)); 21 | w3 = -bound + 2*bound*rand(l3, l4); 22 | w6 = w3'; 23 | bound = 4*sqrt(6/(l4+l5)); 24 | w4 = -bound + 2*bound*rand(l4, l5); 25 | w5 = w4'; 26 | 27 | s = rand(1,size(w1,2)); 28 | w1 = [w1;s]; 29 | s = rand(1,size(w2,2)); 30 | w2 = [w2;s]; 31 | s = rand(1,size(w3,2)); 32 | w3 = [w3;s]; 33 | s = rand(1,size(w4,2)); 34 | w4 = [w4;s]; 35 | 36 | s = rand(1,size(w5,2)); 37 | w5 = [w5;s]; 38 | s = rand(1,size(w6,2)); 39 | w6 = [w6;s]; 40 | s = rand(1,size(w7,2)); 41 | w7 = [w7;s]; 42 | s = rand(1,size(w8,2)); 43 | w8 = [w8;s]; 44 | 45 | save(weiths_path, 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8'); -------------------------------------------------------------------------------- /show/outcomput.m: -------------------------------------------------------------------------------- 1 | function [C,S,output] = outcomput(digitdata,CL,RL,w1,w2,w3,w4,w5,w6,w7,w8) 2 | %This code is for MNIST clustering. It can compute the cluster centers as well as the distribution of clustering. 3 | N =size(digitdata,1); 4 | targetout = netcomput(digitdata,w1,w2,w3,w4);%get hidden layer value 5 | 6 | %get 10 cluster centre 'C' in visual layer(or down layer) 7 | C = zeros(10,10);%C is the cluster center in code layer 8 | counter = zeros(1,10); 9 | for i = 1:N 10 | C(CL(i),:) = C(CL(i),:)+ targetout(i,:); 11 | counter(CL(i)) = counter(CL(i))+1; 12 | end 13 | 14 | for i = 1:10 15 | C(i,:)= C(i,:)./ counter(i); 16 | end 17 | 18 | %count the distribution information 19 | S = zeros(10,10); 20 | for i = 1:N 21 | S(CL(i),RL(i)) = S(CL(i),RL(i))+1; 22 | end 23 | 24 | %compute the cluster center via the decode nets. 25 | w4probs = [C ones(10,1)]; 26 | w5probs = 1./(1 + exp(-w4probs*w5)); w5probs = [w5probs ones(10,1)]; 27 | w6probs = 1./(1 + exp(-w5probs*w6)); w6probs = [w6probs ones(10,1)]; 28 | w7probs = 1./(1 + exp(-w6probs*w7)); w7probs = [w7probs ones(10,1)]; 29 | output = 1./(1 + exp(-w7probs*w8)); 30 | 31 | -------------------------------------------------------------------------------- /show/mnistdisp.m: -------------------------------------------------------------------------------- 1 | % Version 1.000 2 | % 3 | % Code provided by Ruslan Salakhutdinov and Geoff Hinton 4 | % 5 | % Permission is granted for anyone to copy, use, modify, or distribute this 6 | % program and accompanying programs and documents for any purpose, provided 7 | % this copyright notice is retained and prominently displayed, along with 8 | % a note saying that the original programs are available from our 9 | % web page. 10 | % The programs and documents are distributed without any warranty, express or 11 | % implied. As the programs were written for research purposes only, they have 12 | % not been tested to the degree that would be advisable in any important 13 | % application. All use of these programs is entirely at the user's own risk. 14 | 15 | function [err] = mnistdisp(digits); 16 | % display a group of MNIST images 17 | col=28; 18 | row=28; 19 | 20 | [dd,N] = size(digits); 21 | imdisp=zeros(2*28,ceil(N/2)*28); 22 | 23 | for nn=1:N 24 | ii=rem(nn,2); if(ii==0) ii=2; end 25 | jj=ceil(nn/2); 26 | 27 | img1 = reshape(digits(:,nn),row,col); 28 | img2(((ii-1)*row+1):(ii*row),((jj-1)*col+1):(jj*col))=img1'; 29 | end 30 | 31 | imagesc(img2,[0 1]); colormap gray; axis equal; axis off; 32 | drawnow; 33 | err=0; 34 | 35 | -------------------------------------------------------------------------------- /clustering/nmi.m: -------------------------------------------------------------------------------- 1 | function MIhat = nmi( A, B ) 2 | % NMI Normalized mutual information 3 | % http://en.wikipedia.org/wiki/Mutual_information 4 | % http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 5 | % Author: http://www.cnblogs.com/ziqiao/ [2011/12/13] 6 | if length( A ) ~= length( B) 7 | error('length( A ) must == length( B)'); 8 | end 9 | total = length(A); 10 | A_ids = unique(A); 11 | B_ids = unique(B); 12 | 13 | % Mutual information 14 | MI = 0; 15 | for idA = A_ids 16 | for idB = B_ids 17 | idAOccur = find( A == idA ); 18 | idBOccur = find( B == idB ); 19 | idABOccur = intersect(idAOccur,idBOccur); 20 | 21 | px = length(idAOccur)/total; 22 | py = length(idBOccur)/total; 23 | pxy = length(idABOccur)/total; 24 | 25 | MI = MI + pxy*log2(pxy/(px*py)+eps); % eps : the smallest positive number 26 | 27 | end 28 | end 29 | 30 | % Normalized Mutual information 31 | Hx = 0; % Entropies 32 | for idA = A_ids 33 | idAOccurCount = length( find( A == idA ) ); 34 | Hx = Hx - (idAOccurCount/total) * log2(idAOccurCount/total + eps); 35 | end 36 | Hy = 0; % Entropies 37 | for idB = B_ids 38 | idBOccurCount = length( find( B == idB ) ); 39 | Hy = Hy - (idBOccurCount/total) * log2(idBOccurCount/total + eps); 40 | end 41 | 42 | MIhat = 2 * MI / (Hx+Hy); 43 | end 44 | -------------------------------------------------------------------------------- /demo.m: -------------------------------------------------------------------------------- 1 | % Auto-encoder Based Data Clustering 2 | % Demo Code Version 1.0 3 | % 2014.08.13, by Chunfeng Song and Yongzhen Huang 4 | 5 | clear all; clc; 6 | %%%% STEP 1. Set the code params %%%% 7 | forceRestart = false; %if you need to retrain the clustering nets 8 | if forceRestart 9 | delete('./tmp/*'); 10 | end 11 | addpath(genpath(pwd)); 12 | Max_epoch = 60; 13 | R_data = 1; 14 | R_cluster = 0.1;%This is default value for MNIST database 15 | netStru = [1000 250 50 10]; 16 | % If you set a netStru = [200 100 30 10], this will ... 17 | % make an 8 layer deep net with a structure of 18 | % [inputsDims-200-100-30-10-30-100-200-inputsDims] 19 | withRBMinit = true; 20 | fine_tunning = true; 21 | cur_path = pwd; 22 | VL_path = strcat(cur_path,'\util\vlfeat-0.9.16-bin\vlfeat-0.9.16\toolbox\vl_setup.m');%Set up VL_feat to boost the training; 23 | run(VL_path); 24 | data_path = strcat(cur_path,'\data\data.mat'); 25 | save_path = strcat(cur_path,'\tmp\'); 26 | if ~exist(data_path, 'file') 27 | loadmnist; 28 | data_path = '.\data\mnistbatch.mat'; 29 | else 30 | load(data_path); 31 | end 32 | [numcases numdims numbatches]=size(batchdata); 33 | 34 | %%%% STEP 2. Pretrain with RBM or randweights %%%% 35 | if withRBMinit 36 | pretrainnet; 37 | else 38 | randweights; 39 | end 40 | 41 | %%%% STEP 3. Fine-tunning, this step will cost about 5 hours(for mnist)%%%% 42 | if fine_tunning 43 | finetuning; 44 | end 45 | 46 | %%%% STEP 4. Clustering %%%% 47 | runclustering; 48 | 49 | %%%% STEP 5. Show clustering results %%%% 50 | showresults; 51 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | * Auto-encoder Based Data Clustering * 3 | * Demo Code Version 1.0 * 4 | * By Chunfeng Song and Yongzhen Huang * 5 | * E-mail: developfeng@gmail.com * 6 | ------------------------------------------------------------------------ 7 | 8 | i. Overview 9 | ii. Copying 10 | iii. Use 11 | 12 | i. OVERVIEW 13 | ----------------------------- 14 | The auto-encoder based data clustering methods test on MNIST database to 15 | provide a demo version, supporting our paper: 16 | 17 | Chunfeng Song, Feng Liu, Yongzhen Huang, Liang Wang, Tieniu Tan: 18 | Auto-encoder Based Data Clustering. CIARP2013. 19 | 20 | Small part of this code is modified based on an earlier released 21 | package of Hinton's, we have stated their right in our code. 22 | 23 | ii. COPYING 24 | ----------------------------- 25 | We share this code only for research use. We neither warrant 26 | correctness nor take any responsibility for the consequences of 27 | using this code. If you find any problem or inappropriate content 28 | in this code, feel free to contact us. 29 | 30 | iii. USE 31 | ----------------------------- 32 | This code should work on Windows or Linux, with MATLAB. 33 | Take MNIST for example, follow these steps: 34 | 35 | 1) Make sure your matlab has access to VLFeat, You could find them at: 36 | http://www.vlfeat.org/. Download and save it under the path of ./util/. 37 | 2) Download the MNIST database from: http://yann.lecun.com/exdb/mnist/, 38 | and then unzip them in the directory of ./data/. 39 | 3) Setup params in Demo.m. The default setting can achieve the result 40 | reported in our CIARP13 paper. 41 | 4) Run Demo.m. 42 | 5) You can see the clustering result in ./tmp/ClusterResult.txt. -------------------------------------------------------------------------------- /clustering/runclustering.m: -------------------------------------------------------------------------------- 1 | %Run clustering now 2 | cluster_staus_path=strcat(cur_path,'\tmp\cluster_status.mat'); 3 | if ~exist(cluster_staus_path, 'file') 4 | epochnow=1; 5 | [basis,centro,RL] = randinitial(clusterdata,clustertargets,w1,w2,w3,w4); 6 | else 7 | load(cluster_staus_path); 8 | end 9 | % initialization 10 | N = size(clusterdata,1); 11 | l1=size(w1,1)-1; 12 | l2=size(w2,1)-1; 13 | l3=size(w3,1)-1; 14 | l4=size(w4,1)-1; 15 | l5=size(w5,1)-1; 16 | l6=size(w6,1)-1; 17 | l7=size(w7,1)-1; 18 | l8=size(w8,1)-1; 19 | l9=l1; 20 | 21 | for epoch = epochnow:Max_epoch 22 | tt=0; 23 | for batch = 1:N/1000 24 | fprintf(1,'Clustering epoch %d batch %d\r',epoch,batch); 25 | % assign each epoch with 1000 batches % 26 | tt=tt+1; 27 | cur_data=[]; 28 | cur_basis = []; 29 | cur_data = clusterdata((((tt-1)*1000+1):tt*1000),:); 30 | cur_basis = basis((((tt-1)*1000+1):tt*1000),:); 31 | % Perform CG with 3 linearsearch % 32 | max_iter=3; 33 | VV = [w1(:)' w2(:)' w3(:)' w4(:)' w5(:)' w6(:)' w7(:)' w8(:)']'; 34 | Dim = [l1; l2; l3; l4; l5; l6; l7; l8; l9]; 35 | [X, fX] = minimize(VV,'CG_CLUSTER',max_iter,Dim,cur_data,cur_basis,R_data,R_cluster); 36 | % updata cluster weights % 37 | w1 = reshape(X(1:(l1+1)*l2),l1+1,l2); 38 | xxx = (l1+1)*l2; 39 | w2 = reshape(X(xxx+1:xxx+(l2+1)*l3),l2+1,l3); 40 | xxx = xxx+(l2+1)*l3; 41 | w3 = reshape(X(xxx+1:xxx+(l3+1)*l4),l3+1,l4); 42 | xxx = xxx+(l3+1)*l4; 43 | w4 = reshape(X(xxx+1:xxx+(l4+1)*l5),l4+1,l5); 44 | xxx = xxx+(l4+1)*l5; 45 | w5 = reshape(X(xxx+1:xxx+(l5+1)*l6),l5+1,l6); 46 | xxx = xxx+(l5+1)*l6; 47 | w6 = reshape(X(xxx+1:xxx+(l6+1)*l7),l6+1,l7); 48 | xxx = xxx+(l6+1)*l7; 49 | w7 = reshape(X(xxx+1:xxx+(l7+1)*l8),l7+1,l8); 50 | xxx = xxx+(l7+1)*l8; 51 | w8 = reshape(X(xxx+1:xxx+(l8+1)*l9),l8+1,l9); 52 | end 53 | epochnow = epoch+1; 54 | CL = getclusterlabel(clusterdata,centro,w1,w2,w3,w4);%get clustering label 55 | savetxtdata(epoch,RL,CL,cur_path);%computing NMI,Purity,Accuracy and then saved as '*.txt' file 56 | [basis,centro] = updategroup(clusterdata,CL,clustertargets,w1,w2,w3,w4);%updating clustering centre for next epoch 57 | save(cluster_staus_path,'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8', 'RL', 'CL', 'epochnow', 'basis', 'centro'); 58 | end 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /clustering/CG_CLUSTER.m: -------------------------------------------------------------------------------- 1 | %The CG provider function for clustering 2 | function [f, df] = CG_CLUSTER(VV,Dim,XX, Hcen,R_data,R_cluster) 3 | 4 | l1 = Dim(1); 5 | l2 = Dim(2); 6 | l3 = Dim(3); 7 | l4= Dim(4); 8 | l5= Dim(5); 9 | l6= Dim(6); 10 | l7= Dim(7); 11 | l8= Dim(8); 12 | l9= Dim(9); 13 | N = size(XX,1); 14 | 15 | % Do decomversion. 16 | w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2); 17 | xxx = (l1+1)*l2; 18 | w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3); 19 | xxx = xxx+(l2+1)*l3; 20 | w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4); 21 | xxx = xxx+(l3+1)*l4; 22 | w4 = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5); 23 | xxx = xxx+(l4+1)*l5; 24 | w5 = reshape(VV(xxx+1:xxx+(l5+1)*l6),l5+1,l6); 25 | xxx = xxx+(l5+1)*l6; 26 | w6 = reshape(VV(xxx+1:xxx+(l6+1)*l7),l6+1,l7); 27 | xxx = xxx+(l6+1)*l7; 28 | w7 = reshape(VV(xxx+1:xxx+(l7+1)*l8),l7+1,l8); 29 | xxx = xxx+(l7+1)*l8; 30 | w8 = reshape(VV(xxx+1:xxx+(l8+1)*l9),l8+1,l9); 31 | %get targetout,and reconstructed data 32 | XX = [XX ones(N,1)]; 33 | w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs ones(N,1)]; 34 | w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)]; 35 | w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs ones(N,1)]; 36 | w4probs = 1./(1 + exp(-w3probs*w4)); 37 | targetout = w4probs; 38 | w4probs = [w4probs ones(N,1)]; 39 | w5probs = 1./(1 + exp(-w4probs*w5)); w5probs = [w5probs ones(N,1)]; 40 | w6probs = 1./(1 + exp(-w5probs*w6)); w6probs = [w6probs ones(N,1)]; 41 | w7probs = 1./(1 + exp(-w6probs*w7)); w7probs = [w7probs ones(N,1)]; 42 | XXout = 1./(1 + exp(-w7probs*w8)); 43 | 44 | f0 = -R_cluster/N*sum(sum( Hcen.*log(targetout))) ; 45 | %obj fun f0 for cluster restrain 46 | f1 = -R_data/N*sum(sum( XX(:,1:end-1).*log(XXout) + (1-XX(:,1:end-1)).*log(1-XXout))); 47 | %obj fun f1 for data restrain 48 | f = f0 + f1; 49 | %main obj fun 50 | 51 | %get gradient of w1--w8 52 | IO = R_data/N*(XXout-XX(:,1:end-1)); 53 | Ix8=IO; 54 | dw8 = w7probs'*Ix8; 55 | 56 | Ix7 = (Ix8*w8').*w7probs.*(1-w7probs); 57 | Ix7 = Ix7(:,1:end-1); 58 | dw7 = w6probs'*Ix7; 59 | 60 | Ix6 = (Ix7*w7').*w6probs.*(1-w6probs); 61 | Ix6 = Ix6(:,1:end-1); 62 | dw6 = w5probs'*Ix6; 63 | 64 | Ix5 = (Ix6*w6').*w5probs.*(1-w5probs); 65 | Ix5 = Ix5(:,1:end-1); 66 | dw5 = w4probs'*Ix5; 67 | 68 | Ix4 = (Ix5*w5').*w4probs.*(1-w4probs); 69 | Ix4 = Ix4(:,1:end-1)+R_cluster/N*Hcen.*(targetout-1); 70 | dw4 = w3probs'*Ix4; 71 | 72 | Ix3 = (Ix4*w4').*w3probs.*(1-w3probs); 73 | Ix3 = Ix3(:,1:end-1); 74 | dw3 = w2probs'*Ix3; 75 | 76 | Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); 77 | Ix2 = Ix2(:,1:end-1); 78 | dw2 = w1probs'*Ix2; 79 | 80 | Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); 81 | Ix1 = Ix1(:,1:end-1); 82 | dw1 = XX'*Ix1; 83 | 84 | df = [dw1(:)' dw2(:)' dw3(:)' dw4(:)' dw5(:)' dw6(:)' dw7(:)' dw8(:)' ]'; 85 | 86 | -------------------------------------------------------------------------------- /initialize/finetuning.m: -------------------------------------------------------------------------------- 1 | % if need additional fine tuning, this code is sourced from the code shared by Geoff Hinton and Ruslan Salakhutdinov. 2 | 3 | % Version 1.000 4 | % 5 | % Code provided by Geoff Hinton and Ruslan Salakhutdinov 6 | % 7 | % Permission is granted for anyone to copy, use, modify, or distribute this 8 | % program and accompanying programs and documents for any purpose, provided 9 | % this copyright notice is retained and prominently displayed, along with 10 | % a note saying that the original programs are available from our 11 | % web page. 12 | % The programs and documents are distributed without any warranty, express or 13 | % implied. As the programs were written for research purposes only, they have 14 | % not been tested to the degree that would be advisable in any important 15 | % application. All use of these programs is entirely at the user's own risk. 16 | 17 | % We slightly edit this version here to fit our code. 18 | batch_path = strcat(cur_path,'\data\mnistbatch.mat'); 19 | weiths_path=strcat(save_path, 'fine_weights.mat'); 20 | maxepoch=200; 21 | epochnow = 1; 22 | if exist(weiths_path, 'file') 23 | load(weiths_path); 24 | if epochnow ==maxepoch 25 | return; 26 | end 27 | end 28 | load(batch_path); 29 | [numcases numdims numbatches]=size(batchdata); 30 | 31 | fprintf(1,'\nFine-tuning deep autoencoder by minimizing cross entropy error. \n'); 32 | fprintf(1,'60 batches of 1000 cases each. \n'); 33 | %%%%%%%%%% END OF PREINITIALIZATIO OF WEIGHTS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 34 | l1=size(w1,1)-1; 35 | l2=size(w2,1)-1; 36 | l3=size(w3,1)-1; 37 | l4=size(w4,1)-1; 38 | l5=size(w5,1)-1; 39 | l6=size(w6,1)-1; 40 | l7=size(w7,1)-1; 41 | l8=size(w8,1)-1; 42 | l9=l1; 43 | for epoch = epochnow:maxepoch 44 | tt=0; 45 | for batch = 1:numbatches/10 46 | fprintf(1,'Fine-tuning epoch %d batch %d\r',epoch,batch); 47 | %%%%%%%%%%% COMBINE 10 MINIBATCHES INTO 1 LARGER MINIBATCH %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 48 | tt=tt+1; 49 | data=[]; 50 | for kk=1:10 51 | data=[data;batchdata(:,:,(tt-1)*10+kk)]; 52 | end 53 | max_iter=3; 54 | VV = [w1(:)' w2(:)' w3(:)' w4(:)' w5(:)' w6(:)' w7(:)' w8(:)']'; 55 | Dim = [l1; l2; l3; l4; l5; l6; l7; l8; l9]; 56 | [X, fX] = minimize(VV,'CG_DATA',max_iter,Dim,data); 57 | w1 = reshape(X(1:(l1+1)*l2),l1+1,l2); 58 | xxx = (l1+1)*l2; 59 | w2 = reshape(X(xxx+1:xxx+(l2+1)*l3),l2+1,l3); 60 | xxx = xxx+(l2+1)*l3; 61 | w3 = reshape(X(xxx+1:xxx+(l3+1)*l4),l3+1,l4); 62 | xxx = xxx+(l3+1)*l4; 63 | w4 = reshape(X(xxx+1:xxx+(l4+1)*l5),l4+1,l5); 64 | xxx = xxx+(l4+1)*l5; 65 | w5 = reshape(X(xxx+1:xxx+(l5+1)*l6),l5+1,l6); 66 | xxx = xxx+(l5+1)*l6; 67 | w6 = reshape(X(xxx+1:xxx+(l6+1)*l7),l6+1,l7); 68 | xxx = xxx+(l6+1)*l7; 69 | w7 = reshape(X(xxx+1:xxx+(l7+1)*l8),l7+1,l8); 70 | xxx = xxx+(l7+1)*l8; 71 | w8 = reshape(X(xxx+1:xxx+(l8+1)*l9),l8+1,l9); 72 | end 73 | epochnow=epoch+1; 74 | save(weiths_path, 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8','epochnow'); 75 | end 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /initialize/CG_DATA.m: -------------------------------------------------------------------------------- 1 | % Version 1.000 2 | % 3 | % Code provided by Ruslan Salakhutdinov and Geoff Hinton 4 | % 5 | % Permission is granted for anyone to copy, use, modify, or distribute this 6 | % program and accompanying programs and documents for any purpose, provided 7 | % this copyright notice is retained and prominently displayed, along with 8 | % a note saying that the original programs are available from our 9 | % web page. 10 | % The programs and documents are distributed without any warranty, express or 11 | % implied. As the programs were written for research purposes only, they have 12 | % not been tested to the degree that would be advisable in any important 13 | % application. All use of these programs is entirely at the user's own risk. 14 | 15 | function [f, df] = CG_DATA(VV,Dim,XX); 16 | % We slightly edit this version here to fit our code. 17 | l1 = Dim(1); 18 | l2 = Dim(2); 19 | l3 = Dim(3); 20 | l4= Dim(4); 21 | l5= Dim(5); 22 | l6= Dim(6); 23 | l7= Dim(7); 24 | l8= Dim(8); 25 | l9= Dim(9); 26 | N = size(XX,1); 27 | 28 | % Do decomversion. 29 | w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2); 30 | xxx = (l1+1)*l2; 31 | w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3); 32 | xxx = xxx+(l2+1)*l3; 33 | w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4); 34 | xxx = xxx+(l3+1)*l4; 35 | w4 = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5); 36 | xxx = xxx+(l4+1)*l5; 37 | w5 = reshape(VV(xxx+1:xxx+(l5+1)*l6),l5+1,l6); 38 | xxx = xxx+(l5+1)*l6; 39 | w6 = reshape(VV(xxx+1:xxx+(l6+1)*l7),l6+1,l7); 40 | xxx = xxx+(l6+1)*l7; 41 | w7 = reshape(VV(xxx+1:xxx+(l7+1)*l8),l7+1,l8); 42 | xxx = xxx+(l7+1)*l8; 43 | w8 = reshape(VV(xxx+1:xxx+(l8+1)*l9),l8+1,l9); 44 | 45 | XX = [XX ones(N,1)]; 46 | w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs ones(N,1)]; 47 | w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)]; 48 | w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs ones(N,1)]; 49 | w4probs = 1./(1 + exp(-w3probs*w4)); w4probs = [w4probs ones(N,1)]; 50 | w5probs = 1./(1 + exp(-w4probs*w5)); w5probs = [w5probs ones(N,1)]; 51 | w6probs = 1./(1 + exp(-w5probs*w6)); w6probs = [w6probs ones(N,1)]; 52 | w7probs = 1./(1 + exp(-w6probs*w7)); w7probs = [w7probs ones(N,1)]; 53 | XXout = 1./(1 + exp(-w7probs*w8)); 54 | 55 | % Or you can use SSE loss if needed 56 | % f = 0.5/N*sum(sum( (XX(:,1:end-1) -XXout).^2)); 57 | f = -1/N*sum(sum( XX(:,1:end-1).*log(XXout) + (1-XX(:,1:end-1)).*log(1-XXout))); 58 | IO = 1/N*(XXout-XX(:,1:end-1)); 59 | Ix8=IO; 60 | dw8 = w7probs'*Ix8; 61 | 62 | Ix7 = (Ix8*w8').*w7probs.*(1-w7probs); 63 | Ix7 = Ix7(:,1:end-1); 64 | dw7 = w6probs'*Ix7; 65 | 66 | Ix6 = (Ix7*w7').*w6probs.*(1-w6probs); 67 | Ix6 = Ix6(:,1:end-1); 68 | dw6 = w5probs'*Ix6; 69 | 70 | Ix5 = (Ix6*w6').*w5probs.*(1-w5probs); 71 | Ix5 = Ix5(:,1:end-1); 72 | dw5 = w4probs'*Ix5; 73 | 74 | Ix4 = (Ix5*w5').*w4probs.*(1-w4probs); 75 | Ix4 = Ix4(:,1:end-1); 76 | dw4 = w3probs'*Ix4; 77 | 78 | Ix3 = (Ix4*w4').*w3probs.*(1-w3probs); 79 | Ix3 = Ix3(:,1:end-1); 80 | dw3 = w2probs'*Ix3; 81 | 82 | Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); 83 | Ix2 = Ix2(:,1:end-1); 84 | dw2 = w1probs'*Ix2; 85 | 86 | Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); 87 | Ix1 = Ix1(:,1:end-1); 88 | dw1 = XX'*Ix1; 89 | 90 | df = [dw1(:)' dw2(:)' dw3(:)' dw4(:)' dw5(:)' dw6(:)' dw7(:)' dw8(:)' ]'; 91 | 92 | 93 | -------------------------------------------------------------------------------- /data/converter.m: -------------------------------------------------------------------------------- 1 | % Version 1.000 2 | % 3 | % Code provided by Ruslan Salakhutdinov 4 | % 5 | % Permission is granted for anyone to copy, use, modify, or distribute this 6 | % program and accompanying programs and documents for any purpose, provided 7 | % this copyright notice is retained and prominently displayed, along with 8 | % a note saying that the original programs are available from our 9 | % web page. 10 | % The programs and documents are distributed without any warranty, express or 11 | % implied. As the programs were written for research purposes only, they have 12 | % not been tested to the degree that would be advisable in any important 13 | % application. All use of these programs is entirely at the user's own risk. 14 | 15 | % This program reads raw MNIST files available at 16 | % http://yann.lecun.com/exdb/mnist/ 17 | % and converts them to files in matlab format 18 | % Before using this program you first need to download files: 19 | % train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz 20 | % t10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz 21 | % and gunzip them. You need to allocate some space for this. 22 | 23 | % This program was originally written by Yee Whye Teh 24 | 25 | % Work with test files first 26 | 27 | % We slightly edit this version here to fit our code. 28 | fprintf(1,'You first need to download files:\n train-images-idx3-ubyte.gz\n train-labels-idx1-ubyte.gz\n t10k-images-idx3-ubyte.gz\n t10k-labels-idx1-ubyte.gz\n from http://yann.lecun.com/exdb/mnist/\n and gunzip them \n'); 29 | f = fopen('t10k-images-idx3-ubyte','r'); 30 | [a,count] = fread(f,4,'int32'); 31 | 32 | g = fopen('t10k-labels-idx1-ubyte','r'); 33 | [l,count] = fread(g,2,'int32'); 34 | 35 | fprintf(1,'Starting to convert Test MNIST images (prints 10 dots) \n'); 36 | n = 1000; 37 | 38 | Df = cell(1,10); 39 | for d=0:9, 40 | Df{d+1} = fopen(['test' num2str(d) '.ascii'],'w'); 41 | end; 42 | 43 | for i=1:10, 44 | fprintf('.'); 45 | rawimages = fread(f,28*28*n,'uchar'); 46 | rawlabels = fread(g,n,'uchar'); 47 | rawimages = reshape(rawimages,28*28,n); 48 | 49 | for j=1:n, 50 | fprintf(Df{rawlabels(j)+1},'%3d ',rawimages(:,j)); 51 | fprintf(Df{rawlabels(j)+1},'\n'); 52 | end; 53 | end; 54 | 55 | fprintf(1,'\n'); 56 | for d=0:9, 57 | fclose(Df{d+1}); 58 | D = load(['test' num2str(d) '.ascii'],'-ascii'); 59 | fprintf('%5d Digits of class %d\n',size(D,1),d); 60 | save(['test' num2str(d) '.mat'],'D','-mat'); 61 | end; 62 | 63 | 64 | % Work with trainig files second 65 | f = fopen('train-images-idx3-ubyte','r'); 66 | [a,count] = fread(f,4,'int32'); 67 | 68 | g = fopen('train-labels-idx1-ubyte','r'); 69 | [l,count] = fread(g,2,'int32'); 70 | 71 | fprintf(1,'Starting to convert Training MNIST images (prints 60 dots)\n'); 72 | n = 1000; 73 | 74 | Df = cell(1,10); 75 | for d=0:9, 76 | Df{d+1} = fopen(['digit' num2str(d) '.ascii'],'w'); 77 | end; 78 | 79 | for i=1:60, 80 | fprintf('.'); 81 | rawimages = fread(f,28*28*n,'uchar'); 82 | rawlabels = fread(g,n,'uchar'); 83 | rawimages = reshape(rawimages,28*28,n); 84 | 85 | for j=1:n, 86 | fprintf(Df{rawlabels(j)+1},'%3d ',rawimages(:,j)); 87 | fprintf(Df{rawlabels(j)+1},'\n'); 88 | end; 89 | end; 90 | 91 | fprintf(1,'\n'); 92 | for d=0:9, 93 | fclose(Df{d+1}); 94 | D = load(['digit' num2str(d) '.ascii'],'-ascii'); 95 | fprintf('%5d Digits of class %d\n',size(D,1),d); 96 | save(['digit' num2str(d) '.mat'],'D','-mat'); 97 | end; 98 | delete('*.ascii'); %edit to fit the matlab comand, 20150203 99 | 100 | -------------------------------------------------------------------------------- /initialize/rbm.m: -------------------------------------------------------------------------------- 1 | % Version 1.000 2 | % 3 | % Code provided by Geoff Hinton and Ruslan Salakhutdinov 4 | % 5 | % Permission is granted for anyone to copy, use, modify, or distribute this 6 | % program and accompanying programs and documents for any purpose, provided 7 | % this copyright notice is retained and prominently displayed, along with 8 | % a note saying that the original programs are available from our 9 | % web page. 10 | % The programs and documents are distributed without any warranty, express or 11 | % implied. As the programs were written for research purposes only, they have 12 | % not been tested to the degree that would be advisable in any important 13 | % application. All use of these programs is entirely at the user's own risk. 14 | 15 | % This program trains Restricted Boltzmann Machine in which 16 | % visible, binary, stochastic pixels are connected to 17 | % hidden, binary, stochastic feature detectors using symmetrically 18 | % weighted connections. Learning is done with 1-step Contrastive Divergence. 19 | % The program assumes that the following variables are set externally: 20 | % maxepoch -- maximum number of epochs 21 | % numhid -- number of hidden units 22 | % batchdata -- the data that is divided into batches (numcases numdims numbatches) 23 | % restart -- set to 1 if learning starts from beginning 24 | 25 | epsilonw = 0.1; % Learning rate for weights 26 | epsilonvb = 0.1; % Learning rate for biases of visible units 27 | epsilonhb = 0.1; % Learning rate for biases of hidden units 28 | weightcost = 0.0002; 29 | initialmomentum = 0.5; 30 | finalmomentum = 0.9; 31 | 32 | [numcases numdims numbatches]=size(batchdata); 33 | 34 | if restart ==1, 35 | restart=0; 36 | epoch=1; 37 | 38 | % Initializing symmetric weights and biases. 39 | vishid = 0.1*randn(numdims, numhid); 40 | hidbiases = zeros(1,numhid); 41 | visbiases = zeros(1,numdims); 42 | 43 | poshidprobs = zeros(numcases,numhid); 44 | neghidprobs = zeros(numcases,numhid); 45 | posprods = zeros(numdims,numhid); 46 | negprods = zeros(numdims,numhid); 47 | vishidinc = zeros(numdims,numhid); 48 | hidbiasinc = zeros(1,numhid); 49 | visbiasinc = zeros(1,numdims); 50 | batchposhidprobs=zeros(numcases,numhid,numbatches); 51 | end 52 | 53 | for epoch = epoch:maxepoch, 54 | fprintf(1,'epoch %d\r',epoch); 55 | errsum=0; 56 | for batch = 1:numbatches, 57 | fprintf(1,'epoch %d batch %d\r',epoch,batch); 58 | 59 | %%%%%%%%% START POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 60 | data = batchdata(:,:,batch); 61 | poshidprobs = 1./(1 + exp(-data*vishid - repmat(hidbiases,numcases,1))); 62 | batchposhidprobs(:,:,batch)=poshidprobs; 63 | posprods = data' * poshidprobs; 64 | poshidact = sum(poshidprobs); 65 | posvisact = sum(data); 66 | 67 | %%%%%%%%% END OF POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 68 | poshidstates = poshidprobs > rand(numcases,numhid); 69 | 70 | %%%%%%%%% START NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 71 | negdata = 1./(1 + exp(-poshidstates*vishid' - repmat(visbiases,numcases,1))); 72 | neghidprobs = 1./(1 + exp(-negdata*vishid - repmat(hidbiases,numcases,1))); 73 | negprods = negdata'*neghidprobs; 74 | neghidact = sum(neghidprobs); 75 | negvisact = sum(negdata); 76 | 77 | %%%%%%%%% END OF NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 78 | err= sum(sum( (data-negdata).^2 )); 79 | errsum = err + errsum; 80 | 81 | if epoch>5, 82 | momentum=finalmomentum; 83 | else 84 | momentum=initialmomentum; 85 | end; 86 | 87 | %%%%%%%%% UPDATE WEIGHTS AND BIASES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 88 | vishidinc = momentum*vishidinc + ... 89 | epsilonw*( (posprods-negprods)/numcases - weightcost*vishid); 90 | visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact); 91 | hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact); 92 | 93 | vishid = vishid + vishidinc; 94 | visbiases = visbiases + visbiasinc; 95 | hidbiases = hidbiases + hidbiasinc; 96 | 97 | %%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 98 | 99 | end 100 | fprintf(1, 'epoch %4i error %6.1f \n', epoch, errsum); 101 | end; 102 | -------------------------------------------------------------------------------- /data/makebatches.m: -------------------------------------------------------------------------------- 1 | % Version 1.000 2 | % 3 | % Code provided by Ruslan Salakhutdinov and Geoff Hinton 4 | % 5 | % Permission is granted for anyone to copy, use, modify, or distribute this 6 | % program and accompanying programs and documents for any purpose, provided 7 | % this copyright notice is retained and prominently displayed, along with 8 | % a note saying that the original programs are available from our 9 | % web page. 10 | % The programs and documents are distributed without any warranty, express or 11 | % implied. As the programs were written for research purposes only, they have 12 | % not been tested to the degree that would be advisable in any important 13 | % application. All use of these programs is entirely at the user's own risk. 14 | 15 | % We slightly edit this version here to fit our clustering code. 16 | digitdata=[]; 17 | targets=[]; 18 | load digit0; digitdata = [digitdata; D]; targets = [targets; repmat([1 0 0 0 0 0 0 0 0 0], size(D,1), 1)]; 19 | load digit1; digitdata = [digitdata; D]; targets = [targets; repmat([0 1 0 0 0 0 0 0 0 0], size(D,1), 1)]; 20 | load digit2; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 1 0 0 0 0 0 0 0], size(D,1), 1)]; 21 | load digit3; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 1 0 0 0 0 0 0], size(D,1), 1)]; 22 | load digit4; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 1 0 0 0 0 0], size(D,1), 1)]; 23 | load digit5; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 1 0 0 0 0], size(D,1), 1)]; 24 | load digit6; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 1 0 0 0], size(D,1), 1)]; 25 | load digit7; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 1 0 0], size(D,1), 1)]; 26 | load digit8; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 0 1 0], size(D,1), 1)]; 27 | load digit9; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 0 0 1], size(D,1), 1)]; 28 | digitdata = digitdata/255; 29 | 30 | totnum=size(digitdata,1); 31 | fprintf(1, 'Size of the training dataset= %5d \n', totnum); 32 | 33 | rand('state',0); %so we know the permutation of the training data 34 | randomorder=randperm(totnum); 35 | 36 | numbatches=totnum/100; 37 | numdims = size(digitdata,2); 38 | batchsize = 100; 39 | batchdata = zeros(batchsize, numdims, numbatches); 40 | batchtargets = zeros(batchsize, 10, numbatches); 41 | 42 | for b=1:numbatches 43 | batchdata(:,:,b) = digitdata(randomorder(1+(b-1)*batchsize:b*batchsize), :); 44 | batchtargets(:,:,b) = targets(randomorder(1+(b-1)*batchsize:b*batchsize), :); 45 | end; 46 | clear digitdata targets; 47 | 48 | digitdata=[]; 49 | targets=[]; 50 | load test0; digitdata = [digitdata; D]; targets = [targets; repmat([1 0 0 0 0 0 0 0 0 0], size(D,1), 1)]; 51 | load test1; digitdata = [digitdata; D]; targets = [targets; repmat([0 1 0 0 0 0 0 0 0 0], size(D,1), 1)]; 52 | load test2; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 1 0 0 0 0 0 0 0], size(D,1), 1)]; 53 | load test3; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 1 0 0 0 0 0 0], size(D,1), 1)]; 54 | load test4; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 1 0 0 0 0 0], size(D,1), 1)]; 55 | load test5; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 1 0 0 0 0], size(D,1), 1)]; 56 | load test6; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 1 0 0 0], size(D,1), 1)]; 57 | load test7; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 1 0 0], size(D,1), 1)]; 58 | load test8; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 0 1 0], size(D,1), 1)]; 59 | load test9; digitdata = [digitdata; D]; targets = [targets; repmat([0 0 0 0 0 0 0 0 0 1], size(D,1), 1)]; 60 | digitdata = digitdata/255; 61 | 62 | totnum=size(digitdata,1); 63 | fprintf(1, 'Size of the test dataset= %5d \n', totnum); 64 | 65 | rand('state',0); %so we know the permutation of the training data 66 | randomorder=randperm(totnum); 67 | 68 | numbatches=totnum/100; 69 | numdims = size(digitdata,2); 70 | batchsize = 100; 71 | testbatchdata = zeros(batchsize, numdims, numbatches); 72 | testbatchtargets = zeros(batchsize, 10, numbatches); 73 | 74 | for b=1:numbatches 75 | testbatchdata(:,:,b) = digitdata(randomorder(1+(b-1)*batchsize:b*batchsize), :); 76 | testbatchtargets(:,:,b) = targets(randomorder(1+(b-1)*batchsize:b*batchsize), :); 77 | end; 78 | clear digitdata targets; 79 | 80 | delete('test*.mat'); 81 | delete('digit*.mat'); 82 | %%% Reset random seeds 83 | rand('state',sum(100*clock)); 84 | randn('state',sum(100*clock)); 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /clustering/minimize.m: -------------------------------------------------------------------------------- 1 | function [X, fX, i] = minimize(X, f, length, varargin) 2 | 3 | % Minimize a differentiable multivariate function. 4 | % 5 | % Usage: [X, fX, i] = minimize(X, f, length, P1, P2, P3, ... ) 6 | % 7 | % where the starting point is given by "X" (D by 1), and the function named in 8 | % the string "f", must return a function value and a vector of partial 9 | % derivatives of f wrt X, the "length" gives the length of the run: if it is 10 | % positive, it gives the maximum number of line searches, if negative its 11 | % absolute gives the maximum allowed number of function evaluations. You can 12 | % (optionally) give "length" a second component, which will indicate the 13 | % reduction in function value to be expected in the first line-search (defaults 14 | % to 1.0). The parameters P1, P2, P3, ... are passed on to the function f. 15 | % 16 | % The function returns when either its length is up, or if no further progress 17 | % can be made (ie, we are at a (local) minimum, or so close that due to 18 | % numerical problems, we cannot get any closer). NOTE: If the function 19 | % terminates within a few iterations, it could be an indication that the 20 | % function values and derivatives are not consistent (ie, there may be a bug in 21 | % the implementation of your "f" function). The function returns the found 22 | % solution "X", a vector of function values "fX" indicating the progress made 23 | % and "i" the number of iterations (line searches or function evaluations, 24 | % depending on the sign of "length") used. 25 | % 26 | % The Polack-Ribiere flavour of conjugate gradients is used to compute search 27 | % directions, and a line search using quadratic and cubic polynomial 28 | % approximations and the Wolfe-Powell stopping criteria is used together with 29 | % the slope ratio method for guessing initial step sizes. Additionally a bunch 30 | % of checks are made to make sure that exploration is taking place and that 31 | % extrapolation will not be unboundedly large. 32 | % 33 | % See also: checkgrad 34 | % 35 | % Copyright (C) 2001 - 2006 by Carl Edward Rasmussen (2006-09-08). 36 | 37 | INT = 0.1; % don't reevaluate within 0.1 of the limit of the current bracket 38 | EXT = 3.0; % extrapolate maximum 3 times the current step-size 39 | MAX = 20; % max 20 function evaluations per line search 40 | RATIO = 10; % maximum allowed slope ratio 41 | SIG = 0.1; RHO = SIG/2; % SIG and RHO are the constants controlling the Wolfe- 42 | % Powell conditions. SIG is the maximum allowed absolute ratio between 43 | % previous and new slopes (derivatives in the search direction), thus setting 44 | % SIG to low (positive) values forces higher precision in the line-searches. 45 | % RHO is the minimum allowed fraction of the expected (from the slope at the 46 | % initial point in the linesearch). Constants must satisfy 0 < RHO < SIG < 1. 47 | % Tuning of SIG (depending on the nature of the function to be optimized) may 48 | % speed up the minimization; it is probably not worth playing much with RHO. 49 | 50 | % The code falls naturally into 3 parts, after the initial line search is 51 | % started in the direction of steepest descent. 1) we first enter a while loop 52 | % which uses point 1 (p1) and (p2) to compute an extrapolation (p3), until we 53 | % have extrapolated far enough (Wolfe-Powell conditions). 2) if necessary, we 54 | % enter the second loop which takes p2, p3 and p4 chooses the subinterval 55 | % containing a (local) minimum, and interpolates it, unil an acceptable point 56 | % is found (Wolfe-Powell conditions). Note, that points are always maintained 57 | % in order p0 <= p1 <= p2 < p3 < p4. 3) compute a new search direction using 58 | % conjugate gradients (Polack-Ribiere flavour), or revert to steepest if there 59 | % was a problem in the previous line-search. Return the best value so far, if 60 | % two consecutive line-searches fail, or whenever we run out of function 61 | % evaluations or line-searches. During extrapolation, the "f" function may fail 62 | % either with an error or returning Nan or Inf, and minimize should handle this 63 | % gracefully. 64 | 65 | if max(size(length)) == 2, red=length(2); length=length(1); else red=1; end 66 | if length>0, S='Linesearch'; else S='Function evaluation'; end 67 | 68 | i = 0; % zero the run length counter 69 | ls_failed = 0; % no previous line search has failed 70 | [f0 df0] = feval(f, X, varargin{:}); % get function value and gradient 71 | fX = f0; 72 | i = i + (length<0); % count epochs?! 73 | s = -df0; d0 = -s'*s; % initial search direction (steepest) and slope 74 | x3 = red/(1-d0); % initial step is red/(|s|+1) 75 | 76 | while i < abs(length) % while not finished 77 | i = i + (length>0); % count iterations?! 78 | 79 | X0 = X; F0 = f0; dF0 = df0; % make a copy of current values 80 | if length>0, M = MAX; else M = min(MAX, -length-i); end 81 | 82 | while 1 % keep extrapolating as long as necessary 83 | x2 = 0; f2 = f0; d2 = d0; f3 = f0; df3 = df0; 84 | success = 0; 85 | while ~success && M > 0 86 | try 87 | M = M - 1; i = i + (length<0); % count epochs?! 88 | [f3 df3] = feval(f, X+x3*s, varargin{:}); 89 | if isnan(f3) || isinf(f3) || any(isnan(df3)+isinf(df3)), error(''), end 90 | success = 1; 91 | catch % catch any error which occured in f 92 | x3 = (x2+x3)/2; % bisect and try again 93 | end 94 | end 95 | if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values 96 | d3 = df3'*s; % new slope 97 | if d3 > SIG*d0 || f3 > f0+x3*RHO*d0 || M == 0 % are we done extrapolating? 98 | break 99 | end 100 | x1 = x2; f1 = f2; d1 = d2; % move point 2 to point 1 101 | x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2 102 | A = 6*(f1-f2)+3*(d2+d1)*(x2-x1); % make cubic extrapolation 103 | B = 3*(f2-f1)-(2*d1+d2)*(x2-x1); 104 | x3 = x1-d1*(x2-x1)^2/(B+sqrt(B*B-A*d1*(x2-x1))); % num. error possible, ok! 105 | if ~isreal(x3) || isnan(x3) || isinf(x3) || x3 < 0 % num prob | wrong sign? 106 | x3 = x2*EXT; % extrapolate maximum amount 107 | elseif x3 > x2*EXT % new point beyond extrapolation limit? 108 | x3 = x2*EXT; % extrapolate maximum amount 109 | elseif x3 < x2+INT*(x2-x1) % new point too close to previous point? 110 | x3 = x2+INT*(x2-x1); 111 | end 112 | end % end extrapolation 113 | 114 | while (abs(d3) > -SIG*d0 || f3 > f0+x3*RHO*d0) && M > 0 % keep interpolating 115 | if d3 > 0 || f3 > f0+x3*RHO*d0 % choose subinterval 116 | x4 = x3; f4 = f3; d4 = d3; % move point 3 to point 4 117 | else 118 | x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2 119 | end 120 | if f4 > f0 121 | x3 = x2-(0.5*d2*(x4-x2)^2)/(f4-f2-d2*(x4-x2)); % quadratic interpolation 122 | else 123 | A = 6*(f2-f4)/(x4-x2)+3*(d4+d2); % cubic interpolation 124 | B = 3*(f4-f2)-(2*d2+d4)*(x4-x2); 125 | x3 = x2+(sqrt(B*B-A*d2*(x4-x2)^2)-B)/A; % num. error possible, ok! 126 | end 127 | if isnan(x3) || isinf(x3) 128 | x3 = (x2+x4)/2; % if we had a numerical problem then bisect 129 | end 130 | x3 = max(min(x3, x4-INT*(x4-x2)),x2+INT*(x4-x2)); % don't accept too close 131 | [f3 df3] = feval(f, X+x3*s, varargin{:}); 132 | if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values 133 | M = M - 1; i = i + (length<0); % count epochs?! 134 | d3 = df3'*s; % new slope 135 | end % end interpolation 136 | 137 | if abs(d3) < -SIG*d0 && f3 < f0+x3*RHO*d0 % if line search succeeded 138 | X = X+x3*s; f0 = f3; fX = [fX' f0]'; % update variables 139 | fprintf('%s %6i; Value %4.6e\r', S, i, f0); 140 | s = (df3'*df3-df0'*df3)/(df0'*df0)*s - df3; % Polack-Ribiere CG direction 141 | df0 = df3; % swap derivatives 142 | d3 = d0; d0 = df0'*s; 143 | if d0 > 0 % new slope must be negative 144 | s = -df0; d0 = -s'*s; % otherwise use steepest direction 145 | end 146 | x3 = x3 * min(RATIO, d3/(d0-realmin)); % slope ratio but max RATIO 147 | ls_failed = 0; % this line search did not fail 148 | else 149 | X = X0; f0 = F0; df0 = dF0; % restore best point so far 150 | if ls_failed || i > abs(length) % line search failed twice in a row 151 | break; % or we ran out of time, so we give up 152 | end 153 | s = -df0; d0 = -s'*s; % try steepest 154 | x3 = 1/(1-d0); 155 | ls_failed = 1; % this line search failed 156 | end 157 | end 158 | fprintf('\n'); -------------------------------------------------------------------------------- /initialize/minimize.m: -------------------------------------------------------------------------------- 1 | function [X, fX, i] = minimize(X, f, length, varargin) 2 | 3 | % Minimize a differentiable multivariate function. 4 | % 5 | % Usage: [X, fX, i] = minimize(X, f, length, P1, P2, P3, ... ) 6 | % 7 | % where the starting point is given by "X" (D by 1), and the function named in 8 | % the string "f", must return a function value and a vector of partial 9 | % derivatives of f wrt X, the "length" gives the length of the run: if it is 10 | % positive, it gives the maximum number of line searches, if negative its 11 | % absolute gives the maximum allowed number of function evaluations. You can 12 | % (optionally) give "length" a second component, which will indicate the 13 | % reduction in function value to be expected in the first line-search (defaults 14 | % to 1.0). The parameters P1, P2, P3, ... are passed on to the function f. 15 | % 16 | % The function returns when either its length is up, or if no further progress 17 | % can be made (ie, we are at a (local) minimum, or so close that due to 18 | % numerical problems, we cannot get any closer). NOTE: If the function 19 | % terminates within a few iterations, it could be an indication that the 20 | % function values and derivatives are not consistent (ie, there may be a bug in 21 | % the implementation of your "f" function). The function returns the found 22 | % solution "X", a vector of function values "fX" indicating the progress made 23 | % and "i" the number of iterations (line searches or function evaluations, 24 | % depending on the sign of "length") used. 25 | % 26 | % The Polack-Ribiere flavour of conjugate gradients is used to compute search 27 | % directions, and a line search using quadratic and cubic polynomial 28 | % approximations and the Wolfe-Powell stopping criteria is used together with 29 | % the slope ratio method for guessing initial step sizes. Additionally a bunch 30 | % of checks are made to make sure that exploration is taking place and that 31 | % extrapolation will not be unboundedly large. 32 | % 33 | % See also: checkgrad 34 | % 35 | % Copyright (C) 2001 - 2006 by Carl Edward Rasmussen (2006-09-08). 36 | 37 | INT = 0.1; % don't reevaluate within 0.1 of the limit of the current bracket 38 | EXT = 3.0; % extrapolate maximum 3 times the current step-size 39 | MAX = 20; % max 20 function evaluations per line search 40 | RATIO = 10; % maximum allowed slope ratio 41 | SIG = 0.1; RHO = SIG/2; % SIG and RHO are the constants controlling the Wolfe- 42 | % Powell conditions. SIG is the maximum allowed absolute ratio between 43 | % previous and new slopes (derivatives in the search direction), thus setting 44 | % SIG to low (positive) values forces higher precision in the line-searches. 45 | % RHO is the minimum allowed fraction of the expected (from the slope at the 46 | % initial point in the linesearch). Constants must satisfy 0 < RHO < SIG < 1. 47 | % Tuning of SIG (depending on the nature of the function to be optimized) may 48 | % speed up the minimization; it is probably not worth playing much with RHO. 49 | 50 | % The code falls naturally into 3 parts, after the initial line search is 51 | % started in the direction of steepest descent. 1) we first enter a while loop 52 | % which uses point 1 (p1) and (p2) to compute an extrapolation (p3), until we 53 | % have extrapolated far enough (Wolfe-Powell conditions). 2) if necessary, we 54 | % enter the second loop which takes p2, p3 and p4 chooses the subinterval 55 | % containing a (local) minimum, and interpolates it, unil an acceptable point 56 | % is found (Wolfe-Powell conditions). Note, that points are always maintained 57 | % in order p0 <= p1 <= p2 < p3 < p4. 3) compute a new search direction using 58 | % conjugate gradients (Polack-Ribiere flavour), or revert to steepest if there 59 | % was a problem in the previous line-search. Return the best value so far, if 60 | % two consecutive line-searches fail, or whenever we run out of function 61 | % evaluations or line-searches. During extrapolation, the "f" function may fail 62 | % either with an error or returning Nan or Inf, and minimize should handle this 63 | % gracefully. 64 | 65 | if max(size(length)) == 2, red=length(2); length=length(1); else red=1; end 66 | if length>0, S='Linesearch'; else S='Function evaluation'; end 67 | 68 | i = 0; % zero the run length counter 69 | ls_failed = 0; % no previous line search has failed 70 | [f0 df0] = feval(f, X, varargin{:}); % get function value and gradient 71 | fX = f0; 72 | i = i + (length<0); % count epochs?! 73 | s = -df0; d0 = -s'*s; % initial search direction (steepest) and slope 74 | x3 = red/(1-d0); % initial step is red/(|s|+1) 75 | 76 | while i < abs(length) % while not finished 77 | i = i + (length>0); % count iterations?! 78 | 79 | X0 = X; F0 = f0; dF0 = df0; % make a copy of current values 80 | if length>0, M = MAX; else M = min(MAX, -length-i); end 81 | 82 | while 1 % keep extrapolating as long as necessary 83 | x2 = 0; f2 = f0; d2 = d0; f3 = f0; df3 = df0; 84 | success = 0; 85 | while ~success && M > 0 86 | try 87 | M = M - 1; i = i + (length<0); % count epochs?! 88 | [f3 df3] = feval(f, X+x3*s, varargin{:}); 89 | if isnan(f3) || isinf(f3) || any(isnan(df3)+isinf(df3)), error(''), end 90 | success = 1; 91 | catch % catch any error which occured in f 92 | x3 = (x2+x3)/2; % bisect and try again 93 | end 94 | end 95 | if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values 96 | d3 = df3'*s; % new slope 97 | if d3 > SIG*d0 || f3 > f0+x3*RHO*d0 || M == 0 % are we done extrapolating? 98 | break 99 | end 100 | x1 = x2; f1 = f2; d1 = d2; % move point 2 to point 1 101 | x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2 102 | A = 6*(f1-f2)+3*(d2+d1)*(x2-x1); % make cubic extrapolation 103 | B = 3*(f2-f1)-(2*d1+d2)*(x2-x1); 104 | x3 = x1-d1*(x2-x1)^2/(B+sqrt(B*B-A*d1*(x2-x1))); % num. error possible, ok! 105 | if ~isreal(x3) || isnan(x3) || isinf(x3) || x3 < 0 % num prob | wrong sign? 106 | x3 = x2*EXT; % extrapolate maximum amount 107 | elseif x3 > x2*EXT % new point beyond extrapolation limit? 108 | x3 = x2*EXT; % extrapolate maximum amount 109 | elseif x3 < x2+INT*(x2-x1) % new point too close to previous point? 110 | x3 = x2+INT*(x2-x1); 111 | end 112 | end % end extrapolation 113 | 114 | while (abs(d3) > -SIG*d0 || f3 > f0+x3*RHO*d0) && M > 0 % keep interpolating 115 | if d3 > 0 || f3 > f0+x3*RHO*d0 % choose subinterval 116 | x4 = x3; f4 = f3; d4 = d3; % move point 3 to point 4 117 | else 118 | x2 = x3; f2 = f3; d2 = d3; % move point 3 to point 2 119 | end 120 | if f4 > f0 121 | x3 = x2-(0.5*d2*(x4-x2)^2)/(f4-f2-d2*(x4-x2)); % quadratic interpolation 122 | else 123 | A = 6*(f2-f4)/(x4-x2)+3*(d4+d2); % cubic interpolation 124 | B = 3*(f4-f2)-(2*d2+d4)*(x4-x2); 125 | x3 = x2+(sqrt(B*B-A*d2*(x4-x2)^2)-B)/A; % num. error possible, ok! 126 | end 127 | if isnan(x3) || isinf(x3) 128 | x3 = (x2+x4)/2; % if we had a numerical problem then bisect 129 | end 130 | x3 = max(min(x3, x4-INT*(x4-x2)),x2+INT*(x4-x2)); % don't accept too close 131 | [f3 df3] = feval(f, X+x3*s, varargin{:}); 132 | if f3 < F0, X0 = X+x3*s; F0 = f3; dF0 = df3; end % keep best values 133 | M = M - 1; i = i + (length<0); % count epochs?! 134 | d3 = df3'*s; % new slope 135 | end % end interpolation 136 | 137 | if abs(d3) < -SIG*d0 && f3 < f0+x3*RHO*d0 % if line search succeeded 138 | X = X+x3*s; f0 = f3; fX = [fX' f0]'; % update variables 139 | fprintf('%s %6i; Value %4.6e\r', S, i, f0); 140 | s = (df3'*df3-df0'*df3)/(df0'*df0)*s - df3; % Polack-Ribiere CG direction 141 | df0 = df3; % swap derivatives 142 | d3 = d0; d0 = df0'*s; 143 | if d0 > 0 % new slope must be negative 144 | s = -df0; d0 = -s'*s; % otherwise use steepest direction 145 | end 146 | x3 = x3 * min(RATIO, d3/(d0-realmin)); % slope ratio but max RATIO 147 | ls_failed = 0; % this line search did not fail 148 | else 149 | X = X0; f0 = F0; df0 = dF0; % restore best point so far 150 | if ls_failed || i > abs(length) % line search failed twice in a row 151 | break; % or we ran out of time, so we give up 152 | end 153 | s = -df0; d0 = -s'*s; % try steepest 154 | x3 = 1/(1-d0); 155 | ls_failed = 1; % this line search failed 156 | end 157 | end 158 | fprintf('\n'); -------------------------------------------------------------------------------- /clustering/hungarian.m: -------------------------------------------------------------------------------- 1 | function [Matching,Cost] = Hungarian(Perf) 2 | % 3 | % [MATCHING,COST] = Hungarian_New(WEIGHTS) 4 | % 5 | % A function for finding a minimum edge weight matching given a MxN Edge 6 | % weight matrix WEIGHTS using the Hungarian Algorithm. 7 | % 8 | % An edge weight of Inf indicates that the pair of vertices given by its 9 | % position have no adjacent edge. 10 | % 11 | % MATCHING return a MxN matrix with ones in the place of the matchings and 12 | % zeros elsewhere. 13 | % 14 | % COST returns the cost of the minimum matching 15 | 16 | % Written by: Alex Melin 30 June 2006 17 | 18 | 19 | % Initialize Variables 20 | Matching = zeros(size(Perf)); 21 | 22 | % Condense the Performance Matrix by removing any unconnected vertices to 23 | % increase the speed of the algorithm 24 | 25 | % Find the number in each column that are connected 26 | num_y = sum(~isinf(Perf),1); 27 | % Find the number in each row that are connected 28 | num_x = sum(~isinf(Perf),2); 29 | 30 | % Find the columns(vertices) and rows(vertices) that are isolated 31 | x_con = find(num_x~=0); 32 | y_con = find(num_y~=0); 33 | 34 | % Assemble Condensed Performance Matrix 35 | P_size = max(length(x_con),length(y_con)); 36 | P_cond = zeros(P_size); 37 | P_cond(1:length(x_con),1:length(y_con)) = Perf(x_con,y_con); 38 | if isempty(P_cond) 39 | Cost = 0; 40 | return 41 | end 42 | 43 | % Ensure that a perfect matching exists 44 | % Calculate a form of the Edge Matrix 45 | Edge = P_cond; 46 | Edge(P_cond~=Inf) = 0; 47 | % Find the deficiency(CNUM) in the Edge Matrix 48 | cnum = min_line_cover(Edge); 49 | 50 | % Project additional vertices and edges so that a perfect matching 51 | % exists 52 | Pmax = max(max(P_cond(P_cond~=Inf))); 53 | P_size = length(P_cond)+cnum; 54 | P_cond = ones(P_size)*Pmax; 55 | P_cond(1:length(x_con),1:length(y_con)) = Perf(x_con,y_con); 56 | 57 | %************************************************* 58 | % MAIN PROGRAM: CONTROLS WHICH STEP IS EXECUTED 59 | %************************************************* 60 | exit_flag = 1; 61 | stepnum = 1; 62 | while exit_flag 63 | switch stepnum 64 | case 1 65 | [P_cond,stepnum] = step1(P_cond); 66 | case 2 67 | [r_cov,c_cov,M,stepnum] = step2(P_cond); 68 | case 3 69 | [c_cov,stepnum] = step3(M,P_size); 70 | case 4 71 | [M,r_cov,c_cov,Z_r,Z_c,stepnum] = step4(P_cond,r_cov,c_cov,M); 72 | case 5 73 | [M,r_cov,c_cov,stepnum] = step5(M,Z_r,Z_c,r_cov,c_cov); 74 | case 6 75 | [P_cond,stepnum] = step6(P_cond,r_cov,c_cov); 76 | case 7 77 | exit_flag = 0; 78 | end 79 | end 80 | 81 | % Remove all the virtual satellites and targets and uncondense the 82 | % Matching to the size of the original performance matrix. 83 | Matching(x_con,y_con) = M(1:length(x_con),1:length(y_con)); 84 | Cost = sum(sum(Perf(Matching==1))); 85 | 86 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 87 | % STEP 1: Find the smallest number of zeros in each row 88 | % and subtract that minimum from its row 89 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 90 | 91 | function [P_cond,stepnum] = step1(P_cond) 92 | 93 | P_size = length(P_cond); 94 | 95 | % Loop throught each row 96 | for ii = 1:P_size 97 | rmin = min(P_cond(ii,:)); 98 | P_cond(ii,:) = P_cond(ii,:)-rmin; 99 | end 100 | 101 | stepnum = 2; 102 | 103 | %************************************************************************** 104 | % STEP 2: Find a zero in P_cond. If there are no starred zeros in its 105 | % column or row start the zero. Repeat for each zero 106 | %************************************************************************** 107 | 108 | function [r_cov,c_cov,M,stepnum] = step2(P_cond) 109 | 110 | % Define variables 111 | P_size = length(P_cond); 112 | r_cov = zeros(P_size,1); % A vector that shows if a row is covered 113 | c_cov = zeros(P_size,1); % A vector that shows if a column is covered 114 | M = zeros(P_size); % A mask that shows if a position is starred or primed 115 | 116 | for ii = 1:P_size 117 | for jj = 1:P_size 118 | if P_cond(ii,jj) == 0 && r_cov(ii) == 0 && c_cov(jj) == 0 119 | M(ii,jj) = 1; 120 | r_cov(ii) = 1; 121 | c_cov(jj) = 1; 122 | end 123 | end 124 | end 125 | 126 | % Re-initialize the cover vectors 127 | r_cov = zeros(P_size,1); % A vector that shows if a row is covered 128 | c_cov = zeros(P_size,1); % A vector that shows if a column is covered 129 | stepnum = 3; 130 | 131 | %************************************************************************** 132 | % STEP 3: Cover each column with a starred zero. If all the columns are 133 | % covered then the matching is maximum 134 | %************************************************************************** 135 | 136 | function [c_cov,stepnum] = step3(M,P_size) 137 | 138 | c_cov = sum(M,1); 139 | if sum(c_cov) == P_size 140 | stepnum = 7; 141 | else 142 | stepnum = 4; 143 | end 144 | 145 | %************************************************************************** 146 | % STEP 4: Find a noncovered zero and prime it. If there is no starred 147 | % zero in the row containing this primed zero, Go to Step 5. 148 | % Otherwise, cover this row and uncover the column containing 149 | % the starred zero. Continue in this manner until there are no 150 | % uncovered zeros left. Save the smallest uncovered value and 151 | % Go to Step 6. 152 | %************************************************************************** 153 | function [M,r_cov,c_cov,Z_r,Z_c,stepnum] = step4(P_cond,r_cov,c_cov,M) 154 | 155 | P_size = length(P_cond); 156 | 157 | zflag = 1; 158 | while zflag 159 | % Find the first uncovered zero 160 | row = 0; col = 0; exit_flag = 1; 161 | ii = 1; jj = 1; 162 | while exit_flag 163 | if P_cond(ii,jj) == 0 && r_cov(ii) == 0 && c_cov(jj) == 0 164 | row = ii; 165 | col = jj; 166 | exit_flag = 0; 167 | end 168 | jj = jj + 1; 169 | if jj > P_size; jj = 1; ii = ii+1; end 170 | if ii > P_size; exit_flag = 0; end 171 | end 172 | 173 | % If there are no uncovered zeros go to step 6 174 | if row == 0 175 | stepnum = 6; 176 | zflag = 0; 177 | Z_r = 0; 178 | Z_c = 0; 179 | else 180 | % Prime the uncovered zero 181 | M(row,col) = 2; 182 | % If there is a starred zero in that row 183 | % Cover the row and uncover the column containing the zero 184 | if sum(find(M(row,:)==1)) ~= 0 185 | r_cov(row) = 1; 186 | zcol = find(M(row,:)==1); 187 | c_cov(zcol) = 0; 188 | else 189 | stepnum = 5; 190 | zflag = 0; 191 | Z_r = row; 192 | Z_c = col; 193 | end 194 | end 195 | end 196 | 197 | %************************************************************************** 198 | % STEP 5: Construct a series of alternating primed and starred zeros as 199 | % follows. Let Z0 represent the uncovered primed zero found in Step 4. 200 | % Let Z1 denote the starred zero in the column of Z0 (if any). 201 | % Let Z2 denote the primed zero in the row of Z1 (there will always 202 | % be one). Continue until the series terminates at a primed zero 203 | % that has no starred zero in its column. Unstar each starred 204 | % zero of the series, star each primed zero of the series, erase 205 | % all primes and uncover every line in the matrix. Return to Step 3. 206 | %************************************************************************** 207 | 208 | function [M,r_cov,c_cov,stepnum] = step5(M,Z_r,Z_c,r_cov,c_cov) 209 | 210 | zflag = 1; 211 | ii = 1; 212 | while zflag 213 | % Find the index number of the starred zero in the column 214 | rindex = find(M(:,Z_c(ii))==1); 215 | if rindex > 0 216 | % Save the starred zero 217 | ii = ii+1; 218 | % Save the row of the starred zero 219 | Z_r(ii,1) = rindex; 220 | % The column of the starred zero is the same as the column of the 221 | % primed zero 222 | Z_c(ii,1) = Z_c(ii-1); 223 | else 224 | zflag = 0; 225 | end 226 | 227 | % Continue if there is a starred zero in the column of the primed zero 228 | if zflag == 1; 229 | % Find the column of the primed zero in the last starred zeros row 230 | cindex = find(M(Z_r(ii),:)==2); 231 | ii = ii+1; 232 | Z_r(ii,1) = Z_r(ii-1); 233 | Z_c(ii,1) = cindex; 234 | end 235 | end 236 | 237 | % UNSTAR all the starred zeros in the path and STAR all primed zeros 238 | for ii = 1:length(Z_r) 239 | if M(Z_r(ii),Z_c(ii)) == 1 240 | M(Z_r(ii),Z_c(ii)) = 0; 241 | else 242 | M(Z_r(ii),Z_c(ii)) = 1; 243 | end 244 | end 245 | 246 | % Clear the covers 247 | r_cov = r_cov.*0; 248 | c_cov = c_cov.*0; 249 | 250 | % Remove all the primes 251 | M(M==2) = 0; 252 | 253 | stepnum = 3; 254 | 255 | % ************************************************************************* 256 | % STEP 6: Add the minimum uncovered value to every element of each covered 257 | % row, and subtract it from every element of each uncovered column. 258 | % Return to Step 4 without altering any stars, primes, or covered lines. 259 | %************************************************************************** 260 | 261 | function [P_cond,stepnum] = step6(P_cond,r_cov,c_cov) 262 | a = find(r_cov == 0); 263 | b = find(c_cov == 0); 264 | minval = min(min(P_cond(a,b))); 265 | 266 | P_cond(find(r_cov == 1),:) = P_cond(find(r_cov == 1),:) + minval; 267 | P_cond(:,find(c_cov == 0)) = P_cond(:,find(c_cov == 0)) - minval; 268 | 269 | stepnum = 4; 270 | 271 | function cnum = min_line_cover(Edge) 272 | 273 | % Step 2 274 | [r_cov,c_cov,M,stepnum] = step2(Edge); 275 | % Step 3 276 | [c_cov,stepnum] = step3(M,length(Edge)); 277 | % Step 4 278 | [M,r_cov,c_cov,Z_r,Z_c,stepnum] = step4(Edge,r_cov,c_cov,M); 279 | % Calculate the deficiency 280 | cnum = length(Edge)-sum(r_cov)-sum(c_cov); 281 | --------------------------------------------------------------------------------