214 | * This method only works when the model outputs single value. 215 | *
216 | * 217 | * @param feat feature vector 218 | * @return prediction value 219 | */ 220 | public float predictSingle(FVec feat) { 221 | return predictSingle(feat, false); 222 | } 223 | 224 | /** 225 | * Generates a prediction for given feature vector. 226 | *227 | * This method only works when the model outputs single value. 228 | *
229 | * 230 | * @param feat feature vector 231 | * @param output_margin whether to only predict margin value instead of transformed prediction 232 | * @return prediction value 233 | */ 234 | public float predictSingle(FVec feat, boolean output_margin) { 235 | return predictSingle(feat, output_margin, 0); 236 | } 237 | 238 | /** 239 | * Generates a prediction for given feature vector. 240 | *241 | * This method only works when the model outputs single value. 242 | *
243 | * 244 | * @param feat feature vector 245 | * @param output_margin whether to only predict margin value instead of transformed prediction 246 | * @param ntree_limit limit the number of trees used in prediction 247 | * @return prediction value 248 | */ 249 | public float predictSingle(FVec feat, boolean output_margin, int ntree_limit) { 250 | float pred = predictSingleRaw(feat, ntree_limit); 251 | if (!output_margin) { 252 | pred = obj.predTransform(pred); 253 | } 254 | return pred; 255 | } 256 | 257 | float predictSingleRaw(FVec feat, int ntree_limit) { 258 | if (isBeforeOrEqual12()) { 259 | return gbm.predictSingle(feat, ntree_limit, 0) + base_score; 260 | } else { 261 | return gbm.predictSingle(feat, ntree_limit, base_score); 262 | } 263 | } 264 | 265 | /** 266 | * Predicts leaf index of each tree. 267 | * 268 | * @param feat feature vector 269 | * @return leaf indexes 270 | */ 271 | public int[] predictLeaf(FVec feat) { 272 | return predictLeaf(feat, 0); 273 | } 274 | 275 | /** 276 | * Predicts leaf index of each tree. 277 | * 278 | * @param feat feature vector 279 | * @param ntree_limit limit, 0 for all 280 | * @return leaf indexes 281 | */ 282 | public int[] predictLeaf(FVec feat, int ntree_limit) { 283 | return gbm.predictLeaf(feat, ntree_limit); 284 | } 285 | 286 | /** 287 | * Predicts path to leaf of each tree. 288 | * 289 | * @param feat feature vector 290 | * @return leaf paths 291 | */ 292 | public String[] predictLeafPath(FVec feat) { 293 | return predictLeafPath(feat, 0); 294 | } 295 | 296 | /** 297 | * Predicts path to leaf of each tree. 298 | * 299 | * @param feat feature vector 300 | * @param ntree_limit limit, 0 for all 301 | * @return leaf paths 302 | */ 303 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 304 | return gbm.predictLeafPath(feat, ntree_limit); 305 | } 306 | 307 | public SparkModelParam getSparkModelParam() { 308 | return sparkModelParam; 309 | } 310 | 311 | /** 312 | * Returns number of class. 313 | * 314 | * @return number of class 315 | */ 316 | public int getNumClass() { 317 | return mparam.num_class; 318 | } 319 | 320 | /** 321 | * Used e.g. for the change od floating point operation order in between xgboost 1.2 and 1.3 322 | * 323 | * @return True if the booster was build with xgboost version <= 1.2. 324 | */ 325 | private boolean isBeforeOrEqual12() { 326 | return mparam.major_version < 1 || (mparam.major_version == 1 && mparam.minor_version <= 2); 327 | } 328 | 329 | /** 330 | * Parameters. 331 | */ 332 | static class ModelParam implements Serializable { 333 | /* \brief global bias */ 334 | final float base_score; 335 | /* \brief number of features */ 336 | final /* unsigned */ int num_feature; 337 | /* \brief number of class, if it is multi-class classification */ 338 | final int num_class; 339 | /*! \brief whether the model itself is saved with pbuffer */ 340 | final int saved_with_pbuffer; 341 | /*! \brief Model contain eval metrics */ 342 | private final int contain_eval_metrics; 343 | /*! \brief the version of XGBoost. */ 344 | private final int major_version; 345 | private final int minor_version; 346 | /*! \brief reserved field */ 347 | final int[] reserved; 348 | 349 | ModelParam(float base_score, int num_feature, ModelReader reader) throws IOException { 350 | this.base_score = base_score; 351 | this.num_feature = num_feature; 352 | this.num_class = reader.readInt(); 353 | this.saved_with_pbuffer = reader.readInt(); 354 | this.contain_eval_metrics = reader.readInt(); 355 | this.major_version = reader.readUnsignedInt(); 356 | this.minor_version = reader.readUnsignedInt(); 357 | this.reserved = reader.readIntArray(27); 358 | } 359 | } 360 | 361 | public GradBooster getBooster(){ 362 | return gbm; 363 | } 364 | 365 | public String getObjName() { 366 | return name_obj; 367 | } 368 | 369 | public float getBaseScore() { 370 | return base_score; 371 | } 372 | 373 | } 374 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/config/PredictorConfiguration.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.config; 2 | 3 | import biz.k11i.xgboost.learner.ObjFunction; 4 | import biz.k11i.xgboost.tree.DefaultRegTreeFactory; 5 | import biz.k11i.xgboost.tree.RegTreeFactory; 6 | 7 | public class PredictorConfiguration { 8 | public static class Builder { 9 | private PredictorConfiguration predictorConfiguration; 10 | 11 | Builder() { 12 | predictorConfiguration = new PredictorConfiguration(); 13 | } 14 | 15 | public Builder objFunction(ObjFunction objFunction) { 16 | predictorConfiguration.objFunction = objFunction; 17 | return this; 18 | } 19 | 20 | public Builder regTreeFactory(RegTreeFactory regTreeFactory) { 21 | predictorConfiguration.regTreeFactory = regTreeFactory; 22 | return this; 23 | } 24 | 25 | public PredictorConfiguration build() { 26 | PredictorConfiguration result = predictorConfiguration; 27 | predictorConfiguration = null; 28 | return result; 29 | } 30 | } 31 | 32 | public static final PredictorConfiguration DEFAULT = new PredictorConfiguration(); 33 | 34 | private ObjFunction objFunction; 35 | private RegTreeFactory regTreeFactory; 36 | 37 | public PredictorConfiguration() { 38 | this.regTreeFactory = DefaultRegTreeFactory.INSTANCE; 39 | } 40 | 41 | public ObjFunction getObjFunction() { 42 | return objFunction; 43 | } 44 | 45 | public RegTreeFactory getRegTreeFactory() { 46 | return regTreeFactory; 47 | } 48 | 49 | public static Builder builder() { 50 | return new Builder(); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/Dart.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.tree.RegTree; 5 | import biz.k11i.xgboost.util.FVec; 6 | import biz.k11i.xgboost.util.ModelReader; 7 | 8 | import java.io.IOException; 9 | import java.util.Arrays; 10 | 11 | /** 12 | * Gradient boosted DART tree implementation. 13 | */ 14 | public class Dart extends GBTree { 15 | private float[] weightDrop; 16 | 17 | Dart() { 18 | // do nothing 19 | } 20 | 21 | @Override 22 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException { 23 | super.loadModel(config, reader, with_pbuffer); 24 | if (mparam.num_trees != 0) { 25 | long size = reader.readLong(); 26 | weightDrop = reader.readFloatArray((int)size); 27 | } 28 | } 29 | 30 | @Override 31 | float pred(FVec feat, int bst_group, int root_index, int ntree_limit, float base_score) { 32 | RegTree[] trees = _groupTrees[bst_group]; 33 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 34 | 35 | float psum = base_score; 36 | for (int i = 0; i < treeleft; i++) { 37 | psum += weightDrop[i] * trees[i].getLeafValue(feat, root_index); 38 | } 39 | 40 | return psum; 41 | } 42 | 43 | public float weight(int tidx) { 44 | return weightDrop[tidx]; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GBLinear.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Linear booster implementation 12 | */ 13 | public class GBLinear extends GBBase { 14 | 15 | private float[] weights; 16 | 17 | @Override 18 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean ignored_with_pbuffer) throws IOException { 19 | new ModelParam(reader); 20 | long len = reader.readLong(); 21 | if (len == 0) { 22 | weights = new float[(num_feature + 1) * num_output_group]; 23 | } else { 24 | weights = reader.readFloatArray((int) len); 25 | } 26 | } 27 | 28 | @Override 29 | public float[] predict(FVec feat, int ntree_limit, float base_score) { 30 | float[] preds = new float[num_output_group]; 31 | for (int gid = 0; gid < num_output_group; ++gid) { 32 | preds[gid] = pred(feat, gid, base_score); 33 | } 34 | return preds; 35 | } 36 | 37 | @Override 38 | public float predictSingle(FVec feat, int ntree_limit, float base_score) { 39 | if (num_output_group != 1) { 40 | throw new IllegalStateException( 41 | "Can't invoke predictSingle() because this model outputs multiple values: " 42 | + num_output_group); 43 | } 44 | return pred(feat, 0, base_score); 45 | } 46 | 47 | float pred(FVec feat, int gid, float base_score) { 48 | float psum = bias(gid) + base_score; 49 | float featValue; 50 | for (int fid = 0; fid < num_feature; ++fid) { 51 | featValue = feat.fvalue(fid); 52 | if (!Float.isNaN(featValue)) { 53 | psum += featValue * weight(fid, gid); 54 | } 55 | } 56 | return psum; 57 | } 58 | 59 | @Override 60 | public int[] predictLeaf(FVec feat, int ntree_limit) { 61 | throw new UnsupportedOperationException("gblinear does not support predict leaf index"); 62 | } 63 | 64 | @Override 65 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 66 | throw new UnsupportedOperationException("gblinear does not support predict leaf path"); 67 | } 68 | 69 | public float weight(int fid, int gid) { 70 | return weights[(fid * num_output_group) + gid]; 71 | } 72 | 73 | public float bias(int gid) { 74 | return weights[(num_feature * num_output_group) + gid]; 75 | } 76 | 77 | static class ModelParam implements Serializable { 78 | /*! \brief reserved space */ 79 | final int[] reserved; 80 | 81 | ModelParam(ModelReader reader) throws IOException { 82 | reader.readUnsignedInt(); // num_feature deprecated 83 | reader.readInt(); // num_output_group deprecated 84 | reserved = reader.readIntArray(32); 85 | } 86 | } 87 | 88 | public int getNumFeature() { 89 | return num_feature; 90 | } 91 | 92 | public int getNumOutputGroup() { 93 | return num_output_group; 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GBTree.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.tree.RegTree; 5 | import biz.k11i.xgboost.util.FVec; 6 | import biz.k11i.xgboost.util.ModelReader; 7 | 8 | import java.io.IOException; 9 | import java.io.Serializable; 10 | import java.util.Arrays; 11 | 12 | /** 13 | * Gradient boosted tree implementation. 14 | */ 15 | public class GBTree extends GBBase { 16 | 17 | ModelParam mparam; 18 | private RegTree[] trees; 19 | 20 | RegTree[][] _groupTrees; 21 | 22 | @Override 23 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException { 24 | mparam = new ModelParam(reader); 25 | 26 | trees = new RegTree[mparam.num_trees]; 27 | for (int i = 0; i < mparam.num_trees; i++) { 28 | trees[i] = config.getRegTreeFactory().loadTree(reader); 29 | } 30 | 31 | int[] tree_info = mparam.num_trees > 0 ? reader.readIntArray(mparam.num_trees) : new int[0]; 32 | 33 | if (mparam.num_pbuffer != 0 && with_pbuffer) { 34 | reader.skip(4 * predBufferSize()); 35 | reader.skip(4 * predBufferSize()); 36 | } 37 | 38 | _groupTrees = new RegTree[num_output_group][]; 39 | for (int i = 0; i < num_output_group; i++) { 40 | int treeCount = 0; 41 | for (int j = 0; j < tree_info.length; j++) { 42 | if (tree_info[j] == i) { 43 | treeCount++; 44 | } 45 | } 46 | 47 | _groupTrees[i] = new RegTree[treeCount]; 48 | treeCount = 0; 49 | 50 | for (int j = 0; j < tree_info.length; j++) { 51 | if (tree_info[j] == i) { 52 | _groupTrees[i][treeCount++] = trees[j]; 53 | } 54 | } 55 | } 56 | } 57 | 58 | @Override 59 | public float[] predict(FVec feat, int ntree_limit, float base_score) { 60 | float[] preds = new float[num_output_group]; 61 | for (int gid = 0; gid < num_output_group; gid++) { 62 | preds[gid] += pred(feat, gid, 0, ntree_limit, base_score); 63 | } 64 | return preds; 65 | } 66 | 67 | @Override 68 | public float predictSingle(FVec feat, int ntree_limit, float base_score) { 69 | if (num_output_group != 1) { 70 | throw new IllegalStateException( 71 | "Can't invoke predictSingle() because this model outputs multiple values: " 72 | + num_output_group); 73 | } 74 | return pred(feat, 0, 0, ntree_limit, base_score); 75 | } 76 | 77 | float pred(FVec feat, int bst_group, int root_index, int ntree_limit, float base_score) { 78 | RegTree[] trees = _groupTrees[bst_group]; 79 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 80 | 81 | float psum = base_score; 82 | for (int i = 0; i < treeleft; i++) { 83 | psum += trees[i].getLeafValue(feat, root_index); 84 | } 85 | return psum; 86 | } 87 | 88 | @Override 89 | public int[] predictLeaf(FVec feat, int ntree_limit) { 90 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 91 | int[] leafIndex = new int[treeleft]; 92 | for (int i = 0; i < treeleft; i++) { 93 | leafIndex[i] = trees[i].getLeafIndex(feat); 94 | } 95 | return leafIndex; 96 | } 97 | 98 | @Override 99 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 100 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 101 | String[] leafPath = new String[treeleft]; 102 | StringBuilder sb = new StringBuilder(64); 103 | for (int i = 0; i < treeleft; i++) { 104 | trees[i].getLeafPath(feat, sb); 105 | leafPath[i] = sb.toString(); 106 | sb.setLength(0); 107 | } 108 | return leafPath; 109 | } 110 | 111 | private long predBufferSize() { 112 | return num_output_group * mparam.num_pbuffer * (mparam.size_leaf_vector + 1); 113 | } 114 | 115 | static class ModelParam implements Serializable { 116 | /*! \brief number of trees */ 117 | final int num_trees; 118 | /*! \brief number of root: default 0, means single tree */ 119 | final int num_roots; 120 | /*! \brief size of predicton buffer allocated used for buffering */ 121 | final long num_pbuffer; 122 | /*! \brief size of leaf vector needed in tree */ 123 | final int size_leaf_vector; 124 | /*! \brief reserved space */ 125 | final int[] reserved; 126 | 127 | ModelParam(ModelReader reader) throws IOException { 128 | num_trees = reader.readInt(); 129 | num_roots = reader.readInt(); 130 | reader.readInt(); // num_feature deprecated 131 | reader.readInt(); // read padding 132 | num_pbuffer = reader.readLong(); 133 | reader.readInt(); // num_output_group not used anymore 134 | size_leaf_vector = reader.readInt(); 135 | reserved = reader.readIntArray(31); 136 | reader.readInt(); // read padding 137 | } 138 | 139 | } 140 | 141 | /** 142 | * 143 | * @return A two-dim array, with trees grouped into classes. 144 | */ 145 | public RegTree[][] getGroupedTrees(){ 146 | return _groupTrees; 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GradBooster.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Interface of gradient boosting model. 12 | */ 13 | public interface GradBooster extends Serializable { 14 | 15 | class Factory { 16 | /** 17 | * Creates a gradient booster from given name. 18 | * 19 | * @param name name of gradient booster 20 | * @return created gradient booster 21 | */ 22 | public static GradBooster createGradBooster(String name) { 23 | if ("gbtree".equals(name)) { 24 | return new GBTree(); 25 | } else if ("gblinear".equals(name)) { 26 | return new GBLinear(); 27 | } else if ("dart".equals(name)) { 28 | return new Dart(); 29 | } 30 | 31 | throw new IllegalArgumentException(name + " is not supported model."); 32 | } 33 | } 34 | 35 | void setNumClass(int numClass); 36 | void setNumFeature(int numFeature); 37 | 38 | /** 39 | * Loads model from stream. 40 | * 41 | * @param config predictor configuration 42 | * @param reader input stream 43 | * @param with_pbuffer whether the incoming data contains pbuffer 44 | * @throws IOException If an I/O error occurs 45 | */ 46 | void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException; 47 | 48 | /** 49 | * Generates predictions for given feature vector. 50 | * 51 | * @param feat feature vector 52 | * @param ntree_limit limit the number of trees used in prediction 53 | * @param base_score base score to initialize prediction 54 | * @return prediction result 55 | */ 56 | float[] predict(FVec feat, int ntree_limit, float base_score); 57 | 58 | /** 59 | * Generates a prediction for given feature vector. 60 | *61 | * This method only works when the model outputs single value. 62 | *
63 | * 64 | * @param feat feature vector 65 | * @param ntree_limit limit the number of trees used in prediction 66 | * @param base_score base score to initialize prediction 67 | * @return prediction result 68 | */ 69 | float predictSingle(FVec feat, int ntree_limit, float base_score); 70 | 71 | /** 72 | * Predicts the leaf index of each tree. This is only valid in gbtree predictor. 73 | * 74 | * @param feat feature vector 75 | * @param ntree_limit limit the number of trees used in prediction 76 | * @return predicted leaf indexes 77 | */ 78 | int[] predictLeaf(FVec feat, int ntree_limit); 79 | 80 | /** 81 | * Predicts the path to leaf of each tree. This is only valid in gbtree predictor. 82 | * 83 | * @param feat feature vector 84 | * @param ntree_limit limit the number of trees used in prediction 85 | * @return predicted path to leaves 86 | */ 87 | String[] predictLeafPath(FVec feat, int ntree_limit); 88 | 89 | } 90 | 91 | abstract class GBBase implements GradBooster { 92 | protected int num_class; 93 | protected int num_feature; 94 | protected int num_output_group; 95 | 96 | @Override 97 | public void setNumClass(int numClass) { 98 | this.num_class = numClass; 99 | this.num_output_group = (num_class == 0) ? 1 : num_class; 100 | } 101 | 102 | @Override 103 | public void setNumFeature(int numFeature) { 104 | this.num_feature = numFeature; 105 | } 106 | } -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/learner/ObjFunction.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.learner; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import net.jafama.FastMath; 5 | 6 | import java.io.Serializable; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Objective function implementations. 12 | */ 13 | public class ObjFunction implements Serializable { 14 | 15 | private static final Map156 | * Jafama's {@link FastMath#exp(double)} version. 157 | *
158 | */ 159 | static class RegLossObjLogistic_Jafama extends RegLossObjLogistic { 160 | @Override 161 | float sigmoid(float x) { 162 | return (float) (1 / (1 + FastMath.exp(-x))); 163 | } 164 | } 165 | 166 | /** 167 | * Multiclass classification. 168 | */ 169 | static class SoftmaxMultiClassObjClassify extends ObjFunction { 170 | @Override 171 | public float[] predTransform(float[] preds) { 172 | int maxIndex = 0; 173 | float max = preds[0]; 174 | for (int i = 1; i < preds.length; i++) { 175 | if (max < preds[i]) { 176 | maxIndex = i; 177 | max = preds[i]; 178 | } 179 | } 180 | 181 | return new float[]{maxIndex}; 182 | } 183 | 184 | @Override 185 | public float predTransform(float pred) { 186 | throw new UnsupportedOperationException(); 187 | } 188 | } 189 | 190 | /** 191 | * Multiclass classification (predicted probability). 192 | */ 193 | static class SoftmaxMultiClassObjProb extends ObjFunction { 194 | @Override 195 | public float[] predTransform(float[] preds) { 196 | float max = preds[0]; 197 | for (int i = 1; i < preds.length; i++) { 198 | max = Math.max(preds[i], max); 199 | } 200 | 201 | double sum = 0; 202 | for (int i = 0; i < preds.length; i++) { 203 | preds[i] = exp(preds[i] - max); 204 | sum += preds[i]; 205 | } 206 | 207 | for (int i = 0; i < preds.length; i++) { 208 | preds[i] /= (float) sum; 209 | } 210 | 211 | return preds; 212 | } 213 | 214 | @Override 215 | public float predTransform(float pred) { 216 | throw new UnsupportedOperationException(); 217 | } 218 | 219 | float exp(float x) { 220 | return (float) Math.exp(x); 221 | } 222 | } 223 | 224 | /** 225 | * Multiclass classification (predicted probability). 226 | *227 | * Jafama's {@link FastMath#exp(double)} version. 228 | *
229 | */ 230 | static class SoftmaxMultiClassObjProb_Jafama extends SoftmaxMultiClassObjProb { 231 | @Override 232 | float exp(float x) { 233 | return (float) FastMath.exp(x); 234 | } 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/spark/SparkModelParam.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.spark; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | 8 | public class SparkModelParam implements Serializable { 9 | public static final String MODEL_TYPE_CLS = "_cls_"; 10 | public static final String MODEL_TYPE_REG = "_reg_"; 11 | 12 | final String modelType; 13 | final String featureCol; 14 | 15 | final String labelCol; 16 | final String predictionCol; 17 | 18 | // classification model only 19 | final String rawPredictionCol; 20 | final double[] thresholds; 21 | 22 | public SparkModelParam(String modelType, String featureCol, ModelReader reader) throws IOException { 23 | this.modelType = modelType; 24 | this.featureCol = featureCol; 25 | this.labelCol = reader.readUTF(); 26 | this.predictionCol = reader.readUTF(); 27 | 28 | if (MODEL_TYPE_CLS.equals(modelType)) { 29 | this.rawPredictionCol = reader.readUTF(); 30 | int thresholdLength = reader.readIntBE(); 31 | this.thresholds = thresholdLength > 0 ? reader.readDoubleArrayBE(thresholdLength) : null; 32 | 33 | } else if (MODEL_TYPE_REG.equals(modelType)) { 34 | this.rawPredictionCol = null; 35 | this.thresholds = null; 36 | 37 | } else { 38 | throw new UnsupportedOperationException("Unknown modelType: " + modelType); 39 | } 40 | } 41 | 42 | public String getModelType() { 43 | return modelType; 44 | } 45 | 46 | public String getFeatureCol() { 47 | return featureCol; 48 | } 49 | 50 | public String getLabelCol() { 51 | return labelCol; 52 | } 53 | 54 | public String getPredictionCol() { 55 | return predictionCol; 56 | } 57 | 58 | public String getRawPredictionCol() { 59 | return rawPredictionCol; 60 | } 61 | 62 | public double[] getThresholds() { 63 | return thresholds; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/DefaultRegTreeFactory.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | 7 | public final class DefaultRegTreeFactory implements RegTreeFactory { 8 | 9 | public static RegTreeFactory INSTANCE = new DefaultRegTreeFactory(); 10 | 11 | @Override 12 | public final RegTree loadTree(ModelReader reader) throws IOException { 13 | RegTreeImpl regTree = new RegTreeImpl(); 14 | regTree.loadModel(reader); 15 | return regTree; 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTree.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.FVec; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Regression tree. 9 | */ 10 | public interface RegTree extends Serializable { 11 | 12 | /** 13 | * Retrieves nodes from root to leaf and returns leaf index. 14 | * 15 | * @param feat feature vector 16 | * @return leaf index 17 | */ 18 | int getLeafIndex(FVec feat); 19 | 20 | /** 21 | * Retrieves nodes from root to leaf and returns path to leaf. 22 | * 23 | * @param feat feature vector 24 | * @param sb output param, will write path path to leaf into this buffer 25 | */ 26 | void getLeafPath(FVec feat, StringBuilder sb); 27 | 28 | /** 29 | * Retrieves nodes from root to leaf and returns leaf value. 30 | * 31 | * @param feat feature vector 32 | * @param root_id starting root index 33 | * @return leaf value 34 | */ 35 | float getLeafValue(FVec feat, int root_id); 36 | 37 | /** 38 | * 39 | * @return Tree's nodes 40 | */ 41 | RegTreeNode[] getNodes(); 42 | 43 | /** 44 | * @return Tree's nodes stats 45 | */ 46 | RegTreeNodeStat[] getStats(); 47 | 48 | } 49 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeFactory.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | 7 | public interface RegTreeFactory { 8 | 9 | RegTree loadTree(ModelReader reader) throws IOException; 10 | 11 | } 12 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeImpl.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import ai.h2o.algos.tree.INodeStat; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Regression tree. 12 | */ 13 | public class RegTreeImpl implements RegTree { 14 | 15 | private Param param; 16 | private Node[] nodes; 17 | private RegTreeNodeStat[] stats; 18 | 19 | /** 20 | * Loads model from stream. 21 | * 22 | * @param reader input stream 23 | * @throws IOException If an I/O error occurs 24 | */ 25 | public void loadModel(ModelReader reader) throws IOException { 26 | param = new Param(reader); 27 | 28 | nodes = new Node[param.num_nodes]; 29 | for (int i = 0; i < param.num_nodes; i++) { 30 | nodes[i] = new Node(reader); 31 | } 32 | 33 | stats = new RegTreeNodeStat[param.num_nodes]; 34 | for (int i = 0; i < param.num_nodes; i++) { 35 | stats[i] = new RegTreeNodeStat(reader); 36 | } 37 | } 38 | 39 | /** 40 | * Retrieves nodes from root to leaf and returns leaf index. 41 | * 42 | * @param feat feature vector 43 | * @return leaf index 44 | */ 45 | @Override 46 | public int getLeafIndex(FVec feat) { 47 | int id = 0; 48 | Node n; 49 | while (!(n = nodes[id])._isLeaf) { 50 | id = n.next(feat); 51 | } 52 | return id; 53 | } 54 | 55 | /** 56 | * Retrieves nodes from root to leaf and returns path to leaf. 57 | * 58 | * @param feat feature vector 59 | * @param sb output param, will write path path to leaf into this buffer 60 | */ 61 | @Override 62 | public void getLeafPath(FVec feat, StringBuilder sb) { 63 | int id = 0; 64 | Node n; 65 | while (!(n = nodes[id])._isLeaf) { 66 | id = n.next(feat); 67 | sb.append(id == n.cleft_ ? "L" : "R"); 68 | } 69 | } 70 | 71 | /** 72 | * Retrieves nodes from root to leaf and returns leaf value. 73 | * 74 | * @param feat feature vector 75 | * @param root_id starting root index 76 | * @return leaf value 77 | */ 78 | @Override 79 | public float getLeafValue(FVec feat, int root_id) { 80 | Node n = nodes[root_id]; 81 | while (!n._isLeaf) { 82 | n = nodes[n.next(feat)]; 83 | } 84 | 85 | return n.leaf_value; 86 | } 87 | 88 | @Override 89 | public Node[] getNodes() { 90 | return nodes; 91 | } 92 | 93 | @Override 94 | public RegTreeNodeStat[] getStats() { 95 | return stats; 96 | } 97 | 98 | /** 99 | * Parameters. 100 | */ 101 | static class Param implements Serializable { 102 | /*! \brief number of start root */ 103 | final int num_roots; 104 | /*! \brief total number of nodes */ 105 | final int num_nodes; 106 | /*!\brief number of deleted nodes */ 107 | final int num_deleted; 108 | /*! \brief maximum depth, this is a statistics of the tree */ 109 | final int max_depth; 110 | /*! \brief number of features used for tree construction */ 111 | final int num_feature; 112 | /*! 113 | * \brief leaf vector size, used for vector tree 114 | * used to store more than one dimensional information in tree 115 | */ 116 | final int size_leaf_vector; 117 | /*! \brief reserved part */ 118 | final int[] reserved; 119 | 120 | Param(ModelReader reader) throws IOException { 121 | num_roots = reader.readInt(); 122 | num_nodes = reader.readInt(); 123 | num_deleted = reader.readInt(); 124 | max_depth = reader.readInt(); 125 | num_feature = reader.readInt(); 126 | 127 | size_leaf_vector = reader.readInt(); 128 | reserved = reader.readIntArray(31); 129 | } 130 | } 131 | 132 | public static class Node extends RegTreeNode implements Serializable { 133 | // pointer to parent, highest bit is used to 134 | // indicate whether it's a left child or not 135 | final int parent_; 136 | // pointer to left, right 137 | final int cleft_, cright_; 138 | // split feature index, left split or right split depends on the highest bit 139 | final /* unsigned */ int sindex_; 140 | // extra info (leaf_value or split_cond) 141 | final float leaf_value; 142 | final float split_cond; 143 | 144 | private final int _defaultNext; 145 | private final int _splitIndex; 146 | final boolean _isLeaf; 147 | 148 | // set parent 149 | Node(ModelReader reader) throws IOException { 150 | parent_ = reader.readInt(); 151 | cleft_ = reader.readInt(); 152 | cright_ = reader.readInt(); 153 | sindex_ = reader.readInt(); 154 | 155 | if (isLeaf()) { 156 | leaf_value = reader.readFloat(); 157 | split_cond = Float.NaN; 158 | } else { 159 | split_cond = reader.readFloat(); 160 | leaf_value = Float.NaN; 161 | } 162 | 163 | _defaultNext = cdefault(); 164 | _splitIndex = getSplitIndex(); 165 | _isLeaf = isLeaf(); 166 | } 167 | 168 | public boolean isLeaf() { 169 | return cleft_ == -1; 170 | } 171 | 172 | @Override 173 | public int getSplitIndex() { 174 | return (int) (sindex_ & ((1l << 31) - 1l)); 175 | } 176 | 177 | public int cdefault() { 178 | return default_left() ? cleft_ : cright_; 179 | } 180 | 181 | @Override 182 | public boolean default_left() { 183 | return (sindex_ >>> 31) != 0; 184 | } 185 | 186 | @Override 187 | public int next(FVec feat) { 188 | float value = feat.fvalue(_splitIndex); 189 | if (value != value) { // is NaN? 190 | return _defaultNext; 191 | } 192 | return (value < split_cond) ? cleft_ : cright_; 193 | } 194 | 195 | @Override 196 | public int getParentIndex() { 197 | return parent_; 198 | } 199 | 200 | @Override 201 | public int getLeftChildIndex() { 202 | return cleft_; 203 | } 204 | 205 | @Override 206 | public int getRightChildIndex() { 207 | return cright_; 208 | } 209 | 210 | @Override 211 | public float getSplitCondition() { 212 | return split_cond; 213 | } 214 | 215 | @Override 216 | public float getLeafValue(){ 217 | return leaf_value; 218 | } 219 | } 220 | 221 | } 222 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeNode.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import ai.h2o.algos.tree.INode; 4 | import biz.k11i.xgboost.util.FVec; 5 | 6 | import java.io.Serializable; 7 | 8 | public abstract class RegTreeNode implements INode