├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── RELEASENOTES.md ├── build.gradle ├── ci ├── Dockerfile ├── Jenkinsfile.groovy └── Makefile ├── gradle.properties ├── gradle ├── publish.gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── h2o-tree-api ├── build.gradle └── src │ └── main │ └── java │ └── ai │ └── h2o │ └── algos │ └── tree │ ├── INode.java │ └── INodeStat.java ├── settings.gradle └── xgboost-predictor ├── build.gradle └── src ├── main └── java │ └── biz │ └── k11i │ └── xgboost │ ├── Predictor.java │ ├── config │ └── PredictorConfiguration.java │ ├── gbm │ ├── Dart.java │ ├── GBLinear.java │ ├── GBTree.java │ └── GradBooster.java │ ├── learner │ └── ObjFunction.java │ ├── spark │ └── SparkModelParam.java │ ├── tree │ ├── DefaultRegTreeFactory.java │ ├── RegTree.java │ ├── RegTreeFactory.java │ ├── RegTreeImpl.java │ ├── RegTreeNode.java │ └── RegTreeNodeStat.java │ └── util │ ├── FVec.java │ └── ModelReader.java └── test ├── java └── biz │ └── k11i │ └── xgboost │ ├── PredictorSmokeTest.java │ └── tree │ └── PredictorPredictLeafTest.java └── resources ├── boosterBytes.bin └── prostate ├── boosterBytesProstateDart.bin ├── boosterBytesProstateLinear.bin ├── boosterBytesProstateTree.bin ├── boosterBytesProstateTreeVersion12.bin ├── recreate.txt └── recreate_darth.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Gradle 2 | .gradle/ 3 | build/ 4 | 5 | # IntelliJ IDEA 6 | .idea/ 7 | *.iml 8 | 9 | # Spark 10 | derby.log 11 | 12 | tmp/ 13 | 14 | ci/Dockerfile.tag 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | script: ./gradlew clean check 3 | dist: precise 4 | jdk: 5 | - oraclejdk7 6 | - openjdk7 7 | addons: 8 | hosts: 9 | - xgboosthost 10 | hostname: xgboosthost -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | xgboost-predictor-java 2 | ====================== 3 | [![Maven Central](https://maven-badges.herokuapp.com/maven-central/ai.h2o/xgboost-predictor/badge.svg)](https://maven-badges.herokuapp.com/maven-central/ai.h2o/xgboost-predictor) 4 | 5 | Pure Java implementation of [XGBoost](https://github.com/dmlc/xgboost/) predictor for online prediction tasks. 6 | 7 | 8 | # Getting started 9 | 10 | ## Adding to dependencies 11 | 12 | If you use **Maven**: 13 | 14 | ```xml 15 | 16 | 17 | ai.h2o 18 | xgboost-predictor 19 | 0.3.1 20 | 21 | 22 | ``` 23 | 24 | Or **Gradle**: 25 | 26 | ```groovy 27 | repositories { 28 | mavenCentral() 29 | } 30 | 31 | dependencies { 32 | compile group: 'ai.h2o', name: 'xgboost-predictor', version: '0.3.1' 33 | } 34 | ``` 35 | 36 | Or **sbt**: 37 | 38 | ```scala 39 | resolvers += DefaultMavenRepository 40 | 41 | libraryDependencies ++= Seq( 42 | "ai.h2o" % "xgboost-predictor" % "0.3.1" 43 | ) 44 | ``` 45 | 46 | 47 | ## Using Predictor in Java 48 | 49 | ```java 50 | package biz.k11i.xgboost.demo; 51 | 52 | import biz.k11i.xgboost.Predictor; 53 | import biz.k11i.xgboost.util.FVec; 54 | 55 | public class HowToUseXgboostPredictor { 56 | public static void main(String[] args) throws java.io.IOException { 57 | // If you want to use faster exp() calculation, uncomment the line below 58 | // ObjFunction.useFastMathExp(true); 59 | 60 | // Load model and create Predictor 61 | Predictor predictor = new Predictor( 62 | new java.io.FileInputStream("/path/to/xgboost-model-file")); 63 | 64 | // Create feature vector from dense representation by array 65 | double[] denseArray = {0, 0, 32, 0, 0, 16, -8, 0, 0, 0}; 66 | FVec fVecDense = FVec.Transformer.fromArray( 67 | denseArray, 68 | true /* treat zero element as N/A */); 69 | 70 | // Create feature vector from sparse representation by map 71 | FVec fVecSparse = FVec.Transformer.fromMap( 72 | new java.util.HashMap() {{ 73 | put(2, 32.); 74 | put(5, 16.); 75 | put(6, -8.); 76 | }}); 77 | 78 | // Predict probability or classification 79 | double[] prediction = predictor.predict(fVecDense); 80 | 81 | // prediction[0] has 82 | // - probability ("binary:logistic") 83 | // - class label ("multi:softmax") 84 | 85 | // Predict leaf index of each tree 86 | int[] leafIndexes = predictor.predictLeaf(fVecDense); 87 | 88 | // leafIndexes[i] has a leaf index of i-th tree 89 | } 90 | } 91 | ``` 92 | 93 | 94 | # Benchmark 95 | 96 | Throughput comparison to [xgboost4j 1.1](https://github.com/dmlc/xgboost/tree/master/java/xgboost4j) by [xgboost-predictor-benchmark](https://github.com/komiya-atsushi/xgboost-predictor-benchmark). 97 | 98 | | Feature | xgboost-predictor | xgboost4j | 99 | | ----------------- | ----------------: | -------------: | 100 | | Model loading | 49017.60 ops/s | 39669.36 ops/s | 101 | | Single prediction | 6016955.46 ops/s | 1018.01 ops/s | 102 | | Batch prediction | 44985.71 ops/s | 5.04 ops/s | 103 | | Leaf prediction | 11115853.34 ops/s | 1076.54 ops/s | 104 | 105 | Xgboost-predictor-java is about **6,000 to 10,000 times faster than** xgboost4j on prediction tasks. 106 | 107 | 108 | # Supported models, objective functions and API 109 | 110 | - Models 111 | - "gblinear" 112 | - "gbtree" 113 | - "dart" 114 | - Objective functions 115 | - "binary:logistic" 116 | - "binary:logitraw" 117 | - "multi:softmax" 118 | - "multi:softprob" 119 | - "reg:linear" 120 | - "reg:logistic" 121 | - "rank:pairwise" 122 | - "rank:ndcg" 123 | - API 124 | - Predicts probability or classification 125 | - `Predictor#predict(FVec)` 126 | - Outputs margin 127 | - `Predictor#predict(FVec, true /* output margin */)` 128 | - Predicts leaf index 129 | - `Predictor#predictLeaf(FVec)` 130 | -------------------------------------------------------------------------------- /RELEASENOTES.md: -------------------------------------------------------------------------------- 1 | # Release notes 2 | 3 | ## 0.3.20 4 | 5 | - Make the fix in [PR](https://github.com/h2oai/xgboost-predictor/pull/27) backwards compatible in order to use it in H2O-3 and provide same predictions for older MOJOs and pass the compatibility test. [PR](https://github.com/h2oai/xgboost-predictor/pull/28) 6 | 7 | ## 0.3.19 8 | 9 | - Change the order of the floating point operation to match prediction in native xgboost version 1.3.0 and newer [PR](https://github.com/h2oai/xgboost-predictor/pull/27) 10 | 11 | ## 0.3.18 12 | 13 | - Add support for `reg:logistic` and `rank:ndcg` oobjectives. [PR](https://github.com/h2oai/xgboost-predictor/pull/20) 14 | 15 | ## 0.3.17 16 | 17 | - Revert renaming of getWeight. [PR](https://github.com/h2oai/xgboost-predictor/pull/19) 18 | 19 | ## 0.3.16 20 | 21 | - Expose `RegTreeNode` stats. [PR](https://github.com/h2oai/xgboost-predictor/pull/18) 22 | 23 | ## 0.3.15 24 | 25 | - Fix loading an empty gblinear booster [PR](https://github.com/h2oai/xgboost-predictor/pull/16) 26 | 27 | ## 0.3.14 28 | 29 | - Upgrade to XGBoost v1.0.0 support. [PR1](https://github.com/h2oai/xgboost-predictor/pull/14), [PR2](https://github.com/h2oai/xgboost-predictor/pull/15) 30 | 31 | ## 0.3.0 32 | 33 | - [#27](https://github.com/komiya-atsushi/xgboost-predictor-java/pull/27) Support DART model. 34 | 35 | ## 0.2.1 36 | 37 | - Support an objective function: `"rank:pairwise"` 38 | 39 | ## 0.2.0 40 | 41 | - Support XGBoost4J-Spark-generated model file format. 42 | - Introduce [xgboost-predictor-spark](https://github.com/komiya-atsushi/xgboost-predictor-java/tree/master/xgboost-predictor-spark). 43 | 44 | 45 | ## 0.1.8 46 | 47 | - Make `Predictor` Spark-friendly (implement `Serializable` interface, [#11](https://github.com/komiya-atsushi/xgboost-predictor-java/issues/11) ) 48 | 49 | ## 0.1.7 50 | 51 | - Support latest model file format. 52 | - [Commit log of xgboost](https://github.com/dmlc/xgboost/commit/0d95e863c981548b5a7ca363310fc359a9165d85#diff-53a3a623be5ce5a351a89012c7b03a31R193) 53 | 54 | ## 0.1.6 55 | 56 | - Improve the speed performance of prediction: 57 | - Optimize tree retrieval performance. 58 | 59 | ## 0.1.5 60 | 61 | - Support an objective function: `"reg:linear"` 62 | 63 | ## 0.1.4 64 | 65 | - Improve the speed performance of prediction: 66 | - Introduce methods `Predictor#predictSingle()` for predicting single value efficiently. 67 | 68 | ## 0.1.3 69 | 70 | - Improve the speed performance of prediction: 71 | - Use [Jafama](https://github.com/jeffhain/jafama/) for calculating sigmoid function faster. 72 | - Calling `ObjFunction.useFastMathExp(true)` you can use Jafama's `FastMath.exp()`. 73 | 74 | ## 0.1.2 75 | 76 | - #2 Add linear models (`GBLinear`). 77 | 78 | ## 0.1.1 79 | 80 | - #1 Allow users to register their `ObjFunction`. 81 | 82 | ## 0.1.0 83 | 84 | - Initial release. 85 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | repositories { 3 | jcenter() 4 | } 5 | } 6 | 7 | plugins { 8 | id 'base' 9 | } 10 | 11 | clean.doFirst { 12 | delete "${rootDir}/build" 13 | println "${rootDir}/build" 14 | } 15 | 16 | subprojects { 17 | configurations { 18 | publishArchives 19 | } 20 | 21 | repositories { 22 | jcenter() 23 | } 24 | 25 | apply plugin: 'java' 26 | 27 | group 'ai.h2o' 28 | 29 | archivesBaseName = project.name 30 | if (project.findProperty('archivesNameSuffix')) { 31 | archivesBaseName += "-${project.findProperty('archivesNameSuffix')}" 32 | } 33 | 34 | compileJava { 35 | sourceCompatibility = JavaVersion.VERSION_1_7 36 | targetCompatibility = JavaVersion.VERSION_1_7 37 | } 38 | 39 | javadoc { 40 | options.locale = 'en_US' 41 | } 42 | 43 | task checkJavaVersion { 44 | doLast { 45 | def prop = System.getenv('CHECK_JAVA_VERSION'); 46 | if (prop != "false" && !JavaVersion.current().isJava7()) { 47 | String message = "ERROR: Java 7 required but " + 48 | JavaVersion.current() + 49 | " found. Change your JAVA_HOME environment variable." 50 | throw new IllegalStateException(message) 51 | } 52 | } 53 | } 54 | compileJava.dependsOn checkJavaVersion 55 | 56 | task sourcesJar(type: Jar, dependsOn: classes) { 57 | classifier = 'sources' 58 | from sourceSets.main.allSource 59 | } 60 | 61 | task javadocJar(type: Jar, dependsOn: javadoc) { 62 | classifier = 'javadoc' 63 | from javadoc.destinationDir 64 | } 65 | 66 | artifacts { 67 | archives jar 68 | publishArchives sourcesJar 69 | publishArchives javadocJar 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /ci/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y \ 5 | make \ 6 | s3cmd \ 7 | unzip \ 8 | git && \ 9 | rm -rf /var/lib/apt/lists/* 10 | 11 | # Add the Jenkins user 12 | RUN \ 13 | groupadd -g 2117 jenkins && \ 14 | useradd jenkins -m -u 2117 -g jenkins 15 | 16 | COPY ci/jdk1.7.0_80.zip /usr/opt/ 17 | RUN cd /usr/opt && \ 18 | unzip -q jdk1.7.0_80.zip && \ 19 | rm -rf jdk1.7.0_80.zip 20 | 21 | ENV JAVA_HOME=/usr/opt/jdk1.7.0_80 22 | ENV PATH=${PATH}:/usr/opt/jdk1.7.0_80/bin 23 | 24 | RUN mkdir /gradle-home && \ 25 | chmod a+w /gradle-home 26 | 27 | ENV GRADLE_USER_HOME /gradle-home 28 | ENV GRADLE_OPTS -Dorg.gradle.daemon=false 29 | -------------------------------------------------------------------------------- /ci/Jenkinsfile.groovy: -------------------------------------------------------------------------------- 1 | @Library('test-shared-library@1.9') _ 2 | 3 | import ai.h2o.ci.buildsummary.StagesSummary 4 | 5 | // initialize build summary 6 | buildSummary('https://github.com/h2oai/xgboost-predictor', true) 7 | // use default StagesSummary implementation 8 | buildSummary.get().addStagesSummary(this, new StagesSummary()) 9 | 10 | properties([ 11 | parameters([ 12 | choice(choices: ['none', 'private', 'public'], description: 'Nexus to publish to', name: 'targetNexus') 13 | ]) 14 | ]) 15 | 16 | def makeOpts = 'CI=1' 17 | 18 | node ('master') { 19 | buildSummary.stageWithSummary('Checkout') { 20 | cleanWs() 21 | def scmEnv = checkout scm 22 | env.BRANCH_NAME = scmEnv.GIT_BRANCH.replaceAll('origin/', '') 23 | final def version = sh(script: 'cat gradle.properties | grep version | sed "s/version=//"', returnStdout: true).trim() 24 | 25 | def archivesNameSuffix = null 26 | if (env.BRANCH_NAME != 'master') { 27 | archivesNameSuffix = env.BRANCH_NAME.replaceAll('/|\\ ','-') 28 | makeOpts += " ARCHIVES_NAME_SUFFIX=${archivesNameSuffix}" 29 | } 30 | currentBuild.description = "ai.h2o:xgboost-predictor" 31 | if (archivesNameSuffix) { 32 | currentBuild.description += "-${archivesNameSuffix}" 33 | } 34 | currentBuild.description += ":${version}" 35 | } 36 | buildSummary.stageWithSummary('Prepare Docker Image and clean') { 37 | withCredentials([[$class: 'AmazonWebServicesCredentialsBinding', accessKeyVariable: 'AWS_ACCESS_KEY_ID', credentialsId: 'AWS S3 Credentials', secretKeyVariable: 'AWS_SECRET_ACCESS_KEY']]) { 38 | docker.image('harbor.h2o.ai/opsh2oai/s3cmd').inside("-e AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} -e AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}") { 39 | sh """ 40 | cd ci 41 | s3cmd get s3://artifacts.h2o.ai/releases/oracle/jdk-7/x64-linux/jdk1.7.0_80.zip 42 | """ 43 | } 44 | } 45 | sh "make ${makeOpts} -f ci/Makefile clean_in_docker" 46 | } 47 | 48 | buildSummary.stageWithSummary('Build') { 49 | sh "make ${makeOpts} -f ci/Makefile build_in_docker" 50 | archiveArtifacts artifacts: 'xgboost-predictor/build/libs/*.jar' 51 | } 52 | 53 | if (params.targetNexus == 'private' || params.targetNexus == 'public') { 54 | buildSummary.stageWithSummary("Publish to ${params.targetNexus.capitalize()} Nexus") { 55 | def credentialsId 56 | if (params.targetNexus == 'private') { 57 | credentialsId = 'LOCAL_NEXUS' 58 | } else if (params.targetNexus == 'public') { 59 | credentialsId = 'PUBLIC_NEXUS' 60 | } else { 61 | error "Cannot find credentials for targetNexus=${params.targetNexus}" 62 | } 63 | withCredentials([usernamePassword(credentialsId: credentialsId, usernameVariable: 'NEXUS_USERNAME', passwordVariable: 'NEXUS_PASSWORD'), 64 | file(credentialsId: 'release-secret-key-ring-file', variable: 'SECRING_PATH')]) { 65 | sh "make ${makeOpts} TARGET_NEXUS=${params.targetNexus} DO_SIGN=true -f ci/Makefile publish_in_docker" 66 | } 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /ci/Makefile: -------------------------------------------------------------------------------- 1 | THIS_FILE := $(lastword $(MAKEFILE_LIST)) 2 | 3 | .PHONY: clean build publish 4 | SHELL := /bin/bash 5 | 6 | # if no CI flag is specified, make docker calls interactive 7 | DOCKER_ARGS := $(ADDITIONAL_DOCKER_ARGS) 8 | ifeq ($(CI),) 9 | DOCKER_ARGS += -it 10 | endif 11 | ifneq ($(TARGET_NEXUS),) 12 | DOCKER_ARGS += -e TARGET_NEXUS=$(TARGET_NEXUS) 13 | endif 14 | ifneq ($(DO_SIGN),) 15 | DOCKER_ARGS += -e DO_SIGN=$(DO_SIGN) 16 | endif 17 | ifneq ($(ARCHIVES_NAME_SUFFIX),) 18 | DOCKER_ARGS += -e ARCHIVES_NAME_SUFFIX=$(ARCHIVES_NAME_SUFFIX) 19 | endif 20 | 21 | UID ?= $(shell id -u) 22 | GID ?= $(shell id -g) 23 | 24 | clean: 25 | ./gradlew clean 26 | 27 | build: 28 | ifneq ($(ARCHIVES_NAME_SUFFIX),) 29 | $(eval NAME_SUFFIX_FLAG := -ParchivesNameSuffix=$(ARCHIVES_NAME_SUFFIX)) 30 | endif 31 | ./gradlew build $(NAME_SUFFIX_FLAG) 32 | 33 | ifeq ($(DO_SIGN), true) 34 | SIGN_FLAG = -PdoSign 35 | SIGN_FLAG += -Psigning.keyId=$(shell gpg -K | grep -v secring.gpg | grep sec | awk '{print $$2}' | sed 's|.*/||g') 36 | SIGN_FLAG += -Psigning.secretKeyRingFile=/home/jenkins/.gnupg/secring.gpg 37 | SIGN_FLAG += -Psigning.password= 38 | endif 39 | publish: 40 | ifeq ($(TARGET_NEXUS),) 41 | $(error TARGET_NEXUS must be set) 42 | endif 43 | ifneq ($(ARCHIVES_NAME_SUFFIX),) 44 | $(eval NAME_SUFFIX_FLAG := -ParchivesNameSuffix=$(ARCHIVES_NAME_SUFFIX)) 45 | endif 46 | ./gradlew publish $(NAME_SUFFIX_FLAG) -PtargetNexus=$(TARGET_NEXUS) $(SIGN_FLAG) -PnexusUsername=$(NEXUS_USERNAME) -PnexusPassword=$(NEXUS_PASSWORD) --stacktrace 47 | 48 | clean_build: clean build 49 | 50 | bash: 51 | bash 52 | 53 | clean_docker: 54 | rm -f ci/Dockerfile.tag 55 | 56 | DOCKER_IMAGE := docker.h2o.ai/opsh2oai/xgboost-predictor-build 57 | ci/Dockerfile.tag: ci/Dockerfile 58 | $(info Building docker image, git credentials will be required.) 59 | echo $(DOCKER_IMAGE) > ci/Dockerfile.tag 60 | docker build \ 61 | -t $(DOCKER_IMAGE) \ 62 | -f ci/Dockerfile \ 63 | . 64 | 65 | publish_in_docker: ci/Dockerfile.tag 66 | ifeq ($(SECRING_PATH),) 67 | $(error SECRING_PATH must be set.) 68 | endif 69 | ifeq ($(NEXUS_USERNAME),) 70 | $(error NEXUS_USERNAME must be set.) 71 | endif 72 | ifeq ($(NEXUS_PASSWORD),) 73 | $(error NEXUS_PASSWORD must be set.) 74 | endif 75 | docker run --rm \ 76 | -u jenkins:jenkins \ 77 | -v `pwd`:/workspace \ 78 | -v $(SECRING_PATH):/secring.gpg:ro \ 79 | -w /workspace \ 80 | --entrypoint /bin/bash \ 81 | --add-host=nexus:172.17.0.53 \ 82 | -e NEXUS_USERNAME=$(NEXUS_USERNAME) \ 83 | -e NEXUS_PASSWORD=$(NEXUS_PASSWORD) \ 84 | $(DOCKER_ARGS) \ 85 | $(DOCKER_IMAGE) \ 86 | -c "gpg --import /secring.gpg && \ 87 | make -f ci/Makefile publish" 88 | 89 | %_in_docker: ci/Dockerfile.tag 90 | docker run --rm \ 91 | -u $(UID):$(GID) \ 92 | -v `pwd`:/workspace \ 93 | -w /workspace \ 94 | --entrypoint /bin/bash \ 95 | $(DOCKER_ARGS) \ 96 | $(DOCKER_IMAGE) \ 97 | -c "make -f ci/Makefile $*" 98 | -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | # For building on Java 7 we need to explicitly enable TLSv1 (this is mainly for MOJO compatibility tests) 2 | systemProp.https.protocols=TLSv1,TLSv1.1,TLSv1.2 3 | 4 | version=0.3.20 5 | localNexusLocation=http://nexus:8081/nexus/repository 6 | -------------------------------------------------------------------------------- /gradle/publish.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'maven-publish' 2 | apply plugin: 'signing' 3 | 4 | /** 5 | * This plugin works by naming convention. 6 | * 7 | * It publishes by default java jar package and it also 8 | * expects publishArchives configuration and publishes whatever is inside it. 9 | */ 10 | 11 | 12 | // Container for generated POMs (by default added into archive configuration) 13 | configurations { 14 | pom { 15 | transitive = false 16 | } 17 | } 18 | 19 | // 20 | // Record all generated POM files for given publishing task 21 | // 22 | project.tasks.whenTaskAdded({ t -> 23 | // This is adhoc specific task for defined publication 24 | if (t.name.contains('generatePomFileForMavenAll')) { 25 | t.doLast({ tt -> 26 | artifacts { 27 | pom file(t.destination) 28 | } 29 | }) 30 | } 31 | }) 32 | 33 | publishing { 34 | publications { 35 | 36 | mavenAll(MavenPublication) { 37 | artifactId archivesBaseName 38 | // Publish all artifacts 39 | // NOTE: needs to be here to create a POM file with correct dependencies 40 | from components.java 41 | 42 | // Publish additional artifacts as documentation or source code 43 | configurations.publishArchives.allArtifacts.each { art -> 44 | logger.debug("Publishing artifact for: " + art) 45 | artifact art 46 | } 47 | 48 | 49 | pom { 50 | name = archivesBaseName 51 | description = "Pure Java implementation of XGBoost predictor for online prediction tasks" 52 | url = 'https://github.com/h2oai/xgboost-predictor' 53 | inceptionYear = '2018' 54 | 55 | organization { 56 | name = 'H2O.ai' 57 | url = 'http://h2o.ai/' 58 | } 59 | licenses { 60 | license { 61 | name = 'The Apache Software License, Version 2.0' 62 | url = 'http://www.apache.org/licenses/LICENSE-2.0.txt' 63 | distribution = 'repo' 64 | } 65 | } 66 | scm { 67 | url = 'https://github.com/h2oai/xgboost-predictor' 68 | connection = 'scm:git:https://github.com/h2oai/xgboost-predictor.git' 69 | developerConnection = 'scm:git:git@github.com:h2oai/xgboost-predictor.git' 70 | } 71 | issueManagement { 72 | system 'GitHub issues' 73 | url 'https://github.com/h2oai/h2o-3/issues' 74 | } 75 | developers { 76 | developer { 77 | id = 'support' 78 | name = 'H2O.ai Support' 79 | email = 'support@h2o.ai' 80 | } 81 | } 82 | } 83 | } 84 | } 85 | 86 | repositories { 87 | // Release to local repo 88 | maven { 89 | name "BuildRepo" 90 | url "$rootDir/build/repo" 91 | } 92 | 93 | def targetNexus = project.findProperty('targetNexus') 94 | def targetPrivate = false 95 | def targetPublic = false 96 | switch(targetNexus) { 97 | case 'private': 98 | targetPrivate = true 99 | break 100 | case 'public': 101 | targetPublic = true 102 | break 103 | } 104 | 105 | // Release to private nexus 106 | if (targetPrivate) { 107 | maven { 108 | name "LocalNexusRepo" 109 | url "${localNexusLocation}/snapshots" 110 | 111 | credentials { 112 | username project.findProperty("nexusUsername") ?: "" 113 | password project.findProperty("nexusPassword") ?: "" 114 | } 115 | } 116 | } 117 | 118 | // Release to public nexus 119 | if (targetPublic) { 120 | maven { 121 | name 'Public Nexus' 122 | url "https://oss.sonatype.org/service/local/staging/deploy/maven2/" 123 | 124 | credentials { 125 | username project.findProperty("nexusUsername") ?: "" 126 | password project.findProperty("nexusPassword") ?: "" 127 | } 128 | } 129 | } 130 | } 131 | } 132 | 133 | signing { 134 | required { 135 | project.hasProperty("doSign") 136 | } 137 | sign publishing.publications 138 | } 139 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Thu Aug 16 11:09:22 CEST 2018 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.3-all.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn ( ) { 37 | echo "$*" 38 | } 39 | 40 | die ( ) { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules 158 | function splitJvmOpts() { 159 | JVM_OPTS=("$@") 160 | } 161 | eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS 162 | JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" 163 | 164 | exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" 165 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | if "%@eval[2+2]" == "4" goto 4NT_args 53 | 54 | :win9xME_args 55 | @rem Slurp the command line arguments. 56 | set CMD_LINE_ARGS= 57 | set _SKIP=2 58 | 59 | :win9xME_args_slurp 60 | if "x%~1" == "x" goto execute 61 | 62 | set CMD_LINE_ARGS=%* 63 | goto execute 64 | 65 | :4NT_args 66 | @rem Get arguments from the 4NT Shell from JP Software 67 | set CMD_LINE_ARGS=%$ 68 | 69 | :execute 70 | @rem Setup the command line 71 | 72 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 73 | 74 | @rem Execute Gradle 75 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 76 | 77 | :end 78 | @rem End local scope for the variables with windows NT shell 79 | if "%ERRORLEVEL%"=="0" goto mainEnd 80 | 81 | :fail 82 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 83 | rem the _cmd.exe /c_ return code! 84 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 85 | exit /b 1 86 | 87 | :mainEnd 88 | if "%OS%"=="Windows_NT" endlocal 89 | 90 | :omega 91 | -------------------------------------------------------------------------------- /h2o-tree-api/build.gradle: -------------------------------------------------------------------------------- 1 | apply from: "$rootDir/gradle/publish.gradle" 2 | 3 | description = "Minimal API for accessing information about individual nodes of trees used in tree-based ML models" 4 | 5 | dependencies { 6 | } 7 | -------------------------------------------------------------------------------- /h2o-tree-api/src/main/java/ai/h2o/algos/tree/INode.java: -------------------------------------------------------------------------------- 1 | package ai.h2o.algos.tree; 2 | 3 | public interface INode { 4 | 5 | boolean isLeaf(); 6 | 7 | float getLeafValue(); 8 | 9 | int getSplitIndex(); 10 | 11 | int next(T value); 12 | 13 | int getLeftChildIndex(); 14 | 15 | int getRightChildIndex(); 16 | 17 | } -------------------------------------------------------------------------------- /h2o-tree-api/src/main/java/ai/h2o/algos/tree/INodeStat.java: -------------------------------------------------------------------------------- 1 | package ai.h2o.algos.tree; 2 | 3 | public interface INodeStat { 4 | 5 | float getWeight(); 6 | 7 | } 8 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'xgboost-predictor-java' 2 | include 'h2o-tree-api' 3 | include 'xgboost-predictor' -------------------------------------------------------------------------------- /xgboost-predictor/build.gradle: -------------------------------------------------------------------------------- 1 | apply from: "$rootDir/gradle/publish.gradle" 2 | 3 | description = "Pure Java implementation of XGBoost predictor for online prediction tasks" 4 | 5 | dependencies { 6 | compile project(':h2o-tree-api') 7 | compile 'net.jafama:jafama:2.1.0' 8 | testCompile "junit:junit:4.12" 9 | } 10 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/Predictor.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.gbm.GradBooster; 5 | import biz.k11i.xgboost.learner.ObjFunction; 6 | import biz.k11i.xgboost.spark.SparkModelParam; 7 | import biz.k11i.xgboost.util.FVec; 8 | import biz.k11i.xgboost.util.ModelReader; 9 | 10 | import java.io.IOException; 11 | import java.io.InputStream; 12 | import java.io.Serializable; 13 | import java.util.Arrays; 14 | 15 | /** 16 | * Predicts using the Xgboost model. 17 | */ 18 | public class Predictor implements Serializable { 19 | private ModelParam mparam; 20 | private SparkModelParam sparkModelParam; 21 | private String name_obj; 22 | private String name_gbm; 23 | private ObjFunction obj; 24 | private GradBooster gbm; 25 | 26 | private float base_score; 27 | 28 | public Predictor(InputStream in) throws IOException { 29 | this(in, null); 30 | } 31 | 32 | /** 33 | * Instantiates with the Xgboost model 34 | * 35 | * @param in input stream 36 | * @param configuration configuration 37 | * @throws IOException If an I/O error occurs 38 | */ 39 | public Predictor(InputStream in, PredictorConfiguration configuration) throws IOException { 40 | if (configuration == null) { 41 | configuration = PredictorConfiguration.DEFAULT; 42 | } 43 | 44 | ModelReader reader = new ModelReader(in); 45 | 46 | readParam(reader); 47 | initObjFunction(configuration); 48 | initObjGbm(); 49 | 50 | gbm.loadModel(configuration, reader, mparam.saved_with_pbuffer != 0); 51 | 52 | if (mparam.major_version >= 1) { 53 | base_score = obj.probToMargin(mparam.base_score); 54 | } else { 55 | base_score = mparam.base_score; 56 | } 57 | } 58 | 59 | void readParam(ModelReader reader) throws IOException { 60 | byte[] first4Bytes = reader.readByteArray(4); 61 | byte[] next4Bytes = reader.readByteArray(4); 62 | 63 | float base_score; 64 | int num_feature; 65 | 66 | if (first4Bytes[0] == 0x62 && 67 | first4Bytes[1] == 0x69 && 68 | first4Bytes[2] == 0x6e && 69 | first4Bytes[3] == 0x66) { 70 | 71 | // Old model file format has a signature "binf" (62 69 6e 66) 72 | base_score = reader.asFloat(next4Bytes); 73 | num_feature = reader.readUnsignedInt(); 74 | 75 | } else if (first4Bytes[0] == 0x00 && 76 | first4Bytes[1] == 0x05 && 77 | first4Bytes[2] == 0x5f) { 78 | 79 | // Model generated by xgboost4j-spark? 80 | String modelType = null; 81 | if (first4Bytes[3] == 0x63 && 82 | next4Bytes[0] == 0x6c && 83 | next4Bytes[1] == 0x73 && 84 | next4Bytes[2] == 0x5f) { 85 | // classification model 86 | modelType = SparkModelParam.MODEL_TYPE_CLS; 87 | 88 | } else if (first4Bytes[3] == 0x72 && 89 | next4Bytes[0] == 0x65 && 90 | next4Bytes[1] == 0x67 && 91 | next4Bytes[2] == 0x5f) { 92 | // regression model 93 | modelType = SparkModelParam.MODEL_TYPE_REG; 94 | } 95 | 96 | if (modelType != null) { 97 | int len = (next4Bytes[3] << 8) + (reader.readByteAsInt()); 98 | String featuresCol = reader.readUTF(len); 99 | 100 | this.sparkModelParam = new SparkModelParam(modelType, featuresCol, reader); 101 | 102 | base_score = reader.readFloat(); 103 | num_feature = reader.readUnsignedInt(); 104 | 105 | } else { 106 | base_score = reader.asFloat(first4Bytes); 107 | num_feature = reader.asUnsignedInt(next4Bytes); 108 | } 109 | 110 | } else { 111 | base_score = reader.asFloat(first4Bytes); 112 | num_feature = reader.asUnsignedInt(next4Bytes); 113 | } 114 | 115 | mparam = new ModelParam(base_score, num_feature, reader); 116 | 117 | name_obj = reader.readString(); 118 | name_gbm = reader.readString(); 119 | } 120 | 121 | void initObjFunction(PredictorConfiguration configuration) { 122 | obj = configuration.getObjFunction(); 123 | 124 | if (obj == null) { 125 | obj = ObjFunction.fromName(name_obj); 126 | } 127 | } 128 | 129 | void initObjGbm() { 130 | obj = ObjFunction.fromName(name_obj); 131 | gbm = GradBooster.Factory.createGradBooster(name_gbm); 132 | gbm.setNumClass(mparam.num_class); 133 | gbm.setNumFeature(mparam.num_feature); 134 | } 135 | 136 | /** 137 | * Generates predictions for given feature vector. 138 | * 139 | * @param feat feature vector 140 | * @return prediction values 141 | */ 142 | public float[] predict(FVec feat) { 143 | return predict(feat, false); 144 | } 145 | 146 | /** 147 | * Generates predictions for given feature vector. 148 | * 149 | * @param feat feature vector 150 | * @param output_margin whether to only predict margin value instead of transformed prediction 151 | * @return prediction values 152 | */ 153 | public float[] predict(FVec feat, boolean output_margin) { 154 | return predict(feat, output_margin, 0); 155 | } 156 | 157 | /** 158 | * Generates predictions for given feature vector. 159 | * 160 | * @param feat feature vector 161 | * @param base_margin predict with base margin for each prediction 162 | * @return prediction values 163 | */ 164 | public float[] predict(FVec feat, float base_margin) { 165 | return predict(feat, base_margin, 0); 166 | } 167 | 168 | /** 169 | * Generates predictions for given feature vector. 170 | * 171 | * @param feat feature vector 172 | * @param base_margin predict with base margin for each prediction 173 | * @param ntree_limit limit the number of trees used in prediction 174 | * @return prediction values 175 | */ 176 | public float[] predict(FVec feat, float base_margin, int ntree_limit) { 177 | float[] preds = predictRaw(feat, ntree_limit, base_margin); 178 | preds = obj.predTransform(preds); 179 | return preds; 180 | } 181 | 182 | /** 183 | * Generates predictions for given feature vector. 184 | * 185 | * @param feat feature vector 186 | * @param output_margin whether to only predict margin value instead of transformed prediction 187 | * @param ntree_limit limit the number of trees used in prediction 188 | * @return prediction values 189 | */ 190 | public float[] predict(FVec feat, boolean output_margin, int ntree_limit) { 191 | float[] preds = predictRaw(feat, ntree_limit, base_score); 192 | if (! output_margin) { 193 | preds = obj.predTransform(preds); 194 | } 195 | return preds; 196 | } 197 | 198 | float[] predictRaw(FVec feat, int ntree_limit, float base_score) { 199 | if (isBeforeOrEqual12()) { 200 | float[] preds = gbm.predict(feat, ntree_limit, 0 /* intentionally use 0 and add base score after to have the same floating point order of operation */); 201 | for (int i = 0; i < preds.length; i++) { 202 | preds[i] += base_score; 203 | } 204 | return preds; 205 | } else { 206 | // Since xgboost 1.3 the floating point operations order has changed - add base_score as first and predictions after 207 | return gbm.predict(feat, ntree_limit, base_score); 208 | } 209 | } 210 | 211 | /** 212 | * Generates a prediction for given feature vector. 213 | *

