├── .DS_Store
├── .classpath
├── .project
├── .settings
└── org.eclipse.jdt.core.prefs
├── README.md
├── bin
├── .DS_Store
├── Main.class
├── boosting
│ ├── GradientBoostingRegressor$1.class
│ └── GradientBoostingRegressor.class
├── data
│ └── LabeledSample.class
├── decision_tree
│ ├── .DS_Store
│ ├── DecisionRegressionTree$SplitResult.class
│ ├── DecisionRegressionTree.class
│ ├── DecisionRegressionTreeLeafSpliter$1.class
│ ├── DecisionRegressionTreeLeafSpliter$QueueNode.class
│ ├── DecisionRegressionTreeLeafSpliter$SplitResult.class
│ ├── DecisionRegressionTreeLeafSpliter.class
│ ├── DecisionTree.class
│ └── Node.class
├── objective
│ ├── Estimator.class
│ ├── LossFunction.class
│ ├── QuantileEstimator.class
│ ├── QuantileLossFunction.class
│ ├── SquaresEstimator.class
│ ├── SquaresLossFunction.class
│ ├── Utils$Item.class
│ └── Utils.class
├── rf
│ ├── RandomForestRegressor$1.class
│ └── RandomForestRegressor.class
└── util
│ ├── BoostingListener.class
│ ├── DumpTree.class
│ ├── ParamException.class
│ ├── ParamReader.class
│ └── dump
│ └── TreeInfo.class
└── src
├── .DS_Store
├── Main.java
├── boosting
└── GradientBoostingRegressor.java
├── data
└── LabeledSample.java
├── decision_tree
├── .DS_Store
├── DecisionRegressionTree.java
├── DecisionRegressionTreeLeafSpliter.java
├── DecisionTree.java
└── Node.java
├── objective
├── Estimator.java
├── LossFunction.java
├── QuantileEstimator.java
├── QuantileLossFunction.java
├── SquaresEstimator.java
├── SquaresLossFunction.java
└── Utils.java
├── rf
└── RandomForestRegressor.java
└── util
├── BoostingListener.java
├── DumpTree.java
├── ParamException.java
├── ParamReader.java
└── dump
└── TreeInfo.java
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/.DS_Store
--------------------------------------------------------------------------------
/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | GradientBoostingDecisionTree
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 |
15 | org.eclipse.jdt.core.javanature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.settings/org.eclipse.jdt.core.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.7
4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
5 | org.eclipse.jdt.core.compiler.compliance=1.7
6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate
7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate
8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate
9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
11 | org.eclipse.jdt.core.compiler.source=1.7
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### GBRT&RF Java版
2 |
3 | - Java版的GBRT及RF
4 |
5 | - 模型成绩
6 | - 天池-菜鸟物流 第二赛季 10th
7 | - 天池-机场客流 第二赛季 8th
8 |
9 | - 使用说明
10 | - 参见 `Main.java`
11 |
12 | - 模型说明
13 | - DecisionTree 支持默认分裂方式及选择最优叶子节点分裂方式
14 |
15 |
16 | #### 欢迎交流改进
17 |
--------------------------------------------------------------------------------
/bin/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/.DS_Store
--------------------------------------------------------------------------------
/bin/Main.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/Main.class
--------------------------------------------------------------------------------
/bin/boosting/GradientBoostingRegressor$1.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/boosting/GradientBoostingRegressor$1.class
--------------------------------------------------------------------------------
/bin/boosting/GradientBoostingRegressor.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/boosting/GradientBoostingRegressor.class
--------------------------------------------------------------------------------
/bin/data/LabeledSample.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/data/LabeledSample.class
--------------------------------------------------------------------------------
/bin/decision_tree/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/.DS_Store
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTree$SplitResult.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTree$SplitResult.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTree.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTree.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTreeLeafSpliter$1.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$1.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTreeLeafSpliter$QueueNode.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$QueueNode.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTreeLeafSpliter$SplitResult.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter$SplitResult.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionRegressionTreeLeafSpliter.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionRegressionTreeLeafSpliter.class
--------------------------------------------------------------------------------
/bin/decision_tree/DecisionTree.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/DecisionTree.class
--------------------------------------------------------------------------------
/bin/decision_tree/Node.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/decision_tree/Node.class
--------------------------------------------------------------------------------
/bin/objective/Estimator.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Estimator.class
--------------------------------------------------------------------------------
/bin/objective/LossFunction.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/LossFunction.class
--------------------------------------------------------------------------------
/bin/objective/QuantileEstimator.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/QuantileEstimator.class
--------------------------------------------------------------------------------
/bin/objective/QuantileLossFunction.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/QuantileLossFunction.class
--------------------------------------------------------------------------------
/bin/objective/SquaresEstimator.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/SquaresEstimator.class
--------------------------------------------------------------------------------
/bin/objective/SquaresLossFunction.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/SquaresLossFunction.class
--------------------------------------------------------------------------------
/bin/objective/Utils$Item.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Utils$Item.class
--------------------------------------------------------------------------------
/bin/objective/Utils.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/objective/Utils.class
--------------------------------------------------------------------------------
/bin/rf/RandomForestRegressor$1.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/rf/RandomForestRegressor$1.class
--------------------------------------------------------------------------------
/bin/rf/RandomForestRegressor.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/rf/RandomForestRegressor.class
--------------------------------------------------------------------------------
/bin/util/BoostingListener.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/BoostingListener.class
--------------------------------------------------------------------------------
/bin/util/DumpTree.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/DumpTree.class
--------------------------------------------------------------------------------
/bin/util/ParamException.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/ParamException.class
--------------------------------------------------------------------------------
/bin/util/ParamReader.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/ParamReader.class
--------------------------------------------------------------------------------
/bin/util/dump/TreeInfo.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/bin/util/dump/TreeInfo.class
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/src/.DS_Store
--------------------------------------------------------------------------------
/src/Main.java:
--------------------------------------------------------------------------------
1 | import java.io.BufferedReader;
2 | import java.io.File;
3 | import java.io.FileReader;
4 | import java.util.Arrays;
5 | import java.util.HashMap;
6 | import java.util.Scanner;
7 |
8 | import boosting.GradientBoostingRegressor;
9 | import rf.RandomForestRegressor;
10 |
11 | public class Main {
12 | public static void main(String[] args) throws Exception {
13 | HashMap params = new HashMap<>();
14 | params.put("learning_rate", 0.05);
15 | params.put("n_estimator", 100);
16 | params.put("max_depth", 3);
17 | params.put("objective", "ls");
18 | params.put("feature_rate", 0.1);
19 | params.put("sample_rate", 0.8);
20 | params.put("spliter", "leaf");
21 | params.put("num_leafs", 8);
22 | // GradientBoostingRegressor gbdt = new
23 | // GradientBoostingRegressor(params);
24 | RandomForestRegressor gbdt = new RandomForestRegressor(params);
25 |
26 | double trainX[][] = new double[26964][155];
27 | double trainY[] = new double[26964];
28 | double trainWeight[] = new double[26964];
29 | Arrays.fill(trainWeight, 1.0);
30 | Scanner in = new Scanner(new BufferedReader(new FileReader(new File("/Users/mac/Desktop/train_frame1.csv"))));
31 | in.nextLine();
32 | int numLines = 0;
33 | while (in.hasNextLine()) {
34 | String line = in.nextLine();
35 | String[] str = line.split(",", -1);
36 | int featureIndex = 0;
37 | for (int i = 3; i < str.length; i++) {
38 | if (i == 138)
39 | continue;
40 | if (str[i].equals("") || str[i] == null)
41 | trainX[numLines][featureIndex++] = -1000;
42 | else
43 | trainX[numLines][featureIndex++] = Double.valueOf(str[i]);
44 | }
45 | trainY[numLines] = Double.valueOf(str[2]);
46 | numLines++;
47 | assert (featureIndex == 155);
48 | }
49 | System.out.println(numLines);
50 | in.close();
51 |
52 | double testX[][] = new double[4464][155];
53 | double testY[] = new double[4464];
54 | in = new Scanner(new BufferedReader(new FileReader(new File("/Users/mac/Desktop/test_frame1.csv"))));
55 | in.nextLine();
56 | numLines = 0;
57 | while (in.hasNextLine()) {
58 | String line = in.nextLine();
59 | String[] str = line.split(",", -1);
60 | int featureIndex = 0;
61 | for (int i = 3; i < str.length; i++) {
62 | if (i == 138)
63 | continue;
64 | if (str[i].equals("") || str[i] == null)
65 | testX[numLines][featureIndex++] = -1000;
66 | else
67 | testX[numLines][featureIndex++] = Double.valueOf(str[i]);
68 | }
69 | assert (featureIndex == 155);
70 | testY[numLines] = Double.valueOf(str[2]);
71 | numLines++;
72 | }
73 | System.out.println(numLines);
74 | in.close();
75 | gbdt.fit(trainX, trainY, trainWeight);
76 | double[][] ans = gbdt.predict(testX);
77 | double[] ret = new double[ans[0].length];
78 | for (int i = 0; i < testX.length; i++) {
79 | ret[i] = 0;
80 | for (int j = 0; j < ans.length; j++) {
81 | ret[i] += ans[j][i];
82 | }
83 | ret[i] /= ans.length;
84 | //System.out.println(ret[i] + " " + testY[i]);
85 | }
86 |
87 | // --------------------------------------------
88 | // gbdt.fit(trainX, trainY, trainWeight,null);
89 | // double[] ret = gbdt.predict(testX);
90 |
91 | double all = 0;
92 | for (int i = 0; i < testX.length; i++) {
93 | double error = testY[i] - ret[i];
94 | all += error * error;
95 | }
96 | System.out.println(all);
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/src/boosting/GradientBoostingRegressor.java:
--------------------------------------------------------------------------------
1 | package boosting;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Arrays;
5 | import java.util.Comparator;
6 | import java.util.HashMap;
7 | import java.util.HashSet;
8 | import java.util.Random;
9 | import java.util.TreeSet;
10 |
11 | import data.LabeledSample;
12 | import decision_tree.DecisionRegressionTree;
13 | import decision_tree.DecisionRegressionTreeLeafSpliter;
14 | import decision_tree.DecisionTree;
15 | import decision_tree.Node;
16 | import objective.Estimator;
17 | import objective.LossFunction;
18 | import objective.QuantileEstimator;
19 | import objective.QuantileLossFunction;
20 | import objective.SquaresEstimator;
21 | import objective.SquaresLossFunction;
22 | import util.BoostingListener;
23 | import util.ParamException;
24 | import util.ParamReader;
25 | import util.dump.TreeInfo;
26 |
27 | public class GradientBoostingRegressor {
28 |
29 | // ------------------------------------------------------------------------------------
30 | // Boosting的参数
31 | private double learning_rate;
32 | public int n_estimator;
33 | private double alpha; // quantile loss才需要
34 | private double sample_rate;
35 | private double feature_rate;
36 | private Estimator init_; // 初始值估计
37 | private LossFunction loss; // lossFunction
38 | private int random_state;
39 | private DecisionTree[] trees;
40 | private int max_depth;
41 | private int num_leafs; // 基于叶子节点分裂时的叶子数
42 | private double baseValue; // 初始值
43 | private double[] residual; // 残差
44 | private String spliter;
45 | private double min_leaf_sample;
46 |
47 | public GradientBoostingRegressor(HashMap params) throws ParamException {
48 |
49 | // -----------------------------------------------------------------------
50 | // 必需参数
51 |
52 | assert (params.containsKey("learning_rate"));
53 | this.learning_rate = ParamReader.readDouble("learning_rate", params);
54 | assert (params.containsKey("spliter"));
55 | this.spliter = ParamReader.readString("spliter", params);
56 | if (this.spliter.equals("leaf")) {
57 | assert (params.containsKey("num_leafs"));
58 | this.num_leafs = ParamReader.readInt("num_leafs", params);
59 | } else {
60 | assert (params.containsKey("max_depth"));
61 | this.max_depth = ParamReader.readInt("max_depth", params);
62 | }
63 | assert (params.containsKey("n_estimator"));
64 | this.n_estimator = ParamReader.readInt("n_estimator", params);
65 |
66 | // -------------------------------------------------------------------------
67 | // 可选参数
68 |
69 | String objective = ParamReader.readString("objective", params);
70 | if (objective.equals("quantile")) {
71 | assert (params.containsKey("alpha"));
72 | this.alpha = ParamReader.readDouble("alpha", params);
73 | this.loss = new QuantileLossFunction(alpha);
74 | this.init_ = new QuantileEstimator(alpha);
75 | } else if (objective.equals("lad")) {
76 | } else {
77 | this.loss = new SquaresLossFunction();
78 | this.init_ = new SquaresEstimator();
79 | }
80 |
81 | if (params.containsKey("random_state"))
82 | this.random_state = ParamReader.readInt("random_state", params);
83 | else
84 | this.random_state = 0;
85 | if (params.containsKey("sample_rate"))
86 | this.sample_rate = ParamReader.readDouble("sample_rate", params);
87 | else
88 | this.sample_rate = 1.0;
89 | if (params.containsKey("feature_rate"))
90 | this.feature_rate = ParamReader.readDouble("feature_rate", params);
91 | else
92 | this.feature_rate = 1.0;
93 |
94 | if (params.containsKey("min_leaf_sample"))
95 | this.min_leaf_sample = ParamReader.readDouble("min_leaf_sample", params);
96 | else
97 | this.min_leaf_sample = 1;
98 | }
99 |
100 | private LabeledSample[] trainData;
101 | private LabeledSample[][] preSortedSampleArrays;
102 | private LabeledSample[][] copyOfPreSortedSampleArrays;
103 | private BoostingListener listener;
104 |
105 | private double[] _fit_stage(int i, double[][] X, double[] Y, double[] y_pred, double[] sample_weight) {
106 | this.loss.negative_gradient(Y, y_pred, this.residual);
107 | int featureNum = X[0].length;
108 | int sampleNum = X.length;
109 | for (int b = 0; b < sampleNum; b++)
110 | trainData[b].y = residual[b];
111 | for (int a = 0; a < featureNum; a++) {
112 | for (int b = 0; b < sampleNum; b++) {
113 | preSortedSampleArrays[a][b] = copyOfPreSortedSampleArrays[a][b];
114 | }
115 | }
116 | if (this.spliter.equals("leaf"))
117 | trees[i] = new DecisionRegressionTreeLeafSpliter(this.num_leafs, this.random_state, sample_rate,
118 | feature_rate, min_leaf_sample, preSortedSampleArrays);
119 | else
120 | trees[i] = new DecisionRegressionTree(i, this.max_depth, this.random_state, sample_rate, feature_rate,
121 | min_leaf_sample, preSortedSampleArrays);
122 | trees[i].fit(X, residual, sample_weight);
123 |
124 | this.loss.update_terminal_region(trees[i], X, Y, y_pred, sample_weight);
125 | for (int j = 0; j < X.length; j++) {
126 | Node leaf = trees[i].apply(X[j]);
127 | y_pred[j] += learning_rate * leaf.treeVal;
128 | // System.out.print(y_pred[j] + " ");
129 | }
130 | // System.out.println();
131 | return y_pred;
132 | }
133 |
134 | public void registerListener(BoostingListener listener) {
135 | this.listener = listener;
136 | }
137 |
138 | private int _fit_stages(double[][] X, double[] Y, double[] y_pred, double[] sample_weight, TreeInfo[][] infoList) {
139 | int featureNum = X[0].length;
140 | int sampleNum = X.length;
141 | trainData = new LabeledSample[sampleNum];
142 | for (int i = 0; i < sampleNum; i++) {
143 | trainData[i] = new LabeledSample();
144 | trainData[i].x = X[i];
145 | trainData[i].y = Y[i];
146 | trainData[i].weight = sample_weight[i];
147 | }
148 |
149 | preSortedSampleArrays = new LabeledSample[featureNum][sampleNum];
150 | copyOfPreSortedSampleArrays = new LabeledSample[featureNum][sampleNum];
151 | for (int a = 0; a < featureNum; a++) {
152 | for (int b = 0; b < sampleNum; b++) {
153 | copyOfPreSortedSampleArrays[a][b] = trainData[b];
154 | }
155 |
156 | final int compareFeature = a;
157 | Arrays.sort(copyOfPreSortedSampleArrays[a], 0, sampleNum, new Comparator() {
158 |
159 | @Override
160 | public int compare(LabeledSample o1, LabeledSample o2) {
161 | return new Double(o1.x[compareFeature]).compareTo(o2.x[compareFeature]);
162 | }
163 | });
164 | }
165 | this.residual = new double[sampleNum];
166 | this.trees = new DecisionTree[this.n_estimator];
167 |
168 | for (int i = 0; i < this.n_estimator; i++) {
169 | System.out.println("第" + i + "棵树");
170 | if (listener != null)
171 | listener.done(i);
172 | y_pred = _fit_stage(i, X, Y, y_pred, sample_weight);
173 | }
174 | return this.n_estimator;
175 | }
176 |
177 | public void fit(double[][] X, double[] Y, double[] sample_weight, TreeInfo[][] infoList) {
178 | this.init_.fit(X, Y, sample_weight);
179 | double[] y_pred = this.init_.predict(X);
180 | this.baseValue = y_pred[0];
181 | _fit_stages(X, Y, y_pred, sample_weight, infoList);
182 | }
183 |
184 | public double[] predict(double[][] X) {
185 | double[] ans = new double[X.length];
186 | for (int i = 0; i < X.length; i++) {
187 | ans[i] = baseValue;
188 | for (int j = 0; j < n_estimator; j++) {
189 | Node leaf = trees[j].apply(X[i]);
190 | ans[i] += learning_rate * leaf.treeVal;
191 | }
192 | }
193 | return ans;
194 | }
195 |
196 | public double[] predict(double[][] X, int n) {
197 | double[] ans = new double[X.length];
198 | for (int i = 0; i < X.length; i++) {
199 | ans[i] = baseValue;
200 | for (int j = 0; j < n; j++) {
201 | Node leaf = trees[j].apply(X[i]);
202 | ans[i] += learning_rate * leaf.treeVal;
203 | }
204 | }
205 | return ans;
206 | }
207 | }
--------------------------------------------------------------------------------
/src/data/LabeledSample.java:
--------------------------------------------------------------------------------
1 | package data;
2 |
3 | public class LabeledSample {
4 | public double[] x;
5 | public double y;
6 | public double weight;
7 | public boolean isSampled; // 样本采样是否采到
8 | public boolean isSplitToLeft; // 是否划分到左半边的树
9 | }
10 |
--------------------------------------------------------------------------------
/src/decision_tree/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xing89qs/GBRT-RF-Java/d0d09f09c52a09baceb26fc011503b092ca947a4/src/decision_tree/.DS_Store
--------------------------------------------------------------------------------
/src/decision_tree/DecisionRegressionTree.java:
--------------------------------------------------------------------------------
1 | package decision_tree;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Random;
5 | import java.util.Stack;
6 |
7 | import data.LabeledSample;
8 |
9 | public class DecisionRegressionTree extends DecisionTree {
10 |
11 | // 树的参数
12 | private int max_depth;
13 | private int random_state;
14 | private Random random;
15 | private Node root;
16 | private double sample_rate;
17 | private double feature_rate;
18 | private int estimator_num;
19 | private LabeledSample[][] sampleSortedByFeatureArrays;
20 | private double min_leaf_sample;
21 |
22 | public DecisionRegressionTree(int estimator_num, int max_depth, int random_state, double sample_rate,
23 | double feature_rate, double min_leaf_sample, LabeledSample[][] preSortedSampleArrays) {
24 | this.estimator_num = estimator_num;
25 | this.max_depth = max_depth;
26 | this.random_state = random_state;
27 | System.out.println(this.random_state);
28 | this.random = new Random(this.random_state);
29 | this.sample_rate = sample_rate;
30 | this.feature_rate = feature_rate;
31 | this.min_leaf_sample = min_leaf_sample;
32 | this.sampleSortedByFeatureArrays = preSortedSampleArrays;
33 | }
34 |
35 | // public DecisionRegressionTree(int estimator_num, TreeInfo[] infoList) {
36 | // this.estimator_num = estimator_num;
37 | // Node[] nodeList = new Node[infoList.length];
38 | // for (int i = 0; i < infoList.length; i++) {
39 | // if (infoList[i] == null)
40 | // continue;
41 | // Node node = new Node();
42 | // nodeList[(int) infoList[i].root_id] = node;
43 | // if ((int) infoList[i].left_son != -1)
44 | // node.leftNode = nodeList[(int) infoList[i].left_son];
45 | // if ((int) infoList[i].right_son != -1)
46 | // node.rightNode = nodeList[(int) infoList[i].right_son];
47 | // node.split_feature = (int) infoList[i].split_feature;
48 | // node.split_val = infoList[i].split_feature_value;
49 | // node.treeVal = infoList[i].node_value;
50 | // if (infoList[i].is_root == 1L)
51 | // this.root = node;
52 | // }
53 | // }
54 |
55 | public DecisionRegressionTree() {
56 | // TODO Auto-generated constructor stub
57 | }
58 |
59 | class SplitResult {
60 | Stack leftSample, rightSample;
61 | double split_error, left_error, right_error;
62 | int best_feature;
63 | double best_split_val;
64 |
65 | // error = var_left + var_right
66 | // var_left = sigma((y_i-y_bar)^2*w)
67 | // = sigma(y_i*y_i*w_i)- y_bar*sigma(2*w_i*y_i)+ y_bar*y_bar*sigma(w_i)
68 | // y_bar = sigma(y_i*w_i)/sigma(w_i)
69 |
70 | double left_yyw_sum, left_y_w_sum, left_w_sum, left_bar;
71 | double right_yyw_sum, right_y_w_sum, right_w_sum, right_bar;
72 |
73 | public SplitResult(int l, int r, LabeledSample[] samples) {
74 | leftSample = new Stack();
75 | rightSample = new Stack();
76 | init(l, r, samples);
77 | }
78 |
79 | void init(int l, int r, LabeledSample[] samples) {
80 | this.split_error = Double.MAX_VALUE;
81 | leftSample.clear();
82 | rightSample.clear();
83 |
84 | left_yyw_sum = left_y_w_sum = left_w_sum = left_bar = 0;
85 | right_yyw_sum = right_y_w_sum = right_w_sum = right_bar = 0;
86 | for (int i = r; i >= l; i--) {
87 | if (!samples[i].isSampled)
88 | continue;
89 | rightSample.push(samples[i]);
90 | right_yyw_sum += samples[i].y * samples[i].y * samples[i].weight;
91 | right_y_w_sum += samples[i].y * samples[i].weight;
92 | right_w_sum += samples[i].weight;
93 | }
94 | right_bar = right_y_w_sum / right_w_sum;
95 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum;
96 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum;
97 | split_error = left_error + right_error;
98 | }
99 |
100 | void moveSampleToLeft(int split_feature, double split_val) {
101 | while (!rightSample.empty()) {
102 | LabeledSample sample = rightSample.peek();
103 | if (sample.x[split_feature] < split_val) {
104 | rightSample.pop();
105 | right_yyw_sum -= sample.y * sample.y * sample.weight;
106 | right_y_w_sum -= sample.y * sample.weight;
107 | right_w_sum -= sample.weight;
108 |
109 | leftSample.push(sample);
110 | left_yyw_sum += sample.y * sample.y * sample.weight;
111 | left_y_w_sum += sample.y * sample.weight;
112 | left_w_sum += sample.weight;
113 | } else
114 | break;
115 | }
116 | left_bar = left_y_w_sum / left_w_sum;
117 | right_bar = right_y_w_sum / right_w_sum;
118 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum;
119 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum;
120 | split_error = left_error + right_error;
121 | }
122 |
123 | Stack getLeftTrainSamples() {
124 | return leftSample;
125 | }
126 |
127 | Stack getRightTrainSamples() {
128 | return rightSample;
129 | }
130 |
131 | void clear() {
132 | this.leftSample.clear();
133 | this.rightSample.clear();
134 | this.leftSample = this.rightSample = null;
135 | }
136 |
137 | }
138 |
139 | private static double[] splitValueArray = new double[400005];
140 | private static LabeledSample[] samples = new LabeledSample[400005];
141 |
142 | private SplitResult getBestSplit(int l, int r) {
143 |
144 | int featureNum = this.sampleSortedByFeatureArrays[0][0].x.length;
145 | double min_split_error = Double.MAX_VALUE;
146 | int best_feature = -1;
147 | double best_split_val = Double.MAX_VALUE;
148 |
149 | for (int i = l; i <= r; i++) {
150 | if (random.nextDouble() > sample_rate)
151 | this.sampleSortedByFeatureArrays[0][i].isSampled = false;
152 | else
153 | this.sampleSortedByFeatureArrays[0][i].isSampled = true;
154 | }
155 | SplitResult currentResult = new SplitResult(l, r, sampleSortedByFeatureArrays[0]);
156 | for (int i = 0; i < featureNum; i++) {
157 | if (random.nextDouble() > feature_rate)
158 | continue;
159 | int cnt = 0;
160 | for (int j = l; j < r; j++) {
161 | splitValueArray[cnt++] = (sampleSortedByFeatureArrays[i][j].x[i]
162 | + sampleSortedByFeatureArrays[i][j + 1].x[i]) / 2.0;
163 | }
164 | if (cnt == 0)
165 | continue;
166 | currentResult.init(l, r, sampleSortedByFeatureArrays[i]);
167 | for (int j = 0; j < cnt; j++) {
168 | double split_val = splitValueArray[j];
169 | currentResult.moveSampleToLeft(i, split_val);
170 | if (currentResult.split_error < min_split_error
171 | && currentResult.getLeftTrainSamples().size() >= min_leaf_sample
172 | && currentResult.getRightTrainSamples().size() >= min_leaf_sample) {
173 | min_split_error = currentResult.split_error;
174 | best_feature = i;
175 | best_split_val = split_val;
176 | }
177 | }
178 | }
179 | if (best_feature == -1)
180 | return null;
181 | for (int i = l; i <= r; i++)
182 | sampleSortedByFeatureArrays[best_feature][i].isSampled = true;
183 | SplitResult result = new SplitResult(l, r, sampleSortedByFeatureArrays[best_feature]);
184 | result.moveSampleToLeft(best_feature, best_split_val);
185 | result.best_feature = best_feature;
186 | result.best_split_val = best_split_val;
187 | return result;
188 |
189 | }
190 |
191 | private Node createTree(int l, int r, int depth) {
192 | if (depth > max_depth)
193 | return null;
194 | Node root = new Node();
195 | if (l >= r)
196 | return root;
197 | SplitResult result = getBestSplit(l, r);
198 | if (result == null)
199 | return root;
200 | root.split_feature = result.best_feature;
201 | root.split_val = result.best_split_val;
202 | Stack leftSamples = result.getLeftTrainSamples();
203 | Stack rightSamples = result.getRightTrainSamples();
204 | for (LabeledSample sample : leftSamples)
205 | sample.isSplitToLeft = true;
206 | for (LabeledSample sample : rightSamples)
207 | sample.isSplitToLeft = false;
208 | int leftSize = leftSamples.size();
209 | int n_features = sampleSortedByFeatureArrays[0][0].x.length;
210 | for (int i = 0; i < n_features; i++) {
211 | int tot = l;
212 | for (int j = l; j <= r; j++) {
213 | if (sampleSortedByFeatureArrays[i][j].isSplitToLeft) {
214 | samples[tot++] = sampleSortedByFeatureArrays[i][j];
215 | }
216 | }
217 | for (int j = l; j <= r; j++) {
218 | if (!sampleSortedByFeatureArrays[i][j].isSplitToLeft)
219 | samples[tot++] = sampleSortedByFeatureArrays[i][j];
220 | }
221 | for (int j = l; j <= r; j++)
222 | sampleSortedByFeatureArrays[i][j] = samples[j];
223 |
224 | }
225 | result.clear();
226 | root.leftNode = createTree(l, l + leftSize - 1, depth + 1);
227 | root.rightNode = createTree(l + leftSize, r, depth + 1);
228 | return root;
229 | }
230 |
231 | public void fit(double[][] X, double[] Y, double[] sample_weight) {
232 | random_state = random.nextInt();
233 | this.root = createTree(0, X.length - 1, 0);
234 | }
235 |
236 | private Node dfs(Node root, double[] x, int depth) {
237 | if (root.leftNode == null)
238 | return root;
239 | if (x[root.split_feature] < root.split_val)
240 | return dfs(root.leftNode, x, depth + 1);
241 | return dfs(root.rightNode, x, depth + 1);
242 | }
243 |
244 | public Node apply(double[] x) {
245 | return dfs(root, x, 0);
246 | }
247 |
248 | // private int nodeCount = 0;
249 |
250 | // private int dfsNode(Node node, ArrayList infoList, int is_root)
251 | // {
252 | // if (node == null)
253 | // return -1;
254 | // int left_son = dfsNode(node.leftNode, infoList, 0);
255 | // int right_son = dfsNode(node.rightNode, infoList, 0);
256 | // int id = nodeCount++;
257 | // TreeInfo info = new TreeInfo(id, left_son, right_son, node.split_feature,
258 | // node.split_val, estimator_num,
259 | // is_root, node.treeVal);
260 | // infoList.add(info);
261 | // return id;
262 | // }
263 | //
264 | // public ArrayList getTreeInfo() {
265 | // nodeCount = 0;
266 | // ArrayList infoList = new ArrayList();
267 | // dfsNode(root, infoList, 1);
268 | // return infoList;
269 | // }
270 | }
--------------------------------------------------------------------------------
/src/decision_tree/DecisionRegressionTreeLeafSpliter.java:
--------------------------------------------------------------------------------
1 | package decision_tree;
2 |
3 | import java.util.Comparator;
4 | import java.util.Queue;
5 | import java.util.Random;
6 | import java.util.Stack;
7 | import java.util.concurrent.PriorityBlockingQueue;
8 |
9 | import data.LabeledSample;
10 |
11 | public class DecisionRegressionTreeLeafSpliter extends DecisionTree {
12 |
13 | // 树的参数
14 | private int max_depth;
15 | private int random_state;
16 | private Random random;
17 | private Node root;
18 | private double sample_rate;
19 | private double feature_rate;
20 | private int num_leafs;
21 | private LabeledSample[][] sampleSortedByFeatureArrays;
22 | private double min_leaf_sample;
23 |
24 | public DecisionRegressionTreeLeafSpliter(int num_leafs, int random_state, double sample_rate, double feature_rate,
25 | double min_leaf_sample, LabeledSample[][] preSortedSampleArrays) {
26 | this.num_leafs = num_leafs;
27 | this.random_state = random_state;
28 | this.random = new Random(this.random_state);
29 | this.sample_rate = sample_rate;
30 | this.feature_rate = feature_rate;
31 | this.min_leaf_sample = min_leaf_sample;
32 | this.sampleSortedByFeatureArrays = preSortedSampleArrays;
33 | }
34 |
35 | // public DecisionRegressionTree(int estimator_num, TreeInfo[] infoList) {
36 | // this.estimator_num = estimator_num;
37 | // Node[] nodeList = new Node[infoList.length];
38 | // for (int i = 0; i < infoList.length; i++) {
39 | // if (infoList[i] == null)
40 | // continue;
41 | // Node node = new Node();
42 | // nodeList[(int) infoList[i].root_id] = node;
43 | // if ((int) infoList[i].left_son != -1)
44 | // node.leftNode = nodeList[(int) infoList[i].left_son];
45 | // if ((int) infoList[i].right_son != -1)
46 | // node.rightNode = nodeList[(int) infoList[i].right_son];
47 | // node.split_feature = (int) infoList[i].split_feature;
48 | // node.split_val = infoList[i].split_feature_value;
49 | // node.treeVal = infoList[i].node_value;
50 | // if (infoList[i].is_root == 1L)
51 | // this.root = node;
52 | // }
53 | // }
54 |
55 | class SplitResult {
56 | Stack leftSample, rightSample;
57 | double split_error, left_error, right_error;
58 | int best_feature;
59 | double best_split_val;
60 |
61 | // error = var_left + var_right
62 | // var_left = sigma((y_i-y_bar)^2*w)
63 | // = sigma(y_i*y_i*w_i)- y_bar*sigma(2*w_i*y_i)+ y_bar*y_bar*sigma(w_i)
64 | // y_bar = sigma(y_i*w_i)/sigma(w_i)
65 |
66 | double left_yyw_sum, left_y_w_sum, left_w_sum, left_bar;
67 | double right_yyw_sum, right_y_w_sum, right_w_sum, right_bar;
68 |
69 | public SplitResult(int l, int r, LabeledSample[] samples) {
70 | leftSample = new Stack();
71 | rightSample = new Stack();
72 | init(l, r, samples);
73 | }
74 |
75 | void init(int l, int r, LabeledSample[] samples) {
76 | this.split_error = Double.MAX_VALUE;
77 | leftSample.clear();
78 | rightSample.clear();
79 |
80 | left_yyw_sum = left_y_w_sum = left_w_sum = left_bar = 0;
81 | right_yyw_sum = right_y_w_sum = right_w_sum = right_bar = 0;
82 | for (int i = r; i >= l; i--) {
83 | if (!samples[i].isSampled)
84 | continue;
85 | rightSample.push(samples[i]);
86 | right_yyw_sum += samples[i].y * samples[i].y * samples[i].weight;
87 | right_y_w_sum += samples[i].y * samples[i].weight;
88 | right_w_sum += samples[i].weight;
89 | }
90 | right_bar = right_y_w_sum / right_w_sum;
91 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum;
92 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum;
93 | split_error = left_error + right_error;
94 | }
95 |
96 | void moveSampleToLeft(int split_feature, double split_val) {
97 | while (!rightSample.empty()) {
98 | LabeledSample sample = rightSample.peek();
99 | if (sample.x[split_feature] < split_val) {
100 | rightSample.pop();
101 | right_yyw_sum -= sample.y * sample.y * sample.weight;
102 | right_y_w_sum -= sample.y * sample.weight;
103 | right_w_sum -= sample.weight;
104 |
105 | leftSample.push(sample);
106 | left_yyw_sum += sample.y * sample.y * sample.weight;
107 | left_y_w_sum += sample.y * sample.weight;
108 | left_w_sum += sample.weight;
109 | } else
110 | break;
111 | }
112 | left_bar = left_y_w_sum / left_w_sum;
113 | right_bar = right_y_w_sum / right_w_sum;
114 | left_error = left_yyw_sum - 2 * left_y_w_sum * left_bar + left_bar * left_bar * left_w_sum;
115 | right_error = right_yyw_sum - 2 * right_y_w_sum * right_bar + right_bar * right_bar * right_w_sum;
116 | split_error = left_error + right_error;
117 | }
118 |
119 | Stack getLeftTrainSamples() {
120 | return leftSample;
121 | }
122 |
123 | Stack getRightTrainSamples() {
124 | return rightSample;
125 | }
126 |
127 | void clear() {
128 | this.leftSample.clear();
129 | this.rightSample.clear();
130 | this.leftSample = this.rightSample = null;
131 | }
132 |
133 | }
134 |
135 | private static double[] splitValueArray = new double[400005];
136 | private static LabeledSample[] samples = new LabeledSample[400005];
137 |
138 | private SplitResult getBestSplit(int l, int r) {
139 |
140 | int featureNum = this.sampleSortedByFeatureArrays[0][0].x.length;
141 | double min_split_error = Double.MAX_VALUE;
142 | int best_feature = -1;
143 | double best_split_val = Double.MAX_VALUE;
144 |
145 | for (int i = l; i <= r; i++) {
146 | if (random.nextDouble() > sample_rate)
147 | this.sampleSortedByFeatureArrays[0][i].isSampled = false;
148 | else
149 | this.sampleSortedByFeatureArrays[0][i].isSampled = true;
150 | }
151 | SplitResult currentResult = new SplitResult(l, r, sampleSortedByFeatureArrays[0]);
152 | for (int i = 0; i < featureNum; i++) {
153 | if (random.nextDouble() > feature_rate)
154 | continue;
155 | int cnt = 0;
156 | for (int j = l; j < r; j++) {
157 | splitValueArray[cnt++] = (sampleSortedByFeatureArrays[i][j].x[i]
158 | + sampleSortedByFeatureArrays[i][j + 1].x[i]) / 2.0;
159 | }
160 | if (cnt == 0)
161 | continue;
162 | currentResult.init(l, r, sampleSortedByFeatureArrays[i]);
163 | for (int j = 0; j < cnt; j++) {
164 | double split_val = splitValueArray[j];
165 | currentResult.moveSampleToLeft(i, split_val);
166 | if (currentResult.split_error < min_split_error
167 | && currentResult.getLeftTrainSamples().size() >= min_leaf_sample
168 | && currentResult.getRightTrainSamples().size() >= min_leaf_sample) {
169 | min_split_error = currentResult.split_error;
170 | best_feature = i;
171 | best_split_val = split_val;
172 | }
173 | }
174 | }
175 | if (best_feature == -1)
176 | return null;
177 | for (int i = l; i <= r; i++)
178 | sampleSortedByFeatureArrays[best_feature][i].isSampled = true;
179 | SplitResult result = new SplitResult(l, r, sampleSortedByFeatureArrays[best_feature]);
180 | result.moveSampleToLeft(best_feature, best_split_val);
181 | result.best_feature = best_feature;
182 | result.best_split_val = best_split_val;
183 | return result;
184 |
185 | }
186 |
187 | private QueueNode createTree(int l, int r) {
188 | Node root = new Node();
189 | QueueNode qNode = new QueueNode();
190 | qNode.node = root;
191 | qNode.l = l;
192 | qNode.r = r;
193 | qNode.gain = 0;
194 | if (l >= r)
195 | return qNode;
196 | SplitResult result = getBestSplit(l, r);
197 | if (result == null) {
198 | return qNode;
199 | }
200 | qNode.splitResult = result;
201 | root.split_feature = result.best_feature;
202 | root.split_val = result.best_split_val;
203 | qNode.gain = calCurrentLoss(l, r) - result.split_error;
204 | return qNode;
205 | }
206 |
207 | private int doSplit(QueueNode node, int l, int r) {
208 | SplitResult result = node.splitResult;
209 | if (result == null)
210 | return 0;
211 | Stack leftSamples = result.getLeftTrainSamples();
212 | Stack rightSamples = result.getRightTrainSamples();
213 |
214 | for (LabeledSample sample : leftSamples)
215 | sample.isSplitToLeft = true;
216 | for (LabeledSample sample : rightSamples)
217 | sample.isSplitToLeft = false;
218 | int leftSize = leftSamples.size();
219 | int n_features = sampleSortedByFeatureArrays[0][0].x.length;
220 | for (int i = 0; i < n_features; i++) {
221 | int tot = l;
222 | for (int j = l; j <= r; j++) {
223 | if (sampleSortedByFeatureArrays[i][j].isSplitToLeft) {
224 | samples[tot++] = sampleSortedByFeatureArrays[i][j];
225 | }
226 | }
227 | for (int j = l; j <= r; j++) {
228 | if (!sampleSortedByFeatureArrays[i][j].isSplitToLeft)
229 | samples[tot++] = sampleSortedByFeatureArrays[i][j];
230 | }
231 | for (int j = l; j <= r; j++)
232 | sampleSortedByFeatureArrays[i][j] = samples[j];
233 | }
234 | result.clear();
235 | return leftSize;
236 | }
237 |
238 | private double calCurrentLoss(int l, int r) {
239 | double sum = 0;
240 | double sumWeight = 0;
241 | for (int i = l; i <= r; i++) {
242 | sum += sampleSortedByFeatureArrays[0][i].y;
243 | sumWeight += sampleSortedByFeatureArrays[0][i].weight;
244 | }
245 | double avg = sum / sumWeight;
246 | double error = 0;
247 | for (int i = l; i <= r; i++) {
248 | error += (sampleSortedByFeatureArrays[0][i].y - avg) * (sampleSortedByFeatureArrays[0][i].y - avg)
249 | * sampleSortedByFeatureArrays[0][i].weight;
250 | }
251 | return error;
252 | }
253 |
254 | public static class QueueNode {
255 | public SplitResult splitResult;
256 | Node node;
257 | double gain;
258 | int l, r;
259 | }
260 |
261 | private Comparator cmp = new Comparator() {
262 | public int compare(QueueNode e1, QueueNode e2) {
263 | return Double.valueOf(e2.gain).compareTo(e1.gain);
264 | }
265 | };
266 |
267 | public void fit(double[][] X, double[] Y, double[] sample_weight) {
268 | random_state = random.nextInt();
269 | Queue queue = new PriorityBlockingQueue<>(this.num_leafs * 2, cmp);
270 | QueueNode qNode = createTree(0, X.length - 1);
271 | this.root = qNode.node;
272 | queue.add(qNode);
273 | int currentLeaf = 1;
274 | while (!queue.isEmpty()) {
275 | qNode = queue.poll();
276 | System.out.println(qNode.gain);
277 | if (qNode.l == qNode.r)
278 | continue;
279 | int leftSize = doSplit(qNode, qNode.l, qNode.r);
280 | QueueNode qNodeLeft = createTree(qNode.l, qNode.l + leftSize - 1);
281 | QueueNode qNodeRight = createTree(qNode.l + leftSize, qNode.r);
282 | queue.add(qNodeLeft);
283 | queue.add(qNodeRight);
284 | qNode.node.leftNode = qNodeLeft.node;
285 | qNode.node.rightNode = qNodeRight.node;
286 | currentLeaf += 1;
287 | if (currentLeaf == this.num_leafs)
288 | break;
289 | }
290 | }
291 |
292 | private Node dfs(Node root, double[] x, int depth) {
293 | if (root.leftNode == null)
294 | return root;
295 | if (x[root.split_feature] < root.split_val)
296 | return dfs(root.leftNode, x, depth + 1);
297 | return dfs(root.rightNode, x, depth + 1);
298 | }
299 |
300 | public Node apply(double[] x) {
301 | return dfs(root, x, 0);
302 | }
303 |
304 | // private int nodeCount = 0;
305 |
306 | // private int dfsNode(Node node, ArrayList infoList, int is_root)
307 | // {
308 | // if (node == null)
309 | // return -1;
310 | // int left_son = dfsNode(node.leftNode, infoList, 0);
311 | // int right_son = dfsNode(node.rightNode, infoList, 0);
312 | // int id = nodeCount++;
313 | // TreeInfo info = new TreeInfo(id, left_son, right_son, node.split_feature,
314 | // node.split_val, estimator_num,
315 | // is_root, node.treeVal);
316 | // infoList.add(info);
317 | // return id;
318 | // }
319 | //
320 | // public ArrayList getTreeInfo() {
321 | // nodeCount = 0;
322 | // ArrayList infoList = new ArrayList();
323 | // dfsNode(root, infoList, 1);
324 | // return infoList;
325 | // }
326 | }
--------------------------------------------------------------------------------
/src/decision_tree/DecisionTree.java:
--------------------------------------------------------------------------------
1 | package decision_tree;
2 |
3 | public abstract class DecisionTree {
4 | public abstract void fit(double[][] X, double[] Y, double[] sample_weight);
5 |
6 | public abstract Node apply(double[] x);
7 | }
8 |
--------------------------------------------------------------------------------
/src/decision_tree/Node.java:
--------------------------------------------------------------------------------
1 | package decision_tree;
2 |
3 | import java.util.ArrayList;
4 |
5 | public class Node {
6 | double split_val;
7 | int split_feature;
8 | public Node leftNode;
9 | public Node rightNode;
10 | public double treeVal;
11 |
12 | // 落在当前叶子节点的样本的统计信息
13 | public ArrayList diff = new ArrayList();
14 | public ArrayList diff_sample_weight = new ArrayList();
15 |
16 | public void setTreeValue(double value) {
17 | this.treeVal = value;
18 | }
19 |
20 | public void clear() {
21 | this.diff.clear();
22 | this.diff.trimToSize();
23 | this.diff_sample_weight.clear();
24 | this.diff_sample_weight.trimToSize();
25 | this.diff = null;
26 | this.diff_sample_weight = null;
27 | }
28 | }
--------------------------------------------------------------------------------
/src/objective/Estimator.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import objective.Utils.Item;
4 |
5 | public abstract class Estimator {
6 | public abstract void fit(double[][] X, double[] y, double[] sample_weight);
7 |
8 | public abstract double[] predict(double[][] X);
9 | }
--------------------------------------------------------------------------------
/src/objective/LossFunction.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import decision_tree.DecisionTree;
4 |
5 | public abstract class LossFunction {
6 | public abstract double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual);
7 |
8 | public abstract void update_terminal_region(DecisionTree tree, double[][] sample, double[] label, double[] y_pred,
9 | double[] sample_weight);
10 | }
--------------------------------------------------------------------------------
/src/objective/QuantileEstimator.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | public class QuantileEstimator extends Estimator {
4 | private double alpha;
5 | private double quantile;
6 |
7 | public QuantileEstimator(double alpha) {
8 | this.alpha = alpha;
9 | }
10 |
11 | public void fit(double[][] X, double[] y, double[] sample_weight) {
12 | this.quantile = Utils.weighted_percentile(y, sample_weight, this.alpha * 100.0);
13 |
14 | }
15 |
16 | public double[] predict(double[][] X) {
17 | int size = X.length;
18 | double[] ans = new double[size];
19 | for (int i = 0; i < size; i++)
20 | ans[i] = this.quantile;
21 | return ans;
22 | }
23 | }
--------------------------------------------------------------------------------
/src/objective/QuantileLossFunction.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import java.util.HashSet;
4 |
5 | import decision_tree.DecisionRegressionTree;
6 | import decision_tree.DecisionTree;
7 | import decision_tree.Node;
8 |
9 | public class QuantileLossFunction extends LossFunction {
10 |
11 | private double alpha;
12 | private double percentile;
13 |
14 | public QuantileLossFunction(double alpha) {
15 | this.alpha = alpha;
16 | this.percentile = alpha * 100.0;
17 | }
18 |
19 | public double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual) {
20 | int size = y_true.length;
21 | for (int i = 0; i < size; i++) {
22 | if (y_true[i] > y_pred[i])
23 | residual[i] = alpha;
24 | else
25 | residual[i] = -1.0 + alpha;
26 | }
27 | return residual;
28 | }
29 |
30 | @Override
31 | public void update_terminal_region(DecisionTree tree, double[][] x, double[] y, double[] y_pred,
32 | double[] sample_weight) {
33 | int n_samples = x.length;
34 | HashSet nodeSet = new HashSet();
35 | for (int i = 0; i < n_samples; i++) {
36 | Node leaf = tree.apply(x[i]);
37 | leaf.diff.add(y[i] - y_pred[i]);
38 | leaf.diff_sample_weight.add(sample_weight[i]);
39 | nodeSet.add(leaf);
40 | }
41 | for (Node leaf : nodeSet) {
42 | leaf.setTreeValue(Utils.weighted_percentile(leaf.diff, leaf.diff_sample_weight, this.percentile));
43 | // System.out.println(leaf+" "+leaf.treeVal);
44 | leaf.clear();
45 | }
46 | }
47 | }
--------------------------------------------------------------------------------
/src/objective/SquaresEstimator.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import java.util.Arrays;
4 |
5 | public class SquaresEstimator extends Estimator {
6 |
7 | private double mean;
8 |
9 | @Override
10 | public void fit(double[][] X, double[] y, double[] sample_weight) {
11 | // TODO Auto-generated method stub
12 | this.mean = Utils.averavge(y, sample_weight);
13 | }
14 |
15 | @Override
16 | public double[] predict(double[][] X) {
17 | double[] ans = new double[X.length];
18 | Arrays.fill(ans, mean);
19 | return ans;
20 | }
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/src/objective/SquaresLossFunction.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import java.util.HashSet;
4 |
5 | import decision_tree.DecisionRegressionTree;
6 | import decision_tree.DecisionTree;
7 | import decision_tree.Node;
8 |
9 | public class SquaresLossFunction extends LossFunction {
10 |
11 | public SquaresLossFunction() {
12 | super();
13 | // TODO Auto-generated constructor stub
14 | }
15 |
16 | @Override
17 | public double[] negative_gradient(double[] y_true, double[] y_pred, double[] residual) {
18 | // TODO Auto-generated method stub
19 | for (int i = 0; i < y_true.length; i++)
20 | residual[i] = y_true[i] - y_pred[i];
21 | return residual;
22 | }
23 |
24 | @Override
25 | public void update_terminal_region(DecisionTree tree, double[][] x, double[] y, double[] y_pred,
26 | double[] sample_weight) {
27 | // TODO Auto-generated method stub
28 | int n_samples = x.length;
29 | HashSet nodeSet = new HashSet();
30 | for (int i = 0; i < n_samples; i++) {
31 | Node leaf = tree.apply(x[i]);
32 | leaf.diff.add(y[i] - y_pred[i]);
33 | leaf.diff_sample_weight.add(sample_weight[i]);
34 | nodeSet.add(leaf);
35 | }
36 | for (Node leaf : nodeSet) {
37 | leaf.setTreeValue(Utils.averavge(leaf.diff, leaf.diff_sample_weight));
38 | // System.out.println(leaf+" "+leaf.treeVal);
39 | leaf.clear();
40 | }
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/src/objective/Utils.java:
--------------------------------------------------------------------------------
1 | package objective;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Arrays;
5 | import java.util.Collections;
6 | import java.util.Comparator;
7 | import java.util.LinkedList;
8 | import java.util.Random;
9 |
10 | public class Utils {
11 |
12 | public static class Item implements Comparable- {
13 | double y;
14 | double sample_weight;
15 |
16 | public Item(double y, double sample_weight) {
17 | this.y = y;
18 | this.sample_weight = sample_weight;
19 | }
20 |
21 | @Override
22 | public int compareTo(Item o) {
23 | return new Double(y).compareTo(o.y);
24 | }
25 | }
26 |
27 | public static double weighted_percentile(ArrayList y, ArrayList sample_weight, double percentile) {
28 | Item[] items = new Item[y.size()];
29 | for (int i = 0; i < y.size(); i++)
30 | items[i] = new Item(y.get(i), sample_weight.get(i));
31 | Arrays.sort(items);
32 | double[] weight_sum = new double[y.size()];
33 | for (int i = 0; i < y.size(); i++) {
34 | weight_sum[i] = items[i].sample_weight;
35 | if (i > 0)
36 | weight_sum[i] += weight_sum[i - 1];
37 | }
38 | double target = weight_sum[y.size() - 1] * percentile / 100.0;
39 | for (int i = 1; i < y.size(); i++)
40 | if (weight_sum[i] > target) {
41 | // System.out.println(y.size() + " " + weight_sum[i] + " " +
42 | // y.get(i - 1) + " " + target);
43 | return items[i - 1].y;
44 | }
45 | return items[y.size() - 1].y;
46 | }
47 |
48 | public static double weighted_percentile(double[] y, double[] sample_weight, double percentile) {
49 | Item[] items = new Item[y.length];
50 | for (int i = 0; i < y.length; i++) {
51 | items[i] = new Item(y[i], sample_weight[i]);
52 | }
53 | Arrays.sort(items);
54 | double[] weight_sum = new double[y.length];
55 | for (int i = 0; i < y.length; i++) {
56 | weight_sum[i] = items[i].sample_weight;
57 | if (i > 0)
58 | weight_sum[i] += weight_sum[i - 1];
59 | }
60 | double target = weight_sum[y.length - 1] * percentile / 100.0;
61 | for (int i = 1; i < y.length; i++) {
62 | if (weight_sum[i] > target) {
63 | // System.out.println(y.length + " " + weight_sum[i] + " " + y[i
64 | // - 1] + " " + target);
65 | return items[i - 1].y;
66 | }
67 | }
68 | return items[y.length - 1].y;
69 | }
70 |
71 | public static void sample(int n, int m, int[] ret) {
72 | LinkedList list = new LinkedList();
73 | for (int i = 0; i < n; i++)
74 | list.add(i);
75 | Collections.shuffle(list);
76 | for (int i = 0; i < m; i++)
77 | ret[i] = list.pop();
78 | }
79 |
80 | public static double averavge(ArrayList diff, ArrayList diff_sample_weight) {
81 | // TODO Auto-generated method stub
82 | double sum = 0, sumWeight = 0;
83 | for (int i = 0; i < diff.size(); i++) {
84 | sum += diff.get(i) * diff_sample_weight.get(i);
85 | sumWeight += diff_sample_weight.get(i);
86 | }
87 | return sum / sumWeight;
88 | }
89 |
90 | public static double averavge(double[] y, double[] sample_weight) {
91 | double sum = 0, sumWeight = 0;
92 | for (int i = 0; i < y.length; i++) {
93 | sum += y[i] * sample_weight[i];
94 | sumWeight += sample_weight[i];
95 | }
96 | return sum / sumWeight;
97 | }
98 | }
--------------------------------------------------------------------------------
/src/rf/RandomForestRegressor.java:
--------------------------------------------------------------------------------
1 | package rf;
2 |
3 | import java.util.Arrays;
4 | import java.util.Comparator;
5 | import java.util.HashMap;
6 |
7 | import data.LabeledSample;
8 | import decision_tree.DecisionRegressionTree;
9 | import decision_tree.DecisionRegressionTreeLeafSpliter;
10 | import decision_tree.DecisionTree;
11 | import decision_tree.Node;
12 | import objective.LossFunction;
13 | import objective.SquaresLossFunction;
14 | import util.ParamException;
15 | import util.ParamReader;
16 |
17 | public class RandomForestRegressor {
18 |
19 | private String spliter;
20 | private int num_leafs;
21 | private int max_depth;
22 | private int n_estimator;
23 | private int random_state;
24 | private double sample_rate;
25 | private double feature_rate;
26 | private double min_leaf_sample;
27 | private DecisionTree[] tree;
28 | private LabeledSample[] trainData;
29 | private LabeledSample[][] preSortedSampleArrays;
30 | private LabeledSample[][] copyOfPreSortedSampleArrays;
31 | private LossFunction loss;
32 |
33 | public RandomForestRegressor(HashMap params) throws ParamException {
34 |
35 | // -----------------------------------------------------------------------
36 | // 必需参数
37 |
38 | assert (params.containsKey("spliter"));
39 | this.spliter = ParamReader.readString("spliter", params);
40 | if (this.spliter.equals("leaf")) {
41 | assert (params.containsKey("num_leafs"));
42 | this.num_leafs = ParamReader.readInt("num_leafs", params);
43 | } else {
44 | assert (params.containsKey("max_depth"));
45 | this.max_depth = ParamReader.readInt("max_depth", params);
46 | }
47 | assert (params.containsKey("n_estimator"));
48 | this.n_estimator = ParamReader.readInt("n_estimator", params);
49 |
50 | // -------------------------------------------------------------------------
51 | if (params.containsKey("random_state"))
52 | this.random_state = ParamReader.readInt("random_state", params);
53 | else
54 | this.random_state = 0;
55 | if (params.containsKey("sample_rate"))
56 | this.sample_rate = ParamReader.readDouble("sample_rate", params);
57 | else
58 | this.sample_rate = 1.0;
59 | if (params.containsKey("feature_rate"))
60 | this.feature_rate = ParamReader.readDouble("feature_rate", params);
61 | else
62 | this.feature_rate = 1.0;
63 |
64 | if (params.containsKey("min_leaf_sample"))
65 | this.min_leaf_sample = ParamReader.readDouble("min_leaf_sample", params);
66 | else
67 | this.min_leaf_sample = 1;
68 | }
69 |
70 | public void fit(double[][] X, double[] Y, double[] sample_weight) {
71 |
72 | int featureNum = X[0].length;
73 | int sampleNum = X.length;
74 | this.trainData = new LabeledSample[sampleNum];
75 | for (int i = 0; i < sampleNum; i++) {
76 | trainData[i] = new LabeledSample();
77 | trainData[i].x = X[i];
78 | trainData[i].y = Y[i];
79 | trainData[i].weight = sample_weight[i];
80 | }
81 | this.loss = new SquaresLossFunction();
82 |
83 | this.preSortedSampleArrays = new LabeledSample[featureNum][sampleNum];
84 | this.copyOfPreSortedSampleArrays = new LabeledSample[featureNum][sampleNum];
85 | for (int a = 0; a < featureNum; a++) {
86 | for (int b = 0; b < sampleNum; b++) {
87 | copyOfPreSortedSampleArrays[a][b] = trainData[b];
88 | }
89 |
90 | final int compareFeature = a;
91 | Arrays.sort(copyOfPreSortedSampleArrays[a], 0, sampleNum, new Comparator() {
92 |
93 | @Override
94 | public int compare(LabeledSample o1, LabeledSample o2) {
95 | return new Double(o1.x[compareFeature]).compareTo(o2.x[compareFeature]);
96 | }
97 | });
98 | }
99 |
100 | this.tree = new DecisionTree[this.n_estimator];
101 | for (int i = 0; i < n_estimator; i++) {
102 | for (int a = 0; a < featureNum; a++) {
103 | for (int b = 0; b < sampleNum; b++) {
104 | preSortedSampleArrays[a][b] = copyOfPreSortedSampleArrays[a][b];
105 | }
106 | }
107 | if (this.spliter.equals("leaf"))
108 | this.tree[i] = new DecisionRegressionTreeLeafSpliter(this.num_leafs, i, sample_rate, feature_rate,
109 | min_leaf_sample, preSortedSampleArrays);
110 | else
111 | this.tree[i] = new DecisionRegressionTree(i, this.max_depth, i, sample_rate, feature_rate,
112 | min_leaf_sample, preSortedSampleArrays);
113 | System.out.println("第" + i + "棵树");
114 | this.tree[i].fit(X, Y, sample_weight);
115 | double[] pred = new double[X.length];
116 | this.loss.update_terminal_region(this.tree[i], X, Y, pred, sample_weight);
117 | }
118 | }
119 |
120 | public double[][] predict(double[][] X) {
121 | double ans[][] = new double[this.n_estimator][X.length];
122 | for (int i = 0; i < this.n_estimator; i++) {
123 | for (int j = 0; j < X.length; j++) {
124 | ans[i][j] = this.tree[i].apply(X[j]).treeVal;
125 | }
126 | }
127 | return ans;
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/src/util/BoostingListener.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | public interface BoostingListener {
4 | public void done(int treeNum);
5 | }
6 |
--------------------------------------------------------------------------------
/src/util/DumpTree.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | public class DumpTree {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/util/ParamException.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | public class ParamException extends Exception {
4 |
5 | public ParamException(String string) {
6 | // TODO Auto-generated constructor stub
7 | super(string);
8 | }
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/util/ParamReader.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | import java.util.HashMap;
4 |
5 | public class ParamReader {
6 | public static int readInt(String key, HashMap map) throws ParamException {
7 | if (!map.containsKey(key))
8 | throw new ParamException("Params doesn't have key " + key);
9 | return Integer.valueOf(map.get(key).toString());
10 | }
11 |
12 | public static double readDouble(String key, HashMap map) throws ParamException {
13 | if (!map.containsKey(key))
14 | throw new ParamException("Params doesn't have key " + key);
15 | return Double.valueOf(map.get(key).toString());
16 | }
17 |
18 | public static String readString(String key, HashMap map) throws ParamException {
19 | if (!map.containsKey(key))
20 | throw new ParamException("Params doesn't have key " + key);
21 | return map.get(key).toString();
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/util/dump/TreeInfo.java:
--------------------------------------------------------------------------------
1 | package util.dump;
2 |
3 | // 基于对应的数据表
4 |
5 | public class TreeInfo {
6 | public long root_id;
7 | public long left_son;
8 | public long right_son;
9 | public long split_feature;
10 | public double split_feature_value;
11 | public long estimator_num;
12 | public long is_root;
13 | public double node_value;
14 |
15 | public TreeInfo(long root_id, long left_son, long right_son, long split_feature, double split_feature_value,
16 | long estimator_num, long is_root, double node_value) {
17 | this.root_id = root_id;
18 | this.left_son = left_son;
19 | this.right_son = right_son;
20 | this.split_feature = split_feature;
21 | this.split_feature_value = split_feature_value;
22 | this.estimator_num = estimator_num;
23 | this.is_root = is_root;
24 | this.node_value = node_value;
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------