├── README.md ├── src ├── auxiliary │ ├── Classifier.java │ ├── DataSet.java │ ├── Evaluation.java │ ├── RandomForest.java │ └── AdaBoost.java └── dm13 │ └── TestAss4.java ├── .gitattributes ├── .gitignore └── breast-cancer.data /README.md: -------------------------------------------------------------------------------- 1 | Random Forest 2 | ============= 3 | 4 | 对决策树进行改写,在每个节点(有K个可用属性)选择属性时首先随机选择log(1+K)个属性,之后仅从这些属性中选择用于分裂节点的属性。 5 | 6 | 整体上随机森林由多棵随机属性的决策树构成,通过装袋方法构造。每次构造使用有放回均匀抽样构造样本集并训练决策树,分类时,每棵树都投票并且返回得票最多的类。 7 | 8 | AdaBoost 9 | ======== 10 | 11 | AdaBoost是一种提升算法,通过串行的训练多个具有权值的分类器来提高组合分类器的准确率。每次训练首先依据样本权值进行有放回抽样得到样本集,训练结束后对得到分类器在训练集上进行分类,得到错误率和该分类器的权值,并对训练集中的样本权值进行调整,使得后续的分类器训练更加专注于之前分类错误的样本。分类时,每个分类器都投票并根据分类器自身的权值进行加权,最后得到权值最大的类。 12 | -------------------------------------------------------------------------------- /src/auxiliary/Classifier.java: -------------------------------------------------------------------------------- 1 | /* 2 | * To change this template, choose Tools | Templates 3 | * and open the template in the editor. 4 | */ 5 | package auxiliary; 6 | 7 | import java.io.Serializable; 8 | 9 | /** 10 | * 11 | * @author daq 12 | */ 13 | public abstract class Classifier implements Cloneable, Serializable { 14 | 15 | public abstract void train(boolean[] isCategory, double[][] features, double[] labels); 16 | 17 | public abstract double predict(double[] features); 18 | } 19 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | *.sln merge=union 7 | *.csproj merge=union 8 | *.vbproj merge=union 9 | *.fsproj merge=union 10 | *.dbproj merge=union 11 | 12 | # Standard to msysgit 13 | *.doc diff=astextplain 14 | *.DOC diff=astextplain 15 | *.docx diff=astextplain 16 | *.DOCX diff=astextplain 17 | *.dot diff=astextplain 18 | *.DOT diff=astextplain 19 | *.pdf diff=astextplain 20 | *.PDF diff=astextplain 21 | *.rtf diff=astextplain 22 | *.RTF diff=astextplain 23 | -------------------------------------------------------------------------------- /src/dm13/TestAss4.java: -------------------------------------------------------------------------------- 1 | /* 2 | * To change this template, choose Tools | Templates 3 | * and open the template in the editor. 4 | */ 5 | package dm13; 6 | 7 | import auxiliary.DataSet; 8 | import auxiliary.Evaluation; 9 | 10 | /** 11 | * 12 | * @author daq 13 | */ 14 | public class TestAss4 { 15 | 16 | public static void main(String[] args) { 17 | // for RandomForest 18 | System.out.println("for RandomForest"); 19 | String[] dataPaths = new String[]{"breast-cancer.data", "segment.data"}; 20 | for (String path : dataPaths) { 21 | DataSet dataset = new DataSet(path); 22 | 23 | // conduct 10-cv 24 | Evaluation eva = new Evaluation(dataset, "RandomForest"); 25 | eva.crossValidation(); 26 | 27 | // print mean and standard deviation of accuracy 28 | System.out.println("Dataset:" + path + ", mean and standard deviation of accuracy:" + eva.getAccMean() + "," + eva.getAccStd()); 29 | } 30 | 31 | // for AdaBoost 32 | System.out.println("\nfor AdaBoost"); 33 | for (String path : dataPaths) { 34 | DataSet dataset = new DataSet(path); 35 | 36 | // conduct 10-cv 37 | Evaluation eva = new Evaluation(dataset, "AdaBoost"); 38 | eva.crossValidation(); 39 | 40 | // print mean and standard deviation of accuracy 41 | System.out.println("Dataset:" + path + ", mean and standard deviation of accuracy:" + eva.getAccMean() + "," + eva.getAccStd()); 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/auxiliary/DataSet.java: -------------------------------------------------------------------------------- 1 | package auxiliary; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileNotFoundException; 5 | import java.io.FileReader; 6 | import java.io.IOException; 7 | import java.util.logging.Level; 8 | import java.util.logging.Logger; 9 | 10 | /** 11 | * 12 | * @author daq 13 | */ 14 | public class DataSet { 15 | 16 | private boolean[] isCategory; 17 | private double[][] features; 18 | private double[] labels; 19 | private int numAttributes; 20 | private int numInstnaces; 21 | 22 | public DataSet(String path) { 23 | try { 24 | BufferedReader reader = new BufferedReader(new FileReader(path)); 25 | String[] attInfo = reader.readLine().split(","); // attributes info 26 | numAttributes = attInfo.length - 1; 27 | isCategory = new boolean[numAttributes + 1]; 28 | for (int i = 0; i < isCategory.length; i++) { 29 | isCategory[i] = Integer.parseInt(attInfo[i]) == 1 ? true : false; 30 | } 31 | 32 | numInstnaces = 0; 33 | while (reader.readLine() != null) { 34 | numInstnaces++; 35 | } 36 | 37 | features = new double[numInstnaces][numAttributes]; 38 | labels = new double[numInstnaces]; 39 | System.out.println("reading " + numInstnaces + " exmaples with " + numAttributes + " attributes"); 40 | 41 | reader = new BufferedReader(new FileReader(path)); 42 | reader.readLine(); 43 | String line; 44 | int ind = 0; 45 | while ((line = reader.readLine()) != null) { 46 | String[] atts = line.split(","); 47 | for (int i = 0; i < atts.length - 1; i++) { 48 | features[ind][i] = Double.parseDouble(atts[i]); 49 | } 50 | labels[ind] = Double.parseDouble(atts[atts.length - 1]); 51 | ind++; 52 | } 53 | reader.close(); 54 | } catch (FileNotFoundException ex) { 55 | Logger.getLogger(DataSet.class.getName()).log(Level.SEVERE, null, ex); 56 | } catch (IOException ex) { 57 | Logger.getLogger(DataSet.class.getName()).log(Level.SEVERE, null, ex); 58 | } 59 | } 60 | 61 | public boolean[] getIsCategory() { 62 | return isCategory; 63 | } 64 | 65 | public double[][] getFeatures() { 66 | return features; 67 | } 68 | 69 | public double[] getLabels() { 70 | return labels; 71 | } 72 | 73 | public int getNumAttributes() { 74 | return numAttributes; 75 | } 76 | 77 | public int getNumInstnaces() { 78 | return numInstnaces; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ################# 2 | ## Eclipse 3 | ################# 4 | 5 | *.pydevproject 6 | .project 7 | .metadata 8 | bin/ 9 | tmp/ 10 | *.tmp 11 | *.bak 12 | *.swp 13 | *~.nib 14 | local.properties 15 | .classpath 16 | .settings/ 17 | .loadpath 18 | 19 | # External tool builders 20 | .externalToolBuilders/ 21 | 22 | # Locally stored "Eclipse launch configurations" 23 | *.launch 24 | 25 | # CDT-specific 26 | .cproject 27 | 28 | # PDT-specific 29 | .buildpath 30 | 31 | 32 | ################# 33 | ## Visual Studio 34 | ################# 35 | 36 | ## Ignore Visual Studio temporary files, build results, and 37 | ## files generated by popular Visual Studio add-ons. 38 | 39 | # User-specific files 40 | *.suo 41 | *.user 42 | *.sln.docstates 43 | 44 | # Build results 45 | 46 | [Dd]ebug/ 47 | [Rr]elease/ 48 | x64/ 49 | build/ 50 | [Bb]in/ 51 | [Oo]bj/ 52 | 53 | # MSTest test Results 54 | [Tt]est[Rr]esult*/ 55 | [Bb]uild[Ll]og.* 56 | 57 | *_i.c 58 | *_p.c 59 | *.ilk 60 | *.meta 61 | *.obj 62 | *.pch 63 | *.pdb 64 | *.pgc 65 | *.pgd 66 | *.rsp 67 | *.sbr 68 | *.tlb 69 | *.tli 70 | *.tlh 71 | *.tmp 72 | *.tmp_proj 73 | *.log 74 | *.vspscc 75 | *.vssscc 76 | .builds 77 | *.pidb 78 | *.log 79 | *.scc 80 | 81 | # Visual C++ cache files 82 | ipch/ 83 | *.aps 84 | *.ncb 85 | *.opensdf 86 | *.sdf 87 | *.cachefile 88 | 89 | # Visual Studio profiler 90 | *.psess 91 | *.vsp 92 | *.vspx 93 | 94 | # Guidance Automation Toolkit 95 | *.gpState 96 | 97 | # ReSharper is a .NET coding add-in 98 | _ReSharper*/ 99 | *.[Rr]e[Ss]harper 100 | 101 | # TeamCity is a build add-in 102 | _TeamCity* 103 | 104 | # DotCover is a Code Coverage Tool 105 | *.dotCover 106 | 107 | # NCrunch 108 | *.ncrunch* 109 | .*crunch*.local.xml 110 | 111 | # Installshield output folder 112 | [Ee]xpress/ 113 | 114 | # DocProject is a documentation generator add-in 115 | DocProject/buildhelp/ 116 | DocProject/Help/*.HxT 117 | DocProject/Help/*.HxC 118 | DocProject/Help/*.hhc 119 | DocProject/Help/*.hhk 120 | DocProject/Help/*.hhp 121 | DocProject/Help/Html2 122 | DocProject/Help/html 123 | 124 | # Click-Once directory 125 | publish/ 126 | 127 | # Publish Web Output 128 | *.Publish.xml 129 | *.pubxml 130 | 131 | # NuGet Packages Directory 132 | ## TODO: If you have NuGet Package Restore enabled, uncomment the next line 133 | #packages/ 134 | 135 | # Windows Azure Build Output 136 | csx 137 | *.build.csdef 138 | 139 | # Windows Store app package directory 140 | AppPackages/ 141 | 142 | # Others 143 | sql/ 144 | *.Cache 145 | ClientBin/ 146 | [Ss]tyle[Cc]op.* 147 | ~$* 148 | *~ 149 | *.dbmdl 150 | *.[Pp]ublish.xml 151 | *.pfx 152 | *.publishsettings 153 | 154 | # RIA/Silverlight projects 155 | Generated_Code/ 156 | 157 | # Backup & report files from converting an old project file to a newer 158 | # Visual Studio version. Backup files are not needed, because we have git ;-) 159 | _UpgradeReport_Files/ 160 | Backup*/ 161 | UpgradeLog*.XML 162 | UpgradeLog*.htm 163 | 164 | # SQL Server files 165 | App_Data/*.mdf 166 | App_Data/*.ldf 167 | 168 | ############# 169 | ## Windows detritus 170 | ############# 171 | 172 | # Windows image file caches 173 | Thumbs.db 174 | ehthumbs.db 175 | 176 | # Folder config file 177 | Desktop.ini 178 | 179 | # Recycle Bin used on file shares 180 | $RECYCLE.BIN/ 181 | 182 | # Mac crap 183 | .DS_Store 184 | 185 | 186 | ############# 187 | ## Python 188 | ############# 189 | 190 | *.py[co] 191 | 192 | # Packages 193 | *.egg 194 | *.egg-info 195 | dist/ 196 | build/ 197 | eggs/ 198 | parts/ 199 | var/ 200 | sdist/ 201 | develop-eggs/ 202 | .installed.cfg 203 | 204 | # Installer logs 205 | pip-log.txt 206 | 207 | # Unit test / coverage reports 208 | .coverage 209 | .tox 210 | 211 | #Translations 212 | *.mo 213 | 214 | #Mr Developer 215 | .mr.developer.cfg 216 | -------------------------------------------------------------------------------- /src/auxiliary/Evaluation.java: -------------------------------------------------------------------------------- 1 | package auxiliary; 2 | 3 | import java.util.Random; 4 | import java.util.logging.Level; 5 | import java.util.logging.Logger; 6 | 7 | /** 8 | * 9 | * @author daq 10 | */ 11 | public class Evaluation { 12 | 13 | private String clsName; 14 | private DataSet dataset; 15 | private double accMean; 16 | private double accStd; 17 | private double rmseMean; 18 | private double rmseStd; 19 | 20 | public Evaluation() { 21 | } 22 | 23 | public Evaluation(DataSet dataset, String clsName) { 24 | this.dataset = dataset; 25 | this.clsName = clsName; 26 | } 27 | 28 | public void crossValidation() { 29 | int fold = 10; 30 | 31 | Random random = new Random(2013); 32 | int[] permutation = new int[10000]; 33 | for (int i = 0; i < permutation.length; i++) { 34 | permutation[i] = i; 35 | } 36 | for (int i = 0; i < 10 * permutation.length; i++) { 37 | int repInd = random.nextInt(permutation.length); 38 | int ind = i % permutation.length; 39 | 40 | int tmp = permutation[ind]; 41 | permutation[ind] = permutation[repInd]; 42 | permutation[repInd] = tmp; 43 | } 44 | 45 | int[] perm = new int[dataset.getNumInstnaces()]; 46 | int ind = 0; 47 | for (int i = 0; i < permutation.length; i++) { 48 | if (permutation[i] < dataset.getNumInstnaces()) { 49 | perm[ind++] = permutation[i]; 50 | } 51 | } 52 | 53 | int share = dataset.getNumInstnaces() / fold; 54 | 55 | boolean[] isCategory = dataset.getIsCategory(); 56 | double[][] features = dataset.getFeatures(); 57 | double[] labels = dataset.getLabels(); 58 | 59 | boolean isClassification = isCategory[isCategory.length - 1]; 60 | 61 | double[] measures = new double[fold]; 62 | for (int f = 0; f < fold; f++) { 63 | try { 64 | int numTest = f < fold - 1 ? share : dataset.getNumInstnaces() - (fold - 1) * share; 65 | double[][] trainFeatures = new double[dataset.getNumInstnaces() - numTest][dataset.getNumAttributes()]; 66 | double[] trainLabels = new double[dataset.getNumInstnaces() - numTest]; 67 | double[][] testFeatures = new double[numTest][dataset.getNumAttributes()]; 68 | double[] testLabels = new double[numTest]; 69 | 70 | int indTrain = 0, indTest = 0; 71 | for (int j = 0; j < dataset.getNumInstnaces(); j++) { 72 | if ((f < fold - 1 && (j < f * share || j >= (f + 1) * share)) || (f == fold - 1 && j < f * share)) { 73 | System.arraycopy(features[perm[j]], 0, trainFeatures[indTrain], 0, dataset.getNumAttributes()); 74 | trainLabels[indTrain] = labels[perm[j]]; 75 | indTrain++; 76 | } else { 77 | System.arraycopy(features[perm[j]], 0, testFeatures[indTest], 0, dataset.getNumAttributes()); 78 | testLabels[indTest] = labels[perm[j]]; 79 | indTest++; 80 | } 81 | } 82 | 83 | Classifier c = (Classifier) Class.forName("auxiliary." + clsName).newInstance(); 84 | c.train(isCategory, trainFeatures, trainLabels); 85 | 86 | double error = 0; 87 | for (int j = 0; j < testLabels.length; j++) { 88 | double prediction = c.predict(testFeatures[j]); 89 | 90 | if (isClassification) { 91 | if (prediction != testLabels[j]) { 92 | error = error + 1; 93 | } 94 | } else { 95 | error = error + (prediction - testLabels[j]) * (prediction - testLabels[j]); 96 | } 97 | } 98 | if (isClassification) { 99 | measures[f] = 1 - error / testLabels.length;//accuracy = 1 - error 100 | } else { 101 | measures[f] = Math.sqrt(error / testLabels.length); 102 | } 103 | } catch (ClassNotFoundException | InstantiationException | IllegalAccessException ex) { 104 | Logger.getLogger(Evaluation.class.getName()).log(Level.SEVERE, null, ex); 105 | } 106 | } 107 | 108 | double[] mean_std = mean_std(measures); 109 | if (isClassification) { 110 | accMean = mean_std[0]; 111 | accStd = mean_std[1]; 112 | } else { 113 | rmseMean = mean_std[0]; 114 | rmseStd = mean_std[1]; 115 | } 116 | } 117 | 118 | public double[] mean_std(double[] x) { 119 | double[] ms = new double[2]; 120 | int N = x.length; 121 | 122 | ms[0] = 0; 123 | for (int i = 0; i < x.length; i++) { 124 | ms[0] += x[i]; 125 | } 126 | ms[0] /= N; 127 | 128 | ms[1] = 0; 129 | for (int i = 0; i < x.length; i++) { 130 | ms[1] += (x[i] - ms[0]) * (x[i] - ms[0]); 131 | } 132 | ms[1] /= (N - 1); 133 | 134 | return ms; 135 | } 136 | 137 | public double getAccMean() { 138 | return accMean; 139 | } 140 | 141 | public double getAccStd() { 142 | return accStd; 143 | } 144 | 145 | public double getRmseMean() { 146 | return rmseMean; 147 | } 148 | 149 | public double getRmseStd() { 150 | return rmseStd; 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /breast-cancer.data: -------------------------------------------------------------------------------- 1 | 1,1,1,1,1,1,1,1,1,1 2 | 3.0,2.0,3.0,0.0,0.0,2.0,1.0,0.0,1.0,1.0 3 | 4.0,1.0,3.0,0.0,1.0,0.0,1.0,4.0,1.0,0.0 4 | 4.0,1.0,7.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0 5 | 3.0,2.0,7.0,0.0,0.0,2.0,1.0,1.0,0.0,0.0 6 | 3.0,2.0,6.0,1.0,0.0,1.0,0.0,2.0,1.0,1.0 7 | 4.0,2.0,5.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0 8 | 4.0,1.0,8.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 9 | 3.0,2.0,2.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 10 | 3.0,2.0,0.0,0.0,1.0,1.0,1.0,3.0,1.0,0.0 11 | 3.0,1.0,8.0,5.0,0.0,1.0,1.0,0.0,0.0,0.0 12 | 4.0,2.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 13 | 5.0,1.0,3.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 14 | 4.0,1.0,6.0,0.0,1.0,0.0,1.0,4.0,1.0,0.0 15 | 4.0,1.0,5.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 16 | 3.0,2.0,5.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0 17 | 2.0,2.0,4.0,0.0,1.0,2.0,0.0,4.0,1.0,0.0 18 | 4.0,2.0,2.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0 19 | 5.0,1.0,3.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 20 | 4.0,2.0,8.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 21 | 4.0,1.0,4.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 22 | 4.0,0.0,4.0,0.0,NaN,0.0,0.0,1.0,1.0,1.0 23 | 5.0,1.0,8.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0 24 | 4.0,1.0,3.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 25 | 3.0,2.0,2.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 26 | 2.0,2.0,3.0,2.0,0.0,2.0,0.0,1.0,0.0,1.0 27 | 4.0,1.0,4.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0 28 | 4.0,1.0,2.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 29 | 3.0,2.0,2.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 30 | 5.0,1.0,6.0,1.0,0.0,2.0,0.0,1.0,1.0,0.0 31 | 3.0,2.0,3.0,5.0,0.0,2.0,0.0,1.0,1.0,1.0 32 | 5.0,1.0,6.0,0.0,1.0,2.0,1.0,4.0,1.0,1.0 33 | 5.0,1.0,5.0,1.0,NaN,0.0,1.0,1.0,0.0,0.0 34 | 4.0,1.0,5.0,0.0,1.0,2.0,0.0,2.0,1.0,0.0 35 | 4.0,1.0,4.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0 36 | 3.0,2.0,6.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0 37 | 2.0,2.0,3.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 38 | 3.0,2.0,2.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 39 | 5.0,1.0,9.0,2.0,0.0,2.0,0.0,4.0,1.0,0.0 40 | 3.0,1.0,4.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 41 | 3.0,2.0,2.0,0.0,1.0,0.0,1.0,3.0,1.0,0.0 42 | 2.0,2.0,7.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0 43 | 3.0,2.0,7.0,3.0,0.0,1.0,1.0,2.0,0.0,0.0 44 | 5.0,1.0,5.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 45 | 4.0,1.0,4.0,1.0,0.0,2.0,1.0,2.0,1.0,1.0 46 | 2.0,2.0,3.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 47 | 4.0,2.0,6.0,0.0,1.0,2.0,0.0,2.0,1.0,1.0 48 | 5.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0 49 | 3.0,2.0,7.0,0.0,0.0,2.0,1.0,0.0,0.0,0.0 50 | 4.0,2.0,10.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0 51 | 4.0,1.0,8.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0 52 | 6.0,1.0,3.0,3.0,NaN,0.0,0.0,1.0,0.0,1.0 53 | 4.0,0.0,6.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0 54 | 3.0,2.0,0.0,0.0,1.0,2.0,0.0,4.0,1.0,0.0 55 | 6.0,1.0,8.0,0.0,1.0,0.0,1.0,2.0,1.0,0.0 56 | 3.0,2.0,5.0,0.0,NaN,1.0,0.0,3.0,0.0,0.0 57 | 4.0,1.0,5.0,5.0,0.0,2.0,1.0,0.0,1.0,0.0 58 | 4.0,2.0,4.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 59 | 4.0,1.0,7.0,5.0,1.0,2.0,0.0,1.0,1.0,0.0 60 | 4.0,1.0,10.0,0.0,1.0,0.0,1.0,2.0,1.0,0.0 61 | 2.0,2.0,0.0,0.0,1.0,1.0,1.0,4.0,1.0,1.0 62 | 4.0,1.0,8.0,2.0,0.0,2.0,0.0,1.0,0.0,1.0 63 | 3.0,2.0,6.0,0.0,1.0,1.0,1.0,2.0,0.0,0.0 64 | 3.0,1.0,4.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 65 | 3.0,2.0,6.0,5.0,0.0,2.0,0.0,1.0,1.0,1.0 66 | 3.0,1.0,4.0,0.0,1.0,1.0,1.0,0.0,1.0,1.0 67 | 4.0,1.0,3.0,0.0,1.0,0.0,1.0,4.0,1.0,0.0 68 | 2.0,2.0,5.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 69 | 5.0,1.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 70 | 4.0,2.0,10.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0 71 | 2.0,2.0,2.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0 72 | 4.0,2.0,5.0,1.0,0.0,2.0,0.0,1.0,0.0,1.0 73 | 5.0,1.0,5.0,1.0,NaN,0.0,1.0,0.0,0.0,0.0 74 | 5.0,1.0,2.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0 75 | 4.0,1.0,6.0,2.0,0.0,2.0,0.0,3.0,1.0,1.0 76 | 2.0,2.0,5.0,2.0,0.0,2.0,0.0,3.0,0.0,1.0 77 | 4.0,1.0,2.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 78 | 4.0,2.0,3.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 79 | 3.0,2.0,5.0,0.0,1.0,1.0,1.0,4.0,1.0,0.0 80 | 3.0,2.0,5.0,0.0,1.0,2.0,0.0,2.0,1.0,1.0 81 | 5.0,1.0,6.0,2.0,0.0,1.0,1.0,2.0,1.0,0.0 82 | 4.0,0.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 83 | 3.0,2.0,5.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 84 | 3.0,2.0,6.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 85 | 5.0,1.0,3.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0 86 | 2.0,2.0,0.0,0.0,1.0,1.0,1.0,4.0,1.0,0.0 87 | 4.0,1.0,7.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 88 | 3.0,2.0,8.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 89 | 2.0,2.0,5.0,2.0,0.0,1.0,1.0,0.0,0.0,0.0 90 | 4.0,1.0,4.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0 91 | 4.0,1.0,6.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0 92 | 5.0,1.0,4.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0 93 | 2.0,2.0,6.0,1.0,1.0,2.0,1.0,0.0,0.0,1.0 94 | 4.0,0.0,4.0,0.0,NaN,0.0,0.0,0.0,1.0,1.0 95 | 4.0,2.0,2.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 96 | 4.0,1.0,4.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 97 | 3.0,2.0,9.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0 98 | 2.0,2.0,8.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0 99 | 4.0,2.0,2.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 100 | 5.0,1.0,6.0,0.0,1.0,2.0,1.0,0.0,0.0,1.0 101 | 3.0,2.0,7.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0 102 | 3.0,2.0,4.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0 103 | 4.0,2.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0 104 | 4.0,1.0,6.0,0.0,1.0,2.0,1.0,1.0,1.0,0.0 105 | 5.0,1.0,4.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 106 | 3.0,2.0,4.0,0.0,1.0,0.0,0.0,3.0,1.0,0.0 107 | 5.0,1.0,6.0,1.0,0.0,1.0,0.0,4.0,0.0,1.0 108 | 5.0,1.0,4.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0 109 | 4.0,2.0,5.0,0.0,1.0,1.0,0.0,2.0,1.0,1.0 110 | 4.0,1.0,6.0,0.0,1.0,0.0,1.0,2.0,1.0,0.0 111 | 3.0,2.0,4.0,0.0,1.0,1.0,0.0,3.0,1.0,0.0 112 | 5.0,1.0,3.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 113 | 5.0,1.0,6.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0 114 | 2.0,2.0,6.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 115 | 2.0,2.0,8.0,1.0,1.0,2.0,1.0,2.0,0.0,0.0 116 | 5.0,1.0,1.0,0.0,1.0,0.0,0.0,4.0,1.0,0.0 117 | 5.0,1.0,2.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0 118 | 3.0,2.0,6.0,2.0,0.0,2.0,1.0,0.0,1.0,1.0 119 | 5.0,1.0,2.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0 120 | 3.0,2.0,7.0,3.0,0.0,1.0,1.0,0.0,0.0,0.0 121 | 3.0,2.0,4.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0 122 | 3.0,2.0,6.0,0.0,0.0,2.0,1.0,2.0,1.0,1.0 123 | 4.0,2.0,5.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0 124 | 3.0,2.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 125 | 2.0,2.0,7.0,3.0,0.0,2.0,0.0,1.0,1.0,1.0 126 | 2.0,2.0,2.0,0.0,1.0,1.0,0.0,3.0,1.0,0.0 127 | 4.0,1.0,6.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0 128 | 5.0,1.0,6.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 129 | 5.0,1.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 130 | 3.0,2.0,3.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0 131 | 5.0,1.0,3.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 132 | 3.0,2.0,6.0,0.0,1.0,1.0,0.0,3.0,1.0,0.0 133 | 1.0,2.0,7.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 134 | 3.0,2.0,6.0,0.0,1.0,2.0,1.0,2.0,1.0,1.0 135 | 3.0,2.0,5.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0 136 | 2.0,2.0,6.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 137 | 2.0,2.0,3.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0 138 | 4.0,1.0,0.0,0.0,1.0,0.0,1.0,4.0,1.0,0.0 139 | 4.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 140 | 5.0,1.0,10.0,0.0,1.0,2.0,1.0,0.0,1.0,1.0 141 | 4.0,2.0,6.0,0.0,1.0,0.0,0.0,4.0,1.0,0.0 142 | 5.0,1.0,4.0,8.0,0.0,2.0,0.0,1.0,0.0,1.0 143 | 3.0,2.0,5.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 144 | 3.0,2.0,6.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0 145 | 4.0,2.0,4.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0 146 | 4.0,1.0,3.0,0.0,0.0,1.0,0.0,4.0,0.0,0.0 147 | 4.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 148 | 2.0,2.0,6.0,3.0,1.0,1.0,1.0,0.0,0.0,1.0 149 | 5.0,1.0,2.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 150 | 3.0,2.0,8.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 151 | 4.0,1.0,6.0,3.0,NaN,2.0,0.0,0.0,0.0,0.0 152 | 3.0,2.0,10.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0 153 | 4.0,1.0,3.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 154 | 4.0,1.0,8.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0 155 | 2.0,2.0,5.0,1.0,0.0,2.0,0.0,1.0,0.0,1.0 156 | 5.0,1.0,2.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 157 | 5.0,0.0,2.0,0.0,1.0,0.0,0.0,2.0,1.0,0.0 158 | 2.0,2.0,6.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0 159 | 2.0,2.0,4.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0 160 | 4.0,1.0,2.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 161 | 5.0,1.0,5.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0 162 | 4.0,1.0,5.0,1.0,0.0,2.0,1.0,0.0,1.0,0.0 163 | 3.0,2.0,6.0,2.0,1.0,1.0,0.0,0.0,1.0,0.0 164 | 5.0,1.0,10.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 165 | 4.0,2.0,6.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 166 | 3.0,1.0,4.0,1.0,1.0,2.0,1.0,1.0,0.0,1.0 167 | 4.0,1.0,6.0,2.0,0.0,1.0,0.0,3.0,0.0,1.0 168 | 5.0,1.0,5.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0 169 | 3.0,2.0,4.0,0.0,1.0,1.0,0.0,4.0,1.0,0.0 170 | 3.0,2.0,4.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 171 | 3.0,2.0,10.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 172 | 4.0,1.0,4.0,0.0,1.0,1.0,1.0,4.0,1.0,1.0 173 | 4.0,1.0,6.0,1.0,1.0,2.0,1.0,0.0,1.0,1.0 174 | 3.0,1.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 175 | 4.0,2.0,5.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0 176 | 3.0,2.0,8.0,1.0,0.0,2.0,1.0,0.0,0.0,0.0 177 | 3.0,2.0,4.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 178 | 3.0,2.0,4.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0 179 | 3.0,2.0,5.0,3.0,0.0,2.0,1.0,0.0,1.0,1.0 180 | 3.0,2.0,5.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0 181 | 3.0,2.0,4.0,0.0,1.0,0.0,1.0,2.0,1.0,0.0 182 | 2.0,2.0,8.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 183 | 5.0,1.0,2.0,2.0,0.0,2.0,0.0,0.0,0.0,1.0 184 | 3.0,2.0,7.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 185 | 4.0,1.0,6.0,1.0,1.0,2.0,0.0,1.0,1.0,1.0 186 | 3.0,2.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0 187 | 5.0,1.0,3.0,0.0,1.0,0.0,0.0,3.0,1.0,0.0 188 | 3.0,2.0,6.0,0.0,1.0,2.0,1.0,2.0,1.0,0.0 189 | 3.0,2.0,5.0,0.0,1.0,2.0,0.0,0.0,1.0,1.0 190 | 4.0,1.0,1.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 191 | 4.0,2.0,5.0,0.0,1.0,1.0,1.0,3.0,1.0,0.0 192 | 4.0,2.0,5.0,0.0,1.0,1.0,0.0,2.0,1.0,1.0 193 | 3.0,2.0,2.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0 194 | 5.0,1.0,7.0,2.0,0.0,2.0,0.0,1.0,1.0,1.0 195 | 5.0,1.0,10.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0 196 | 3.0,2.0,5.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 197 | 2.0,2.0,4.0,1.0,1.0,1.0,1.0,4.0,1.0,0.0 198 | 2.0,2.0,6.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0 199 | 5.0,0.0,6.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 200 | 3.0,2.0,3.0,4.0,1.0,2.0,1.0,3.0,0.0,0.0 201 | 5.0,1.0,4.0,0.0,1.0,2.0,1.0,1.0,1.0,1.0 202 | 2.0,2.0,1.0,0.0,1.0,1.0,0.0,3.0,1.0,0.0 203 | 3.0,2.0,6.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 204 | 5.0,1.0,6.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 205 | 3.0,2.0,5.0,0.0,1.0,0.0,1.0,3.0,1.0,0.0 206 | 3.0,2.0,5.0,0.0,1.0,0.0,0.0,3.0,1.0,0.0 207 | 5.0,1.0,8.0,1.0,0.0,2.0,1.0,1.0,1.0,1.0 208 | 4.0,1.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 209 | 4.0,2.0,6.0,0.0,1.0,2.0,1.0,0.0,0.0,1.0 210 | 3.0,1.0,6.0,1.0,1.0,2.0,0.0,1.0,1.0,1.0 211 | 3.0,2.0,5.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0 212 | 3.0,1.0,5.0,4.0,0.0,2.0,0.0,3.0,0.0,1.0 213 | 3.0,2.0,8.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0 214 | 3.0,2.0,4.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 215 | 4.0,1.0,5.0,0.0,1.0,0.0,0.0,3.0,1.0,0.0 216 | 3.0,2.0,4.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 217 | 6.0,1.0,8.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 218 | 5.0,1.0,5.0,0.0,1.0,2.0,0.0,0.0,1.0,1.0 219 | 4.0,2.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 220 | 5.0,1.0,9.0,0.0,1.0,0.0,1.0,2.0,0.0,1.0 221 | 4.0,1.0,4.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0 222 | 4.0,1.0,5.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 223 | 4.0,1.0,4.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 224 | 3.0,2.0,4.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0 225 | 4.0,1.0,7.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0 226 | 2.0,2.0,4.0,0.0,1.0,2.0,0.0,0.0,0.0,1.0 227 | 5.0,1.0,6.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 228 | 5.0,1.0,5.0,0.0,1.0,2.0,1.0,1.0,1.0,0.0 229 | 3.0,1.0,6.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0 230 | 2.0,2.0,5.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 231 | 3.0,2.0,4.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0 232 | 2.0,2.0,4.0,0.0,1.0,1.0,0.0,3.0,1.0,0.0 233 | 3.0,2.0,2.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 234 | 4.0,2.0,3.0,0.0,1.0,1.0,1.0,3.0,1.0,0.0 235 | 4.0,2.0,5.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0 236 | 5.0,1.0,4.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 237 | 5.0,1.0,8.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0 238 | 2.0,0.0,3.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0 239 | 3.0,2.0,6.0,4.0,0.0,2.0,0.0,0.0,0.0,1.0 240 | 5.0,1.0,6.0,0.0,0.0,1.0,1.0,2.0,0.0,1.0 241 | 4.0,1.0,8.0,2.0,0.0,2.0,0.0,1.0,0.0,1.0 242 | 4.0,1.0,6.0,0.0,1.0,2.0,0.0,NaN,1.0,1.0 243 | 6.0,1.0,2.0,0.0,1.0,1.0,0.0,4.0,1.0,0.0 244 | 2.0,2.0,8.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0 245 | 3.0,2.0,6.0,0.0,1.0,1.0,1.0,3.0,1.0,0.0 246 | 3.0,2.0,6.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 247 | 5.0,1.0,3.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 248 | 3.0,2.0,2.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 249 | 5.0,1.0,4.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 250 | 4.0,1.0,2.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0 251 | 4.0,2.0,5.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 252 | 4.0,1.0,6.0,3.0,0.0,2.0,0.0,3.0,0.0,1.0 253 | 4.0,1.0,2.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 254 | 3.0,2.0,6.0,0.0,1.0,0.0,0.0,2.0,1.0,0.0 255 | 6.0,1.0,0.0,0.0,1.0,0.0,0.0,3.0,1.0,0.0 256 | 3.0,2.0,5.0,0.0,1.0,2.0,1.0,0.0,0.0,0.0 257 | 4.0,2.0,5.0,0.0,1.0,2.0,1.0,1.0,0.0,1.0 258 | 4.0,1.0,8.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 259 | 5.0,1.0,5.0,0.0,1.0,2.0,0.0,3.0,0.0,1.0 260 | 3.0,2.0,6.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0 261 | 4.0,1.0,4.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0 262 | 6.0,1.0,4.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0 263 | 2.0,2.0,5.0,0.0,1.0,0.0,0.0,4.0,1.0,0.0 264 | 5.0,1.0,6.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0 265 | 3.0,2.0,4.0,1.0,0.0,1.0,1.0,2.0,0.0,1.0 266 | 4.0,1.0,6.0,3.0,NaN,2.0,0.0,1.0,0.0,0.0 267 | 4.0,1.0,0.0,0.0,1.0,1.0,0.0,4.0,1.0,0.0 268 | 3.0,2.0,4.0,0.0,1.0,2.0,1.0,1.0,0.0,0.0 269 | 2.0,2.0,7.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0 270 | 5.0,1.0,6.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0 271 | 5.0,1.0,4.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0 272 | 4.0,1.0,5.0,2.0,1.0,2.0,0.0,1.0,0.0,1.0 273 | 4.0,2.0,7.0,5.0,0.0,2.0,1.0,2.0,1.0,1.0 274 | 2.0,2.0,4.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0 275 | 3.0,2.0,4.0,2.0,1.0,1.0,1.0,1.0,0.0,0.0 276 | 4.0,1.0,7.0,0.0,1.0,2.0,0.0,1.0,1.0,0.0 277 | 4.0,2.0,7.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0 278 | 3.0,2.0,5.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0 279 | 3.0,2.0,7.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 280 | 4.0,2.0,6.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0 281 | 3.0,2.0,4.0,0.0,1.0,1.0,1.0,2.0,1.0,0.0 282 | 5.0,1.0,3.0,0.0,1.0,2.0,1.0,0.0,0.0,0.0 283 | 4.0,1.0,6.0,2.0,0.0,1.0,0.0,1.0,1.0,0.0 284 | 4.0,2.0,5.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0 285 | 2.0,2.0,6.0,2.0,0.0,1.0,1.0,2.0,1.0,0.0 286 | 4.0,2.0,3.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0 287 | 4.0,1.0,8.0,0.0,1.0,2.0,0.0,2.0,1.0,0.0 288 | -------------------------------------------------------------------------------- /src/auxiliary/RandomForest.java: -------------------------------------------------------------------------------- 1 | package auxiliary; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * 7 | * @author 孔繁宇 MF1333020 8 | */ 9 | 10 | public class RandomForest extends Classifier { 11 | // 存储可放回抽取生成的新样本集 12 | class RepickSamples { 13 | double[][] features; 14 | double[] labels; 15 | int[] index; 16 | } 17 | 18 | private static int classifier = 9; // 生成分类器的数量 19 | private RandomDecisionTree forest[]; 20 | 21 | public RandomForest() { 22 | } 23 | 24 | @Override 25 | public void train(boolean[] isCategory, double[][] features, double[] labels) { 26 | forest = new RandomDecisionTree[classifier]; 27 | for (int i = 0; i < classifier; ++i) { 28 | RepickSamples samples = repickSamples(features, labels); 29 | forest[i] = new RandomDecisionTree(); 30 | forest[i].train(isCategory, samples.features, samples.labels); 31 | } 32 | } 33 | 34 | // 可放回收取新样本集 35 | private RepickSamples repickSamples(double[][] features, double[] labels) { 36 | RepickSamples samples = new RepickSamples(); 37 | int size = labels.length; 38 | Random random = new Random(); 39 | 40 | samples.features = new double[size][]; 41 | samples.labels = new double[size]; 42 | samples.index = new int[size]; 43 | for (int i = 0; i < size; ++i) { 44 | int index = random.nextInt(size); 45 | samples.features[i] = features[index].clone(); 46 | samples.labels[i] = labels[index]; 47 | samples.index[i] = index; 48 | } 49 | 50 | return samples; 51 | } 52 | 53 | @Override 54 | public double predict(double[] features) { 55 | HashMap counter = new HashMap(); 56 | for (int i = 0; i < forest.length; ++i) { 57 | double label = forest[i].predict(features); 58 | if (counter.get(label) == null) { 59 | counter.put(label, 1); 60 | } else { 61 | int count = counter.get(label) + 1; 62 | counter.put(label, count); 63 | } 64 | } 65 | 66 | int temp_max = 0; 67 | double label = 0; 68 | Iterator iterator = counter.keySet().iterator(); 69 | while (iterator.hasNext()) { 70 | double key = iterator.next(); 71 | int count = counter.get(key); 72 | if (count > temp_max) { 73 | temp_max = count; 74 | label = key; 75 | } 76 | } 77 | 78 | return label; 79 | } 80 | } 81 | 82 | // <<<<--------------------------华丽的分界线,下面是随机决策树的实现---------------------------->>>> 83 | 84 | class RandomDecisionTree extends Classifier { 85 | // 决策树节点结构 86 | class TreeNode { 87 | int[] set; // 样本下标集合 88 | int[] attr_index; // 可用属性下标集合 89 | double label; // 标签 90 | int split_attr; // 该节点用于分割的属性下标 91 | double[] split_points; // 切割点 离散属性为多值,连续属性只有一个值 92 | TreeNode[] childrenNodes; // 子节点 93 | } 94 | 95 | // 存储分割信息 96 | class SplitData { 97 | int split_attr; 98 | double[] split_points; 99 | int[][] split_sets; // 分割后新的样本集合的数组 100 | } 101 | 102 | class BundleData { 103 | double floatValue; // 存储增益率或MSE 104 | SplitData split_info; 105 | } 106 | 107 | // 当分割出现错误时抛出此异常 108 | class SplitException extends Exception { 109 | } 110 | 111 | private boolean _isClassification; 112 | private double[][] _features; 113 | private boolean[] _isCategory; 114 | private double[] _labels; 115 | private double[] _defaults; 116 | 117 | private TreeNode root; 118 | 119 | public RandomDecisionTree() { 120 | } 121 | 122 | @Override 123 | public void train(boolean[] isCategory, double[][] features, double[] labels) { 124 | _isClassification = isCategory[isCategory.length - 1]; 125 | _features = features; 126 | _isCategory = isCategory; 127 | _labels = labels; 128 | 129 | int set[] = new int[_features.length]; 130 | for (int i = 0; i < set.length; ++i) { 131 | set[i] = i; 132 | } 133 | 134 | int attr_index[] = new int[_features[0].length]; 135 | for (int i = 0; i < attr_index.length; ++i) { 136 | attr_index[i] = i; 137 | } 138 | 139 | // 处理缺失属性 140 | _defaults = kill_missing_data(); 141 | 142 | root = build_decision_tree(set, attr_index); 143 | } 144 | 145 | private double[] kill_missing_data() { 146 | int num = _isCategory.length - 1; 147 | double[] defaults = new double[num]; 148 | 149 | for (int i = 0; i < defaults.length; ++i) { 150 | if (_isCategory[i]) { 151 | // 离散属性取最多的值 152 | HashMap counter = new HashMap(); 153 | for (int j = 0; j < _features.length; ++j) { 154 | double feature = _features[j][i]; 155 | if (!Double.isNaN(feature)) { 156 | if (counter.get(feature) == null) { 157 | counter.put(feature, 1); 158 | } else { 159 | int count = counter.get(feature) + 1; 160 | counter.put(feature, count); 161 | } 162 | } 163 | } 164 | 165 | int max_time = 0; 166 | double value = 0; 167 | Iterator iterator = counter.keySet().iterator(); 168 | while (iterator.hasNext()) { 169 | double key = iterator.next(); 170 | int count = counter.get(key); 171 | if (count > max_time) { 172 | max_time = count; 173 | value = key; 174 | } 175 | } 176 | defaults[i] = value; 177 | } else { 178 | // 连续属性取平均值 179 | int count = 0; 180 | double total = 0; 181 | for (int j = 0; j < _features.length; ++j) { 182 | if (!Double.isNaN(_features[j][i])) { 183 | count++; 184 | total += _features[j][i]; 185 | } 186 | } 187 | defaults[i] = total / count; 188 | } 189 | } 190 | 191 | // 代换 192 | for (int i = 0; i < _features.length; ++i) { 193 | for (int j = 0; j < defaults.length; ++j) { 194 | if (Double.isNaN(_features[i][j])) { 195 | _features[i][j] = defaults[j]; 196 | } 197 | } 198 | } 199 | return defaults; 200 | } 201 | 202 | @Override 203 | public double predict(double[] features) { 204 | // 处理缺失属性 205 | for (int i = 0; i < features.length; ++i) { 206 | if (Double.isNaN(features[i])) { 207 | features[i] = _defaults[i]; 208 | } 209 | } 210 | 211 | return predict_with_decision_tree(features, root); 212 | } 213 | 214 | private double predict_with_decision_tree(double[] features, TreeNode node) { 215 | if (node.childrenNodes == null) { 216 | return node.label; 217 | } 218 | 219 | double feature = features[node.split_attr]; 220 | 221 | if (_isCategory[node.split_attr]) { 222 | // 离散属性 223 | for (int i = 0; i < node.split_points.length; ++i) { 224 | if (node.split_points[i] == feature) { 225 | return predict_with_decision_tree(features, node.childrenNodes[i]); 226 | } 227 | } 228 | 229 | return node.label; // 不存在的属性取父节点样本的标签,减少叶子结点 230 | } else { 231 | // 连续属性 232 | if (feature < node.split_points[0]) { 233 | return predict_with_decision_tree(features, node.childrenNodes[0]); 234 | } else { 235 | return predict_with_decision_tree(features, node.childrenNodes[1]); 236 | } 237 | } 238 | 239 | } 240 | 241 | private TreeNode build_decision_tree(int[] set, int[] attr_index) { 242 | TreeNode node = new TreeNode(); 243 | node.set = set; 244 | node.attr_index = attr_index; 245 | node.label = 0; 246 | node.childrenNodes = null; 247 | 248 | // 都为同类返回直接返回 249 | double label = _labels[node.set[0]]; 250 | boolean flag = true; 251 | for (int i = 0; i < node.set.length; ++i) { 252 | if (_labels[node.set[i]] != label) { 253 | flag = false; 254 | break; 255 | } 256 | } 257 | if (flag) { 258 | node.label = label; 259 | return node; 260 | } 261 | 262 | // 没有可用属性标记为大多数(离散)或平均值(连续) 263 | if (_isClassification) { 264 | node.label = most_label(set); 265 | } else { 266 | node.label = mean_value(set); 267 | } 268 | if (node.attr_index == null || node.attr_index.length == 0) { 269 | return node; 270 | } 271 | 272 | // 寻找最优切割属性 273 | SplitData split_info = attribute_selection(node); 274 | node.split_attr = split_info.split_attr; 275 | // 没有可以分割的属性 276 | if (node.split_attr < 0) { 277 | return node; 278 | } 279 | 280 | node.split_points = split_info.split_points; 281 | 282 | // 去掉已使用的离散属性,连续属性不做删除 283 | int[] child_attr_index = null; 284 | if (_isCategory[node.split_attr]) { 285 | child_attr_index = new int[attr_index.length - 1]; 286 | int t = 0; 287 | for (int index : attr_index) { 288 | if (index != node.split_attr) { 289 | child_attr_index[t++] = index; 290 | } 291 | } 292 | } else { 293 | child_attr_index = node.attr_index.clone(); 294 | } 295 | 296 | // 递归建立子节点 297 | node.childrenNodes = new TreeNode[split_info.split_sets.length]; 298 | for (int i = 0; i < split_info.split_sets.length; ++i) { 299 | node.childrenNodes[i] = build_decision_tree(split_info.split_sets[i], child_attr_index); 300 | } 301 | 302 | return node; 303 | } 304 | 305 | // 给定样本中出现最多的标签 306 | private double most_label(int[] set) { 307 | HashMap counter = new HashMap(); 308 | for (int item : set) { 309 | double label = _labels[item]; 310 | if (counter.get(label) == null) { 311 | counter.put(label, 1); 312 | } else { 313 | int count = counter.get(label) + 1; 314 | counter.put(label, count); 315 | } 316 | } 317 | 318 | int max_time = 0; 319 | double label = 0; 320 | Iterator iterator = counter.keySet().iterator(); 321 | while (iterator.hasNext()) { 322 | double key = iterator.next(); 323 | int count = counter.get(key); 324 | if (count > max_time) { 325 | max_time = count; 326 | label = key; 327 | } 328 | } 329 | return label; 330 | } 331 | 332 | // 给定样本的标签平均值 333 | private double mean_value(int[] set) { 334 | double temp = 0; 335 | for (int index : set) { 336 | temp += _labels[index]; 337 | } 338 | return temp / set.length; 339 | } 340 | 341 | private SplitData attribute_selection(TreeNode node) { 342 | SplitData result = new SplitData(); 343 | result.split_attr = -1; 344 | 345 | // 前剪枝 346 | double reference_value = _isClassification ? 0.05 : -1; 347 | if (node.set.length < 7) return result; 348 | 349 | // 生成随机选取的属性 350 | int n = (int) (Math.log(1 + node.attr_index.length) / Math.log(2)); 351 | int attrs[] = new int[n]; 352 | Random random = new Random(); 353 | HashSet hash = new HashSet(); 354 | for (int i = 0; i < n; ++i) { 355 | int index = 0; 356 | do { 357 | index = random.nextInt(node.attr_index.length); 358 | } while (hash.contains(index)); 359 | hash.add(index); 360 | attrs[i] = node.attr_index[index]; 361 | } 362 | if (_isClassification) { 363 | for (int attribute : attrs) { 364 | try { 365 | BundleData gain_ratio_info = gain_ratio_use_attribute(node.set, attribute); // 分割错误会抛出分割异常 366 | if (gain_ratio_info.floatValue > reference_value) { 367 | reference_value = gain_ratio_info.floatValue; 368 | result = gain_ratio_info.split_info; 369 | } 370 | } catch (SplitException ex) { // 捕获异常,直接丢弃 371 | } 372 | } 373 | } else { 374 | for (int attribute : attrs) { 375 | try { 376 | BundleData mse_info = mse_use_attribute(node.set, attribute); 377 | if (reference_value < 0 || mse_info.floatValue < reference_value) { 378 | reference_value = mse_info.floatValue; 379 | result = mse_info.split_info; 380 | } 381 | } catch (SplitException ex) { 382 | } 383 | } 384 | } 385 | return result; 386 | } 387 | 388 | private SplitData split_with_attribute(int[] set, int attribute) throws SplitException { 389 | SplitData result = new SplitData(); 390 | result.split_attr = attribute; 391 | 392 | if (_isCategory[attribute]) { 393 | // 离散属性 394 | int amount_of_features = 0; 395 | HashMap counter = new HashMap(); 396 | HashMap index_recorder = new HashMap(); 397 | for (int item : set) { 398 | double feature = _features[item][attribute]; 399 | if (counter.get(feature) == null) { 400 | counter.put(feature, 1); 401 | index_recorder.put(feature, amount_of_features++); 402 | } else { 403 | int count = counter.get(feature) + 1; 404 | counter.put(feature, count); 405 | } 406 | } 407 | 408 | // 记录切割点 409 | result.split_points = new double[amount_of_features]; 410 | Iterator iterator = index_recorder.keySet().iterator(); 411 | 412 | while (iterator.hasNext()) { 413 | double key = iterator.next(); 414 | int value = index_recorder.get(key); 415 | result.split_points[value] = key; 416 | } 417 | 418 | result.split_sets = new int[amount_of_features][]; 419 | int[] t_index = new int[amount_of_features]; 420 | for (int i = 0; i < amount_of_features; ++i) t_index[i] = 0; 421 | 422 | for (int item : set) { 423 | int index = index_recorder.get(_features[item][attribute]); 424 | if (result.split_sets[index] == null) { 425 | result.split_sets[index] = new int[counter.get(_features[item][attribute])]; 426 | } 427 | result.split_sets[index][t_index[index]++] = item; 428 | } 429 | } else { 430 | // 连续属性 431 | double[] features = new double[set.length]; 432 | for (int i = 0; i < features.length; ++i) { 433 | features[i] = _features[set[i]][attribute]; 434 | } 435 | Arrays.sort(features); 436 | 437 | double reference_value = _isClassification ? 0 : -1; 438 | double best_split_point = 0; 439 | result.split_sets = new int[2][]; 440 | for (int i = 0; i < features.length - 1; ++i) { 441 | if (features[i] == features[i + 1]) continue; 442 | double split_point = (features[i] + features[i + 1]) / 2; 443 | int[] sub_set_a = new int[i + 1]; 444 | int[] sub_set_b = new int[set.length - i - 1]; 445 | 446 | int a_index = 0; 447 | int b_index = 0; 448 | for (int j = 0; j < set.length; ++j) { 449 | if (_features[set[j]][attribute] < split_point) { 450 | sub_set_a[a_index++] = set[j]; 451 | } else { 452 | sub_set_b[b_index++] = set[j]; 453 | } 454 | } 455 | 456 | if (_isClassification) { 457 | double temp = gain_ratio_use_numerical_attribute(set, attribute, sub_set_a, sub_set_b); 458 | if (temp > reference_value) { 459 | reference_value = temp; 460 | best_split_point = split_point; 461 | result.split_sets[0] = sub_set_a; 462 | result.split_sets[1] = sub_set_b; 463 | } 464 | } else { 465 | double temp = (sub_set_a.length * mse(sub_set_a) + sub_set_b.length * mse(sub_set_b)) / set.length; 466 | if (reference_value < 0 || temp < reference_value) { 467 | reference_value = temp; 468 | best_split_point = split_point; 469 | result.split_sets[0] = sub_set_a; 470 | result.split_sets[1] = sub_set_b; 471 | } 472 | } 473 | } 474 | // 没有分割点,抛出分割异常 475 | if (result.split_sets[0] == null && result.split_sets[1] == null) throw new SplitException(); 476 | result.split_points = new double[1]; 477 | result.split_points[0] = best_split_point; 478 | } 479 | return result; 480 | } 481 | 482 | // 计算给定样本集合的熵 483 | private double entropy(int[] set) { 484 | HashMap counter = new HashMap(); 485 | for (int item : set) { 486 | double label = _labels[item]; 487 | if (counter.get(label) == null) { 488 | counter.put(label, 1); 489 | } else { 490 | int count = counter.get(label) + 1; 491 | counter.put(label, count); 492 | } 493 | } 494 | 495 | double result = 0; 496 | Iterator iterator = counter.keySet().iterator(); 497 | while (iterator.hasNext()) { 498 | int count = counter.get(iterator.next()); 499 | double p = (double) count / set.length; 500 | result += -p * Math.log(p); 501 | } 502 | 503 | return result; 504 | } 505 | 506 | // 增益率 C4.5 507 | private BundleData gain_ratio_use_attribute(int[] set, int attribute) throws SplitException { 508 | BundleData result = new BundleData(); 509 | double entropy_before_split = entropy(set); 510 | 511 | double entropy_after_split = 0; 512 | double split_information = 0; 513 | result.split_info = split_with_attribute(set, attribute); 514 | for (int[] sub_set : result.split_info.split_sets) { 515 | entropy_after_split += (double) sub_set.length / set.length * entropy(sub_set); 516 | double p = (double) sub_set.length / set.length; 517 | split_information += -p * Math.log(p); 518 | } 519 | result.floatValue = (entropy_before_split - entropy_after_split) / split_information; 520 | return result; 521 | } 522 | 523 | private double gain_ratio_use_numerical_attribute(int[] set, int attribute, int[] part_a, int[] part_b) { 524 | double entropy_before_split = entropy(set); 525 | double entropy_after_split = (part_a.length * entropy(part_a) + part_b.length * entropy(part_b)) / set.length; 526 | 527 | double split_information = 0; 528 | double p = (double) part_a.length / set.length; 529 | split_information += -p * Math.log(p); 530 | p = (double) part_b.length / set.length; 531 | split_information += -p * Math.log(p); 532 | 533 | return (entropy_before_split - entropy_after_split) / split_information; 534 | } 535 | 536 | private double mse(int[] set) { 537 | double mean = mean_value(set); 538 | 539 | double temp = 0; 540 | for (int index : set) { 541 | double t = _labels[index] - mean; 542 | temp += t * t; 543 | } 544 | return temp / set.length; 545 | } 546 | 547 | private BundleData mse_use_attribute(int[] set, int attribute) throws SplitException { 548 | BundleData mse_info = new BundleData(); 549 | mse_info.floatValue = 0; 550 | mse_info.split_info = split_with_attribute(set, attribute); 551 | for (int[] sub_set : mse_info.split_info.split_sets) { 552 | mse_info.floatValue += (double) sub_set.length / set.length * mse(sub_set); 553 | } 554 | return mse_info; 555 | } 556 | } -------------------------------------------------------------------------------- /src/auxiliary/AdaBoost.java: -------------------------------------------------------------------------------- 1 | package auxiliary; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * 7 | * @author 孔繁宇 MF1333020 8 | */ 9 | 10 | public class AdaBoost extends Classifier { 11 | // 存储可放回抽取生成的新样本集 12 | class RepickSamples { 13 | double[][] features; 14 | double[] labels; 15 | int[] index; 16 | } 17 | 18 | // 学习结果 19 | class LearningResult { 20 | double error; // 训练错误率 21 | boolean correct[]; 22 | } 23 | 24 | private static int classifier = 9; // 生成分类器的数量 25 | private Classifier classifiers[]; 26 | private double weightsOfClassifiers[]; 27 | private double weightsOfSamples[]; 28 | 29 | public AdaBoost() { 30 | } 31 | 32 | @Override 33 | public void train(boolean[] isCategory, double[][] features, double[] labels) { 34 | classifiers = new Classifier[classifier]; 35 | weightsOfClassifiers = new double[classifier]; 36 | 37 | int samplesCount = labels.length; 38 | weightsOfSamples = new double[samplesCount]; 39 | 40 | for (int i = 0; i < samplesCount; ++i) { 41 | weightsOfSamples[i] = (double)1 / samplesCount; 42 | } 43 | 44 | for (int i = 0; i < classifier; ++i) { 45 | LearningResult result = null; 46 | RepickSamples samples = null; 47 | do { 48 | samples = repickSamplesWithWeights(features, labels); 49 | classifiers[i] = new DecisionTree(); 50 | classifiers[i].train(isCategory, samples.features, samples.labels); 51 | 52 | result = validateLearningPerformance(classifiers[i], samples); 53 | } while (result.error > 0.5); 54 | 55 | // 调整样本权重 56 | weightsOfClassifiers[i] = 0.5 * Math.log((1 - result.error) / result.error); 57 | HashSet checker = new HashSet(); 58 | for (int j = 0; j < result.correct.length; ++j) { 59 | //同一样本不重复更新权值 60 | if (checker.contains(samples.index[j])) continue; 61 | checker.add(samples.index[j]); 62 | 63 | if (result.correct[j]) { 64 | weightsOfSamples[samples.index[j]] *= Math.exp(-weightsOfClassifiers[i]); 65 | } else { 66 | weightsOfSamples[samples.index[j]] *= Math.exp(weightsOfClassifiers[i]); 67 | } 68 | } 69 | 70 | // 规范化 71 | double total = 0; 72 | for (int j = 0; j < samplesCount; ++j) { 73 | total += weightsOfSamples[j]; 74 | } 75 | for (int j = 0; j < samplesCount; ++j) { 76 | weightsOfSamples[j] /= total; 77 | } 78 | } 79 | } 80 | 81 | private RepickSamples repickSamplesWithWeights(double[][] features, double[] labels) { 82 | RepickSamples samples = new RepickSamples(); 83 | int size = labels.length; 84 | samples.features = new double[size][]; 85 | samples.index = new int[size]; 86 | samples.labels = new double[size]; 87 | 88 | Random random = new Random(); 89 | for (int i = 0; i < size; ++i) { 90 | double weight = random.nextDouble(); 91 | double temp = 0; 92 | int j; 93 | for (j = 0; j < size; ++j) { 94 | temp += weightsOfSamples[j]; 95 | if (temp > weight) break; 96 | } 97 | if (j == size) j--; 98 | 99 | samples.features[i] = features[j].clone(); 100 | samples.labels[i] = labels[j]; 101 | samples.index[i] = j; 102 | } 103 | 104 | return samples; 105 | } 106 | 107 | private LearningResult validateLearningPerformance(Classifier classifier, RepickSamples samples) { 108 | LearningResult result = new LearningResult(); 109 | result.error = 0; 110 | result.correct = new boolean[samples.labels.length]; 111 | 112 | HashSet checker = new HashSet(); 113 | for (int i = 0; i < samples.labels.length; ++i) { 114 | if (samples.labels[i] == classifier.predict(samples.features[i])) { 115 | // 预测正确 116 | result.correct[i] = true; 117 | } else { 118 | result.correct[i] = false; 119 | if (checker.contains(samples.index[i])) continue; 120 | checker.add(samples.index[i]); 121 | result.error += weightsOfSamples[samples.index[i]]; 122 | } 123 | } 124 | return result; 125 | } 126 | 127 | @Override 128 | public double predict(double[] features) { 129 | HashMap counter = new HashMap(); 130 | for (int i = 0; i < classifiers.length; ++i) { 131 | double label = classifiers[i].predict(features); 132 | if (counter.get(label) == null) { 133 | counter.put(label, weightsOfClassifiers[i]); 134 | } else { 135 | double weight = counter.get(label) + weightsOfClassifiers[i]; 136 | counter.put(label, weight); 137 | } 138 | } 139 | 140 | double temp_max = 0; 141 | double label = 0; 142 | Iterator iterator = counter.keySet().iterator(); 143 | while (iterator.hasNext()) { 144 | double key = iterator.next(); 145 | double weight = counter.get(key); 146 | if (weight > temp_max) { 147 | temp_max = weight; 148 | label = key; 149 | } 150 | } 151 | 152 | return label; 153 | } 154 | } 155 | 156 | // <<<<---------------------------华丽的分界线,下面是决策树的实现----------------------------->>>> 157 | 158 | class DecisionTree extends Classifier { 159 | // 决策树节点结构 160 | class TreeNode { 161 | int[] set; // 样本下标集合 162 | int[] attr_index; // 可用属性下标集合 163 | double label; // 标签 164 | int split_attr; // 该节点用于分割的属性下标 165 | double[] split_points; // 切割点 离散属性为多值,连续属性只有一个值 166 | TreeNode[] childrenNodes; // 子节点 167 | } 168 | 169 | // 存储分割信息 170 | class SplitData { 171 | int split_attr; 172 | double[] split_points; 173 | int[][] split_sets; // 分割后新的样本集合的数组 174 | } 175 | 176 | class BundleData { 177 | double floatValue; // 存储增益率或MSE 178 | SplitData split_info; 179 | } 180 | 181 | // 当分割出现错误时抛出此异常 182 | class SplitException extends Exception { 183 | } 184 | 185 | private boolean _isClassification; 186 | private double[][] _features; 187 | private boolean[] _isCategory; 188 | private double[] _labels; 189 | private double[] _defaults; 190 | 191 | private TreeNode root; 192 | 193 | public DecisionTree() { 194 | } 195 | 196 | @Override 197 | public void train(boolean[] isCategory, double[][] features, double[] labels) { 198 | _isClassification = isCategory[isCategory.length - 1]; 199 | _features = features; 200 | _isCategory = isCategory; 201 | _labels = labels; 202 | 203 | int set[] = new int[_features.length]; 204 | for (int i = 0; i < set.length; ++i) { 205 | set[i] = i; 206 | } 207 | 208 | int attr_index[] = new int[_features[0].length]; 209 | for (int i = 0; i < attr_index.length; ++i) { 210 | attr_index[i] = i; 211 | } 212 | 213 | // 处理缺失属性 214 | _defaults = kill_missing_data(); 215 | 216 | root = build_decision_tree(set, attr_index); 217 | } 218 | 219 | private double[] kill_missing_data() { 220 | int num = _isCategory.length - 1; 221 | double[] defaults = new double[num]; 222 | 223 | for (int i = 0; i < defaults.length; ++i) { 224 | if (_isCategory[i]) { 225 | // 离散属性取最多的值 226 | HashMap counter = new HashMap(); 227 | for (int j = 0; j < _features.length; ++j) { 228 | double feature = _features[j][i]; 229 | if (!Double.isNaN(feature)) { 230 | if (counter.get(feature) == null) { 231 | counter.put(feature, 1); 232 | } else { 233 | int count = counter.get(feature) + 1; 234 | counter.put(feature, count); 235 | } 236 | } 237 | } 238 | 239 | int max_time = 0; 240 | double value = 0; 241 | Iterator iterator = counter.keySet().iterator(); 242 | while (iterator.hasNext()) { 243 | double key = iterator.next(); 244 | int count = counter.get(key); 245 | if (count > max_time) { 246 | max_time = count; 247 | value = key; 248 | } 249 | } 250 | defaults[i] = value; 251 | } else { 252 | // 连续属性取平均值 253 | int count = 0; 254 | double total = 0; 255 | for (int j = 0; j < _features.length; ++j) { 256 | if (!Double.isNaN(_features[j][i])) { 257 | count++; 258 | total += _features[j][i]; 259 | } 260 | } 261 | defaults[i] = total / count; 262 | } 263 | } 264 | 265 | // 代换 266 | for (int i = 0; i < _features.length; ++i) { 267 | for (int j = 0; j < defaults.length; ++j) { 268 | if (Double.isNaN(_features[i][j])) { 269 | _features[i][j] = defaults[j]; 270 | } 271 | } 272 | } 273 | return defaults; 274 | } 275 | 276 | @Override 277 | public double predict(double[] features) { 278 | // 处理缺失属性 279 | for (int i = 0; i < features.length; ++i) { 280 | if (Double.isNaN(features[i])) { 281 | features[i] = _defaults[i]; 282 | } 283 | } 284 | 285 | return predict_with_decision_tree(features, root); 286 | } 287 | 288 | private double predict_with_decision_tree(double[] features, TreeNode node) { 289 | if (node.childrenNodes == null) { 290 | return node.label; 291 | } 292 | 293 | double feature = features[node.split_attr]; 294 | 295 | if (_isCategory[node.split_attr]) { 296 | // 离散属性 297 | for (int i = 0; i < node.split_points.length; ++i) { 298 | if (node.split_points[i] == feature) { 299 | return predict_with_decision_tree(features, node.childrenNodes[i]); 300 | } 301 | } 302 | 303 | return node.label; // 不存在的属性取父节点样本的标签,减少叶子结点 304 | } else { 305 | // 连续属性 306 | if (feature < node.split_points[0]) { 307 | return predict_with_decision_tree(features, node.childrenNodes[0]); 308 | } else { 309 | return predict_with_decision_tree(features, node.childrenNodes[1]); 310 | } 311 | } 312 | 313 | } 314 | 315 | private TreeNode build_decision_tree(int[] set, int[] attr_index) { 316 | TreeNode node = new TreeNode(); 317 | node.set = set; 318 | node.attr_index = attr_index; 319 | node.label = 0; 320 | node.childrenNodes = null; 321 | 322 | // 都为同类返回直接返回 323 | double label = _labels[node.set[0]]; 324 | boolean flag = true; 325 | for (int i = 0; i < node.set.length; ++i) { 326 | if (_labels[node.set[i]] != label) { 327 | flag = false; 328 | break; 329 | } 330 | } 331 | if (flag) { 332 | node.label = label; 333 | return node; 334 | } 335 | 336 | // 没有可用属性标记为大多数(离散)或平均值(连续) 337 | if (_isClassification) { 338 | node.label = most_label(set); 339 | } else { 340 | node.label = mean_value(set); 341 | } 342 | if (node.attr_index == null || node.attr_index.length == 0) { 343 | return node; 344 | } 345 | 346 | // 寻找最优切割属性 347 | SplitData split_info = attribute_selection(node); 348 | node.split_attr = split_info.split_attr; 349 | // 没有可以分割的属性 350 | if (node.split_attr < 0) { 351 | return node; 352 | } 353 | 354 | node.split_points = split_info.split_points; 355 | 356 | // 去掉已使用的离散属性,连续属性不做删除 357 | int[] child_attr_index = null; 358 | if (_isCategory[node.split_attr]) { 359 | child_attr_index = new int[attr_index.length - 1]; 360 | int t = 0; 361 | for (int index : attr_index) { 362 | if (index != node.split_attr) { 363 | child_attr_index[t++] = index; 364 | } 365 | } 366 | } else { 367 | child_attr_index = node.attr_index.clone(); 368 | } 369 | 370 | // 递归建立子节点 371 | node.childrenNodes = new TreeNode[split_info.split_sets.length]; 372 | for (int i = 0; i < split_info.split_sets.length; ++i) { 373 | node.childrenNodes[i] = build_decision_tree(split_info.split_sets[i], child_attr_index); 374 | } 375 | 376 | return node; 377 | } 378 | 379 | // 给定样本中出现最多的标签 380 | private double most_label(int[] set) { 381 | HashMap counter = new HashMap(); 382 | for (int item : set) { 383 | double label = _labels[item]; 384 | if (counter.get(label) == null) { 385 | counter.put(label, 1); 386 | } else { 387 | int count = counter.get(label) + 1; 388 | counter.put(label, count); 389 | } 390 | } 391 | 392 | int max_time = 0; 393 | double label = 0; 394 | Iterator iterator = counter.keySet().iterator(); 395 | while (iterator.hasNext()) { 396 | double key = iterator.next(); 397 | int count = counter.get(key); 398 | if (count > max_time) { 399 | max_time = count; 400 | label = key; 401 | } 402 | } 403 | return label; 404 | } 405 | 406 | // 给定样本的标签平均值 407 | private double mean_value(int[] set) { 408 | double temp = 0; 409 | for (int index : set) { 410 | temp += _labels[index]; 411 | } 412 | return temp / set.length; 413 | } 414 | 415 | private SplitData attribute_selection(TreeNode node) { 416 | SplitData result = new SplitData(); 417 | result.split_attr = -1; 418 | 419 | // 前剪枝 420 | double reference_value = _isClassification ? 0.05 : -1; 421 | if (node.set.length < 7) return result; 422 | 423 | if (_isClassification) { 424 | for (int attribute : node.attr_index) { 425 | try { 426 | BundleData gain_ratio_info = gain_ratio_use_attribute(node.set, attribute); // 分割错误会抛出分割异常 427 | if (gain_ratio_info.floatValue > reference_value) { 428 | reference_value = gain_ratio_info.floatValue; 429 | result = gain_ratio_info.split_info; 430 | } 431 | } catch (SplitException ex) { // 捕获异常,直接丢弃 432 | } 433 | } 434 | } else { 435 | for (int attribute : node.attr_index) { 436 | try { 437 | BundleData mse_info = mse_use_attribute(node.set, attribute); 438 | if (reference_value < 0 || mse_info.floatValue < reference_value) { 439 | reference_value = mse_info.floatValue; 440 | result = mse_info.split_info; 441 | } 442 | } catch (SplitException ex) { 443 | } 444 | } 445 | } 446 | return result; 447 | } 448 | 449 | private SplitData split_with_attribute(int[] set, int attribute) throws SplitException { 450 | SplitData result = new SplitData(); 451 | result.split_attr = attribute; 452 | 453 | if (_isCategory[attribute]) { 454 | // 离散属性 455 | int amount_of_features = 0; 456 | HashMap counter = new HashMap(); 457 | HashMap index_recorder = new HashMap(); 458 | for (int item : set) { 459 | double feature = _features[item][attribute]; 460 | if (counter.get(feature) == null) { 461 | counter.put(feature, 1); 462 | index_recorder.put(feature, amount_of_features++); 463 | } else { 464 | int count = counter.get(feature) + 1; 465 | counter.put(feature, count); 466 | } 467 | } 468 | 469 | // 记录切割点 470 | result.split_points = new double[amount_of_features]; 471 | Iterator iterator = index_recorder.keySet().iterator(); 472 | 473 | while (iterator.hasNext()) { 474 | double key = iterator.next(); 475 | int value = index_recorder.get(key); 476 | result.split_points[value] = key; 477 | } 478 | 479 | result.split_sets = new int[amount_of_features][]; 480 | int[] t_index = new int[amount_of_features]; 481 | for (int i = 0; i < amount_of_features; ++i) t_index[i] = 0; 482 | 483 | for (int item : set) { 484 | int index = index_recorder.get(_features[item][attribute]); 485 | if (result.split_sets[index] == null) { 486 | result.split_sets[index] = new int[counter.get(_features[item][attribute])]; 487 | } 488 | result.split_sets[index][t_index[index]++] = item; 489 | } 490 | } else { 491 | // 连续属性 492 | double[] features = new double[set.length]; 493 | for (int i = 0; i < features.length; ++i) { 494 | features[i] = _features[set[i]][attribute]; 495 | } 496 | Arrays.sort(features); 497 | 498 | double reference_value = _isClassification ? 0 : -1; 499 | double best_split_point = 0; 500 | result.split_sets = new int[2][]; 501 | for (int i = 0; i < features.length - 1; ++i) { 502 | if (features[i] == features[i + 1]) continue; 503 | double split_point = (features[i] + features[i + 1]) / 2; 504 | int[] sub_set_a = new int[i + 1]; 505 | int[] sub_set_b = new int[set.length - i - 1]; 506 | 507 | int a_index = 0; 508 | int b_index = 0; 509 | for (int j = 0; j < set.length; ++j) { 510 | if (_features[set[j]][attribute] < split_point) { 511 | sub_set_a[a_index++] = set[j]; 512 | } else { 513 | sub_set_b[b_index++] = set[j]; 514 | } 515 | } 516 | 517 | if (_isClassification) { 518 | double temp = gain_ratio_use_numerical_attribute(set, attribute, sub_set_a, sub_set_b); 519 | if (temp > reference_value) { 520 | reference_value = temp; 521 | best_split_point = split_point; 522 | result.split_sets[0] = sub_set_a; 523 | result.split_sets[1] = sub_set_b; 524 | } 525 | } else { 526 | double temp = (sub_set_a.length * mse(sub_set_a) + sub_set_b.length * mse(sub_set_b)) / set.length; 527 | if (reference_value < 0 || temp < reference_value) { 528 | reference_value = temp; 529 | best_split_point = split_point; 530 | result.split_sets[0] = sub_set_a; 531 | result.split_sets[1] = sub_set_b; 532 | } 533 | } 534 | } 535 | // 没有分割点,抛出分割异常 536 | if (result.split_sets[0] == null && result.split_sets[1] == null) throw new SplitException(); 537 | result.split_points = new double[1]; 538 | result.split_points[0] = best_split_point; 539 | } 540 | return result; 541 | } 542 | 543 | // 计算给定样本集合的熵 544 | private double entropy(int[] set) { 545 | HashMap counter = new HashMap(); 546 | for (int item : set) { 547 | double label = _labels[item]; 548 | if (counter.get(label) == null) { 549 | counter.put(label, 1); 550 | } else { 551 | int count = counter.get(label) + 1; 552 | counter.put(label, count); 553 | } 554 | } 555 | 556 | double result = 0; 557 | Iterator iterator = counter.keySet().iterator(); 558 | while (iterator.hasNext()) { 559 | int count = counter.get(iterator.next()); 560 | double p = (double) count / set.length; 561 | result += -p * Math.log(p); 562 | } 563 | 564 | return result; 565 | } 566 | 567 | // 增益率 C4.5 568 | private BundleData gain_ratio_use_attribute(int[] set, int attribute) throws SplitException { 569 | BundleData result = new BundleData(); 570 | double entropy_before_split = entropy(set); 571 | 572 | double entropy_after_split = 0; 573 | double split_information = 0; 574 | result.split_info = split_with_attribute(set, attribute); 575 | for (int[] sub_set : result.split_info.split_sets) { 576 | entropy_after_split += (double) sub_set.length / set.length * entropy(sub_set); 577 | double p = (double) sub_set.length / set.length; 578 | split_information += -p * Math.log(p); 579 | } 580 | result.floatValue = (entropy_before_split - entropy_after_split) / split_information; 581 | return result; 582 | } 583 | 584 | private double gain_ratio_use_numerical_attribute(int[] set, int attribute, int[] part_a, int[] part_b) { 585 | double entropy_before_split = entropy(set); 586 | double entropy_after_split = (part_a.length * entropy(part_a) + part_b.length * entropy(part_b)) / set.length; 587 | 588 | double split_information = 0; 589 | double p = (double) part_a.length / set.length; 590 | split_information += -p * Math.log(p); 591 | p = (double) part_b.length / set.length; 592 | split_information += -p * Math.log(p); 593 | 594 | return (entropy_before_split - entropy_after_split) / split_information; 595 | } 596 | 597 | private double mse(int[] set) { 598 | double mean = mean_value(set); 599 | 600 | double temp = 0; 601 | for (int index : set) { 602 | double t = _labels[index] - mean; 603 | temp += t * t; 604 | } 605 | return temp / set.length; 606 | } 607 | 608 | private BundleData mse_use_attribute(int[] set, int attribute) throws SplitException { 609 | BundleData mse_info = new BundleData(); 610 | mse_info.floatValue = 0; 611 | mse_info.split_info = split_with_attribute(set, attribute); 612 | for (int[] sub_set : mse_info.split_info.split_sets) { 613 | mse_info.floatValue += (double) sub_set.length / set.length * mse(sub_set); 614 | } 615 | return mse_info; 616 | } 617 | } --------------------------------------------------------------------------------