├── README.md ├── demo_multiclassSMM.m ├── multiclass_SMM.pdf ├── shrinkage.m ├── ssmm_learn.m └── ssmm_pegasos_w.m /README.md: -------------------------------------------------------------------------------- 1 | The code is created based on the method described in the following paper 2 | 3 | [1] Zheng, Q., Zhu, F., Qin, J., & Heng, P. A. (2018). Multiclass support matrix machine for single trial EEG classification. Neurocomputing, 275, 869-880. 4 | 5 | The code and the algorithm are for non-comercial use only. 6 | 7 | Author: Qingqing Zheng (qqzheng@cse.cuhk.edu.hk) 8 | 9 | Date : 02/20/2018 10 | 11 | Version : 1.0 12 | 13 | Copyright 2018, The Chinese University of Hong Kong. 14 | -------------------------------------------------------------------------------- /demo_multiclassSMM.m: -------------------------------------------------------------------------------- 1 | function demo_multiclassSMM 2 | % A demo for our method: multiclass support matrix machine 3 | 4 | 5 | 6 | %% -------------Load Data-------------------- 7 | load multiclassdata.mat 8 | % input: X: p X q X n 9 | % X_test: p X q X n_test 10 | % y: n x 1 : {1,2,3,...} 11 | % y_test: n_test x 1: {1,2,3,..} 12 | %% ------------Parameter Setting------------- 13 | C = 1; 14 | tau = 0.5; 15 | rho = 0.1; 16 | 17 | param.X = X; 18 | param.y = y; 19 | param.class = length(unique(param.y)); 20 | param.dim = size(X(:,:,1)); 21 | param.train_num = size(X,3); 22 | param.lossFn = @lossCB; 23 | param.featureFn = @featureCB; 24 | param.DelPsiFn = @DelPsiCB; 25 | param.constraintFn = @constrainCB; 26 | param.prediction = @predictionCB; 27 | param.norm_nuc = @norm_nuc; 28 | 29 | args = sprintf('-C %g -tau %g -rho %g',C,tau,rho); 30 | tic; 31 | [model,~] = ssmm_learn(args, param); 32 | time_train = toc; 33 | 34 | acc_train = param.prediction(param,model.W,X,y); 35 | kappa_train = (acc_train - 1/param.class)/(1-1/param.class); 36 | 37 | tic; 38 | [acc_test,y_test_hat] = param.prediction(param,model.W,X_test,y_test); 39 | time_test = toc; 40 | 41 | kappa_test = (acc_test - 1/param.class)/(1-1/param.class); 42 | fprintf('training kappa is %.4f \n', kappa_train); 43 | fprintf('testing kappa is %.4f \n\n',kappa_test); 44 | 45 | %calculate the error rate 46 | err = length(find(y_test ~= y_test_hat))/length(y_test)*100; 47 | %calculate for each MI task 48 | numClass = param.class; 49 | for i = 1:numClass 50 | tp(i) = length(find(y_test_hat == i & y_test == i)); 51 | fp(i) = length(find(y_test_hat == i)) - tp(i); 52 | fn(i) = length(find(y_test == i)) - tp(i); 53 | precision_all(i) = tp(i)/(tp(i)+fp(i)); 54 | recall_all(i) = tp(i)/(tp(i)+fn(i)); 55 | acc_class(i) = length(find(y_test_hat == i & y_test == i))/length(find(y_test==i)); 56 | kappa_class(i) = (acc_class(i) - 1/numClass)/(1-1/numClass); 57 | end 58 | fprintf('testing accuracy in each class %.3f,%.3f,%.3f %.3f\n',acc_class); 59 | %calculate the precision (sum tp)/(sum (tp+fp)) 60 | precision = sum(precision_all)/numClass; 61 | %calculate the recall (sum tp)/(sum(tp+fn)) 62 | recall = sum(recall_all)/numClass; 63 | %calculate the F1 score 64 | F1score = 2*precision*recall/(precision+recall); 65 | 66 | 67 | fprintf('%s\n',[' training time is: ' num2str(time_train)]); 68 | fprintf('%s\n',[' testing time is: ' num2str(time_test)]); 69 | fprintf('%s\n',[' testing error rate(%) is: ' num2str(err)]); 70 | fprintf('%s\n',[' testing kappa value is: ' num2str(kappa_test)]); 71 | 72 | fprintf('%s\n',[' testing acc on each MI are ' num2str(acc_class)]); 73 | fprintf('%s\n',[' testing kappa on each MI are' num2str(kappa_class)]); 74 | 75 | 76 | fprintf('%s\n',[' testing precision is: ' num2str(precision)]); 77 | fprintf('%s\n',[' testing recall is: ' num2str(recall)]); 78 | fprintf('%s\n\n\n',[' testing F1 score is: ' num2str(F1score)]); 79 | end 80 | 81 | 82 | 83 | % ----- Calculate the loss cost ------------------ 84 | function delta = lossCB(y, ybar) 85 | delta = sum(double(y ~= ybar)) ; 86 | end 87 | 88 | % ----- Calculate the most likely yhat_i --------- 89 | function yhat_fn = constrainCB(param,W,X) 90 | num = size(X,3); 91 | if num == 1 92 | f_tmp = zeros(param.class,1); 93 | for j = 1:param.class 94 | tmp = times(W(:,:,j),X); 95 | f_tmp(j) = sum(tmp(:)); 96 | end 97 | [~,ind] = max(f_tmp); 98 | yhat_fn = ind; 99 | else 100 | yhat_fn = zeros(num,1); 101 | for i = 1:num 102 | f_tmp = zeros(param.class,1); 103 | for j = 1:param.class 104 | tmp = times(W(:,:,j),X(:,:,i)); 105 | f_tmp(j) = sum(tmp(:)); 106 | end 107 | [~,ind] = max(f_tmp); 108 | yhat_fn(i) = ind; 109 | end 110 | end 111 | end 112 | 113 | % ------ Calculate the feature tensor -------------- 114 | function psi = featureCB(param,X,y) 115 | num = size(X,3); 116 | if num == 1 117 | psi = zeors(param.dim(1),param.dim(2),param.class); 118 | psi(:,:,y) = X; 119 | else 120 | psi = zeros(param.dim(1),param.dim(2),param.class,num); 121 | for i = 1:num 122 | psi(:,:,y(i),i) = X(:,:,i); 123 | end 124 | end 125 | end 126 | 127 | % ------ Calculate the delta feature tensor ---------- 128 | function del_psi = DelPsiCB(param,X,yhat,y) 129 | num = size(X,3); 130 | if num == 1 131 | del_psi = zeros(param.dim(1),param.dim(2),param.class); 132 | if yhat ~= y 133 | del_psi(:,:,yhat) = X; 134 | del_psi(:,:,y) = -X; 135 | end 136 | % del_psi = zeros(param.dim(1),param.dim(2),param.class); 137 | % del_psi(:,:,yhat) = X; 138 | % del_psi(:,:,y) = -X; 139 | else 140 | del_psi = zeros(param.dim(1),param.dim(2),param.class,num); 141 | for i = 1:num 142 | if yhat(i) ~= y(i) 143 | del_psi(:,:,yhat(i),i) = X(:,:,i); 144 | del_psi(:,:,y(i),i) = -X(:,:,i); 145 | end 146 | end 147 | 148 | end 149 | end 150 | 151 | % ------- Calculate the yhat --------------------------- 152 | function [acc_rate,y_pred] = predictionCB(param,W,X,y) 153 | y_pred = param.constraintFn(param,W,X); 154 | acc_rate = length(find(y_pred == y))/length(y); 155 | end 156 | 157 | function z = norm_nuc(X) 158 | z = sum(svd(X)); 159 | end 160 | -------------------------------------------------------------------------------- /multiclass_SMM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqq/multiclass-support-matrix-machine-/2ad9bd05df1dd192f8fb6894e20cecf7433f2464/multiclass_SMM.pdf -------------------------------------------------------------------------------- /shrinkage.m: -------------------------------------------------------------------------------- 1 | function [D nuc rk] = shrinkage(X, tau) 2 | D = zeros(size(X)); 3 | for i = 1:size(X,3) 4 | [U, S, V] = svd(X(:,:,i)); 5 | s = max(0, S-tau); 6 | nuc(i) = sum(diag(s)); 7 | D(:,:,i) = U * s * V'; 8 | rk(i) = sum(diag(s>0)); 9 | end 10 | end -------------------------------------------------------------------------------- /ssmm_learn.m: -------------------------------------------------------------------------------- 1 | function [model,obj] = ssmm_learn(args,param) 2 | 3 | %Setting the default parameter for the msmm model 4 | model.C = 1; 5 | model.tau = 0.1; 6 | model.rho = 1; 7 | model.eps = 1e-5; 8 | 9 | %Handle the model parameter if it is fixed advanced 10 | args = strread(args, '%s'); 11 | for i = 1:2:size(args,1) 12 | if strcmp(args{i},'-C') 13 | model.C = str2num(args{i+1}); 14 | elseif strcmp(args{i},'-tau') 15 | model.tau = str2num(args{i+1}); 16 | elseif strcmp(args{i},'-rho') 17 | model.rho = str2num(args{i+1}); 18 | else 19 | error('Unknown parameters'); 20 | end 21 | end 22 | fprintf('C=%.2f, tau = %.2f, rho = %.2f\n',model.C,model.tau,model.rho); 23 | clear i; 24 | 25 | %Handle error of the number of samples and labels 26 | if(size(param.X,3)~=length(param.y)) 27 | error('Number of samples should equal to the number of labels'); 28 | end 29 | 30 | X = param.X; 31 | y = param.y; 32 | 33 | sz = [param.dim(1),param.dim(2),param.class]; 34 | model.W = zeros(sz); 35 | model.S = zeros(sz); 36 | model.xi = 0; 37 | model.Lambda = zeros(sz); 38 | 39 | 40 | iter = 1; 41 | max_iter = 100; 42 | iterFlag = 1; 43 | 44 | % initialze for the most violated constraints 45 | expandstep = max_iter; 46 | fdiffs = zeros(param.dim(1),param.dim(2),param.class,expandstep); 47 | margins = zeros(expandstep,1); 48 | activeCons_num = 0; 49 | 50 | while (iter <= max_iter && iterFlag) 51 | if ~mod(iter,5) 52 | fprintf('*'); 53 | end 54 | 55 | %Update S first 56 | model.S = shrinkage(model.rho*model.W - model.Lambda, model.tau)/(1+model.rho); 57 | %Update W then 58 | yhat = param.constraintFn(param,model.W,X); 59 | loss = param.lossFn(yhat,y)/size(X,3); 60 | fd_all = param.DelPsiFn(param,X,yhat,y); %fd is a 4d arrays 61 | fd = sum(fd_all,4)/size(X,3); 62 | clear fd_all; 63 | 64 | acc = sum(yhat == y)/length(y); 65 | w_fd = times(model.W,fd); 66 | cost = loss + sum(w_fd(:)); 67 | tmp_nuc = 0; 68 | for cla = 1:param.class 69 | tmp_nuc = tmp_nuc + param.norm_nuc(model.W(:,:,cla)); 70 | end 71 | obj(iter) = model.C*cost + 0.5*norm(model.W(:),2)^2+ model.tau*tmp_nuc; 72 | 73 | clear w_fd; 74 | 75 | if cost > model.xi + model.eps 76 | activeCons_num = activeCons_num + 1; 77 | fdiffs(:,:,:,activeCons_num) = fd; 78 | margins(activeCons_num) = loss; 79 | [W_opt, xi_opt] = ssmm_pegasos_w(fdiffs,margins,activeCons_num,model); 80 | model.W = W_opt; 81 | model.xi = xi_opt; 82 | else 83 | fprintf('the stop iteration is %d \n',iter); 84 | iterFlag = false; 85 | end 86 | %Update Lambda 87 | model.Lambda = model.Lambda + model.rho*(model.S - model.W); 88 | iter = iter +1; 89 | end 90 | -------------------------------------------------------------------------------- /ssmm_pegasos_w.m: -------------------------------------------------------------------------------- 1 | function [W,xi] = ssmm_pegasos_w(fdiffs,margins,activeCons_num,model,inner_iter) 2 | 3 | if (nargin < 5) 4 | inner_iter = 3000; 5 | end 6 | sz_fd = size(fdiffs); 7 | W = model.W; 8 | for t = 1:inner_iter 9 | subgradient = zeros(sz_fd(1),sz_fd(2),sz_fd(3)); 10 | for k = 1:activeCons_num 11 | tmp = times(W,fdiffs(:,:,:,k)); 12 | dis(k) = margins(k) + sum(tmp(:)); 13 | end 14 | [mvc_val,mvc_idx] = max(dis); 15 | xi = max([0,mvc_val]); 16 | if mvc_val > 0 17 | subgradient = model.rho*(W-model.S) - model.Lambda + model.C*fdiffs(:,:,:,mvc_idx); 18 | else 19 | break; 20 | end 21 | % eta_t = model.C/(model.rho*t); 22 | eta_t = 0.000001; 23 | % eta_t = 0.0000001; 24 | W = W - eta_t*subgradient; 25 | end 26 | end --------------------------------------------------------------------------------