├── Dataset ├── Parliment1984.mat ├── breast-cancer-wisconsin.mat ├── carevaluation.mat ├── heartdata.mat ├── ionosphere.mat ├── lymphography.mat ├── useless └── winedata.mat ├── README.md ├── binaryGOA ├── S_func.m ├── binaryGOA.m ├── distance.m ├── func_plot.m └── initialization.m ├── main.m ├── objfun.m └── utility ├── checkempty.m ├── finalEval.m └── normalize.m /Dataset/Parliment1984.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/Parliment1984.mat -------------------------------------------------------------------------------- /Dataset/breast-cancer-wisconsin.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/breast-cancer-wisconsin.mat -------------------------------------------------------------------------------- /Dataset/carevaluation.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/carevaluation.mat -------------------------------------------------------------------------------- /Dataset/heartdata.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/heartdata.mat -------------------------------------------------------------------------------- /Dataset/ionosphere.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/ionosphere.mat -------------------------------------------------------------------------------- /Dataset/lymphography.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/lymphography.mat -------------------------------------------------------------------------------- /Dataset/useless: -------------------------------------------------------------------------------- 1 | random 2 | -------------------------------------------------------------------------------- /Dataset/winedata.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthat/Optimal-Feature-selection-for-KNN-classifier/f51b7346cd1d2abcc4037dc29a75709a9ea74595/Dataset/winedata.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal-Feature-selection-for-KNN-classifier 2 | This MATLAB code implements the binary Grass hopper optimization algorithm to select the features and train with KNN 3 | 4 | The description of this can be checked at https://free-thesis.com/product/feature-selection-and-classification-by-hybrid-optimization/ 5 | 6 | ![Accuracy vs Selected Features](https://free-thesis.com/wp-content/uploads/2019/04/feature-selection-324x243.png) 7 | -------------------------------------------------------------------------------- /binaryGOA/S_func.m: -------------------------------------------------------------------------------- 1 | 2 | function o=S_func(r) 3 | f=0.5; 4 | l=1.5; 5 | o=f*exp(-r/l)-exp(-r); % Eq. (2.3) in the paper 6 | end -------------------------------------------------------------------------------- /binaryGOA/binaryGOA.m: -------------------------------------------------------------------------------- 1 | 2 | % The Grasshopper Optimization Algorithm 3 | function [TargetFitness,TargetPosition,Convergence_curve,Trajectories,... 4 | fitness_history, position_history]=binaryGOA(N, Max_iter, lb,ub, dim,... 5 | trainData,testData,trainlabel,testlabel) 6 | 7 | disp('GOA is now estimating the global optimum for your problem....') 8 | flag=0; 9 | if size(ub,1)==1 10 | ub=ones(dim,1)*ub; 11 | lb=ones(dim,1)*lb; 12 | end 13 | 14 | if (rem(dim,2)~=0) % this algorithm should be run with a even number of variables. 15 | %This line is to handle odd number of variables 16 | dim = dim+1; 17 | ub = [ub; 100]; 18 | lb = [lb; -100]; 19 | flag=1; 20 | end 21 | 22 | %Initialize the population of grasshoppers 23 | 24 | GrassHopperPositions=round(initialization(N,dim,ub,lb)); 25 | GrassHopperFitness = zeros(1,N); 26 | 27 | fitness_history=zeros(N,Max_iter); 28 | position_history=zeros(N,Max_iter,dim); 29 | Convergence_curve=zeros(1,Max_iter); 30 | Trajectories=zeros(N,Max_iter); 31 | 32 | % cMax=1; 33 | % cMin=0.00004; 34 | cMax=2.079; 35 | cMin=0.00004; 36 | %Calculate the fitness of initial grasshoppers 37 | 38 | for i=1:size(GrassHopperPositions,1) 39 | if flag == 1 40 | GrassHopperPositions(i,1:end-1) = checkempty(GrassHopperPositions(i,1:end-1),dim); 41 | GrassHopperFitness(1,i)=objfun(GrassHopperPositions(i,1:end-1),... 42 | trainData,testData,trainlabel,testlabel,dim); 43 | else 44 | GrassHopperPositions(i,:) = checkempty(GrassHopperPositions(i,:),dim); 45 | GrassHopperFitness(1,i)=objfun(GrassHopperPositions(i,:),... 46 | trainData,testData,trainlabel,testlabel,dim); 47 | end 48 | fitness_history(i,1)=GrassHopperFitness(1,i); 49 | position_history(i,1,:)=GrassHopperPositions(i,:); 50 | Trajectories(:,1)=GrassHopperPositions(:,1); 51 | end 52 | 53 | [sorted_fitness,sorted_indexes]=sort(GrassHopperFitness); 54 | 55 | % Find the best grasshopper (target) in the first population 56 | for newindex=1:N 57 | Sorted_grasshopper(newindex,:)=GrassHopperPositions(sorted_indexes(newindex),:); 58 | end 59 | 60 | TargetPosition=Sorted_grasshopper(1,:); 61 | TargetFitness=sorted_fitness(1); 62 | 63 | % Main loop 64 | l=2; % Start from the second iteration since the first iteration was 65 | %dedicated to calculating the fitness of antlions 66 | while l1 10 | for i=1:dim 11 | high=up(i);low=down(i); 12 | X(:,i)=rand(1,N).*(high-low)+low; 13 | end 14 | end -------------------------------------------------------------------------------- /main.m: -------------------------------------------------------------------------------- 1 | close all 2 | clear 3 | clc 4 | addpath(genpath(cd)) 5 | %% load the data 6 | % load winedata.mat 7 | load breast-cancer-wisconsin 8 | % load ionosphere 9 | % load Parliment1984 10 | % load heartdata 11 | load lymphography 12 | %% 13 | % preprocess data to remove Nan entries 14 | for ii=1:size(Tdata,2) 15 | nanindex=isnan(Tdata(:,ii)); 16 | Tdata(nanindex,:)=[]; 17 | end 18 | labels=Tdata(:,end); %classes 19 | attributesData=Tdata(:,1:end-1); %wine data 20 | % for ii=1:size(attributesData,2) %normalize the data 21 | % attributesData(:,ii)=normalize(attributesData(:,ii)); 22 | % end 23 | [rows,colms]=size(attributesData); %size of data 24 | %% seprate the data into training and testing 25 | [trainIdx,~,testIdx]=dividerand(rows,0.8,0,0.2); 26 | trainData=attributesData(trainIdx,:); %training data 27 | testData=attributesData(testIdx,:); %testing data 28 | trainlabel=labels(trainIdx); %training labels 29 | testlabel=labels(testIdx); %testing labels 30 | %% KNN classification 31 | Mdl = fitcknn(trainData,trainlabel,'NumNeighbors',5,'Standardize',1); 32 | predictedLables_KNN=predict(Mdl,testData); 33 | cp=classperf(testlabel,predictedLables_KNN); 34 | err=cp.ErrorRate; 35 | accuracy=cp.CorrectRate; 36 | %% SA optimisation for feature selection 37 | dim=size(attributesData,2); 38 | lb=0;ub=1; 39 | x0=round(rand(1,dim)); 40 | fun=@(x) objfun(x,trainData,testData,trainlabel,testlabel,dim); 41 | options = optimoptions(@simulannealbnd,'MaxIterations',150,... 42 | 'PlotFcn','saplotbestf'); 43 | [x,fval,exitflag,output] = simulannealbnd(fun,x0,zeros(1,dim),ones(1,dim),options) ; 44 | Target_pos_SA=round(x); 45 | % final evaluation for GOA tuned selected features 46 | [error_SA,accuracy_SA,predictedLables_SA]=finalEval(Target_pos_SA,trainData,testData,... 47 | trainlabel,testlabel); 48 | %% GOA optimisation for feature selection 49 | SearchAgents_no=10; % Number of search agents 50 | Max_iteration=100; % Maximum numbef of iterations 51 | [Target_score,Target_pos,GOA_cg_curve, Trajectories,fitness_history,... 52 | position_history]=binaryGOA(SearchAgents_no,Max_iteration,lb,ub,dim,... 53 | trainData,testData,trainlabel,testlabel); 54 | % final evaluation for GOA tuned selected features 55 | [error_GOA,accuracy_GOA,predictedLables_GOA]=finalEval(Target_pos,trainData,testData,trainlabel,testlabel); 56 | 57 | %% 58 | % plot for Predicted classes 59 | figure 60 | plot(testlabel,'s','LineWidth',1,'MarkerSize',12) 61 | hold on 62 | plot(predictedLables_KNN,'o','LineWidth',1,'MarkerSize',6) 63 | hold on 64 | plot(predictedLables_GOA,'x','LineWidth',1,'MarkerSize',6) 65 | hold on 66 | plot(predictedLables_SA,'^','LineWidth',1,'MarkerSize',6) 67 | % hold on 68 | % plot(predictedLables,'.','LineWidth',1,'MarkerSize',3) 69 | legend('Original Labels','Predicted by All','Predcited by GOA Tuned',... 70 | 'Predcited by SA Tuned','Location','best') 71 | title('Output Label comparison of testing Data') 72 | xlabel('-->No of test points') 73 | ylabel('Test Data Labels' ) 74 | axis tight 75 | 76 | % pie chart for accuracy corresponding to number of features 77 | figure 78 | subplot(1,2,1) 79 | labels={num2str(size(testData,2)),num2str(numel(find(Target_pos))),... 80 | num2str(numel(find(Target_pos_SA)))}; 81 | 82 | pie([(size(testData,2)),numel(find(Target_pos)),numel(find(Target_pos_SA))],labels) 83 | title('Number of features selected') 84 | legendlabels={'Total Features','Features after GOA Selection',... 85 | 'Features after SA Selection'}; 86 | legend(legendlabels,'Location','southoutside','Orientation','vertical') 87 | 88 | subplot(1,2,2) 89 | labels={num2str(accuracy*100),num2str(accuracy_GOA*100),num2str(accuracy_SA*100)}; 90 | pie([accuracy,accuracy_GOA,accuracy_SA].*100,labels) 91 | title('Accuracy for features selected') 92 | legendlabels={'Total Features','Features after GOA Selection',... 93 | 'Features after SA Selection'}; 94 | legend(legendlabels,'Location','southoutside','Orientation','vertical') 95 | -------------------------------------------------------------------------------- /objfun.m: -------------------------------------------------------------------------------- 1 | function objval=objfun(index,trainigdata ,testingdata,trainiglabels,testinglabels,N) 2 | index=round(index); 3 | index = checkempty(index,N); 4 | newtrainigdata=trainigdata(:,find(index)); 5 | newtestingdata=testingdata(:,find(index)); 6 | Mdl = fitcknn(newtrainigdata,trainiglabels,'NumNeighbors',5,'Standardize',1); 7 | Y=predict(Mdl,newtestingdata); 8 | cp=classperf(testinglabels,Y); 9 | err=cp.ErrorRate; 10 | R=numel(find(index==1)); 11 | alpha=0.7; 12 | beta=1-alpha; 13 | objval=alpha*err+beta*(R/N); -------------------------------------------------------------------------------- /utility/checkempty.m: -------------------------------------------------------------------------------- 1 | function GrassHopperPositions = checkempty(GrassHopperPositions,dim) 2 | while numel(find(GrassHopperPositions==0))==numel(GrassHopperPositions) 3 | GrassHopperPositions=round(rand(1,dim)); 4 | end 5 | -------------------------------------------------------------------------------- /utility/finalEval.m: -------------------------------------------------------------------------------- 1 | function [error,accuracy,Y]=finalEval(index,trainigdata ,testingdata,trainiglabels,testinglabels) 2 | newtrainigdata=trainigdata(:,find(index)); 3 | newtestingdata=testingdata(:,find(index)); 4 | Mdl = fitcknn(newtrainigdata,trainiglabels,'NumNeighbors',5,'Standardize',1); 5 | Y=predict(Mdl,newtestingdata); 6 | cp=classperf(testinglabels,Y); 7 | error=cp.ErrorRate; 8 | accuracy=cp.CorrectRate; -------------------------------------------------------------------------------- /utility/normalize.m: -------------------------------------------------------------------------------- 1 | function output= normalize(input) 2 | for ii=1:size(input,2) 3 | output(:,ii)=(input(:,ii)-min(input(:,ii)))/(max(input(:,ii))-min(input(:,ii))); 4 | end --------------------------------------------------------------------------------