214 | * This method only works when the model outputs single value. 215 | *

216 | * 217 | * @param feat feature vector 218 | * @return prediction value 219 | */ 220 | public float predictSingle(FVec feat) { 221 | return predictSingle(feat, false); 222 | } 223 | 224 | /** 225 | * Generates a prediction for given feature vector. 226 | *

227 | * This method only works when the model outputs single value. 228 | *

229 | * 230 | * @param feat feature vector 231 | * @param output_margin whether to only predict margin value instead of transformed prediction 232 | * @return prediction value 233 | */ 234 | public float predictSingle(FVec feat, boolean output_margin) { 235 | return predictSingle(feat, output_margin, 0); 236 | } 237 | 238 | /** 239 | * Generates a prediction for given feature vector. 240 | *

241 | * This method only works when the model outputs single value. 242 | *

243 | * 244 | * @param feat feature vector 245 | * @param output_margin whether to only predict margin value instead of transformed prediction 246 | * @param ntree_limit limit the number of trees used in prediction 247 | * @return prediction value 248 | */ 249 | public float predictSingle(FVec feat, boolean output_margin, int ntree_limit) { 250 | float pred = predictSingleRaw(feat, ntree_limit); 251 | if (!output_margin) { 252 | pred = obj.predTransform(pred); 253 | } 254 | return pred; 255 | } 256 | 257 | float predictSingleRaw(FVec feat, int ntree_limit) { 258 | if (isBeforeOrEqual12()) { 259 | return gbm.predictSingle(feat, ntree_limit, 0) + base_score; 260 | } else { 261 | return gbm.predictSingle(feat, ntree_limit, base_score); 262 | } 263 | } 264 | 265 | /** 266 | * Predicts leaf index of each tree. 267 | * 268 | * @param feat feature vector 269 | * @return leaf indexes 270 | */ 271 | public int[] predictLeaf(FVec feat) { 272 | return predictLeaf(feat, 0); 273 | } 274 | 275 | /** 276 | * Predicts leaf index of each tree. 277 | * 278 | * @param feat feature vector 279 | * @param ntree_limit limit, 0 for all 280 | * @return leaf indexes 281 | */ 282 | public int[] predictLeaf(FVec feat, int ntree_limit) { 283 | return gbm.predictLeaf(feat, ntree_limit); 284 | } 285 | 286 | /** 287 | * Predicts path to leaf of each tree. 288 | * 289 | * @param feat feature vector 290 | * @return leaf paths 291 | */ 292 | public String[] predictLeafPath(FVec feat) { 293 | return predictLeafPath(feat, 0); 294 | } 295 | 296 | /** 297 | * Predicts path to leaf of each tree. 298 | * 299 | * @param feat feature vector 300 | * @param ntree_limit limit, 0 for all 301 | * @return leaf paths 302 | */ 303 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 304 | return gbm.predictLeafPath(feat, ntree_limit); 305 | } 306 | 307 | public SparkModelParam getSparkModelParam() { 308 | return sparkModelParam; 309 | } 310 | 311 | /** 312 | * Returns number of class. 313 | * 314 | * @return number of class 315 | */ 316 | public int getNumClass() { 317 | return mparam.num_class; 318 | } 319 | 320 | /** 321 | * Used e.g. for the change od floating point operation order in between xgboost 1.2 and 1.3 322 | * 323 | * @return True if the booster was build with xgboost version <= 1.2. 324 | */ 325 | private boolean isBeforeOrEqual12() { 326 | return mparam.major_version < 1 || (mparam.major_version == 1 && mparam.minor_version <= 2); 327 | } 328 | 329 | /** 330 | * Parameters. 331 | */ 332 | static class ModelParam implements Serializable { 333 | /* \brief global bias */ 334 | final float base_score; 335 | /* \brief number of features */ 336 | final /* unsigned */ int num_feature; 337 | /* \brief number of class, if it is multi-class classification */ 338 | final int num_class; 339 | /*! \brief whether the model itself is saved with pbuffer */ 340 | final int saved_with_pbuffer; 341 | /*! \brief Model contain eval metrics */ 342 | private final int contain_eval_metrics; 343 | /*! \brief the version of XGBoost. */ 344 | private final int major_version; 345 | private final int minor_version; 346 | /*! \brief reserved field */ 347 | final int[] reserved; 348 | 349 | ModelParam(float base_score, int num_feature, ModelReader reader) throws IOException { 350 | this.base_score = base_score; 351 | this.num_feature = num_feature; 352 | this.num_class = reader.readInt(); 353 | this.saved_with_pbuffer = reader.readInt(); 354 | this.contain_eval_metrics = reader.readInt(); 355 | this.major_version = reader.readUnsignedInt(); 356 | this.minor_version = reader.readUnsignedInt(); 357 | this.reserved = reader.readIntArray(27); 358 | } 359 | } 360 | 361 | public GradBooster getBooster(){ 362 | return gbm; 363 | } 364 | 365 | public String getObjName() { 366 | return name_obj; 367 | } 368 | 369 | public float getBaseScore() { 370 | return base_score; 371 | } 372 | 373 | } 374 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/config/PredictorConfiguration.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.config; 2 | 3 | import biz.k11i.xgboost.learner.ObjFunction; 4 | import biz.k11i.xgboost.tree.DefaultRegTreeFactory; 5 | import biz.k11i.xgboost.tree.RegTreeFactory; 6 | 7 | public class PredictorConfiguration { 8 | public static class Builder { 9 | private PredictorConfiguration predictorConfiguration; 10 | 11 | Builder() { 12 | predictorConfiguration = new PredictorConfiguration(); 13 | } 14 | 15 | public Builder objFunction(ObjFunction objFunction) { 16 | predictorConfiguration.objFunction = objFunction; 17 | return this; 18 | } 19 | 20 | public Builder regTreeFactory(RegTreeFactory regTreeFactory) { 21 | predictorConfiguration.regTreeFactory = regTreeFactory; 22 | return this; 23 | } 24 | 25 | public PredictorConfiguration build() { 26 | PredictorConfiguration result = predictorConfiguration; 27 | predictorConfiguration = null; 28 | return result; 29 | } 30 | } 31 | 32 | public static final PredictorConfiguration DEFAULT = new PredictorConfiguration(); 33 | 34 | private ObjFunction objFunction; 35 | private RegTreeFactory regTreeFactory; 36 | 37 | public PredictorConfiguration() { 38 | this.regTreeFactory = DefaultRegTreeFactory.INSTANCE; 39 | } 40 | 41 | public ObjFunction getObjFunction() { 42 | return objFunction; 43 | } 44 | 45 | public RegTreeFactory getRegTreeFactory() { 46 | return regTreeFactory; 47 | } 48 | 49 | public static Builder builder() { 50 | return new Builder(); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/Dart.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.tree.RegTree; 5 | import biz.k11i.xgboost.util.FVec; 6 | import biz.k11i.xgboost.util.ModelReader; 7 | 8 | import java.io.IOException; 9 | import java.util.Arrays; 10 | 11 | /** 12 | * Gradient boosted DART tree implementation. 13 | */ 14 | public class Dart extends GBTree { 15 | private float[] weightDrop; 16 | 17 | Dart() { 18 | // do nothing 19 | } 20 | 21 | @Override 22 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException { 23 | super.loadModel(config, reader, with_pbuffer); 24 | if (mparam.num_trees != 0) { 25 | long size = reader.readLong(); 26 | weightDrop = reader.readFloatArray((int)size); 27 | } 28 | } 29 | 30 | @Override 31 | float pred(FVec feat, int bst_group, int root_index, int ntree_limit, float base_score) { 32 | RegTree[] trees = _groupTrees[bst_group]; 33 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 34 | 35 | float psum = base_score; 36 | for (int i = 0; i < treeleft; i++) { 37 | psum += weightDrop[i] * trees[i].getLeafValue(feat, root_index); 38 | } 39 | 40 | return psum; 41 | } 42 | 43 | public float weight(int tidx) { 44 | return weightDrop[tidx]; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GBLinear.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Linear booster implementation 12 | */ 13 | public class GBLinear extends GBBase { 14 | 15 | private float[] weights; 16 | 17 | @Override 18 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean ignored_with_pbuffer) throws IOException { 19 | new ModelParam(reader); 20 | long len = reader.readLong(); 21 | if (len == 0) { 22 | weights = new float[(num_feature + 1) * num_output_group]; 23 | } else { 24 | weights = reader.readFloatArray((int) len); 25 | } 26 | } 27 | 28 | @Override 29 | public float[] predict(FVec feat, int ntree_limit, float base_score) { 30 | float[] preds = new float[num_output_group]; 31 | for (int gid = 0; gid < num_output_group; ++gid) { 32 | preds[gid] = pred(feat, gid, base_score); 33 | } 34 | return preds; 35 | } 36 | 37 | @Override 38 | public float predictSingle(FVec feat, int ntree_limit, float base_score) { 39 | if (num_output_group != 1) { 40 | throw new IllegalStateException( 41 | "Can't invoke predictSingle() because this model outputs multiple values: " 42 | + num_output_group); 43 | } 44 | return pred(feat, 0, base_score); 45 | } 46 | 47 | float pred(FVec feat, int gid, float base_score) { 48 | float psum = bias(gid) + base_score; 49 | float featValue; 50 | for (int fid = 0; fid < num_feature; ++fid) { 51 | featValue = feat.fvalue(fid); 52 | if (!Float.isNaN(featValue)) { 53 | psum += featValue * weight(fid, gid); 54 | } 55 | } 56 | return psum; 57 | } 58 | 59 | @Override 60 | public int[] predictLeaf(FVec feat, int ntree_limit) { 61 | throw new UnsupportedOperationException("gblinear does not support predict leaf index"); 62 | } 63 | 64 | @Override 65 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 66 | throw new UnsupportedOperationException("gblinear does not support predict leaf path"); 67 | } 68 | 69 | public float weight(int fid, int gid) { 70 | return weights[(fid * num_output_group) + gid]; 71 | } 72 | 73 | public float bias(int gid) { 74 | return weights[(num_feature * num_output_group) + gid]; 75 | } 76 | 77 | static class ModelParam implements Serializable { 78 | /*! \brief reserved space */ 79 | final int[] reserved; 80 | 81 | ModelParam(ModelReader reader) throws IOException { 82 | reader.readUnsignedInt(); // num_feature deprecated 83 | reader.readInt(); // num_output_group deprecated 84 | reserved = reader.readIntArray(32); 85 | } 86 | } 87 | 88 | public int getNumFeature() { 89 | return num_feature; 90 | } 91 | 92 | public int getNumOutputGroup() { 93 | return num_output_group; 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GBTree.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.tree.RegTree; 5 | import biz.k11i.xgboost.util.FVec; 6 | import biz.k11i.xgboost.util.ModelReader; 7 | 8 | import java.io.IOException; 9 | import java.io.Serializable; 10 | import java.util.Arrays; 11 | 12 | /** 13 | * Gradient boosted tree implementation. 14 | */ 15 | public class GBTree extends GBBase { 16 | 17 | ModelParam mparam; 18 | private RegTree[] trees; 19 | 20 | RegTree[][] _groupTrees; 21 | 22 | @Override 23 | public void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException { 24 | mparam = new ModelParam(reader); 25 | 26 | trees = new RegTree[mparam.num_trees]; 27 | for (int i = 0; i < mparam.num_trees; i++) { 28 | trees[i] = config.getRegTreeFactory().loadTree(reader); 29 | } 30 | 31 | int[] tree_info = mparam.num_trees > 0 ? reader.readIntArray(mparam.num_trees) : new int[0]; 32 | 33 | if (mparam.num_pbuffer != 0 && with_pbuffer) { 34 | reader.skip(4 * predBufferSize()); 35 | reader.skip(4 * predBufferSize()); 36 | } 37 | 38 | _groupTrees = new RegTree[num_output_group][]; 39 | for (int i = 0; i < num_output_group; i++) { 40 | int treeCount = 0; 41 | for (int j = 0; j < tree_info.length; j++) { 42 | if (tree_info[j] == i) { 43 | treeCount++; 44 | } 45 | } 46 | 47 | _groupTrees[i] = new RegTree[treeCount]; 48 | treeCount = 0; 49 | 50 | for (int j = 0; j < tree_info.length; j++) { 51 | if (tree_info[j] == i) { 52 | _groupTrees[i][treeCount++] = trees[j]; 53 | } 54 | } 55 | } 56 | } 57 | 58 | @Override 59 | public float[] predict(FVec feat, int ntree_limit, float base_score) { 60 | float[] preds = new float[num_output_group]; 61 | for (int gid = 0; gid < num_output_group; gid++) { 62 | preds[gid] += pred(feat, gid, 0, ntree_limit, base_score); 63 | } 64 | return preds; 65 | } 66 | 67 | @Override 68 | public float predictSingle(FVec feat, int ntree_limit, float base_score) { 69 | if (num_output_group != 1) { 70 | throw new IllegalStateException( 71 | "Can't invoke predictSingle() because this model outputs multiple values: " 72 | + num_output_group); 73 | } 74 | return pred(feat, 0, 0, ntree_limit, base_score); 75 | } 76 | 77 | float pred(FVec feat, int bst_group, int root_index, int ntree_limit, float base_score) { 78 | RegTree[] trees = _groupTrees[bst_group]; 79 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 80 | 81 | float psum = base_score; 82 | for (int i = 0; i < treeleft; i++) { 83 | psum += trees[i].getLeafValue(feat, root_index); 84 | } 85 | return psum; 86 | } 87 | 88 | @Override 89 | public int[] predictLeaf(FVec feat, int ntree_limit) { 90 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 91 | int[] leafIndex = new int[treeleft]; 92 | for (int i = 0; i < treeleft; i++) { 93 | leafIndex[i] = trees[i].getLeafIndex(feat); 94 | } 95 | return leafIndex; 96 | } 97 | 98 | @Override 99 | public String[] predictLeafPath(FVec feat, int ntree_limit) { 100 | int treeleft = ntree_limit == 0 ? trees.length : ntree_limit; 101 | String[] leafPath = new String[treeleft]; 102 | StringBuilder sb = new StringBuilder(64); 103 | for (int i = 0; i < treeleft; i++) { 104 | trees[i].getLeafPath(feat, sb); 105 | leafPath[i] = sb.toString(); 106 | sb.setLength(0); 107 | } 108 | return leafPath; 109 | } 110 | 111 | private long predBufferSize() { 112 | return num_output_group * mparam.num_pbuffer * (mparam.size_leaf_vector + 1); 113 | } 114 | 115 | static class ModelParam implements Serializable { 116 | /*! \brief number of trees */ 117 | final int num_trees; 118 | /*! \brief number of root: default 0, means single tree */ 119 | final int num_roots; 120 | /*! \brief size of predicton buffer allocated used for buffering */ 121 | final long num_pbuffer; 122 | /*! \brief size of leaf vector needed in tree */ 123 | final int size_leaf_vector; 124 | /*! \brief reserved space */ 125 | final int[] reserved; 126 | 127 | ModelParam(ModelReader reader) throws IOException { 128 | num_trees = reader.readInt(); 129 | num_roots = reader.readInt(); 130 | reader.readInt(); // num_feature deprecated 131 | reader.readInt(); // read padding 132 | num_pbuffer = reader.readLong(); 133 | reader.readInt(); // num_output_group not used anymore 134 | size_leaf_vector = reader.readInt(); 135 | reserved = reader.readIntArray(31); 136 | reader.readInt(); // read padding 137 | } 138 | 139 | } 140 | 141 | /** 142 | * 143 | * @return A two-dim array, with trees grouped into classes. 144 | */ 145 | public RegTree[][] getGroupedTrees(){ 146 | return _groupTrees; 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/gbm/GradBooster.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.gbm; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Interface of gradient boosting model. 12 | */ 13 | public interface GradBooster extends Serializable { 14 | 15 | class Factory { 16 | /** 17 | * Creates a gradient booster from given name. 18 | * 19 | * @param name name of gradient booster 20 | * @return created gradient booster 21 | */ 22 | public static GradBooster createGradBooster(String name) { 23 | if ("gbtree".equals(name)) { 24 | return new GBTree(); 25 | } else if ("gblinear".equals(name)) { 26 | return new GBLinear(); 27 | } else if ("dart".equals(name)) { 28 | return new Dart(); 29 | } 30 | 31 | throw new IllegalArgumentException(name + " is not supported model."); 32 | } 33 | } 34 | 35 | void setNumClass(int numClass); 36 | void setNumFeature(int numFeature); 37 | 38 | /** 39 | * Loads model from stream. 40 | * 41 | * @param config predictor configuration 42 | * @param reader input stream 43 | * @param with_pbuffer whether the incoming data contains pbuffer 44 | * @throws IOException If an I/O error occurs 45 | */ 46 | void loadModel(PredictorConfiguration config, ModelReader reader, boolean with_pbuffer) throws IOException; 47 | 48 | /** 49 | * Generates predictions for given feature vector. 50 | * 51 | * @param feat feature vector 52 | * @param ntree_limit limit the number of trees used in prediction 53 | * @param base_score base score to initialize prediction 54 | * @return prediction result 55 | */ 56 | float[] predict(FVec feat, int ntree_limit, float base_score); 57 | 58 | /** 59 | * Generates a prediction for given feature vector. 60 | *

61 | * This method only works when the model outputs single value. 62 | *

63 | * 64 | * @param feat feature vector 65 | * @param ntree_limit limit the number of trees used in prediction 66 | * @param base_score base score to initialize prediction 67 | * @return prediction result 68 | */ 69 | float predictSingle(FVec feat, int ntree_limit, float base_score); 70 | 71 | /** 72 | * Predicts the leaf index of each tree. This is only valid in gbtree predictor. 73 | * 74 | * @param feat feature vector 75 | * @param ntree_limit limit the number of trees used in prediction 76 | * @return predicted leaf indexes 77 | */ 78 | int[] predictLeaf(FVec feat, int ntree_limit); 79 | 80 | /** 81 | * Predicts the path to leaf of each tree. This is only valid in gbtree predictor. 82 | * 83 | * @param feat feature vector 84 | * @param ntree_limit limit the number of trees used in prediction 85 | * @return predicted path to leaves 86 | */ 87 | String[] predictLeafPath(FVec feat, int ntree_limit); 88 | 89 | } 90 | 91 | abstract class GBBase implements GradBooster { 92 | protected int num_class; 93 | protected int num_feature; 94 | protected int num_output_group; 95 | 96 | @Override 97 | public void setNumClass(int numClass) { 98 | this.num_class = numClass; 99 | this.num_output_group = (num_class == 0) ? 1 : num_class; 100 | } 101 | 102 | @Override 103 | public void setNumFeature(int numFeature) { 104 | this.num_feature = numFeature; 105 | } 106 | } -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/learner/ObjFunction.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.learner; 2 | 3 | import biz.k11i.xgboost.config.PredictorConfiguration; 4 | import net.jafama.FastMath; 5 | 6 | import java.io.Serializable; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Objective function implementations. 12 | */ 13 | public class ObjFunction implements Serializable { 14 | 15 | private static final Map FUNCTIONS = new HashMap<>(); 16 | 17 | static { 18 | register("rank:pairwise", new ObjFunction()); 19 | register("rank:ndcg", new ObjFunction()); 20 | register("binary:logistic", new RegLossObjLogistic()); 21 | register("reg:logistic", new RegLossObjLogistic()); 22 | register("binary:logitraw", new ObjFunction()); 23 | register("multi:softmax", new SoftmaxMultiClassObjClassify()); 24 | register("multi:softprob", new SoftmaxMultiClassObjProb()); 25 | register("reg:linear", new ObjFunction()); 26 | register("reg:squarederror", new ObjFunction()); 27 | register("reg:gamma", new RegObjFunction()); 28 | register("reg:tweedie", new RegObjFunction()); 29 | register("count:poisson", new RegObjFunction()); 30 | } 31 | 32 | /** 33 | * Gets {@link ObjFunction} from given name. 34 | * 35 | * @param name name of objective function 36 | * @return objective function 37 | */ 38 | public static ObjFunction fromName(String name) { 39 | ObjFunction result = FUNCTIONS.get(name); 40 | if (result == null) { 41 | throw new IllegalArgumentException(name + " is not supported objective function."); 42 | } 43 | return result; 44 | } 45 | 46 | /** 47 | * Register an {@link ObjFunction} for a given name. 48 | * 49 | * @param name name of objective function 50 | * @param objFunction objective function 51 | * @deprecated This method will be made private. Please use {@link PredictorConfiguration.Builder#objFunction(ObjFunction)} instead. 52 | */ 53 | public static void register(String name, ObjFunction objFunction) { 54 | FUNCTIONS.put(name, objFunction); 55 | } 56 | 57 | /** 58 | * Uses Jafama's {@link FastMath#exp(double)} instead of {@link Math#exp(double)}. 59 | * 60 | * @param useJafama {@code true} if you want to use Jafama's {@link FastMath#exp(double)}, 61 | * or {@code false} if you don't want to use it but JDK's {@link Math#exp(double)}. 62 | */ 63 | public static void useFastMathExp(boolean useJafama) { 64 | if (useJafama) { 65 | register("binary:logistic", new RegLossObjLogistic_Jafama()); 66 | register("multi:softprob", new SoftmaxMultiClassObjProb_Jafama()); 67 | 68 | } else { 69 | register("binary:logistic", new RegLossObjLogistic()); 70 | register("multi:softprob", new SoftmaxMultiClassObjProb()); 71 | } 72 | } 73 | 74 | /** 75 | * Transforms prediction values. 76 | * 77 | * @param preds prediction 78 | * @return transformed values 79 | */ 80 | public float[] predTransform(float[] preds) { 81 | // do nothing 82 | return preds; 83 | } 84 | 85 | /** 86 | * Transforms a prediction value. 87 | * 88 | * @param pred prediction 89 | * @return transformed value 90 | */ 91 | public float predTransform(float pred) { 92 | // do nothing 93 | return pred; 94 | } 95 | 96 | public float probToMargin(float prob) { 97 | // do nothing 98 | return prob; 99 | } 100 | 101 | /** 102 | * Regression. 103 | */ 104 | static class RegObjFunction extends ObjFunction { 105 | @Override 106 | public float[] predTransform(float[] preds) { 107 | if (preds.length != 1) 108 | throw new IllegalStateException( 109 | "Regression problem is supposed to have just a single predicted value, got " + preds.length + " instead." 110 | ); 111 | preds[0] = (float) Math.exp(preds[0]); 112 | return preds; 113 | } 114 | 115 | @Override 116 | public float predTransform(float pred) { 117 | return (float) Math.exp(pred); 118 | } 119 | 120 | @Override 121 | public float probToMargin(float prob) { 122 | return (float) Math.log(prob); 123 | } 124 | } 125 | 126 | /** 127 | * Logistic regression. 128 | */ 129 | static class RegLossObjLogistic extends ObjFunction { 130 | @Override 131 | public float[] predTransform(float[] preds) { 132 | for (int i = 0; i < preds.length; i++) { 133 | preds[i] = sigmoid(preds[i]); 134 | } 135 | return preds; 136 | } 137 | 138 | @Override 139 | public float predTransform(float pred) { 140 | return sigmoid(pred); 141 | } 142 | 143 | float sigmoid(float x) { 144 | return (1f / (1f + (float) Math.exp(-x))); 145 | } 146 | 147 | @Override 148 | public float probToMargin(float prob) { 149 | return (float) -Math.log(1.0f / prob - 1.0f); 150 | } 151 | } 152 | 153 | /** 154 | * Logistic regression. 155 | *

156 | * Jafama's {@link FastMath#exp(double)} version. 157 | *

158 | */ 159 | static class RegLossObjLogistic_Jafama extends RegLossObjLogistic { 160 | @Override 161 | float sigmoid(float x) { 162 | return (float) (1 / (1 + FastMath.exp(-x))); 163 | } 164 | } 165 | 166 | /** 167 | * Multiclass classification. 168 | */ 169 | static class SoftmaxMultiClassObjClassify extends ObjFunction { 170 | @Override 171 | public float[] predTransform(float[] preds) { 172 | int maxIndex = 0; 173 | float max = preds[0]; 174 | for (int i = 1; i < preds.length; i++) { 175 | if (max < preds[i]) { 176 | maxIndex = i; 177 | max = preds[i]; 178 | } 179 | } 180 | 181 | return new float[]{maxIndex}; 182 | } 183 | 184 | @Override 185 | public float predTransform(float pred) { 186 | throw new UnsupportedOperationException(); 187 | } 188 | } 189 | 190 | /** 191 | * Multiclass classification (predicted probability). 192 | */ 193 | static class SoftmaxMultiClassObjProb extends ObjFunction { 194 | @Override 195 | public float[] predTransform(float[] preds) { 196 | float max = preds[0]; 197 | for (int i = 1; i < preds.length; i++) { 198 | max = Math.max(preds[i], max); 199 | } 200 | 201 | double sum = 0; 202 | for (int i = 0; i < preds.length; i++) { 203 | preds[i] = exp(preds[i] - max); 204 | sum += preds[i]; 205 | } 206 | 207 | for (int i = 0; i < preds.length; i++) { 208 | preds[i] /= (float) sum; 209 | } 210 | 211 | return preds; 212 | } 213 | 214 | @Override 215 | public float predTransform(float pred) { 216 | throw new UnsupportedOperationException(); 217 | } 218 | 219 | float exp(float x) { 220 | return (float) Math.exp(x); 221 | } 222 | } 223 | 224 | /** 225 | * Multiclass classification (predicted probability). 226 | *

227 | * Jafama's {@link FastMath#exp(double)} version. 228 | *

229 | */ 230 | static class SoftmaxMultiClassObjProb_Jafama extends SoftmaxMultiClassObjProb { 231 | @Override 232 | float exp(float x) { 233 | return (float) FastMath.exp(x); 234 | } 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/spark/SparkModelParam.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.spark; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | 8 | public class SparkModelParam implements Serializable { 9 | public static final String MODEL_TYPE_CLS = "_cls_"; 10 | public static final String MODEL_TYPE_REG = "_reg_"; 11 | 12 | final String modelType; 13 | final String featureCol; 14 | 15 | final String labelCol; 16 | final String predictionCol; 17 | 18 | // classification model only 19 | final String rawPredictionCol; 20 | final double[] thresholds; 21 | 22 | public SparkModelParam(String modelType, String featureCol, ModelReader reader) throws IOException { 23 | this.modelType = modelType; 24 | this.featureCol = featureCol; 25 | this.labelCol = reader.readUTF(); 26 | this.predictionCol = reader.readUTF(); 27 | 28 | if (MODEL_TYPE_CLS.equals(modelType)) { 29 | this.rawPredictionCol = reader.readUTF(); 30 | int thresholdLength = reader.readIntBE(); 31 | this.thresholds = thresholdLength > 0 ? reader.readDoubleArrayBE(thresholdLength) : null; 32 | 33 | } else if (MODEL_TYPE_REG.equals(modelType)) { 34 | this.rawPredictionCol = null; 35 | this.thresholds = null; 36 | 37 | } else { 38 | throw new UnsupportedOperationException("Unknown modelType: " + modelType); 39 | } 40 | } 41 | 42 | public String getModelType() { 43 | return modelType; 44 | } 45 | 46 | public String getFeatureCol() { 47 | return featureCol; 48 | } 49 | 50 | public String getLabelCol() { 51 | return labelCol; 52 | } 53 | 54 | public String getPredictionCol() { 55 | return predictionCol; 56 | } 57 | 58 | public String getRawPredictionCol() { 59 | return rawPredictionCol; 60 | } 61 | 62 | public double[] getThresholds() { 63 | return thresholds; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/DefaultRegTreeFactory.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | 7 | public final class DefaultRegTreeFactory implements RegTreeFactory { 8 | 9 | public static RegTreeFactory INSTANCE = new DefaultRegTreeFactory(); 10 | 11 | @Override 12 | public final RegTree loadTree(ModelReader reader) throws IOException { 13 | RegTreeImpl regTree = new RegTreeImpl(); 14 | regTree.loadModel(reader); 15 | return regTree; 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTree.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.FVec; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Regression tree. 9 | */ 10 | public interface RegTree extends Serializable { 11 | 12 | /** 13 | * Retrieves nodes from root to leaf and returns leaf index. 14 | * 15 | * @param feat feature vector 16 | * @return leaf index 17 | */ 18 | int getLeafIndex(FVec feat); 19 | 20 | /** 21 | * Retrieves nodes from root to leaf and returns path to leaf. 22 | * 23 | * @param feat feature vector 24 | * @param sb output param, will write path path to leaf into this buffer 25 | */ 26 | void getLeafPath(FVec feat, StringBuilder sb); 27 | 28 | /** 29 | * Retrieves nodes from root to leaf and returns leaf value. 30 | * 31 | * @param feat feature vector 32 | * @param root_id starting root index 33 | * @return leaf value 34 | */ 35 | float getLeafValue(FVec feat, int root_id); 36 | 37 | /** 38 | * 39 | * @return Tree's nodes 40 | */ 41 | RegTreeNode[] getNodes(); 42 | 43 | /** 44 | * @return Tree's nodes stats 45 | */ 46 | RegTreeNodeStat[] getStats(); 47 | 48 | } 49 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeFactory.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.util.ModelReader; 4 | 5 | import java.io.IOException; 6 | 7 | public interface RegTreeFactory { 8 | 9 | RegTree loadTree(ModelReader reader) throws IOException; 10 | 11 | } 12 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeImpl.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import ai.h2o.algos.tree.INodeStat; 4 | import biz.k11i.xgboost.util.FVec; 5 | import biz.k11i.xgboost.util.ModelReader; 6 | 7 | import java.io.IOException; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Regression tree. 12 | */ 13 | public class RegTreeImpl implements RegTree { 14 | 15 | private Param param; 16 | private Node[] nodes; 17 | private RegTreeNodeStat[] stats; 18 | 19 | /** 20 | * Loads model from stream. 21 | * 22 | * @param reader input stream 23 | * @throws IOException If an I/O error occurs 24 | */ 25 | public void loadModel(ModelReader reader) throws IOException { 26 | param = new Param(reader); 27 | 28 | nodes = new Node[param.num_nodes]; 29 | for (int i = 0; i < param.num_nodes; i++) { 30 | nodes[i] = new Node(reader); 31 | } 32 | 33 | stats = new RegTreeNodeStat[param.num_nodes]; 34 | for (int i = 0; i < param.num_nodes; i++) { 35 | stats[i] = new RegTreeNodeStat(reader); 36 | } 37 | } 38 | 39 | /** 40 | * Retrieves nodes from root to leaf and returns leaf index. 41 | * 42 | * @param feat feature vector 43 | * @return leaf index 44 | */ 45 | @Override 46 | public int getLeafIndex(FVec feat) { 47 | int id = 0; 48 | Node n; 49 | while (!(n = nodes[id])._isLeaf) { 50 | id = n.next(feat); 51 | } 52 | return id; 53 | } 54 | 55 | /** 56 | * Retrieves nodes from root to leaf and returns path to leaf. 57 | * 58 | * @param feat feature vector 59 | * @param sb output param, will write path path to leaf into this buffer 60 | */ 61 | @Override 62 | public void getLeafPath(FVec feat, StringBuilder sb) { 63 | int id = 0; 64 | Node n; 65 | while (!(n = nodes[id])._isLeaf) { 66 | id = n.next(feat); 67 | sb.append(id == n.cleft_ ? "L" : "R"); 68 | } 69 | } 70 | 71 | /** 72 | * Retrieves nodes from root to leaf and returns leaf value. 73 | * 74 | * @param feat feature vector 75 | * @param root_id starting root index 76 | * @return leaf value 77 | */ 78 | @Override 79 | public float getLeafValue(FVec feat, int root_id) { 80 | Node n = nodes[root_id]; 81 | while (!n._isLeaf) { 82 | n = nodes[n.next(feat)]; 83 | } 84 | 85 | return n.leaf_value; 86 | } 87 | 88 | @Override 89 | public Node[] getNodes() { 90 | return nodes; 91 | } 92 | 93 | @Override 94 | public RegTreeNodeStat[] getStats() { 95 | return stats; 96 | } 97 | 98 | /** 99 | * Parameters. 100 | */ 101 | static class Param implements Serializable { 102 | /*! \brief number of start root */ 103 | final int num_roots; 104 | /*! \brief total number of nodes */ 105 | final int num_nodes; 106 | /*!\brief number of deleted nodes */ 107 | final int num_deleted; 108 | /*! \brief maximum depth, this is a statistics of the tree */ 109 | final int max_depth; 110 | /*! \brief number of features used for tree construction */ 111 | final int num_feature; 112 | /*! 113 | * \brief leaf vector size, used for vector tree 114 | * used to store more than one dimensional information in tree 115 | */ 116 | final int size_leaf_vector; 117 | /*! \brief reserved part */ 118 | final int[] reserved; 119 | 120 | Param(ModelReader reader) throws IOException { 121 | num_roots = reader.readInt(); 122 | num_nodes = reader.readInt(); 123 | num_deleted = reader.readInt(); 124 | max_depth = reader.readInt(); 125 | num_feature = reader.readInt(); 126 | 127 | size_leaf_vector = reader.readInt(); 128 | reserved = reader.readIntArray(31); 129 | } 130 | } 131 | 132 | public static class Node extends RegTreeNode implements Serializable { 133 | // pointer to parent, highest bit is used to 134 | // indicate whether it's a left child or not 135 | final int parent_; 136 | // pointer to left, right 137 | final int cleft_, cright_; 138 | // split feature index, left split or right split depends on the highest bit 139 | final /* unsigned */ int sindex_; 140 | // extra info (leaf_value or split_cond) 141 | final float leaf_value; 142 | final float split_cond; 143 | 144 | private final int _defaultNext; 145 | private final int _splitIndex; 146 | final boolean _isLeaf; 147 | 148 | // set parent 149 | Node(ModelReader reader) throws IOException { 150 | parent_ = reader.readInt(); 151 | cleft_ = reader.readInt(); 152 | cright_ = reader.readInt(); 153 | sindex_ = reader.readInt(); 154 | 155 | if (isLeaf()) { 156 | leaf_value = reader.readFloat(); 157 | split_cond = Float.NaN; 158 | } else { 159 | split_cond = reader.readFloat(); 160 | leaf_value = Float.NaN; 161 | } 162 | 163 | _defaultNext = cdefault(); 164 | _splitIndex = getSplitIndex(); 165 | _isLeaf = isLeaf(); 166 | } 167 | 168 | public boolean isLeaf() { 169 | return cleft_ == -1; 170 | } 171 | 172 | @Override 173 | public int getSplitIndex() { 174 | return (int) (sindex_ & ((1l << 31) - 1l)); 175 | } 176 | 177 | public int cdefault() { 178 | return default_left() ? cleft_ : cright_; 179 | } 180 | 181 | @Override 182 | public boolean default_left() { 183 | return (sindex_ >>> 31) != 0; 184 | } 185 | 186 | @Override 187 | public int next(FVec feat) { 188 | float value = feat.fvalue(_splitIndex); 189 | if (value != value) { // is NaN? 190 | return _defaultNext; 191 | } 192 | return (value < split_cond) ? cleft_ : cright_; 193 | } 194 | 195 | @Override 196 | public int getParentIndex() { 197 | return parent_; 198 | } 199 | 200 | @Override 201 | public int getLeftChildIndex() { 202 | return cleft_; 203 | } 204 | 205 | @Override 206 | public int getRightChildIndex() { 207 | return cright_; 208 | } 209 | 210 | @Override 211 | public float getSplitCondition() { 212 | return split_cond; 213 | } 214 | 215 | @Override 216 | public float getLeafValue(){ 217 | return leaf_value; 218 | } 219 | } 220 | 221 | } 222 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeNode.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import ai.h2o.algos.tree.INode; 4 | import biz.k11i.xgboost.util.FVec; 5 | 6 | import java.io.Serializable; 7 | 8 | public abstract class RegTreeNode implements INode, Serializable { 9 | 10 | /** 11 | * 12 | * @return Index of node's parent 13 | */ 14 | public abstract int getParentIndex(); 15 | 16 | /** 17 | * 18 | * @return Index of node's left child node 19 | */ 20 | public abstract int getLeftChildIndex(); 21 | 22 | /** 23 | * 24 | * @return Index of node's right child node 25 | */ 26 | public abstract int getRightChildIndex(); 27 | 28 | /** 29 | * 30 | * @return Split condition on the node, if the node is a split node. Leaf nodes have this value set to NaN 31 | */ 32 | public abstract float getSplitCondition(); 33 | 34 | /** 35 | * 36 | * @return Predicted value on the leaf node, if the node is leaf. Otherwise NaN 37 | */ 38 | public abstract float getLeafValue(); 39 | 40 | /** 41 | * 42 | * @return True if default direction for unrecognized values is the LEFT child, otherwise false. 43 | */ 44 | public abstract boolean default_left(); 45 | 46 | /** 47 | * 48 | * @return Index of domain category used to split on the node 49 | */ 50 | public abstract int getSplitIndex(); 51 | 52 | } 53 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/tree/RegTreeNodeStat.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import ai.h2o.algos.tree.INodeStat; 4 | import biz.k11i.xgboost.util.ModelReader; 5 | 6 | import java.io.IOException; 7 | import java.io.Serializable; 8 | 9 | /** 10 | * Statistics for node in tree. 11 | */ 12 | public class RegTreeNodeStat implements INodeStat, Serializable { 13 | 14 | final float loss_chg; 15 | final float sum_hess; 16 | final float base_weight; 17 | final int leaf_child_cnt; 18 | 19 | RegTreeNodeStat(ModelReader reader) throws IOException { 20 | loss_chg = reader.readFloat(); 21 | sum_hess = reader.readFloat(); 22 | base_weight = reader.readFloat(); 23 | leaf_child_cnt = reader.readInt(); 24 | } 25 | 26 | @Override 27 | public float getWeight() { 28 | return getCover(); 29 | } 30 | 31 | /** 32 | * @return loss chg caused by current split 33 | */ 34 | public float getGain() { 35 | return loss_chg; 36 | } 37 | 38 | /** 39 | * @return sum of hessian values, used to measure coverage of data 40 | */ 41 | public float getCover() { 42 | return sum_hess; 43 | } 44 | 45 | /** 46 | * @return weight of current node 47 | */ 48 | public float getBaseWeight() { 49 | return base_weight; 50 | } 51 | 52 | /** 53 | * @return number of child that is leaf node known up to now 54 | */ 55 | public int getLeafCount() { 56 | return leaf_child_cnt; 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/util/FVec.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | 6 | /** 7 | * Interface of feature vector. 8 | */ 9 | public interface FVec extends Serializable { 10 | /** 11 | * Gets index-th value. 12 | * 13 | * @param index index 14 | * @return value 15 | */ 16 | float fvalue(int index); 17 | 18 | class Transformer { 19 | private Transformer() { 20 | // do nothing 21 | } 22 | 23 | /** 24 | * Builds FVec from dense vector. 25 | * 26 | * @param values float values 27 | * @param treatsZeroAsNA treat zero as N/A if true 28 | * @return FVec 29 | */ 30 | public static FVec fromArray(float[] values, boolean treatsZeroAsNA) { 31 | return new FVecArrayImpl.FVecFloatArrayImpl(values, treatsZeroAsNA); 32 | } 33 | 34 | /** 35 | * Builds FVec from dense vector. 36 | * 37 | * @param values double values 38 | * @param treatsZeroAsNA treat zero as N/A if true 39 | * @return FVec 40 | */ 41 | public static FVec fromArray(double[] values, boolean treatsZeroAsNA) { 42 | return new FVecArrayImpl.FVecDoubleArrayImpl(values, treatsZeroAsNA); 43 | } 44 | 45 | /** 46 | * Builds FVec from map. 47 | * 48 | * @param map map containing non-zero values 49 | * @return FVec 50 | */ 51 | public static FVec fromMap(Map map) { 52 | return new FVecMapImpl(map); 53 | } 54 | } 55 | } 56 | 57 | class FVecMapImpl implements FVec { 58 | private final Map values; 59 | 60 | FVecMapImpl(Map values) { 61 | this.values = values; 62 | } 63 | 64 | @Override 65 | public float fvalue(int index) { 66 | Number number = values.get(index); 67 | if (number == null) { 68 | return Float.NaN; 69 | } 70 | 71 | return number.floatValue(); 72 | } 73 | } 74 | 75 | class FVecArrayImpl { 76 | static class FVecFloatArrayImpl implements FVec { 77 | private final float[] values; 78 | private final boolean treatsZeroAsNA; 79 | 80 | FVecFloatArrayImpl(float[] values, boolean treatsZeroAsNA) { 81 | this.values = values; 82 | this.treatsZeroAsNA = treatsZeroAsNA; 83 | } 84 | 85 | @Override 86 | public float fvalue(int index) { 87 | if (values.length <= index) { 88 | return Float.NaN; 89 | } 90 | 91 | float result = values[index]; 92 | if (treatsZeroAsNA && result == 0) { 93 | return Float.NaN; 94 | } 95 | 96 | return result; 97 | } 98 | } 99 | 100 | static class FVecDoubleArrayImpl implements FVec { 101 | private final double[] values; 102 | private final boolean treatsZeroAsNA; 103 | 104 | FVecDoubleArrayImpl(double[] values, boolean treatsZeroAsNA) { 105 | this.values = values; 106 | this.treatsZeroAsNA = treatsZeroAsNA; 107 | } 108 | 109 | @Override 110 | public float fvalue(int index) { 111 | if (values.length <= index) { 112 | return Float.NaN; 113 | } 114 | 115 | final double result = values[index]; 116 | if (treatsZeroAsNA && result == 0) { 117 | return Float.NaN; 118 | } 119 | 120 | return (float) result; 121 | } 122 | } 123 | } -------------------------------------------------------------------------------- /xgboost-predictor/src/main/java/biz/k11i/xgboost/util/ModelReader.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.util; 2 | 3 | import java.io.Closeable; 4 | import java.io.EOFException; 5 | import java.io.FileInputStream; 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.io.UTFDataFormatException; 9 | import java.nio.ByteBuffer; 10 | import java.nio.ByteOrder; 11 | import java.nio.charset.Charset; 12 | 13 | /** 14 | * Reads the Xgboost model from stream. 15 | */ 16 | public class ModelReader implements Closeable { 17 | private final InputStream stream; 18 | private byte[] buffer; 19 | 20 | @Deprecated 21 | public ModelReader(String filename) throws IOException { 22 | this(new FileInputStream(filename)); 23 | } 24 | 25 | public ModelReader(InputStream in) { 26 | stream = in; 27 | } 28 | 29 | private int fillBuffer(int numBytes) throws IOException { 30 | if (buffer == null || buffer.length < numBytes) { 31 | buffer = new byte[numBytes]; 32 | } 33 | 34 | int numBytesRead = 0; 35 | while (numBytesRead < numBytes) { 36 | int count = stream.read(buffer, numBytesRead, numBytes - numBytesRead); 37 | if (count < 0) { 38 | return numBytesRead; 39 | } 40 | numBytesRead += count; 41 | } 42 | 43 | return numBytesRead; 44 | } 45 | 46 | public int readByteAsInt() throws IOException { 47 | return stream.read(); 48 | } 49 | 50 | public byte[] readByteArray(int numBytes) throws IOException { 51 | int numBytesRead = fillBuffer(numBytes); 52 | if (numBytesRead < numBytes) { 53 | throw new EOFException( 54 | String.format("Cannot read byte array (shortage): expected = %d, actual = %d", 55 | numBytes, numBytesRead)); 56 | } 57 | 58 | byte[] result = new byte[numBytes]; 59 | System.arraycopy(buffer, 0, result, 0, numBytes); 60 | 61 | return result; 62 | } 63 | 64 | public int readInt() throws IOException { 65 | return readInt(ByteOrder.LITTLE_ENDIAN); 66 | } 67 | 68 | public int readIntBE() throws IOException { 69 | return readInt(ByteOrder.BIG_ENDIAN); 70 | } 71 | 72 | private int readInt(ByteOrder byteOrder) throws IOException { 73 | int numBytesRead = fillBuffer(4); 74 | if (numBytesRead < 4) { 75 | throw new EOFException("Cannot read int value (shortage): " + numBytesRead); 76 | } 77 | 78 | return ByteBuffer.wrap(buffer).order(byteOrder).getInt(); 79 | } 80 | 81 | public int[] readIntArray(int numValues) throws IOException { 82 | int numBytesRead = fillBuffer(numValues * 4); 83 | if (numBytesRead < numValues * 4) { 84 | throw new EOFException( 85 | String.format("Cannot read int array (shortage): expected = %d, actual = %d", 86 | numValues * 4, numBytesRead)); 87 | } 88 | 89 | ByteBuffer byteBuffer = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN); 90 | 91 | int[] result = new int[numValues]; 92 | for (int i = 0; i < numValues; i++) { 93 | result[i] = byteBuffer.getInt(); 94 | } 95 | 96 | return result; 97 | } 98 | 99 | public int readUnsignedInt() throws IOException { 100 | int result = readInt(); 101 | if (result < 0) { 102 | throw new IOException("Cannot read unsigned int (overflow): " + result); 103 | } 104 | 105 | return result; 106 | } 107 | 108 | public long readLong() throws IOException { 109 | int numBytesRead = fillBuffer(8); 110 | if (numBytesRead < 8) { 111 | throw new IOException("Cannot read long value (shortage): " + numBytesRead); 112 | } 113 | 114 | return ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getLong(); 115 | } 116 | 117 | public float asFloat(byte[] bytes) { 118 | return ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).getFloat(); 119 | } 120 | 121 | public int asUnsignedInt(byte[] bytes) throws IOException { 122 | int result = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).getInt(); 123 | if (result < 0) { 124 | throw new IOException("Cannot treat as unsigned int (overflow): " + result); 125 | } 126 | 127 | return result; 128 | } 129 | 130 | public float readFloat() throws IOException { 131 | int numBytesRead = fillBuffer(4); 132 | if (numBytesRead < 4) { 133 | throw new IOException("Cannot read float value (shortage): " + numBytesRead); 134 | } 135 | 136 | return ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getFloat(); 137 | } 138 | 139 | public float[] readFloatArray(int numValues) throws IOException { 140 | int numBytesRead = fillBuffer(numValues * 4); 141 | if (numBytesRead < numValues * 4) { 142 | throw new EOFException( 143 | String.format("Cannot read float array (shortage): expected = %d, actual = %d", 144 | numValues * 4, numBytesRead)); 145 | } 146 | 147 | ByteBuffer byteBuffer = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN); 148 | 149 | float[] result = new float[numValues]; 150 | for (int i = 0; i < numValues; i++) { 151 | result[i] = byteBuffer.getFloat(); 152 | } 153 | 154 | return result; 155 | } 156 | 157 | public double[] readDoubleArrayBE(int numValues) throws IOException { 158 | int numBytesRead = fillBuffer(numValues * 8); 159 | if (numBytesRead < numValues * 8) { 160 | throw new EOFException( 161 | String.format("Cannot read double array (shortage): expected = %d, actual = %d", 162 | numValues * 8, numBytesRead)); 163 | } 164 | 165 | ByteBuffer byteBuffer = ByteBuffer.wrap(buffer).order(ByteOrder.BIG_ENDIAN); 166 | 167 | double[] result = new double[numValues]; 168 | for (int i = 0; i < numValues; i++) { 169 | result[i] = byteBuffer.getDouble(); 170 | } 171 | 172 | return result; 173 | } 174 | 175 | public void skip(long numBytes) throws IOException { 176 | long numBytesRead = stream.skip(numBytes); 177 | if (numBytesRead < numBytes) { 178 | throw new IOException("Cannot skip bytes: " + numBytesRead); 179 | } 180 | } 181 | 182 | public String readString() throws IOException { 183 | long length = readLong(); 184 | if (length > Integer.MAX_VALUE) { 185 | throw new IOException("Too long string: " + length); 186 | } 187 | 188 | return readString((int) length); 189 | } 190 | 191 | public String readString(int numBytes) throws IOException { 192 | int numBytesRead = fillBuffer(numBytes); 193 | if (numBytesRead < numBytes) { 194 | throw new IOException(String.format("Cannot read string(%d) (shortage): %d", numBytes, numBytesRead)); 195 | } 196 | 197 | return new String(buffer, 0, numBytes, Charset.forName("UTF-8")); 198 | } 199 | 200 | public String readUTF() throws IOException { 201 | int utflen = readByteAsInt(); 202 | utflen = (short)((utflen << 8) | readByteAsInt()); 203 | return readUTF(utflen); 204 | } 205 | 206 | public String readUTF(int utflen) throws IOException { 207 | int numBytesRead = fillBuffer(utflen); 208 | if (numBytesRead < utflen) { 209 | throw new EOFException( 210 | String.format("Cannot read UTF string bytes: expected = %d, actual = %d", 211 | utflen, numBytesRead)); 212 | } 213 | 214 | char[] chararr = new char[utflen]; 215 | 216 | int c, char2, char3; 217 | int count = 0; 218 | int chararr_count=0; 219 | 220 | while (count < utflen) { 221 | c = (int) buffer[count] & 0xff; 222 | if (c > 127) break; 223 | count++; 224 | chararr[chararr_count++]=(char)c; 225 | } 226 | 227 | while (count < utflen) { 228 | c = (int) buffer[count] & 0xff; 229 | switch (c >> 4) { 230 | case 0: case 1: case 2: case 3: case 4: case 5: case 6: case 7: 231 | /* 0xxxxxxx*/ 232 | count++; 233 | chararr[chararr_count++]=(char)c; 234 | break; 235 | case 12: case 13: 236 | /* 110x xxxx 10xx xxxx*/ 237 | count += 2; 238 | if (count > utflen) 239 | throw new UTFDataFormatException( 240 | "malformed input: partial character at end"); 241 | char2 = (int) buffer[count-1]; 242 | if ((char2 & 0xC0) != 0x80) 243 | throw new UTFDataFormatException( 244 | "malformed input around byte " + count); 245 | chararr[chararr_count++]=(char)(((c & 0x1F) << 6) | 246 | (char2 & 0x3F)); 247 | break; 248 | case 14: 249 | /* 1110 xxxx 10xx xxxx 10xx xxxx */ 250 | count += 3; 251 | if (count > utflen) 252 | throw new UTFDataFormatException( 253 | "malformed input: partial character at end"); 254 | char2 = (int) buffer[count-2]; 255 | char3 = (int) buffer[count-1]; 256 | if (((char2 & 0xC0) != 0x80) || ((char3 & 0xC0) != 0x80)) 257 | throw new UTFDataFormatException( 258 | "malformed input around byte " + (count-1)); 259 | chararr[chararr_count++]=(char)(((c & 0x0F) << 12) | 260 | ((char2 & 0x3F) << 6) | 261 | ((char3 & 0x3F) << 0)); 262 | break; 263 | default: 264 | /* 10xx xxxx, 1111 xxxx */ 265 | throw new UTFDataFormatException( 266 | "malformed input around byte " + count); 267 | } 268 | } 269 | // The number of chars produced may be less than utflen 270 | return new String(chararr, 0, chararr_count); 271 | } 272 | 273 | @Override 274 | public void close() throws IOException { 275 | stream.close(); 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /xgboost-predictor/src/test/java/biz/k11i/xgboost/PredictorSmokeTest.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost; 2 | 3 | import biz.k11i.xgboost.util.FVec; 4 | import org.junit.Assert; 5 | import org.junit.Test; 6 | 7 | import java.io.IOException; 8 | 9 | public class PredictorSmokeTest { 10 | 11 | @Test 12 | public void shouldProvideEqualPredictionWithDifferentAPI_GBTree() throws IOException { 13 | Predictor predictor = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateTree.bin")); 14 | checkAPIForEachTreeImplementation(predictor, 63.56952667f /* obtain from xgboost version 1.2.0 - not affected with floating point operation order */); 15 | } 16 | 17 | @Test 18 | public void shouldProvideEqualPredictionWithDifferentAPI_GBLinear() throws IOException { 19 | Predictor predictor = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateLinear.bin")); 20 | checkAPIForEachTreeImplementation(predictor, 61.912136f /* obtain from xgboost version 1.2.0 - not affected with floating point operation order */); 21 | } 22 | 23 | @Test 24 | public void shouldProvideEqualPredictionWithDifferentAPI_Dart() throws IOException { 25 | Predictor predictor = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateDart.bin")); 26 | checkAPIForEachTreeImplementation(predictor, 66.0059433f /* obtain from xgboost version 1.2.0 - not affected with floating point operation order */); 27 | } 28 | 29 | private void checkAPIForEachTreeImplementation(Predictor predictor, float expected) { 30 | float[] rowData = new float[] { 13.2f, 23.6f}; 31 | FVec row = FVec.Transformer.fromArray(rowData, false); 32 | float[] prediction = predictor.predict(row); 33 | float[] prediction2 = predictor.predict(row, false); 34 | float[] prediction3 = predictor.predict(row, predictor.getBaseScore()); 35 | float[] prediction4 = predictor.predict(row, predictor.getBaseScore(), 0); 36 | float[] prediction5 = predictor.predict(row, false, 0); 37 | float predictionSingle = predictor.predictSingle(row); 38 | float predictionSingle2 = predictor.predictSingle(row, false); 39 | float predictionSingle3 = predictor.predictSingle(row, false, 0); 40 | 41 | Assert.assertEquals(1, prediction.length); 42 | Assert.assertEquals(expected, prediction[0], 0); 43 | Assert.assertEquals(expected, prediction2[0], 0); 44 | Assert.assertEquals(expected, prediction3[0], 0); 45 | Assert.assertEquals(expected, prediction4[0], 0); 46 | Assert.assertEquals(expected, prediction5[0], 0); 47 | Assert.assertEquals(expected, predictionSingle, 0); 48 | Assert.assertEquals(expected, predictionSingle2, 0); 49 | Assert.assertEquals(expected, predictionSingle3, 0); 50 | } 51 | 52 | @Test 53 | public void shouldProvideDifferentPredictionPrecisionDueTheFloatingPointError() throws IOException { 54 | Predictor predictor = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateTree.bin")); 55 | float[] rowData = new float[] { 15.2f, 0}; 56 | FVec row = FVec.Transformer.fromArray(rowData, false); 57 | float[] prediction = predictor.predict(row); 58 | float predictionSingle = predictor.predictSingle(row); 59 | 60 | /* 61 | * obtained from xgboost version 1.3.0 and newer (Current upper bound is 1.6.1) 62 | * Difference is due the different sequence of floating point addition 63 | * see: https://github.com/dmlc/xgboost/issues/6350 64 | * Previous result was 64.41069031f obtained from xgboost version 1.2.0 65 | */ 66 | float previousExpected = 64.41069031f; 67 | float expected = 64.41068268f; 68 | 69 | Assert.assertNotEquals(previousExpected, prediction[0], 0); 70 | Assert.assertNotEquals(previousExpected, predictionSingle, 0); 71 | Assert.assertEquals(1, prediction.length); 72 | Assert.assertEquals(expected, prediction[0], 0); 73 | Assert.assertEquals(expected, predictionSingle, 0); 74 | } 75 | 76 | @Test 77 | public void shouldProvidePredictionWithCustomBaseMargin() throws IOException { 78 | Predictor predictor = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateTree.bin")); 79 | float[] rowData = new float[] { 13.2f, 23.6f}; 80 | FVec row = FVec.Transformer.fromArray(rowData, false); 81 | float[] prediction = predictor.predict(row); 82 | float[] predictionWithCustomBaseMargin = predictor.predict(row, predictor.getBaseScore() + 0.2f); 83 | 84 | Assert.assertEquals(1, prediction.length); 85 | Assert.assertEquals(1, predictionWithCustomBaseMargin.length); 86 | Assert.assertEquals(prediction[0] + 0.2f, predictionWithCustomBaseMargin[0], 0); 87 | } 88 | 89 | @Test 90 | public void testThatPredictorCanSimulateDifferenceInFloatingPointOperationOrderAccordingToTheVersionInsideOfBinaryFile() throws IOException { 91 | Predictor predictor12 = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateTreeVersion12.bin")); 92 | Predictor predictor16 = new Predictor(getClass().getResourceAsStream("/prostate/boosterBytesProstateTree.bin")); 93 | float[] rowData = new float[] { 15.2f, 0}; 94 | FVec row = FVec.Transformer.fromArray(rowData, false); 95 | float[] prediction12 = predictor12.predict(row); 96 | float predictionSingle12 = predictor12.predictSingle(row); 97 | float[] prediction16 = predictor16.predict(row); 98 | float predictionSingle16 = predictor16.predictSingle(row); 99 | 100 | float expected12 = 64.41069031f; 101 | float expected16 = 64.41068268f; 102 | 103 | Assert.assertEquals(expected12, prediction12[0], 0); 104 | Assert.assertEquals(expected12, predictionSingle12, 0); 105 | Assert.assertEquals(1, prediction12.length); 106 | Assert.assertEquals(expected12, prediction12[0], 0); 107 | Assert.assertEquals(expected12, predictionSingle12, 0); 108 | 109 | Assert.assertEquals(expected16, prediction16[0], 0); 110 | Assert.assertEquals(expected16, predictionSingle16, 0); 111 | Assert.assertEquals(1, prediction16.length); 112 | Assert.assertEquals(expected16, prediction16[0], 0); 113 | Assert.assertEquals(expected16, predictionSingle16, 0); 114 | } 115 | 116 | //TODO add test that each parameter in each GradBooster implementation works properly 117 | //TODO add test for multivariate prediction 118 | } 119 | -------------------------------------------------------------------------------- /xgboost-predictor/src/test/java/biz/k11i/xgboost/tree/PredictorPredictLeafTest.java: -------------------------------------------------------------------------------- 1 | package biz.k11i.xgboost.tree; 2 | 3 | import biz.k11i.xgboost.Predictor; 4 | import biz.k11i.xgboost.util.FVec; 5 | import org.junit.Test; 6 | 7 | import java.io.IOException; 8 | 9 | import static org.junit.Assert.assertArrayEquals; 10 | 11 | public class PredictorPredictLeafTest { 12 | 13 | @Test 14 | public void shouldPredictLeafIds() throws IOException { 15 | Predictor p = new Predictor(getClass().getResourceAsStream("/boosterBytes.bin")); 16 | float[] input = new float[] { 10, 20, 30, 5, 7, 10 }; 17 | FVec vec = FVec.Transformer.fromArray(input, false); 18 | int[] preds = p.predictLeaf(vec); 19 | int[] exp = new int[] { 33, 47, 41, 41, 49 }; 20 | assertArrayEquals(exp, preds); 21 | } 22 | 23 | @Test 24 | public void shouldPredictLeafPaths() throws IOException { 25 | Predictor p = new Predictor(getClass().getResourceAsStream("/boosterBytes.bin")); 26 | float[] input = new float[] { 10, 20, 30, 5, 7, 10 }; 27 | FVec vec = FVec.Transformer.fromArray(input, false); 28 | String[] preds = p.predictLeafPath(vec); 29 | String[] exp = new String[] { "LRRLL", "LLRRLL", "LLRRLL", "LLRRLL", "LLRRLL" }; 30 | assertArrayEquals(exp, preds); 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/boosterBytes.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/xgboost-predictor/src/test/resources/boosterBytes.bin -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/boosterBytesProstateDart.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/xgboost-predictor/src/test/resources/prostate/boosterBytesProstateDart.bin -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/boosterBytesProstateLinear.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/xgboost-predictor/src/test/resources/prostate/boosterBytesProstateLinear.bin -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/boosterBytesProstateTree.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/xgboost-predictor/src/test/resources/prostate/boosterBytesProstateTree.bin -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/boosterBytesProstateTreeVersion12.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2oai/xgboost-predictor/4ee6ddfd5352abdadc02cc417d02b7db15da3982/xgboost-predictor/src/test/resources/prostate/boosterBytesProstateTreeVersion12.bin -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/recreate.txt: -------------------------------------------------------------------------------- 1 | boosterBytesProstateTree.bin and boosterBytesProstateLinear.bin was taken from MOJO of the test XGBoostTest.ProstateRegressionSmall() in H2O-3 repository but can be also obtained by this python code 2 | boosterBytesProstateTreeVersion12.bin is the same as boosterBytesProstateTree.bin but version is manually change in the binary editor 3 | 4 | !pip install xgboost==1.6.1 5 | !pip install scikit-learn 6 | !pip install pandas 7 | !wget "https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate54-2.csv" 8 | 9 | import xgboost as xgb 10 | from xgboost import XGBRegressor 11 | import sklearn 12 | import pandas 13 | 14 | data = pandas.read_csv("prostate54-2.csv") 15 | y_train = data["AGE"] 16 | X_train = data[["PSA", "VOL"]] 17 | # create model instance 18 | bst = XGBRegressor(n_estimators=8, max_depth=5, learning_rate=0.3, objective='reg:squarederror', seed=847020, booster="gbtree") # booster=gblinear 19 | # fit model 20 | bst.fit(X_train, y_train) 21 | bst.save_model("boosterBytesProstate.bin") -------------------------------------------------------------------------------- /xgboost-predictor/src/test/resources/prostate/recreate_darth.txt: -------------------------------------------------------------------------------- 1 | boosterBytesProstate.bin was taken from MOJO of the test XGBoostTest.ProstateRegressionSmall() in H2O-3 repository but can be also obtained by this python code 2 | 3 | !pip install xgboost==1.6.1 4 | !pip install scikit-learn 5 | !pip install pandas 6 | !wget "https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate54-2.csv" 7 | 8 | import xgboost as xgb 9 | from xgboost import XGBRegressor 10 | import sklearn 11 | import pandas 12 | 13 | data = pandas.read_csv("prostate54-2.csv") 14 | y_train = data["AGE"] 15 | X_train = data[["PSA", "VOL"]] 16 | # create model instance 17 | bst = XGBRegressor(n_estimators=20, max_depth=5, learning_rate=0.3, objective='reg:squarederror', seed=847020, booster="dart", rate_drop=0.1) 18 | # fit model 19 | bst.fit(X_train, y_train) 20 | bst.save_model("boosterBytesProstate.bin") --------------------------------------------------------------------------------