├── .DS_Store ├── .classpath ├── .project ├── .settings └── org.eclipse.jdt.core.prefs ├── README.md ├── bin ├── .DS_Store ├── Main.class ├── boosting │ ├── GradientBoostingRegressor$1.class │ └── GradientBoostingRegressor.class ├── data │ └── LabeledSample.class ├── decision_tree │ ├── .DS_Store │ ├── DecisionRegressionTree$SplitResult.class │ ├── DecisionRegressionTree.class │ ├── DecisionRegressionTreeLeafSpliter$1.class │ ├── DecisionRegressionTreeLeafSpliter$QueueNode.class │ ├── DecisionRegressionTreeLeafSpliter$SplitResult.class │ ├── DecisionRegressionTreeLeafSpliter.class │ ├── DecisionTree.class │ └── Node.class ├── objective │ ├── Estimator.class │ ├── LossFunction.class │ ├── QuantileEstimator.class │ ├── QuantileLossFunction.class │ ├── SquaresEstimator.class │ ├── SquaresLossFunction.class │ ├── Utils$Item.class │ └── Utils.class ├── rf │ ├── RandomForestRegressor$1.class │ └── RandomForestRegressor.class └── util │ ├── BoostingListener.class │ ├── DumpTree.class │ ├── ParamException.class │ ├── ParamReader.class │ └── dump │ └── TreeInfo.class └── src ├── .DS_Store ├── Main.java ├── boosting └── GradientBoostingRegressor.java ├── data └── LabeledSample.java ├── decision_tree ├── .DS_Store ├── DecisionRegressionTree.java ├── DecisionRegressionTreeLeafSpliter.java ├── DecisionTree.java └── Node.java ├── objective ├── Estimator.java ├── LossFunction.java ├── QuantileEstimator.java ├── QuantileLossFunction.java ├── SquaresEstimator.java ├── SquaresLossFunction.java └── Utils.java ├── rf └── RandomForestRegressor.java └── util ├── BoostingListener.java ├── DumpTree.java ├── ParamException.java ├── ParamReader.java └── dump └── TreeInfo.java /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/.DS_Store -------------------------------------------------------------------------------- /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | GradientBoostingDecisionTree 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.7 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.7 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.7 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### GBRT&RF Java版 2 | 3 | - Java版的GBRT及RF 4 | 5 | - 模型成绩 6 | - 天池-菜鸟物流 第二赛季 10th 7 | - 天池-机场客流 第二赛季 8th 8 |   9 | - 使用说明 10 | - 参见 `Main.java` 11 |   12 | - 模型说明 13 | - DecisionTree 支持默认分裂方式及选择最优叶子节点分裂方式 14 | 15 | 16 | #### 欢迎交流改进 17 | -------------------------------------------------------------------------------- /bin/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/.DS_Store -------------------------------------------------------------------------------- /bin/Main.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/Main.class -------------------------------------------------------------------------------- /bin/boosting/GradientBoostingRegressor$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/boosting/GradientBoostingRegressor$1.class -------------------------------------------------------------------------------- /bin/boosting/GradientBoostingRegressor.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/boosting/GradientBoostingRegressor.class -------------------------------------------------------------------------------- /bin/data/LabeledSample.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/data/LabeledSample.class -------------------------------------------------------------------------------- /bin/decision_tree/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/.DS_Store -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTree$SplitResult.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTree$SplitResult.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTree.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTree.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTreeLeafSpliter$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$1.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTreeLeafSpliter$QueueNode.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$QueueNode.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTreeLeafSpliter$SplitResult.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$SplitResult.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionRegressionTreeLeafSpliter.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter.class -------------------------------------------------------------------------------- /bin/decision_tree/DecisionTree.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionTree.class -------------------------------------------------------------------------------- /bin/decision_tree/Node.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/Node.class -------------------------------------------------------------------------------- /bin/objective/Estimator.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Estimator.class -------------------------------------------------------------------------------- /bin/objective/LossFunction.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/LossFunction.class -------------------------------------------------------------------------------- /bin/objective/QuantileEstimator.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/QuantileEstimator.class -------------------------------------------------------------------------------- /bin/objective/QuantileLossFunction.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/QuantileLossFunction.class -------------------------------------------------------------------------------- /bin/objective/SquaresEstimator.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/SquaresEstimator.class -------------------------------------------------------------------------------- /bin/objective/SquaresLossFunction.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/SquaresLossFunction.class -------------------------------------------------------------------------------- /bin/objective/Utils$Item.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Utils$Item.class -------------------------------------------------------------------------------- /bin/objective/Utils.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Utils.class -------------------------------------------------------------------------------- /bin/rf/RandomForestRegressor$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/rf/RandomForestRegressor$1.class -------------------------------------------------------------------------------- /bin/rf/RandomForestRegressor.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/rf/RandomForestRegressor.class -------------------------------------------------------------------------------- /bin/util/BoostingListener.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/BoostingListener.class -------------------------------------------------------------------------------- /bin/util/DumpTree.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/DumpTree.class -------------------------------------------------------------------------------- /bin/util/ParamException.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/ParamException.class -------------------------------------------------------------------------------- /bin/util/ParamReader.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/ParamReader.class -------------------------------------------------------------------------------- /bin/util/dump/TreeInfo.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/dump/TreeInfo.class -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/src/.DS_Store -------------------------------------------------------------------------------- /src/Main.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.File; 3 | import java.io.FileReader; 4 | import java.util.Arrays; 5 | import java.util.HashMap; 6 | import java.util.Scanner; 7 | 8 | import boosting.GradientBoostingRegressor; 9 | import rf.RandomForestRegressor; 10 | 11 | public class Main { 12 | public static void main(String[] args) throws Exception { 13 | HashMap params = new HashMap<>(); 14 | params.put("learning_rate", 0.05); 15 | params.put("n_estimator", 100); 16 | params.put("max_depth", 3); 17 | params.put("objective", "ls"); 18 | params.put("feature_rate", 0.1); 19 | params.put("sample_rate", 0.8); 20 | params.put("spliter", "leaf"); 21 | params.put("num_leafs", 8); 22 | // GradientBoostingRegressor gbdt = new 23 | // GradientBoostingRegressor(params); 24 | RandomForestRegressor gbdt = new RandomForestRegressor(params); 25 | 26 | double trainX[][] = new double[26964][155]; 27 | double trainY[] = new double[26964]; 28 | double trainWeight[] = new double[26964]; 29 | Arrays.fill(trainWeight, 1.0); 30 | Scanner in = new Scanner(new BufferedReader(new FileReader(new File("/Users/mac/Desktop/train_frame1.csv")))); 31 | in.nextLine(); 32 | int numLines = 0; 33 | while (in.hasNextLine()) { 34 | String line = in.nextLine(); 35 | String[] str = line.split(",", -1); 36 | int featureIndex = 0; 37 | for (int i = 3; i < str.length; i++) { 38 | if (i == 138) 39 | continue; 40 | if (str[i].equals("") || str[i] == null) 41 | trainX[numLines][featureIndex++] = -1000; 42 | else 43 | trainX[numLines][featureIndex++] = Double.valueOf(str[i]); 44 | } 45 | trainY[numLines] = Double.valueOf(str[2]); 46 | numLines++; 47 | assert (featureIndex == 155); 48 | } 49 | System.out.println(numLines); 50 | in.close(); 51 | 52 | double testX[][] = new double[4464][155]; 53 | double testY[] = new double[4464]; 54 | in = new Scanner(new BufferedReader(new FileReader(new File("/Users/mac/Desktop/test_frame1.csv")))); 55 | in.nextLine(); 56 | numLines = 0; 57 | while (in.hasNextLine()) { 58 | String line = in.nextLine(); 59 | String[] str = line.split(",", -1); 60 | int featureIndex = 0; 61 | for (int i = 3; i < str.length; i++) { 62 | if (i == 138) 63 | continue; 64 | if (str[i].equals("") || str[i] == null) 65 | testX[numLines][featureIndex++] = -1000; 66 | else 67 | testX[numLines][featureIndex++] = Double.valueOf(str[i]); 68 | } 69 | assert (featureIndex == 155); 70 | testY[numLines] = Double.valueOf(str[2]); 71 | numLines++; 72 | } 73 | System.out.println(numLines); 74 | in.close(); 75 | gbdt.fit(trainX, trainY, trainWeight); 76 | double[][] ans = gbdt.predict(testX); 77 | double[] ret = new double[ans[0].length]; 78 | for (int i = 0; i < testX.length; i++) { 79 | ret[i] = 0; 80 | for (int j = 0; j < ans.length; j++) { 81 | ret[i] += ans[j][i]; 82 | } 83 | ret[i] /= ans.length; 84 | //System.out.println(ret[i] + " " + testY[i]); 85 | } 86 | 87 | // -------------------------------------------- 88 | // gbdt.fit(trainX, trainY, trainWeight,null); 89 | // double[] ret = gbdt.predict(testX); 90 | 91 | double all = 0; 92 | for (int i = 0; i < testX.length; i++) { 93 | double error = testY[i] - ret[i]; 94 | all += error * error; 95 | } 96 | System.out.println(all); 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/boosting/GradientBoostingRegressor.java: -------------------------------------------------------------------------------- 1 | package boosting; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.Comparator; 6 | import java.util.HashMap; 7 | import java.util.HashSet; 8 | import java.util.Random; 9 | import java.util.TreeSet; 10 | 11 | import data.LabeledSample; 12 | import decision_tree.DecisionRegressionTree; 13 | import decision_tree.DecisionRegressionTreeLeafSpliter; 14 | import decision_tree.DecisionTree; 15 | import decision_tree.Node; 16 | import objective.Estimator; 17 | import objective.LossFunction; 18 | import objective.QuantileEstimator; 19 | import objective.QuantileLossFunction; 20 | import objective.SquaresEstimator; 21 | import objective.SquaresLossFunction; 22 | import util.BoostingListener; 23 | import util.ParamException; 24 | import util.ParamReader; 25 | import util.dump.TreeInfo; 26 | 27 | public class GradientBoostingRegressor { 28 | 29 | // ------------------------------------------------------------------------------------ 30 | // Boosting的参数 31 | private double learning_rate; 32 | public int n_estimator; 33 | private double alpha; // quantile loss才需要 34 | private double sample_rate; 35 | private double feature_rate; 36 | private Estimator init_; // 初始值估计 37 | private LossFunction loss; // lossFunction 38 | private int random_state; 39 | private DecisionTree[] trees; 40 | private int max_depth; 41 | private int num_leafs; // 基于叶子节点分裂时的叶子数 42 | private double baseValue; // 初始值 43 | private double[] residual; // 残差 44 | private String spliter; 45 | private double min_leaf_sample; 46 | 47 | public GradientBoostingRegressor(HashMap params) throws ParamException { 48 | 49 | // ----------------------------------------------------------------------- 50 | // 必需参数 51 | 52 | assert (params.containsKey("learning_rate")); 53 | this.learning_rate = ParamReader.readDouble("learning_rate", params); 54 | assert (params.containsKey("spliter")); 55 | this.spliter = ParamReader.readString("spliter", params); 56 | if (this.spliter.equals("leaf")) { 57 | assert (params.containsKey("num_leafs")); 58 | this.num_leafs = ParamReader.readInt("num_leafs", params); 59 | } else { 60 | assert (params.containsKey("max_depth")); 61 | this.max_depth = ParamReader.readInt("max_depth", params); 62 | } 63 | assert (params.containsKey("n_estimator")); 64 | this.n_estimator = ParamReader.readInt("n_estimator", params); 65 | 66 | // ------------------------------------------------------------------------- 67 | // 可选参数 68 | 69 | String objective = ParamReader.readString("objective", params); 70 | if (objective.equals("quantile")) { 71 | assert (params.containsKey("alpha")); 72 | this.alpha = ParamReader.readDouble("alpha", params); 73 | this.loss = new QuantileLossFunction(alpha); 74 | this.init_ = new QuantileEstimator(alpha); 75 | } else if (objective.equals("lad")) { 76 | } else { 77 | this.loss = new SquaresLossFunction(); 78 | this.init_ = new SquaresEstimator(); 79 | } 80 | 81 | if (params.containsKey("random_state")) 82 | this.random_state = ParamReader.readInt("random_state", params); 83 | else 84 | this.random_state = 0; 85 | if (params.containsKey("sample_rate")) 86 | this.sample_rate = ParamReader.readDouble("sample_rate", params); 87 | else 88 | this.sample_rate = 1.0; 89 | if (params.containsKey("feature_rate")) 90 | this.feature_rate = ParamReader.readDouble("feature_rate", params); 91 | else 92 | this.feature_rate = 1.0; 93 | 94 | if (params.containsKey("min_leaf_sample")) 95 | this.min_leaf_sample = ParamReader.readDouble("min_leaf_sample", params); 96 | else 97 | this.min_leaf_sample = 1; 98 | } 99 | 100 | private LabeledSample[] trainData; 101 | private LabeledSample[][] preSortedSampleArrays; 102 | private LabeledSample[][] copyOfPreSortedSampleArrays; 103 | private BoostingListener listener; 104 | 105 | private double[] _fit_stage(int i, double[][] X, double[] Y, double[] y_pred, double[] sample_weight) { 106 | this.loss.negative_gradient(Y, y_pred, this.residual); 107 | int featureNum = X[0].length; 108 | int sampleNum = X.length; 109 | for (int b = 0; b < sampleNum; b++) 110 | trainData[b].y = residual[b]; 111 | for (int a = 0; a < featureNum; a++) { 112 | for (int b = 0; b < sampleNum; b++) { 113 | preSortedSampleArrays[a][b] = copyOfPreSortedSampleArrays[a][b]; 114 | } 115 | } 116 | if (this.spliter.equals("leaf")) 117 | trees[i] = new DecisionRegressionTreeLeafSpliter(this.num_leafs, this.random_state, sample_rate, 118 | feature_rate, min_leaf_sample, preSortedSampleArrays); 119 | else 120 | trees[i] = new DecisionRegressionTree(i, this.max_depth, this.random_state, sample_rate, feature_rate, 121 | min_leaf_sample, preSortedSampleArrays); 122 | trees[i].fit(X, residual, sample_weight); 123 | 124 | this.loss.update_terminal_region(trees[i], X, Y, y_pred, sample_weight); 125 | for (int j = 0; j < X.length; j++) { 126 | Node leaf = trees[i].apply(X[j]); 127 | y_pred[j] += learning_rate * leaf.treeVal; 128 | // System.out.print(y_pred[j] + " "); 129 | } 130 | // System.out.println(); 131 | return y_pred; 132 | } 133 | 134 | public void registerListener(BoostingListener listener) { 135 | this.listener = listener; 136 | } 137 | 138 | private int _fit_stages(double[][] X, double[] Y, double[] y_pred, double[] sample_weight, TreeInfo[][] infoList) { 139 | int featureNum = X[0].length; 140 | int sampleNum = X.length; 141 | trainData = new LabeledSample[sampleNum]; 142 | for (int i = 0; i < sampleNum; i++) { 143 | trainData[i] = new LabeledSample(); 144 | trainData[i].x = X[i]; 145 | trainData[i].y = Y[i]; 146 | trainData[i].weight = sample_weight[i]; 147 | } 148 | 149 | preSortedSampleArrays = new LabeledSample[featureNum][sampleNum]; 150 | copyOfPreSortedSampleArrays = new LabeledSample[featureNum][sampleNum]; 151 | for (int a = 0; a < featureNum; a++) { 152 | for (int b = 0; b < sampleNum; b++) { 153 | copyOfPreSortedSampleArrays[a][b] = trainData[b]; 154 | } 155 | 156 | final int compareFeature = a; 157 | Arrays.sort(copyOfPreSortedSampleArrays[a], 0, sampleNum, new Comparator() { 158 | 159 | @Override 160 | public int compare(LabeledSample o1, LabeledSample o2) { 161 | return new Double(o1.x[compareFeature]).compareTo(o2.x[compareFeature]); 162 | } 163 | }); 164 | } 165 | this.residual = new double[sampleNum]; 166 | this.trees = new DecisionTree[this.n_estimator]; 167 | 168 | for (int i = 0; i < this.n_estimator; i++) { 169 | System.out.println("第" + i + "棵树"); 170 | if (listener != null) 171 | listener.done(i); 172 | y_pred = _fit_stage(i, X, Y, y_pred, sample_weight); 173 | } 174 | return this.n_estimator; 175 | } 176 | 177 | public void fit(double[][] X, double[] Y, double[] sample_weight, TreeInfo[][] infoList) { 178 | this.init_.fit(X, Y, sample_weight); 179 | double[] y_pred = this.init_.predict(X); 180 | this.baseValue = y_pred[0]; 181 | _fit_stages(X, Y, y_pred, sample_weight, infoList); 182 | } 183 | 184 | public double[] predict(double[][] X) { 185 | double[] ans = new double[X.length]; 186 | for (int i = 0; i < X.length; i++) { 187 | ans[i] = baseValue; 188 | for (int j = 0; j < n_estimator; j++) { 189 | Node leaf = trees[j].apply(X[i]); 190 | ans[i] += learning_rate * leaf.treeVal; 191 | } 192 | } 193 | return ans; 194 | } 195 | 196 | public double[] predict(double[][] X, int n) { 197 | double[] ans = new double[X.length]; 198 | for (int i = 0; i < X.length; i++) { 199 | ans[i] = baseValue; 200 | for (int j = 0; j < n; j++) { 201 | Node leaf = trees[j].apply(X[i]); 202 | ans[i] += learning_rate * leaf.treeVal; 203 | } 204 | } 205 | return ans; 206 | } 207 | } -------------------------------------------------------------------------------- /src/data/LabeledSample.java: -------------------------------------------------------------------------------- 1 | package data; 2 | 3 | public class LabeledSample { 4 | public double[] x; 5 | public double y; 6 | public double weight; 7 | public boolean isSampled; // 样本采样是否采到 8 | public boolean isSplitToLeft; // 是否划分到左半边的树 9 | } 10 | -------------------------------------------------------------------------------- /src/decision_tree/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/src/decision_tree/.DS_Store -------------------------------------------------------------------------------- /src/decision_tree/DecisionRegressionTree.java: -------------------------------------------------------------------------------- 1 | package decision_tree; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Random; 5 | import java.util.Stack; 6 | 7 | import data.LabeledSample; 8 | 9 | public class DecisionRegressionTree extends DecisionTree { 10 | 11 | // 树的参数 12 | private int max_depth; 13 | private int random_state; 14 | private Random random; 15 | private Node root; 16 | private double sample_rate; 17 | private double feature_rate; 18 | private int estimator_num; 19 | private LabeledSample[][] sampleSortedByFeatureArrays; 20 | private double min_leaf_sample; 21 | 22 | public DecisionRegressionTree(int estimator_num, int max_depth, int random_state, double sample_rate, 23 | double feature_rate, double min_leaf_sample, LabeledSample[][] preSortedSampleArrays) { 24 | this.estimator_num = estimator_num; 25 | this.max_depth = max_depth; 26 | this.random_state = random_state; 27 | System.out.println(this.random_state); 28 | this.random = new Random(this.random_state); 29 | this.sample_rate = sample_rate; 30 | this.feature_rate = feature_rate; 31 | this.min_leaf_sample = min_leaf_sample; 32 | this.sampleSortedByFeatureArrays = preSortedSampleArrays; 33 | } 34 | 35 | // public DecisionRegressionTree(int estimator_num, TreeInfo[] infoList) { 36 | // this.estimator_num = estimator_num; 37 | // Node[] nodeList = new Node[infoList.length]; 38 | // for (int i = 0; i < infoList.length; i++) { 39 | // if (infoList[i] == null) 40 | // continue; 41 | // Node node = new Node(); 42 | // nodeList[(int) infoList[i].root_id] = node; 43 | // if ((int) infoList[i].left_son != -1) 44 | // node.leftNode = nodeList[(int) infoList[i].left_son]; 45 | // if ((int) infoList[i].right_son != -1) 46 | // node.rightNode = nodeList[(int) infoList[i].right_son]; 47 | // node.split_feature = (int) infoList[i].split_feature; 48 | // node.split_val = infoList[i].split_feature_value; 49 | // node.treeVal = infoList[i].node_value; 50 | // if (infoList[i].is_root == 1L) 51 | // this.root = node; 52 | // } 53 | // } 54 | 55 | public DecisionRegressionTree() { 56 | // TODO Auto-generated constructor stub 57 | } 58 | 59 | class SplitResult { 60 | Stack leftSample, rightSample; 61 | double split_error, left_error, right_error; 62 | int best_feature; 63 | double best_split_val; 64 | 65 | // error = var_left + var_right 66 | // var_left = sigma((y_i-y_bar)^2*w) 67 | // = sigma(y_i*y_i*w_i)- y_bar*sigma(2*w_i*y_i)+ y_bar*y_bar*sigma(w_i) 68 | // y_bar = sigma(y_i*w_i)/sigma(w_i) 69 | 70 | double left_yyw_sum, left_y_w_sum, left_w_sum, left_bar; 71 | double right_yyw_sum, right_y_w_sum, right_w_sum, right_bar; 72 | 73 | public SplitResult(int l, int r, LabeledSample[] samples) { 74 | leftSample = new Stack(); 75 | rightSample = new Stack(); 76 | init(l, r, samples); 77 | } 78 | 79 | void init(int l, int r, LabeledSample[] samples) { 80 | this.split_error = Double.MAX_VALUE; 81 | leftSample.clear(); 82 | rightSample.clear(); 83 | 84 | left_yyw_sum = left_y_w_sum = left_w_sum = left_bar = 0; 85 | right_yyw_sum = right_y_w_sum = right_w_sum = right_bar = 0; 86 | for (int i = r; i >= l; i--) { 87 | if (!samples[i].isSampled) 88 | continue; 89 | rightSample.push(samples[i]); 90 | right_yyw_sum += samples[i].y * samples[i].y * samples[i].weight; 91 | right_y_w_sum += samples[i].y * samples[i].weight; 92 | right_w_sum += samples[i].weight; 93 | } 94 | right_bar = right_y_w_sum / right_w_sum; 95 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum; 96 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum; 97 | split_error = left_error + right_error; 98 | } 99 | 100 | void moveSampleToLeft(int split_feature, double split_val) { 101 | while (!rightSample.empty()) { 102 | LabeledSample sample = rightSample.peek(); 103 | if (sample.x[split_feature] < split_val) { 104 | rightSample.pop(); 105 | right_yyw_sum -= sample.y * sample.y * sample.weight; 106 | right_y_w_sum -= sample.y * sample.weight; 107 | right_w_sum -= sample.weight; 108 | 109 | leftSample.push(sample); 110 | left_yyw_sum += sample.y * sample.y * sample.weight; 111 | left_y_w_sum += sample.y * sample.weight; 112 | left_w_sum += sample.weight; 113 | } else 114 | break; 115 | } 116 | left_bar = left_y_w_sum / left_w_sum; 117 | right_bar = right_y_w_sum / right_w_sum; 118 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum; 119 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum; 120 | split_error = left_error + right_error; 121 | } 122 | 123 | Stack getLeftTrainSamples() { 124 | return leftSample; 125 | } 126 | 127 | Stack getRightTrainSamples() { 128 | return rightSample; 129 | } 130 | 131 | void clear() { 132 | this.leftSample.clear(); 133 | this.rightSample.clear(); 134 | this.leftSample = this.rightSample = null; 135 | } 136 | 137 | } 138 | 139 | private static double[] splitValueArray = new double[400005]; 140 | private static LabeledSample[] samples = new LabeledSample[400005]; 141 | 142 | private SplitResult getBestSplit(int l, int r) { 143 | 144 | int featureNum = this.sampleSortedByFeatureArrays[0][0].x.length; 145 | double min_split_error = Double.MAX_VALUE; 146 | int best_feature = -1; 147 | double best_split_val = Double.MAX_VALUE; 148 | 149 | for (int i = l; i <= r; i++) { 150 | if (random.nextDouble() > sample_rate) 151 | this.sampleSortedByFeatureArrays[0][i].isSampled = false; 152 | else 153 | this.sampleSortedByFeatureArrays[0][i].isSampled = true; 154 | } 155 | SplitResult currentResult = new SplitResult(l, r, sampleSortedByFeatureArrays[0]); 156 | for (int i = 0; i < featureNum; i++) { 157 | if (random.nextDouble() > feature_rate) 158 | continue; 159 | int cnt = 0; 160 | for (int j = l; j < r; j++) { 161 | splitValueArray[cnt++] = (sampleSortedByFeatureArrays[i][j].x[i] 162 | + sampleSortedByFeatureArrays[i][j + 1].x[i]) / 2.0; 163 | } 164 | if (cnt == 0) 165 | continue; 166 | currentResult.init(l, r, sampleSortedByFeatureArrays[i]); 167 | for (int j = 0; j < cnt; j++) { 168 | double split_val = splitValueArray[j]; 169 | currentResult.moveSampleToLeft(i, split_val); 170 | if (currentResult.split_error < min_split_error 171 | && currentResult.getLeftTrainSamples().size() >= min_leaf_sample 172 | && currentResult.getRightTrainSamples().size() >= min_leaf_sample) { 173 | min_split_error = currentResult.split_error; 174 | best_feature = i; 175 | best_split_val = split_val; 176 | } 177 | } 178 | } 179 | if (best_feature == -1) 180 | return null; 181 | for (int i = l; i <= r; i++) 182 | sampleSortedByFeatureArrays[best_feature][i].isSampled = true; 183 | SplitResult result = new SplitResult(l, r, sampleSortedByFeatureArrays[best_feature]); 184 | result.moveSampleToLeft(best_feature, best_split_val); 185 | result.best_feature = best_feature; 186 | result.best_split_val = best_split_val; 187 | return result; 188 | 189 | } 190 | 191 | private Node createTree(int l, int r, int depth) { 192 | if (depth > max_depth) 193 | return null; 194 | Node root = new Node(); 195 | if (l >= r) 196 | return root; 197 | SplitResult result = getBestSplit(l, r); 198 | if (result == null) 199 | return root; 200 | root.split_feature = result.best_feature; 201 | root.split_val = result.best_split_val; 202 | Stack leftSamples = result.getLeftTrainSamples(); 203 | Stack rightSamples = result.getRightTrainSamples(); 204 | for (LabeledSample sample : leftSamples) 205 | sample.isSplitToLeft = true; 206 | for (LabeledSample sample : rightSamples) 207 | sample.isSplitToLeft = false; 208 | int leftSize = leftSamples.size(); 209 | int n_features = sampleSortedByFeatureArrays[0][0].x.length; 210 | for (int i = 0; i < n_features; i++) { 211 | int tot = l; 212 | for (int j = l; j <= r; j++) { 213 | if (sampleSortedByFeatureArrays[i][j].isSplitToLeft) { 214 | samples[tot++] = sampleSortedByFeatureArrays[i][j]; 215 | } 216 | } 217 | for (int j = l; j <= r; j++) { 218 | if (!sampleSortedByFeatureArrays[i][j].isSplitToLeft) 219 | samples[tot++] = sampleSortedByFeatureArrays[i][j]; 220 | } 221 | for (int j = l; j <= r; j++) 222 | sampleSortedByFeatureArrays[i][j] = samples[j]; 223 | 224 | } 225 | result.clear(); 226 | root.leftNode = createTree(l, l + leftSize - 1, depth + 1); 227 | root.rightNode = createTree(l + leftSize, r, depth + 1); 228 | return root; 229 | } 230 | 231 | public void fit(double[][] X, double[] Y, double[] sample_weight) { 232 | random_state = random.nextInt(); 233 | this.root = createTree(0, X.length - 1, 0); 234 | } 235 | 236 | private Node dfs(Node root, double[] x, int depth) { 237 | if (root.leftNode == null) 238 | return root; 239 | if (x[root.split_feature] < root.split_val) 240 | return dfs(root.leftNode, x, depth + 1); 241 | return dfs(root.rightNode, x, depth + 1); 242 | } 243 | 244 | public Node apply(double[] x) { 245 | return dfs(root, x, 0); 246 | } 247 | 248 | // private int nodeCount = 0; 249 | 250 | // private int dfsNode(Node node, ArrayList infoList, int is_root) 251 | // { 252 | // if (node == null) 253 | // return -1; 254 | // int left_son = dfsNode(node.leftNode, infoList, 0); 255 | // int right_son = dfsNode(node.rightNode, infoList, 0); 256 | // int id = nodeCount++; 257 | // TreeInfo info = new TreeInfo(id, left_son, right_son, node.split_feature, 258 | // node.split_val, estimator_num, 259 | // is_root, node.treeVal); 260 | // infoList.add(info); 261 | // return id; 262 | // } 263 | // 264 | // public ArrayList getTreeInfo() { 265 | // nodeCount = 0; 266 | // ArrayList infoList = new ArrayList(); 267 | // dfsNode(root, infoList, 1); 268 | // return infoList; 269 | // } 270 | } -------------------------------------------------------------------------------- /src/decision_tree/DecisionRegressionTreeLeafSpliter.java: -------------------------------------------------------------------------------- 1 | package decision_tree; 2 | 3 | import java.util.Comparator; 4 | import java.util.Queue; 5 | import java.util.Random; 6 | import java.util.Stack; 7 | import java.util.concurrent.PriorityBlockingQueue; 8 | 9 | import data.LabeledSample; 10 | 11 | public class DecisionRegressionTreeLeafSpliter extends DecisionTree { 12 | 13 | // 树的参数 14 | private int max_depth; 15 | private int random_state; 16 | private Random random; 17 | private Node root; 18 | private double sample_rate; 19 | private double feature_rate; 20 | private int num_leafs; 21 | private LabeledSample[][] sampleSortedByFeatureArrays; 22 | private double min_leaf_sample; 23 | 24 | public DecisionRegressionTreeLeafSpliter(int num_leafs, int random_state, double sample_rate, double feature_rate, 25 | double min_leaf_sample, LabeledSample[][] preSortedSampleArrays) { 26 | this.num_leafs = num_leafs; 27 | this.random_state = random_state; 28 | this.random = new Random(this.random_state); 29 | this.sample_rate = sample_rate; 30 | this.feature_rate = feature_rate; 31 | this.min_leaf_sample = min_leaf_sample; 32 | this.sampleSortedByFeatureArrays = preSortedSampleArrays; 33 | } 34 | 35 | // public DecisionRegressionTree(int estimator_num, TreeInfo[] infoList) { 36 | // this.estimator_num = estimator_num; 37 | // Node[] nodeList = new Node[infoList.length]; 38 | // for (int i = 0; i < infoList.length; i++) { 39 | // if (infoList[i] == null) 40 | // continue; 41 | // Node node = new Node(); 42 | // nodeList[(int) infoList[i].root_id] = node; 43 | // if ((int) infoList[i].left_son != -1) 44 | // node.leftNode = nodeList[(int) infoList[i].left_son]; 45 | // if ((int) infoList[i].right_son != -1) 46 | // node.rightNode = nodeList[(int) infoList[i].right_son]; 47 | // node.split_feature = (int) infoList[i].split_feature; 48 | // node.split_val = infoList[i].split_feature_value; 49 | // node.treeVal = infoList[i].node_value; 50 | // if (infoList[i].is_root == 1L) 51 | // this.root = node; 52 | // } 53 | // } 54 | 55 | class SplitResult { 56 | Stack leftSample, rightSample; 57 | double split_error, left_error, right_error; 58 | int best_feature; 59 | double best_split_val; 60 | 61 | // error = var_left + var_right 62 | // var_left = sigma((y_i-y_bar)^2*w) 63 | // = sigma(y_i*y_i*w_i)- y_bar*sigma(2*w_i*y_i)+ y_bar*y_bar*sigma(w_i) 64 | // y_bar = sigma(y_i*w_i)/sigma(w_i) 65 | 66 | double left_yyw_sum, left_y_w_sum, left_w_sum, left_bar; 67 | double right_yyw_sum, right_y_w_sum, right_w_sum, right_bar; 68 | 69 | public SplitResult(int l, int r, LabeledSample[] samples) { 70 | leftSample = new Stack(); 71 | rightSample = new Stack(); 72 | init(l, r, samples); 73 | } 74 | 75 | void init(int l, int r, LabeledSample[] samples) { 76 | this.split_error = Double.MAX_VALUE; 77 | leftSample.clear(); 78 | rightSample.clear(); 79 | 80 | left_yyw_sum = left_y_w_sum = left_w_sum = left_bar = 0; 81 | right_yyw_sum = right_y_w_sum = right_w_sum = right_bar = 0; 82 | for (int i = r; i >= l; i--) { 83 | if (!samples[i].isSampled) 84 | continue; 85 | rightSample.push(samples[i]); 86 | right_yyw_sum += samples[i].y * samples[i].y * samples[i].weight; 87 | right_y_w_sum += samples[i].y * samples[i].weight; 88 | right_w_sum += samples[i].weight; 89 | } 90 | right_bar = right_y_w_sum / right_w_sum; 91 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum; 92 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum; 93 | split_error = left_error + right_error; 94 | } 95 | 96 | void moveSampleToLeft(int split_feature, double split_val) { 97 | while (!rightSample.empty()) { 98 | LabeledSample sample = rightSample.peek(); 99 | if (sample.x[split_feature] < split_val) { 100 | rightSample.pop(); 101 | right_yyw_sum -= sample.y * sample.y * sample.weight; 102 | right_y_w_sum -= sample.y * sample.weight; 103 | right_w_sum -= sample.weight; 104 | 105 | leftSample.push(sample); 106 | left_yyw_sum += sample.y * sample.y * sample.weight; 107 | left_y_w_sum += sample.y * sample.weight; 108 | left_w_sum += sample.weight; 109 | } else 110 | break; 111 | } 112 | left_bar = left_y_w_sum / left_w_sum; 113 | right_bar = right_y_w_sum / right_w_sum; 114 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum; 115 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum; 116 | split_error = left_error + right_error; 117 | } 118 | 119 | Stack getLeftTrainSamples() { 120 | return leftSample; 121 | } 122 | 123 | Stack getRightTrainSamples() { 124 | return rightSample; 125 | } 126 | 127 | void clear() { 128 | this.leftSample.clear(); 129 | this.rightSample.clear(); 130 | this.leftSample = this.rightSample = null; 131 | } 132 | 133 | } 134 | 135 | private static double[] splitValueArray = new double[400005]; 136 | private static LabeledSample[] samples = new LabeledSample[400005]; 137 | 138 | private SplitResult getBestSplit(int l, int r) { 139 | 140 | int featureNum = this.sampleSortedByFeatureArrays[0][0].x.length; 141 | double min_split_error = Double.MAX_VALUE; 142 | int best_feature = -1; 143 | double best_split_val = Double.MAX_VALUE; 144 | 145 | for (int i = l; i <= r; i++) { 146 | if (random.nextDouble() > sample_rate) 147 | this.sampleSortedByFeatureArrays[0][i].isSampled = false; 148 | else 149 | this.sampleSortedByFeatureArrays[0][i].isSampled = true; 150 | } 151 | SplitResult currentResult = new SplitResult(l, r, sampleSortedByFeatureArrays[0]); 152 | for (int i = 0; i < featureNum; i++) { 153 | if (random.nextDouble() > feature_rate) 154 | continue; 155 | int cnt = 0; 156 | for (int j = l; j < r; j++) { 157 | splitValueArray[cnt++] = (sampleSortedByFeatureArrays[i][j].x[i] 158 | + sampleSortedByFeatureArrays[i][j + 1].x[i]) / 2.0; 159 | } 160 | if (cnt == 0) 161 | continue; 162 | currentResult.init(l, r, sampleSortedByFeatureArrays[i]); 163 | for (int j = 0; j < cnt; j++) { 164 | double split_val = splitValueArray[j]; 165 | currentResult.moveSampleToLeft(i, split_val); 166 | if (currentResult.split_error < min_split_error 167 | && currentResult.getLeftTrainSamples().size() >= min_leaf_sample 168 | && currentResult.getRightTrainSamples().size() >= min_leaf_sample) { 169 | min_split_error = currentResult.split_error; 170 | best_feature = i; 171 | best_split_val = split_val; 172 | } 173 | } 174 | } 175 | if (best_feature == -1) 176 | return null; 177 | for (int i = l; i <= r; i++) 178 | sampleSortedByFeatureArrays[best_feature][i].isSampled = true; 179 | SplitResult result = new SplitResult(l, r, sampleSortedByFeatureArrays[best_feature]); 180 | result.moveSampleToLeft(best_feature, best_split_val); 181 | result.best_feature = best_feature; 182 | result.best_split_val = best_split_val; 183 | return result; 184 | 185 | } 186 | 187 | private QueueNode createTree(int l, int r) { 188 | Node root = new Node(); 189 | QueueNode qNode = new QueueNode(); 190 | qNode.node = root; 191 | qNode.l = l; 192 | qNode.r = r; 193 | qNode.gain = 0; 194 | if (l >= r) 195 | return qNode; 196 | SplitResult result = getBestSplit(l, r); 197 | if (result == null) { 198 | return qNode; 199 | } 200 | qNode.splitResult = result; 201 | root.split_feature = result.best_feature; 202 | root.split_val = result.best_split_val; 203 | qNode.gain = calCurrentLoss(l, r) - result.split_error; 204 | return qNode; 205 | } 206 | 207 | private int doSplit(QueueNode node, int l, int r) { 208 | SplitResult result = node.splitResult; 209 | if (result == null) 210 | return 0; 211 | Stack leftSamples = result.getLeftTrainSamples(); 212 | Stack rightSamples = result.getRightTrainSamples(); 213 | 214 | for (LabeledSample sample : leftSamples) 215 | sample.isSplitToLeft = true; 216 | for (LabeledSample sample : rightSamples) 217 | sample.isSplitToLeft = false; 218 | int leftSize = leftSamples.size(); 219 | int n_features = sampleSortedByFeatureArrays[0][0].x.length; 220 | for (int i = 0; i < n_features; i++) { 221 | int tot = l; 222 | for (int j = l; j <= r; j++) { 223 | if (sampleSortedByFeatureArrays[i][j].isSplitToLeft) { 224 | samples[tot++] = sampleSortedByFeatureArrays[i][j]; 225 | } 226 | } 227 | for (int j = l; j <= r; j++) { 228 | if (!sampleSortedByFeatureArrays[i][j].isSplitToLeft) 229 | samples[tot++] = sampleSortedByFeatureArrays[i][j]; 230 | } 231 | for (int j = l; j <= r; j++) 232 | sampleSortedByFeatureArrays[i][j] = samples[j]; 233 | } 234 | result.clear(); 235 | return leftSize; 236 | } 237 | 238 | private double calCurrentLoss(int l, int r) { 239 | double sum = 0; 240 | double sumWeight = 0; 241 | for (int i = l; i <= r; i++) { 242 | sum += sampleSortedByFeatureArrays[0][i].y; 243 | sumWeight += sampleSortedByFeatureArrays[0][i].weight; 244 | } 245 | double avg = sum / sumWeight; 246 | double error = 0; 247 | for (int i = l; i <= r; i++) { 248 | error += (sampleSortedByFeatureArrays[0][i].y - avg) * (sampleSortedByFeatureArrays[0][i].y - avg) 249 | * sampleSortedByFeatureArrays[0][i].weight; 250 | } 251 | return error; 252 | } 253 | 254 | public static class QueueNode { 255 | public SplitResult splitResult; 256 | Node node; 257 | double gain; 258 | int l, r; 259 | } 260 | 261 | private Comparator cmp = new Comparator() { 262 | public int compare(QueueNode e1, QueueNode e2) { 263 | return Double.valueOf(e2.gain).compareTo(e1.gain); 264 | } 265 | }; 266 | 267 | public void fit(double[][] X, double[] Y, double[] sample_weight) { 268 | random_state = random.nextInt(); 269 | Queue queue = new PriorityBlockingQueue<>(this.num_leafs * 2, cmp); 270 | QueueNode qNode = createTree(0, X.length - 1); 271 | this.root = qNode.node; 272 | queue.add(qNode); 273 | int currentLeaf = 1; 274 | while (!queue.isEmpty()) { 275 | qNode = queue.poll(); 276 | System.out.println(qNode.gain); 277 | if (qNode.l == qNode.r) 278 | continue; 279 | int leftSize = doSplit(qNode, qNode.l, qNode.r); 280 | QueueNode qNodeLeft = createTree(qNode.l, qNode.l + leftSize - 1); 281 | QueueNode qNodeRight = createTree(qNode.l + leftSize, qNode.r); 282 | queue.add(qNodeLeft); 283 | queue.add(qNodeRight); 284 | qNode.node.leftNode = qNodeLeft.node; 285 | qNode.node.rightNode = qNodeRight.node; 286 | currentLeaf += 1; 287 | if (currentLeaf == this.num_leafs) 288 | break; 289 | } 290 | } 291 | 292 | private Node dfs(Node root, double[] x, int depth) { 293 | if (root.leftNode == null) 294 | return root; 295 | if (x[root.split_feature] < root.split_val) 296 | return dfs(root.leftNode, x, depth + 1); 297 | return dfs(root.rightNode, x, depth + 1); 298 | } 299 | 300 | public Node apply(double[] x) { 301 | return dfs(root, x, 0); 302 | } 303 | 304 | // private int nodeCount = 0; 305 | 306 | // private int dfsNode(Node node, ArrayList infoList, int is_root) 307 | // { 308 | // if (node == null) 309 | // return -1; 310 | // int left_son = dfsNode(node.leftNode, infoList, 0); 311 | // int right_son = dfsNode(node.rightNode, infoList, 0); 312 | // int id = nodeCount++; 313 | // TreeInfo info = new TreeInfo(id, left_son, right_son, node.split_feature, 314 | // node.split_val, estimator_num, 315 | // is_root, node.treeVal); 316 | // infoList.add(info); 317 | // return id; 318 | // } 319 | // 320 | // public ArrayList getTreeInfo() { 321 | // nodeCount = 0; 322 | // ArrayList infoList = new ArrayList(); 323 | // dfsNode(root, infoList, 1); 324 | // return infoList; 325 | // } 326 | } -------------------------------------------------------------------------------- /src/decision_tree/DecisionTree.java: -------------------------------------------------------------------------------- 1 | package decision_tree; 2 | 3 | public abstract class DecisionTree { 4 | public abstract void fit(double[][] X, double[] Y, double[] sample_weight); 5 | 6 | public abstract Node apply(double[] x); 7 | } 8 | -------------------------------------------------------------------------------- /src/decision_tree/Node.java: -------------------------------------------------------------------------------- 1 | package decision_tree; 2 | 3 | import java.util.ArrayList; 4 | 5 | public class Node { 6 | double split_val; 7 | int split_feature; 8 | public Node leftNode; 9 | public Node rightNode; 10 | public double treeVal; 11 | 12 | // 落在当前叶子节点的样本的统计信息 13 | public ArrayList diff = new ArrayList(); 14 | public ArrayList diff_sample_weight = new ArrayList(); 15 | 16 | public void setTreeValue(double value) { 17 | this.treeVal = value; 18 | } 19 | 20 | public void clear() { 21 | this.diff.clear(); 22 | this.diff.trimToSize(); 23 | this.diff_sample_weight.clear(); 24 | this.diff_sample_weight.trimToSize(); 25 | this.diff = null; 26 | this.diff_sample_weight = null; 27 | } 28 | } -------------------------------------------------------------------------------- /src/objective/Estimator.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import objective.Utils.Item; 4 | 5 | public abstract class Estimator { 6 | public abstract void fit(double[][] X, double[] y, double[] sample_weight); 7 | 8 | public abstract double[] predict(double[][] X); 9 | } -------------------------------------------------------------------------------- /src/objective/LossFunction.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import decision_tree.DecisionTree; 4 | 5 | public abstract class LossFunction { 6 | public abstract double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual); 7 | 8 | public abstract void update_terminal_region(DecisionTree tree, double[][] sample, double[] label, double[] y_pred, 9 | double[] sample_weight); 10 | } -------------------------------------------------------------------------------- /src/objective/QuantileEstimator.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | public class QuantileEstimator extends Estimator { 4 | private double alpha; 5 | private double quantile; 6 | 7 | public QuantileEstimator(double alpha) { 8 | this.alpha = alpha; 9 | } 10 | 11 | public void fit(double[][] X, double[] y, double[] sample_weight) { 12 | this.quantile = Utils.weighted_percentile(y, sample_weight, this.alpha * 100.0); 13 | 14 | } 15 | 16 | public double[] predict(double[][] X) { 17 | int size = X.length; 18 | double[] ans = new double[size]; 19 | for (int i = 0; i < size; i++) 20 | ans[i] = this.quantile; 21 | return ans; 22 | } 23 | } -------------------------------------------------------------------------------- /src/objective/QuantileLossFunction.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import java.util.HashSet; 4 | 5 | import decision_tree.DecisionRegressionTree; 6 | import decision_tree.DecisionTree; 7 | import decision_tree.Node; 8 | 9 | public class QuantileLossFunction extends LossFunction { 10 | 11 | private double alpha; 12 | private double percentile; 13 | 14 | public QuantileLossFunction(double alpha) { 15 | this.alpha = alpha; 16 | this.percentile = alpha * 100.0; 17 | } 18 | 19 | public double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual) { 20 | int size = y_true.length; 21 | for (int i = 0; i < size; i++) { 22 | if (y_true[i] > y_pred[i]) 23 | residual[i] = alpha; 24 | else 25 | residual[i] = -1.0 + alpha; 26 | } 27 | return residual; 28 | } 29 | 30 | @Override 31 | public void update_terminal_region(DecisionTree tree, double[][] x, double[] y, double[] y_pred, 32 | double[] sample_weight) { 33 | int n_samples = x.length; 34 | HashSet nodeSet = new HashSet(); 35 | for (int i = 0; i < n_samples; i++) { 36 | Node leaf = tree.apply(x[i]); 37 | leaf.diff.add(y[i] - y_pred[i]); 38 | leaf.diff_sample_weight.add(sample_weight[i]); 39 | nodeSet.add(leaf); 40 | } 41 | for (Node leaf : nodeSet) { 42 | leaf.setTreeValue(Utils.weighted_percentile(leaf.diff, leaf.diff_sample_weight, this.percentile)); 43 | // System.out.println(leaf+" "+leaf.treeVal); 44 | leaf.clear(); 45 | } 46 | } 47 | } -------------------------------------------------------------------------------- /src/objective/SquaresEstimator.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import java.util.Arrays; 4 | 5 | public class SquaresEstimator extends Estimator { 6 | 7 | private double mean; 8 | 9 | @Override 10 | public void fit(double[][] X, double[] y, double[] sample_weight) { 11 | // TODO Auto-generated method stub 12 | this.mean = Utils.averavge(y, sample_weight); 13 | } 14 | 15 | @Override 16 | public double[] predict(double[][] X) { 17 | double[] ans = new double[X.length]; 18 | Arrays.fill(ans, mean); 19 | return ans; 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/objective/SquaresLossFunction.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import java.util.HashSet; 4 | 5 | import decision_tree.DecisionRegressionTree; 6 | import decision_tree.DecisionTree; 7 | import decision_tree.Node; 8 | 9 | public class SquaresLossFunction extends LossFunction { 10 | 11 | public SquaresLossFunction() { 12 | super(); 13 | // TODO Auto-generated constructor stub 14 | } 15 | 16 | @Override 17 | public double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual) { 18 | // TODO Auto-generated method stub 19 | for (int i = 0; i < y_true.length; i++) 20 | residual[i] = y_true[i] - y_pred[i]; 21 | return residual; 22 | } 23 | 24 | @Override 25 | public void update_terminal_region(DecisionTree tree, double[][] x, double[] y, double[] y_pred, 26 | double[] sample_weight) { 27 | // TODO Auto-generated method stub 28 | int n_samples = x.length; 29 | HashSet nodeSet = new HashSet(); 30 | for (int i = 0; i < n_samples; i++) { 31 | Node leaf = tree.apply(x[i]); 32 | leaf.diff.add(y[i] - y_pred[i]); 33 | leaf.diff_sample_weight.add(sample_weight[i]); 34 | nodeSet.add(leaf); 35 | } 36 | for (Node leaf : nodeSet) { 37 | leaf.setTreeValue(Utils.averavge(leaf.diff, leaf.diff_sample_weight)); 38 | // System.out.println(leaf+" "+leaf.treeVal); 39 | leaf.clear(); 40 | } 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/objective/Utils.java: -------------------------------------------------------------------------------- 1 | package objective; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.Collections; 6 | import java.util.Comparator; 7 | import java.util.LinkedList; 8 | import java.util.Random; 9 | 10 | public class Utils { 11 | 12 | public static class Item implements Comparable { 13 | double y; 14 | double sample_weight; 15 | 16 | public Item(double y, double sample_weight) { 17 | this.y = y; 18 | this.sample_weight = sample_weight; 19 | } 20 | 21 | @Override 22 | public int compareTo(Item o) { 23 | return new Double(y).compareTo(o.y); 24 | } 25 | } 26 | 27 | public static double weighted_percentile(ArrayList y, ArrayList sample_weight, double percentile) { 28 | Item[] items = new Item[y.size()]; 29 | for (int i = 0; i < y.size(); i++) 30 | items[i] = new Item(y.get(i), sample_weight.get(i)); 31 | Arrays.sort(items); 32 | double[] weight_sum = new double[y.size()]; 33 | for (int i = 0; i < y.size(); i++) { 34 | weight_sum[i] = items[i].sample_weight; 35 | if (i > 0) 36 | weight_sum[i] += weight_sum[i - 1]; 37 | } 38 | double target = weight_sum[y.size() - 1] * percentile / 100.0; 39 | for (int i = 1; i < y.size(); i++) 40 | if (weight_sum[i] > target) { 41 | // System.out.println(y.size() + " " + weight_sum[i] + " " + 42 | // y.get(i - 1) + " " + target); 43 | return items[i - 1].y; 44 | } 45 | return items[y.size() - 1].y; 46 | } 47 | 48 | public static double weighted_percentile(double[] y, double[] sample_weight, double percentile) { 49 | Item[] items = new Item[y.length]; 50 | for (int i = 0; i < y.length; i++) { 51 | items[i] = new Item(y[i], sample_weight[i]); 52 | } 53 | Arrays.sort(items); 54 | double[] weight_sum = new double[y.length]; 55 | for (int i = 0; i < y.length; i++) { 56 | weight_sum[i] = items[i].sample_weight; 57 | if (i > 0) 58 | weight_sum[i] += weight_sum[i - 1]; 59 | } 60 | double target = weight_sum[y.length - 1] * percentile / 100.0; 61 | for (int i = 1; i < y.length; i++) { 62 | if (weight_sum[i] > target) { 63 | // System.out.println(y.length + " " + weight_sum[i] + " " + y[i 64 | // - 1] + " " + target); 65 | return items[i - 1].y; 66 | } 67 | } 68 | return items[y.length - 1].y; 69 | } 70 | 71 | public static void sample(int n, int m, int[] ret) { 72 | LinkedList list = new LinkedList(); 73 | for (int i = 0; i < n; i++) 74 | list.add(i); 75 | Collections.shuffle(list); 76 | for (int i = 0; i < m; i++) 77 | ret[i] = list.pop(); 78 | } 79 | 80 | public static double averavge(ArrayList diff, ArrayList diff_sample_weight) { 81 | // TODO Auto-generated method stub 82 | double sum = 0, sumWeight = 0; 83 | for (int i = 0; i < diff.size(); i++) { 84 | sum += diff.get(i) * diff_sample_weight.get(i); 85 | sumWeight += diff_sample_weight.get(i); 86 | } 87 | return sum / sumWeight; 88 | } 89 | 90 | public static double averavge(double[] y, double[] sample_weight) { 91 | double sum = 0, sumWeight = 0; 92 | for (int i = 0; i < y.length; i++) { 93 | sum += y[i] * sample_weight[i]; 94 | sumWeight += sample_weight[i]; 95 | } 96 | return sum / sumWeight; 97 | } 98 | } -------------------------------------------------------------------------------- /src/rf/RandomForestRegressor.java: -------------------------------------------------------------------------------- 1 | package rf; 2 | 3 | import java.util.Arrays; 4 | import java.util.Comparator; 5 | import java.util.HashMap; 6 | 7 | import data.LabeledSample; 8 | import decision_tree.DecisionRegressionTree; 9 | import decision_tree.DecisionRegressionTreeLeafSpliter; 10 | import decision_tree.DecisionTree; 11 | import decision_tree.Node; 12 | import objective.LossFunction; 13 | import objective.SquaresLossFunction; 14 | import util.ParamException; 15 | import util.ParamReader; 16 | 17 | public class RandomForestRegressor { 18 | 19 | private String spliter; 20 | private int num_leafs; 21 | private int max_depth; 22 | private int n_estimator; 23 | private int random_state; 24 | private double sample_rate; 25 | private double feature_rate; 26 | private double min_leaf_sample; 27 | private DecisionTree[] tree; 28 | private LabeledSample[] trainData; 29 | private LabeledSample[][] preSortedSampleArrays; 30 | private LabeledSample[][] copyOfPreSortedSampleArrays; 31 | private LossFunction loss; 32 | 33 | public RandomForestRegressor(HashMap params) throws ParamException { 34 | 35 | // ----------------------------------------------------------------------- 36 | // 必需参数 37 | 38 | assert (params.containsKey("spliter")); 39 | this.spliter = ParamReader.readString("spliter", params); 40 | if (this.spliter.equals("leaf")) { 41 | assert (params.containsKey("num_leafs")); 42 | this.num_leafs = ParamReader.readInt("num_leafs", params); 43 | } else { 44 | assert (params.containsKey("max_depth")); 45 | this.max_depth = ParamReader.readInt("max_depth", params); 46 | } 47 | assert (params.containsKey("n_estimator")); 48 | this.n_estimator = ParamReader.readInt("n_estimator", params); 49 | 50 | // ------------------------------------------------------------------------- 51 | if (params.containsKey("random_state")) 52 | this.random_state = ParamReader.readInt("random_state", params); 53 | else 54 | this.random_state = 0; 55 | if (params.containsKey("sample_rate")) 56 | this.sample_rate = ParamReader.readDouble("sample_rate", params); 57 | else 58 | this.sample_rate = 1.0; 59 | if (params.containsKey("feature_rate")) 60 | this.feature_rate = ParamReader.readDouble("feature_rate", params); 61 | else 62 | this.feature_rate = 1.0; 63 | 64 | if (params.containsKey("min_leaf_sample")) 65 | this.min_leaf_sample = ParamReader.readDouble("min_leaf_sample", params); 66 | else 67 | this.min_leaf_sample = 1; 68 | } 69 | 70 | public void fit(double[][] X, double[] Y, double[] sample_weight) { 71 | 72 | int featureNum = X[0].length; 73 | int sampleNum = X.length; 74 | this.trainData = new LabeledSample[sampleNum]; 75 | for (int i = 0; i < sampleNum; i++) { 76 | trainData[i] = new LabeledSample(); 77 | trainData[i].x = X[i]; 78 | trainData[i].y = Y[i]; 79 | trainData[i].weight = sample_weight[i]; 80 | } 81 | this.loss = new SquaresLossFunction(); 82 | 83 | this.preSortedSampleArrays = new LabeledSample[featureNum][sampleNum]; 84 | this.copyOfPreSortedSampleArrays = new LabeledSample[featureNum][sampleNum]; 85 | for (int a = 0; a < featureNum; a++) { 86 | for (int b = 0; b < sampleNum; b++) { 87 | copyOfPreSortedSampleArrays[a][b] = trainData[b]; 88 | } 89 | 90 | final int compareFeature = a; 91 | Arrays.sort(copyOfPreSortedSampleArrays[a], 0, sampleNum, new Comparator() { 92 | 93 | @Override 94 | public int compare(LabeledSample o1, LabeledSample o2) { 95 | return new Double(o1.x[compareFeature]).compareTo(o2.x[compareFeature]); 96 | } 97 | }); 98 | } 99 | 100 | this.tree = new DecisionTree[this.n_estimator]; 101 | for (int i = 0; i < n_estimator; i++) { 102 | for (int a = 0; a < featureNum; a++) { 103 | for (int b = 0; b < sampleNum; b++) { 104 | preSortedSampleArrays[a][b] = copyOfPreSortedSampleArrays[a][b]; 105 | } 106 | } 107 | if (this.spliter.equals("leaf")) 108 | this.tree[i] = new DecisionRegressionTreeLeafSpliter(this.num_leafs, i, sample_rate, feature_rate, 109 | min_leaf_sample, preSortedSampleArrays); 110 | else 111 | this.tree[i] = new DecisionRegressionTree(i, this.max_depth, i, sample_rate, feature_rate, 112 | min_leaf_sample, preSortedSampleArrays); 113 | System.out.println("第" + i + "棵树"); 114 | this.tree[i].fit(X, Y, sample_weight); 115 | double[] pred = new double[X.length]; 116 | this.loss.update_terminal_region(this.tree[i], X, Y, pred, sample_weight); 117 | } 118 | } 119 | 120 | public double[][] predict(double[][] X) { 121 | double ans[][] = new double[this.n_estimator][X.length]; 122 | for (int i = 0; i < this.n_estimator; i++) { 123 | for (int j = 0; j < X.length; j++) { 124 | ans[i][j] = this.tree[i].apply(X[j]).treeVal; 125 | } 126 | } 127 | return ans; 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/util/BoostingListener.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | public interface BoostingListener { 4 | public void done(int treeNum); 5 | } 6 | -------------------------------------------------------------------------------- /src/util/DumpTree.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | public class DumpTree { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/util/ParamException.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | public class ParamException extends Exception { 4 | 5 | public ParamException(String string) { 6 | // TODO Auto-generated constructor stub 7 | super(string); 8 | } 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/util/ParamReader.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | import java.util.HashMap; 4 | 5 | public class ParamReader { 6 | public static int readInt(String key, HashMap map) throws ParamException { 7 | if (!map.containsKey(key)) 8 | throw new ParamException("Params doesn't have key " + key); 9 | return Integer.valueOf(map.get(key).toString()); 10 | } 11 | 12 | public static double readDouble(String key, HashMap map) throws ParamException { 13 | if (!map.containsKey(key)) 14 | throw new ParamException("Params doesn't have key " + key); 15 | return Double.valueOf(map.get(key).toString()); 16 | } 17 | 18 | public static String readString(String key, HashMap map) throws ParamException { 19 | if (!map.containsKey(key)) 20 | throw new ParamException("Params doesn't have key " + key); 21 | return map.get(key).toString(); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/util/dump/TreeInfo.java: -------------------------------------------------------------------------------- 1 | package util.dump; 2 | 3 | // 基于对应的数据表 4 | 5 | public class TreeInfo { 6 | public long root_id; 7 | public long left_son; 8 | public long right_son; 9 | public long split_feature; 10 | public double split_feature_value; 11 | public long estimator_num; 12 | public long is_root; 13 | public double node_value; 14 | 15 | public TreeInfo(long root_id, long left_son, long right_son, long split_feature, double split_feature_value, 16 | long estimator_num, long is_root, double node_value) { 17 | this.root_id = root_id; 18 | this.left_son = left_son; 19 | this.right_son = right_son; 20 | this.split_feature = split_feature; 21 | this.split_feature_value = split_feature_value; 22 | this.estimator_num = estimator_num; 23 | this.is_root = is_root; 24 | this.node_value = node_value; 25 | } 26 | 27 | } 28 | --------------------------------------------------------------------------------