├── docs ├── imgs │ └── benchmark.png ├── release │ └── release.md └── benchmark │ └── result.md ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── .github ├── ISSUE_TEMPLATE │ ├── custom.md │ ├── feature_request.md │ └── bug-report.md └── workflows │ └── gradle.yml ├── gradle.properties ├── settings.gradle ├── treetops-benchmark ├── src │ └── main │ │ └── java │ │ └── io │ │ └── github │ │ └── horoc │ │ └── treetops │ │ └── benchmark │ │ ├── wine │ │ ├── SimplePredictorBenchmark.java │ │ └── GeneratedPredictorBenchmark.java │ │ ├── diabetes │ │ ├── SimplePredictorBenchmark.java │ │ └── GeneratedPredictorBenchmark.java │ │ ├── breastcancer │ │ ├── SimplePredictorBenchmark.java │ │ └── GeneratedPredictorBenchmark.java │ │ ├── californiahousing │ │ ├── SimplePredictorBenchmark.java │ │ └── GeneratedPredictorBenchmark.java │ │ └── common │ │ ├── ClassPathLoader.java │ │ └── AverageTimeBenchmarkTemplate.java └── build.gradle ├── treetops-core ├── build.gradle └── src │ ├── main │ └── java │ │ └── io │ │ └── github │ │ └── horoc │ │ └── treetops │ │ └── core │ │ ├── predictor │ │ ├── objective │ │ │ ├── RegressionObjectiveDecorator.java │ │ │ ├── CrossEntropyObjectiveConvertor.java │ │ │ ├── CrossEntropyLambdaObjectiveDecorator.java │ │ │ ├── RegressionL2ObjectiveDecorator.java │ │ │ ├── MultiClassObjectiveDecorator.java │ │ │ ├── AbstractOutputConvertor.java │ │ │ └── BinaryObjectiveDecorator.java │ │ ├── Predictor.java │ │ ├── MetaDataHolder.java │ │ ├── PredictorWrapper.java │ │ └── SimplePredictor.java │ │ ├── generator │ │ ├── Generator.java │ │ └── PredictorClassGenerator.java │ │ ├── model │ │ ├── MissingType.java │ │ ├── TreeModel.java │ │ ├── TreeNode.java │ │ └── RawTreeBlock.java │ │ ├── factory │ │ ├── FileTreeModelFactory.java │ │ ├── ObjectiveDecoratorFactory.java │ │ └── TreePredictorFactory.java │ │ ├── loader │ │ ├── FileTreeModelLoader.java │ │ └── AbstractLoader.java │ │ └── parser │ │ └── TreeModelParser.java │ └── test │ └── java │ ├── org │ └── treetops │ │ └── core │ │ └── validation │ │ ├── LoadModelTemplate.java │ │ ├── DiabetesModelTest.java │ │ ├── CaliforniaHousingModelTest.java │ │ ├── BreastCancerModelTest.java │ │ └── WineModelTest.java │ └── io │ └── github │ └── horoc │ └── treetops │ └── core │ └── validation │ ├── LoadModelTemplate.java │ ├── DiabetesModelTest.java │ ├── CaliforniaHousingModelTest.java │ ├── BreastCancerModelTest.java │ └── WineModelTest.java ├── .gitignore ├── LICENSE ├── gradlew.bat ├── README.md ├── gradlew └── checkstyle └── checkstyle.xml /docs/imgs/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horoc/treetops/HEAD/docs/imgs/benchmark.png -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/horoc/treetops/HEAD/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: "[ISSUE]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | group=io.github.horoc 2 | version=1.0-SNAPSHOT 3 | 4 | ossrhUsername=UserName 5 | ossrhPassword=Password 6 | 7 | signing.keyId=KeyId 8 | signing.password=PublicKeyPassword 9 | signing.secretKeyRingFile=PathToYourKeyRingFile -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'treetops' 2 | include 'treetops-core' 3 | findProject(':treetops-core')?.name = 'treetops-core' 4 | include 'treetops-benchmark' 5 | findProject(':treetops-benchmark')?.name = 'treetops-benchmark' 6 | 7 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.1-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/wine/SimplePredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.wine; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/18 6 | */ 7 | public class SimplePredictorBenchmark extends GeneratedPredictorBenchmark { 8 | 9 | @Override 10 | protected boolean isGenerated() { 11 | return false; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/diabetes/SimplePredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.diabetes; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/18 6 | */ 7 | public class SimplePredictorBenchmark extends GeneratedPredictorBenchmark { 8 | 9 | @Override 10 | protected boolean isGenerated() { 11 | return false; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/breastcancer/SimplePredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.breastcancer; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/18 6 | */ 7 | public class SimplePredictorBenchmark extends GeneratedPredictorBenchmark { 8 | 9 | @Override 10 | protected boolean isGenerated() { 11 | return false; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/californiahousing/SimplePredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.californiahousing; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/18 6 | */ 7 | public class SimplePredictorBenchmark extends GeneratedPredictorBenchmark { 8 | 9 | @Override 10 | protected boolean isGenerated() { 11 | return false; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /treetops-core/build.gradle: -------------------------------------------------------------------------------- 1 | dependencies { 2 | implementation 'commons-io:commons-io:2.11.0' 3 | implementation 'org.apache.commons:commons-lang3:3.12.0' 4 | implementation 'org.ow2.asm:asm:9.4' 5 | implementation 'org.ow2.asm:asm-util:9.4' 6 | implementation 'com.google.code.findbugs:jsr305:3.0.2' 7 | } 8 | 9 | test { 10 | useJUnitPlatform() 11 | testLogging { 12 | events "PASSED", "SKIPPED", "FAILED", "STANDARD_OUT", "STANDARD_ERROR" 13 | } 14 | } -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/common/ClassPathLoader.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.common; 2 | 3 | import io.github.horoc.treetops.core.loader.AbstractLoader; 4 | import java.io.InputStream; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/18 9 | */ 10 | public class ClassPathLoader extends AbstractLoader { 11 | 12 | @Override 13 | protected InputStream loadStream(final String resource) throws Exception { 14 | return this.getClass().getResourceAsStream(resource); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | *.ctxt 4 | .mtj.tmp/ 5 | *.war 6 | *.nar 7 | *.ear 8 | *.zip 9 | *.tar.gz 10 | *.rar 11 | hs_err_pid* 12 | target/ 13 | pom.xml.tag 14 | pom.xml.releaseBackup 15 | pom.xml.versionsBackup 16 | pom.xml.next 17 | release.properties 18 | dependency-reduced-pom.xml 19 | buildNumber.properties 20 | .mvn/timing.properties 21 | !/.mvn/wrapper/maven-wrapper.jar 22 | .idea/* 23 | .idea/compiler.xml 24 | .idea/encodings.xml 25 | .idea/modules.xml 26 | *.iml 27 | *.sw? 28 | .#* 29 | *# 30 | *~ 31 | .classpath 32 | .project 33 | .settings/ 34 | bin 35 | build 36 | target 37 | dependency-reduced-pom.xml 38 | *.sublime-* 39 | /scratch 40 | .gradle 41 | Guardfile 42 | README.html 43 | *.iml 44 | .idea 45 | Mock* 46 | exclude -------------------------------------------------------------------------------- /treetops-benchmark/build.gradle: -------------------------------------------------------------------------------- 1 | configurations.all { 2 | resolutionStrategy { 3 | force 'org.ow2.asm:asm:9.4' 4 | } 5 | } 6 | 7 | dependencies { 8 | implementation project(':treetops-core') 9 | implementation 'org.openjdk.jmh:jmh-core:1.35' 10 | implementation 'org.openjdk.jmh:jmh-generator-annprocess:1.35' 11 | testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0' 12 | testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0' 13 | } 14 | 15 | test { 16 | useJUnitPlatform() 17 | } 18 | 19 | sourceSets { 20 | jmh { 21 | java.srcDirs = ['src/main/java'] 22 | resources.srcDirs = ['src/main/resources'] 23 | compileClasspath += sourceSets.main.runtimeClasspath 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/RegressionObjectiveDecorator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/19 9 | */ 10 | public class RegressionObjectiveDecorator extends AbstractOutputConvertor { 11 | 12 | @Override 13 | public double[] convert(double[] input) { 14 | return new double[] {Math.exp(input[0])}; 15 | } 16 | 17 | @Override 18 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 19 | this.predictor = predictor; 20 | return this; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/CrossEntropyObjectiveConvertor.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/19 9 | */ 10 | public class CrossEntropyObjectiveConvertor extends AbstractOutputConvertor { 11 | 12 | @Override 13 | public double[] convert(double[] input) { 14 | return new double[] {1.0 / (1.0 + Math.exp(-input[0]))}; 15 | } 16 | 17 | @Override 18 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 19 | this.predictor = predictor; 20 | return this; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/CrossEntropyLambdaObjectiveDecorator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/19 9 | */ 10 | public class CrossEntropyLambdaObjectiveDecorator extends AbstractOutputConvertor { 11 | 12 | @Override 13 | public double[] convert(double[] input) { 14 | return new double[] {Math.log1p(Math.exp(-input[0]))}; 15 | } 16 | 17 | @Override 18 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 19 | this.predictor = predictor; 20 | return this; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/generator/Generator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.generator; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | 5 | /** 6 | * @author chenzhou@apache.org 7 | * created on 2023/2/14 8 | */ 9 | public interface Generator { 10 | 11 | /** 12 | * Load generated class bytecode into memory. 13 | * 14 | * @param className class name 15 | * @param code bytecode data 16 | * @return defined class 17 | */ 18 | Class defineClassFromCode(String className, byte[] code); 19 | 20 | /** 21 | * Generate predictor bytecode. 22 | * 23 | * @param className class name 24 | * @param model tree model 25 | * @return bytecode data 26 | */ 27 | byte[] generateCode(String className, TreeModel model); 28 | } 29 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/org/treetops/core/validation/LoadModelTemplate.java: -------------------------------------------------------------------------------- 1 | package org.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.factory.TreePredictorFactory; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/18 9 | */ 10 | public abstract class LoadModelTemplate { 11 | /** 12 | * Get test feature. 13 | * 14 | * @return features 15 | */ 16 | protected abstract double[] getFeature(); 17 | 18 | protected Predictor loadModel(String resource, boolean isGenerated) { 19 | ClassLoader classLoader = getClass().getClassLoader(); 20 | String path = classLoader.getResource(resource + ".txt").getPath(); 21 | return TreePredictorFactory.newInstance(resource, path, isGenerated); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/io/github/horoc/treetops/core/validation/LoadModelTemplate.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.factory.TreePredictorFactory; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/18 9 | */ 10 | public abstract class LoadModelTemplate { 11 | /** 12 | * Get test feature. 13 | * 14 | * @return features 15 | */ 16 | protected abstract double[] getFeature(); 17 | 18 | protected Predictor loadModel(String resource, boolean isGenerated) { 19 | ClassLoader classLoader = getClass().getClassLoader(); 20 | String path = classLoader.getResource(resource + ".txt").getPath(); 21 | return TreePredictorFactory.newInstance(resource, path, "./generated/", isGenerated); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/wine/GeneratedPredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.wine; 2 | 3 | import io.github.horoc.treetops.benchmark.common.AverageTimeBenchmarkTemplate; 4 | import java.util.Random; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/17 9 | */ 10 | public class GeneratedPredictorBenchmark extends AverageTimeBenchmarkTemplate { 11 | 12 | @Override 13 | protected String modelName() { 14 | return "wine_model"; 15 | } 16 | 17 | @Override 18 | protected boolean isGenerated() { 19 | return true; 20 | } 21 | 22 | @Override 23 | protected double[] getFeature() { 24 | features = new double[13]; 25 | Random r = new Random(); 26 | for (int i = 0; i < features.length; i++) { 27 | features[i] = -2.0 + 4 * r.nextDouble(); 28 | } 29 | return features; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/breastcancer/GeneratedPredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.breastcancer; 2 | 3 | import io.github.horoc.treetops.benchmark.common.AverageTimeBenchmarkTemplate; 4 | import java.util.Random; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/17 9 | */ 10 | public class GeneratedPredictorBenchmark extends AverageTimeBenchmarkTemplate { 11 | 12 | @Override 13 | protected String modelName() { 14 | return "breast_cancer_model"; 15 | } 16 | 17 | @Override 18 | protected boolean isGenerated() { 19 | return true; 20 | } 21 | 22 | @Override 23 | protected double[] getFeature() { 24 | features = new double[30]; 25 | Random r = new Random(); 26 | for (int i = 0; i < features.length; i++) { 27 | features[i] = -2.0 + 4 * r.nextDouble(); 28 | } 29 | return features; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/model/MissingType.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.model; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/14 6 | */ 7 | public enum MissingType { 8 | /** 9 | * decision type mask for default None value. 10 | */ 11 | None(0), 12 | /** 13 | * decision type mask for default Zero value. 14 | */ 15 | Zero(1), 16 | /** 17 | * decision type mask for default Nan value. 18 | */ 19 | Nan(2); 20 | 21 | private final int mask; 22 | 23 | MissingType(int mask) { 24 | this.mask = mask; 25 | } 26 | 27 | public int getMask() { 28 | return mask; 29 | } 30 | 31 | public static MissingType ofMask(int mask) { 32 | for (MissingType type : MissingType.values()) { 33 | if (type.getMask() == mask) { 34 | return type; 35 | } 36 | } 37 | return null; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/californiahousing/GeneratedPredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.californiahousing; 2 | 3 | import io.github.horoc.treetops.benchmark.common.AverageTimeBenchmarkTemplate; 4 | import java.util.Random; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/17 9 | */ 10 | public class GeneratedPredictorBenchmark extends AverageTimeBenchmarkTemplate { 11 | 12 | @Override 13 | protected String modelName() { 14 | return "california_housing_model"; 15 | } 16 | 17 | @Override 18 | protected boolean isGenerated() { 19 | return true; 20 | } 21 | 22 | @Override 23 | protected double[] getFeature() { 24 | features = new double[8]; 25 | Random r = new Random(); 26 | for (int i = 0; i < features.length; i++) { 27 | features[i] = -2.0 + 4 * r.nextDouble(); 28 | } 29 | return features; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: horoc 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/diabetes/GeneratedPredictorBenchmark.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.diabetes; 2 | 3 | import io.github.horoc.treetops.benchmark.common.AverageTimeBenchmarkTemplate; 4 | import java.util.Random; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/17 9 | */ 10 | public class GeneratedPredictorBenchmark extends AverageTimeBenchmarkTemplate { 11 | 12 | @Override 13 | protected String modelName() { 14 | return "diabetes_model"; 15 | } 16 | 17 | @Override 18 | protected boolean isGenerated() { 19 | return true; 20 | } 21 | 22 | @Override 23 | protected double[] getFeature() { 24 | features = new double[10]; 25 | Random r = new Random(); 26 | for (int i = 0; i < features.length; i++) { 27 | if (i == 1) { 28 | features[i] = r.nextInt(2); 29 | } else { 30 | features[i] = -2.0 + 4 * r.nextDouble(); 31 | } 32 | } 33 | return features; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /.github/workflows/gradle.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | # This workflow will build a Java project with Gradle and cache/restore any dependencies to improve the workflow execution time 6 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-java-with-gradle 7 | 8 | name: Continuous Integration 9 | 10 | on: 11 | push: 12 | branches: [ "master" ] 13 | pull_request: 14 | branches: [ "master" ] 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | build: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up JDK 8 27 | uses: actions/setup-java@v3 28 | with: 29 | java-version: '8' 30 | distribution: 'zulu' 31 | - name: Build with Gradle 32 | uses: gradle/gradle-build-action@67421db6bd0bf253fb4bd25b31ebb98943c375e1 33 | with: 34 | arguments: build 35 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/factory/FileTreeModelFactory.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.factory; 2 | 3 | import io.github.horoc.treetops.core.loader.FileTreeModelLoader; 4 | import io.github.horoc.treetops.core.model.TreeModel; 5 | import javax.annotation.ParametersAreNonnullByDefault; 6 | import javax.annotation.concurrent.Immutable; 7 | import javax.annotation.concurrent.ThreadSafe; 8 | import org.apache.commons.lang3.Validate; 9 | 10 | /** 11 | * Tree model loading factory. 12 | *
13 | * 14 | * @author chenzhou@apache.org 15 | * created on 2023/2/14 16 | */ 17 | @Immutable 18 | @ThreadSafe 19 | @ParametersAreNonnullByDefault 20 | public class FileTreeModelFactory { 21 | 22 | /** 23 | * Load and parse tree model from file. 24 | * 25 | * @param resource file path 26 | * @return tree model 27 | */ 28 | public static TreeModel newInstance(final String resource) { 29 | Validate.notBlank(resource, "model file resource path must not be empty"); 30 | return FileTreeModelLoader.getInstance().loadModel(resource); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ChenZhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/RegressionL2ObjectiveDecorator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/15 9 | */ 10 | public final class RegressionL2ObjectiveDecorator extends AbstractOutputConvertor { 11 | 12 | private static final String SQRT_CONFIG_KEY = "sqrt"; 13 | 14 | private boolean sqrt; 15 | 16 | @Override 17 | public double[] convert(double[] input) { 18 | double ret = input[0]; 19 | if (sqrt) { 20 | if (!Double.isNaN(ret)) { 21 | ret = ret * ret * (ret >= 0 ? 1 : -1); 22 | } 23 | } 24 | return new double[]{ret}; 25 | } 26 | 27 | @Override 28 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 29 | this.predictor = predictor; 30 | if (SQRT_CONFIG_KEY.equals(treeModel.getObjectiveConfig())) { 31 | sqrt = true; 32 | } 33 | return this; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/Predictor.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor; 2 | 3 | /** 4 | * @author chenzhou@apache.org 5 | * created on 2023/2/14 6 | */ 7 | public interface Predictor { 8 | 9 | /** 10 | * All implementation of Predictor should be subclass of this package. 11 | */ 12 | String PREDICTOR_CLASS_PREFIX = "io.github.treetops.core.predictor"; 13 | 14 | /** 15 | * Refer to official library api: microsoft/LightGBM/src/boosting/gbdt.h#GBDT::PredictRaw. 16 | *
17 | * 18 | * @param features input feature, size of features should be equals to max_feature_idx 19 | * @return output value, size of output should be num_class 20 | */ 21 | double[] predictRaw(double[] features); 22 | 23 | /** 24 | * Refer to official library api: microsoft/LightGBM/src/boosting/gbdt.h#GBDT::Predict. 25 | *
26 | * 27 | * @param features input feature, size of features should be equals to max_feature_idx 28 | * @return output value, size of output should be num_class 29 | */ 30 | default double[] predict(double[] features) { 31 | return predictRaw(features); 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/loader/FileTreeModelLoader.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.loader; 2 | 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileNotFoundException; 6 | import java.io.InputStream; 7 | import org.apache.commons.lang3.Validate; 8 | 9 | /** 10 | * File model loader. 11 | *
12 | * @author chenzhou@apache.org 13 | * created on 2023/2/14 14 | */ 15 | public final class FileTreeModelLoader extends AbstractLoader { 16 | 17 | private FileTreeModelLoader() { 18 | } 19 | 20 | public static FileTreeModelLoader getInstance() { 21 | return SingletonHolder.instance; 22 | } 23 | 24 | @Override 25 | protected InputStream loadStream(final String resource) throws Exception { 26 | Validate.notBlank(resource, "model file resource path must not be empty"); 27 | File file = new File(resource); 28 | if (!file.exists()) { 29 | throw new FileNotFoundException(String.format("model file not found, resource: %s", resource)); 30 | } 31 | return new FileInputStream(file); 32 | } 33 | 34 | private static class SingletonHolder { 35 | private static FileTreeModelLoader instance = new FileTreeModelLoader(); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/MultiClassObjectiveDecorator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/18 9 | */ 10 | public class MultiClassObjectiveDecorator extends AbstractOutputConvertor { 11 | 12 | @Override 13 | public double[] convert(double[] input) { 14 | double max = input[0]; 15 | double[] output = new double[input.length]; 16 | for (int i = 1; i < input.length; i++) { 17 | max = max < input[i] ? input[i] : max; 18 | } 19 | 20 | double expSum = 0.0; 21 | for (int i = 0; i < input.length; i++) { 22 | output[i] = Math.exp(input[i] - max); 23 | expSum += output[i]; 24 | } 25 | if (expSum == 0.0) { 26 | return output; 27 | } 28 | for (int i = 0; i < output.length; i++) { 29 | output[i] = output[i] / expSum; 30 | } 31 | return output; 32 | } 33 | 34 | @Override 35 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 36 | this.predictor = predictor; 37 | return this; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/AbstractOutputConvertor.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | 6 | /** 7 | * Objective convert function interface. 8 | *
9 | * 10 | * @author chenzhou@apache.org 11 | * created on 2023/2/14 12 | */ 13 | public abstract class AbstractOutputConvertor implements Predictor { 14 | 15 | protected Predictor predictor; 16 | 17 | /** 18 | * Convert input based on certain objective strategy. 19 | * 20 | * @param input raw input data 21 | * @return output data 22 | */ 23 | public abstract double[] convert(double[] input); 24 | 25 | /** 26 | * Binding to a predictor. 27 | * 28 | * @param predictor predictor which is need to be decorated 29 | * @param treeModel tree model 30 | * @return decorated predictor 31 | */ 32 | public abstract Predictor decorate(Predictor predictor, TreeModel treeModel); 33 | 34 | @Override 35 | public double[] predictRaw(double[] features) { 36 | return predictor.predictRaw(features); 37 | } 38 | 39 | @Override 40 | public double[] predict(double[] features) { 41 | return convert(predictor.predictRaw(features)); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/objective/BinaryObjectiveDecorator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor.objective; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | import org.apache.commons.lang3.StringUtils; 6 | 7 | /** 8 | * Binary objective function convertor. 9 | * 10 | * @author chenzhou@apache.org 11 | * created on 2023/2/14 12 | */ 13 | public final class BinaryObjectiveDecorator extends AbstractOutputConvertor { 14 | 15 | private static final String CONFIG_SEPARATOR = ":"; 16 | 17 | private double sigmoid; 18 | 19 | @Override 20 | public Predictor decorate(Predictor predictor, TreeModel treeModel) { 21 | this.predictor = predictor; 22 | this.sigmoid = parseSigmoidValue(treeModel.getObjectiveConfig()); 23 | return this; 24 | } 25 | 26 | @Override 27 | public double[] convert(double[] input) { 28 | return new double[]{1.0f / (1.0f + Math.exp(-sigmoid * input[0]))}; 29 | } 30 | 31 | private double parseSigmoidValue(String objectiveConfig) { 32 | if (StringUtils.isNotBlank(objectiveConfig)) { 33 | String[] sp = objectiveConfig.split(CONFIG_SEPARATOR); 34 | if (sp.length >= 1) { 35 | return Double.parseDouble(sp[1]); 36 | } 37 | } 38 | return 0.0; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/org/treetops/core/validation/DiabetesModelTest.java: -------------------------------------------------------------------------------- 1 | package org.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class DiabetesModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.41865135341839177, 1.0, 2.2034765108899994, 1.4731908880433968, -0.7561792976005409, -0.5608918117369689, -0.5254405155610098, 16 | -0.054499187536269665, 0.07803482757967586, 0.8481708171566219}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("diabetes_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | Assertions.assertEquals(258.1874753775234D, result[0]); 25 | } catch (Throwable e) { 26 | Assertions.fail(e); 27 | } 28 | } 29 | 30 | @Test 31 | public void testSimplePredictor() { 32 | try { 33 | Predictor predictor = loadModel("diabetes_model", false); 34 | double[] result = predictor.predict(getFeature()); 35 | Assertions.assertEquals(258.1874753775234D, result[0]); 36 | } catch (Throwable e) { 37 | Assertions.fail(e); 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/org/treetops/core/validation/CaliforniaHousingModelTest.java: -------------------------------------------------------------------------------- 1 | package org.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class CaliforniaHousingModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.14798009991170735, -0.5275608347965259, 0.09460886182134916, -0.04474251825448332, 16 | 0.11084370442811828, 0.10687073221002079, -1.428840535575722, 1.2576618924353786}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("california_housing_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | Assertions.assertEquals(1.6224174281069879D, result[0]); 25 | } catch (Throwable e) { 26 | Assertions.fail(e); 27 | } 28 | } 29 | 30 | @Test 31 | public void testSimplePredictor() { 32 | try { 33 | Predictor predictor = loadModel("california_housing_model", false); 34 | double[] result = predictor.predict(getFeature()); 35 | Assertions.assertEquals(1.6224174281069879D, result[0]); 36 | } catch (Throwable e) { 37 | Assertions.fail(e); 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/loader/AbstractLoader.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.loader; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.parser.TreeModelParser; 5 | import java.io.InputStream; 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.List; 8 | import org.apache.commons.io.IOUtils; 9 | import org.apache.commons.lang3.Validate; 10 | 11 | /** 12 | * @author chenzhou@apache.org 13 | * created on 2023/2/14 14 | */ 15 | public abstract class AbstractLoader { 16 | 17 | /** 18 | * Get input stream from corresponding resource path. 19 | * 20 | * @param resource resource path 21 | * @return InputStream from resource 22 | * @throws Exception exception while loading stream from resource 23 | */ 24 | protected abstract InputStream loadStream(String resource) throws Exception; 25 | 26 | /** 27 | * Load model from resource based on loadStream implementation, and parse into {@link TreeModel}. 28 | * 29 | * @param resource resource pah 30 | * @return TreeModel instance 31 | */ 32 | public TreeModel loadModel(final String resource) { 33 | Validate.notBlank(resource, "model resource path must not be empty"); 34 | List lines; 35 | try (InputStream stream = loadStream(resource)) { 36 | lines = IOUtils.readLines(stream, StandardCharsets.UTF_8); 37 | } catch (Exception e) { 38 | throw new RuntimeException(String.format("fail to load model from resource: %s", resource), e); 39 | } 40 | return TreeModelParser.parseTreeModel(lines); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/MetaDataHolder.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.model.TreeNode; 5 | import java.util.Objects; 6 | 7 | /** 8 | * Meta data holder for Predictor. 9 | * 10 | * @author chenzhou@apache.org 11 | * created on 2023/2/14 12 | */ 13 | public class MetaDataHolder { 14 | 15 | /** 16 | * category bit threshold used in process of category node decision. 17 | */ 18 | private long[][] catBitSet; 19 | 20 | public void initialize(TreeModel model) { 21 | if (model.isContainsCatNode()) { 22 | catBitSet = new long[model.getTrees().size()][]; 23 | for (int i = 0; i < catBitSet.length; i++) { 24 | TreeNode root = model.getTrees().get(i); 25 | if (Objects.isNull(root.getCatThreshold()) || root.getCatThreshold().isEmpty()) { 26 | continue; 27 | } 28 | catBitSet[i] = root.getCatThreshold().stream().mapToLong(l -> l).toArray(); 29 | } 30 | } 31 | } 32 | 33 | /** 34 | * Refer to include/LightGBM/utils/common.h#FindInBitsets. 35 | * 36 | * @param index tree index 37 | * @param begin begin 38 | * @param n length 39 | * @param val feature value 40 | * @return is find bit 41 | */ 42 | protected boolean findCatBitset(int index, int begin, int n, double val) { 43 | int pos = (int) val; 44 | int i1 = pos / 32; 45 | if (i1 >= n) { 46 | return false; 47 | } 48 | int i2 = pos % 32; 49 | return ((catBitSet[index][i1 + begin] >> i2) & 1) != 0; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /docs/release/release.md: -------------------------------------------------------------------------------- 1 | # Release 2 | 3 | Release to the maven center, you should follow these steps. 4 | 5 | 6 | 7 | ## 1. GPG Key 8 | 9 | ``` 10 | brew install gpg 11 | 12 | gpg --generate-key 13 | ``` 14 | 15 | 16 | 17 | To see generated gpg key 18 | 19 | ``` 20 | gpg -k 21 | ``` 22 | 23 | Last 8 characters of pub key is your `Key ID`, like `8D65454A` 24 | 25 | ``` 26 | pub ed25519 2023-02-22 [SC] [有效至:2025-02-21] 27 | 811E7040F84D1BD44BEC15EE6DE91A6E8D65454A 28 | uid ... ... 29 | sub ... ... 30 | ``` 31 | 32 | Export gpg file, replace your own Key ID: 33 | 34 | ``` 35 | gpg --export-secret-keys 8D65454A > secret.gpg 36 | ``` 37 | 38 | Send to public key server: 39 | 40 | ``` 41 | gpg --keyserver keyserver.ubuntu.com --send-keys 8D65454A 42 | ``` 43 | 44 | 45 | 46 | Add gpg secret file into treetops core module. 47 | 48 | 49 | 50 | ## 2. Modify Properties 51 | 52 | 53 | 54 | ``` 55 | group=io.github.horoc 56 | version=1.0-SNAPSHOT // version your want to release 57 | 58 | ossrhUsername=UserName 59 | ossrhPassword=Password 60 | 61 | signing.keyId=KeyId // key id, last 8 characters of your public key 62 | signing.password=PublicKeyPassword // password of your gpg key 63 | signing.secretKeyRingFile=PathToYourKeyRingFile // gpg file path, for treetops-core, it should be `secret.gpg` 64 | ``` 65 | 66 | 67 | 68 | ## 3. Run Publish 69 | 70 | ``` 71 | ./gradlw publish 72 | ``` 73 | 74 | 75 | 76 | ## 4. Check Publish Status 77 | 78 | for SNAPSHOT: 79 | 80 | ``` 81 | https://s01.oss.sonatype.org/content/repositories/snapshots/io/github/horoc 82 | ``` 83 | 84 | for Release: 85 | 86 | ``` 87 | https://s01.oss.sonatype.org/ 88 | ``` 89 | 90 | 91 | 92 | ## 5. Clean Up Your Modification 93 | 94 | 1. **You should clean the content of the gradle.properties to avoid publishing your password to the GitHub repository.** 95 | 2. **You should delete the gpg file which you add to the project** 96 | 97 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/PredictorWrapper.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import java.util.Objects; 5 | 6 | /** 7 | * Predictor Wrapper, introduce preprocess/postprocess of prediction process. 8 | * 9 | * @author chenzhou@apache.org 10 | * created on 2023/2/14 11 | */ 12 | public final class PredictorWrapper implements Predictor { 13 | 14 | private Predictor innerPredictor; 15 | 16 | /** 17 | * index of feature must not greater than it. 18 | */ 19 | private int maxFeatureIdx; 20 | 21 | public PredictorWrapper(Predictor innerPredictor, TreeModel treeModel) { 22 | if (Objects.isNull(innerPredictor)) { 23 | throw new IllegalArgumentException("new PredictorWrapper error, innerPredictor can not be null"); 24 | } 25 | this.innerPredictor = innerPredictor; 26 | this.maxFeatureIdx = treeModel.getMaxFeatureIndex(); 27 | } 28 | 29 | @Override 30 | public double[] predictRaw(double[] features) { 31 | checkInputFeature(features); 32 | return innerPredictor.predictRaw(features); 33 | } 34 | 35 | @Override 36 | public double[] predict(double[] features) { 37 | checkInputFeature(features); 38 | return innerPredictor.predict(features); 39 | } 40 | 41 | /** 42 | * pre-check of input features. 43 | * @param features input features 44 | */ 45 | private void checkInputFeature(double[] features) { 46 | if (Objects.isNull(features) || features.length > maxFeatureIdx + 1) { 47 | throw new IllegalArgumentException("input features size does not match the predict model"); 48 | } 49 | } 50 | 51 | /** 52 | * help gc, clean reference. 53 | */ 54 | public void release() { 55 | this.innerPredictor = null; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/org/treetops/core/validation/BreastCancerModelTest.java: -------------------------------------------------------------------------------- 1 | package org.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class BreastCancerModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {-0.20656117887535716, 0.2863110515326301, -0.13712355201435386, -0.2792598864377929, 1.0133758820201568, 0.8065563120018667, 0.699320480275225, 0.8460646517821447, 16 | 1.1112791563311162, 1.4817350698939018, -0.05259361139069434, -0.5193621584675753, 0.11234262958699215, -0.14668713819749174, -0.5423482925856186, -0.15806337722209013, 17 | 0.08707974741228168, 0.250429487524521, -0.4228423090304796, 0.0794691444911607, 0.029159331082772376, 0.6485704748522869, 0.1798703441573848, -0.06360678115852603, 1.0972739926049955, 18 | 0.835473817212997, 1.1437848605273928, 1.3779123052290023, 1.1069571429479146, 1.4936880726625947}; 19 | } 20 | 21 | @Test 22 | public void testPredictByGeneratedClass() { 23 | try { 24 | Predictor predictor = loadModel("breast_cancer_model", true); 25 | double[] result = predictor.predict(getFeature()); 26 | Assertions.assertEquals(0.03825810017556627D, result[0]); 27 | } catch (Throwable e) { 28 | Assertions.fail(e); 29 | } 30 | } 31 | 32 | @Test 33 | public void testSimplePredictor() { 34 | try { 35 | Predictor predictor = loadModel("breast_cancer_model", false); 36 | double[] result = predictor.predict(getFeature()); 37 | Assertions.assertEquals(0.03825810017556627D, result[0]); 38 | } catch (Throwable e) { 39 | Assertions.fail(e); 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/org/treetops/core/validation/WineModelTest.java: -------------------------------------------------------------------------------- 1 | package org.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class WineModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.913332708127694, -0.5981563241524619, -0.42590882302929406, -0.9293651811915726, 1.2819851519177938, 0.4885310849747506, 0.874184282556729, -1.2236095387553816, 16 | 0.050987616822829956, 0.34255654624083676, -0.16430337010617904, 0.830960739456274, 0.9970864580546609}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("wine_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | assertDoubleEquals(0.9849612333276241D, result[0]); 25 | assertDoubleEquals(0.008531186707393178D, result[1]); 26 | assertDoubleEquals(0.006507579964982725D, result[2]); 27 | } catch (Throwable e) { 28 | Assertions.fail(e); 29 | } 30 | } 31 | 32 | @Test 33 | public void testSimplePredictor() { 34 | try { 35 | Predictor predictor = loadModel("wine_model", false); 36 | double[] result = predictor.predict(getFeature()); 37 | assertDoubleEquals(0.9849612333276241D, result[0]); 38 | assertDoubleEquals(0.008531186707393178D, result[1]); 39 | assertDoubleEquals(0.006507579964982725D, result[2]); 40 | } catch (Throwable e) { 41 | Assertions.fail(e); 42 | } 43 | } 44 | 45 | private void assertDoubleEquals(double expected, double actual) { 46 | Assertions.assertTrue(Math.abs(expected - actual) < 1e-35f); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/model/TreeModel.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.model; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | /** 7 | * @author chenzhou@apache.org 8 | * created on 2023/2/14 9 | */ 10 | public class TreeModel { 11 | 12 | private int numClass; 13 | 14 | private int numberTreePerIteration; 15 | 16 | private int maxFeatureIndex; 17 | 18 | private boolean containsCatNode; 19 | 20 | private String objectiveType; 21 | 22 | private String objectiveConfig; 23 | 24 | private List trees = new ArrayList<>(); 25 | 26 | public int getNumClass() { 27 | return numClass; 28 | } 29 | 30 | public void setNumClass(int numClass) { 31 | this.numClass = numClass; 32 | } 33 | 34 | public int getNumberTreePerIteration() { 35 | return numberTreePerIteration; 36 | } 37 | 38 | public void setNumberTreePerIteration(int numberTreePerIteration) { 39 | this.numberTreePerIteration = numberTreePerIteration; 40 | } 41 | 42 | public int getMaxFeatureIndex() { 43 | return maxFeatureIndex; 44 | } 45 | 46 | public void setMaxFeatureIndex(int maxFeatureIndex) { 47 | this.maxFeatureIndex = maxFeatureIndex; 48 | } 49 | 50 | public boolean isContainsCatNode() { 51 | return containsCatNode; 52 | } 53 | 54 | public void setContainsCatNode(boolean containsCatNode) { 55 | this.containsCatNode = containsCatNode; 56 | } 57 | 58 | public String getObjectiveType() { 59 | return objectiveType; 60 | } 61 | 62 | public void setObjectiveType(String objectiveType) { 63 | this.objectiveType = objectiveType; 64 | } 65 | 66 | public String getObjectiveConfig() { 67 | return objectiveConfig; 68 | } 69 | 70 | public void setObjectiveConfig(String objectiveConfig) { 71 | this.objectiveConfig = objectiveConfig; 72 | } 73 | 74 | public List getTrees() { 75 | return trees; 76 | } 77 | 78 | public void setTrees(List trees) { 79 | this.trees = trees; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /treetops-benchmark/src/main/java/io/github/horoc/treetops/benchmark/common/AverageTimeBenchmarkTemplate.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.benchmark.common; 2 | 3 | import io.github.horoc.treetops.core.factory.TreePredictorFactory; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | import java.util.concurrent.TimeUnit; 6 | import org.openjdk.jmh.annotations.Benchmark; 7 | import org.openjdk.jmh.annotations.BenchmarkMode; 8 | import org.openjdk.jmh.annotations.Fork; 9 | import org.openjdk.jmh.annotations.Group; 10 | import org.openjdk.jmh.annotations.Measurement; 11 | import org.openjdk.jmh.annotations.Mode; 12 | import org.openjdk.jmh.annotations.OutputTimeUnit; 13 | import org.openjdk.jmh.annotations.Scope; 14 | import org.openjdk.jmh.annotations.Setup; 15 | import org.openjdk.jmh.annotations.State; 16 | import org.openjdk.jmh.annotations.Warmup; 17 | 18 | @BenchmarkMode(Mode.AverageTime) 19 | @OutputTimeUnit(TimeUnit.NANOSECONDS) 20 | @State(Scope.Group) 21 | @Warmup(iterations = 5, time = 100, timeUnit = TimeUnit.MILLISECONDS) 22 | @Measurement(iterations = 5, time = 100, timeUnit = TimeUnit.MILLISECONDS) 23 | @Fork(2) 24 | public abstract class AverageTimeBenchmarkTemplate { 25 | 26 | private static final int BATCH_SIZE = 500; 27 | 28 | @SuppressWarnings("checkstyle:VisibilityModifier") 29 | protected Predictor predictor; 30 | 31 | @SuppressWarnings("checkstyle:VisibilityModifier") 32 | protected double[] features; 33 | 34 | /** 35 | * model name, model file should be modelName.txt. 36 | * 37 | * @return model name 38 | */ 39 | protected abstract String modelName(); 40 | 41 | /** 42 | * is test predictor enable generated. 43 | * 44 | * @return is generated predictor 45 | */ 46 | protected abstract boolean isGenerated(); 47 | 48 | /** 49 | * Get test feature. 50 | * 51 | * @return features 52 | */ 53 | protected abstract double[] getFeature(); 54 | 55 | @Setup 56 | public void setup() { 57 | TreePredictorFactory.setTreeModelLoader(new ClassPathLoader()); 58 | this.predictor = TreePredictorFactory.newInstance(modelName(), "/" + modelName() + ".txt", isGenerated()); 59 | features = getFeature(); 60 | } 61 | 62 | @Benchmark 63 | @Group 64 | public void predict() { 65 | for (int i = 0; i < BATCH_SIZE; i++) { 66 | predictor.predict(features); 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/io/github/horoc/treetops/core/validation/DiabetesModelTest.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class DiabetesModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.41865135341839177, 1.0, 2.2034765108899994, 1.4731908880433968, -0.7561792976005409, -0.5608918117369689, -0.5254405155610098, 16 | -0.054499187536269665, 0.07803482757967586, 0.8481708171566219}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("diabetes_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | Assertions.assertEquals(258.1874753775234D, result[0]); 25 | } catch (Throwable e) { 26 | Assertions.fail(e); 27 | } 28 | } 29 | 30 | @Test 31 | public void testSimplePredictor() { 32 | try { 33 | Predictor predictor = loadModel("diabetes_model", false); 34 | double[] result = predictor.predict(getFeature()); 35 | Assertions.assertEquals(258.1874753775234D, result[0]); 36 | } catch (Throwable e) { 37 | Assertions.fail(e); 38 | } 39 | } 40 | 41 | @Test 42 | public void testNanFeature() { 43 | try { 44 | Predictor predictor = loadModel("diabetes_model", false); 45 | double[] features = getFeature(); 46 | features[1] = Double.NaN; 47 | features[3] = Double.NaN; 48 | double[] result = predictor.predict(features); 49 | Assertions.assertEquals(223.98172555247731D, result[0]); 50 | } catch (Throwable e) { 51 | Assertions.fail(e); 52 | } 53 | } 54 | 55 | @Test 56 | public void testNanFeatureByGeneratedClass() { 57 | try { 58 | Predictor predictor = loadModel("diabetes_model", true); 59 | double[] features = getFeature(); 60 | features[1] = Double.NaN; 61 | features[3] = Double.NaN; 62 | double[] result = predictor.predict(features); 63 | Assertions.assertEquals(223.98172555247731D, result[0]); 64 | } catch (Throwable e) { 65 | Assertions.fail(e); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/io/github/horoc/treetops/core/validation/CaliforniaHousingModelTest.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class CaliforniaHousingModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.14798009991170735, -0.5275608347965259, 0.09460886182134916, -0.04474251825448332, 16 | 0.11084370442811828, 0.10687073221002079, -1.428840535575722, 1.2576618924353786}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("california_housing_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | Assertions.assertEquals(1.6224174281069879D, result[0]); 25 | } catch (Throwable e) { 26 | Assertions.fail(e); 27 | } 28 | } 29 | 30 | @Test 31 | public void testSimplePredictor() { 32 | try { 33 | Predictor predictor = loadModel("california_housing_model", false); 34 | double[] result = predictor.predict(getFeature()); 35 | Assertions.assertEquals(1.6224174281069879D, result[0]); 36 | } catch (Throwable e) { 37 | Assertions.fail(e); 38 | } 39 | } 40 | 41 | @Test 42 | public void testNanFeature() { 43 | try { 44 | Predictor predictor = loadModel("california_housing_model", false); 45 | double[] features = getFeature(); 46 | features[0] = Double.NaN; 47 | features[5] = Double.NaN; 48 | double[] result = predictor.predict(features); 49 | Assertions.assertEquals(1.5796328060168647D, result[0]); 50 | } catch (Throwable e) { 51 | Assertions.fail(e); 52 | } 53 | } 54 | 55 | @Test 56 | public void testNanFeatureByGeneratedClass() { 57 | try { 58 | Predictor predictor = loadModel("california_housing_model", true); 59 | double[] features = getFeature(); 60 | features[0] = Double.NaN; 61 | features[5] = Double.NaN; 62 | double[] result = predictor.predict(features); 63 | Assertions.assertEquals(1.5796328060168647D, result[0]); 64 | } catch (Throwable e) { 65 | Assertions.fail(e); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /docs/benchmark/result.md: -------------------------------------------------------------------------------- 1 | # Benchmark Result 2 | 3 | ## Environment 4 | 5 | - CPU: Intel(R) Xeon(R) Platinum 8372HC CPU @ 3.40GHz 6 | - Memory: 8G 7 | - Jdk: 1.8.0_362 8 | - Tool: JMH 9 | 10 | Test config : see class `io/github/treetops/benchmark/common/AverageTimeBenchmarkTemplate.java` 11 | 12 | ## BatchSize: 100 13 | 14 | ``` 15 | Benchmark Mode Cnt Score Error Units 16 | i.g.t.b.breastcancer.GeneratedPredictorBenchmark.group avgt 10 41839.248 ± 1021.564 ns/op 17 | i.g.t.b.breastcancer.SimplePredictorBenchmark.group avgt 10 496211.590 ± 144021.236 ns/op 18 | i.g.t.b.californiahousing.GeneratedPredictorBenchmark.group avgt 10 47484.875 ± 6165.425 ns/op 19 | i.g.t.b.californiahousing.SimplePredictorBenchmark.group avgt 10 681600.553 ± 257224.470 ns/op 20 | i.g.t.b.diabetes.GeneratedPredictorBenchmark.group avgt 10 42794.777 ± 4549.512 ns/op 21 | i.g.t.b.diabetes.SimplePredictorBenchmark.group avgt 10 465262.379 ± 193065.665 ns/op 22 | i.g.t.b.wine.GeneratedPredictorBenchmark.group avgt 10 111460.641 ± 12352.069 ns/op 23 | i.g.t.b.wine.SimplePredictorBenchmark.group avgt 10 1200014.729 ± 283487.105 ns/op 24 | ``` 25 | 26 | ## BatchSize: 500 27 | 28 | ``` 29 | Benchmark Mode Cnt Score Error Units 30 | i.g.t.b.breastcancer.GeneratedPredictorBenchmark.group avgt 10 208786.500 ± 7699.946 ns/op 31 | i.g.t.b.breastcancer.SimplePredictorBenchmark.group avgt 10 1994982.258 ± 403898.299 ns/op 32 | i.g.t.b.californiahousing.GeneratedPredictorBenchmark.group avgt 10 222625.586 ± 1881.976 ns/op 33 | i.g.t.b.californiahousing.SimplePredictorBenchmark.group avgt 10 3409444.662 ± 1336838.246 ns/op 34 | i.g.t.b.diabetes.GeneratedPredictorBenchmark.group avgt 10 147066.124 ± 15336.055 ns/op 35 | i.g.t.b.diabetes.SimplePredictorBenchmark.group avgt 10 2023224.922 ± 498936.117 ns/op 36 | i.g.t.b.wine.GeneratedPredictorBenchmark.group avgt 10 517825.494 ± 85232.860 ns/op 37 | i.g.t.b.wine.SimplePredictorBenchmark.group avgt 10 5981068.592 ± 1483154.756 ns/op 38 | ``` 39 | 40 | 41 | ## Graph 42 | 43 | ```python 44 | from pyecharts.charts import Bar 45 | from pyecharts import options as opts 46 | from pyecharts.globals import ThemeType 47 | from pyecharts.render import make_snapshot 48 | 49 | bar = ( 50 | Bar(init_opts=opts.InitOpts(theme=ThemeType.WESTEROS)) 51 | .add_xaxis(["bc/100", "ch/100", "db/100", "wn/100","bc/500", "ch/500", "db/500", "wn/500"]) 52 | .add_yaxis("asm", [41, 47, 42, 111, 208, 222, 147, 517]) 53 | .add_yaxis("simple", [496, 681, 465, 1200, 1994, 3409, 2023, 5981]) 54 | .set_global_opts(title_opts=opts.TitleOpts(title="Average Latency (us)", subtitle="model_name/batch_size")) 55 | ) 56 | bar.render() 57 | ``` -------------------------------------------------------------------------------- /treetops-core/src/test/java/io/github/horoc/treetops/core/validation/BreastCancerModelTest.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class BreastCancerModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {-0.20656117887535716, 0.2863110515326301, -0.13712355201435386, -0.2792598864377929, 1.0133758820201568, 0.8065563120018667, 0.699320480275225, 0.8460646517821447, 16 | 1.1112791563311162, 1.4817350698939018, -0.05259361139069434, -0.5193621584675753, 0.11234262958699215, -0.14668713819749174, -0.5423482925856186, -0.15806337722209013, 17 | 0.08707974741228168, 0.250429487524521, -0.4228423090304796, 0.0794691444911607, 0.029159331082772376, 0.6485704748522869, 0.1798703441573848, -0.06360678115852603, 1.0972739926049955, 18 | 0.835473817212997, 1.1437848605273928, 1.3779123052290023, 1.1069571429479146, 1.4936880726625947}; 19 | } 20 | 21 | @Test 22 | public void testPredictByGeneratedClass() { 23 | try { 24 | Predictor predictor = loadModel("breast_cancer_model", true); 25 | double[] result = predictor.predict(getFeature()); 26 | Assertions.assertEquals(0.03825810017556627D, result[0]); 27 | } catch (Throwable e) { 28 | Assertions.fail(e); 29 | } 30 | } 31 | 32 | @Test 33 | public void testSimplePredictor() { 34 | try { 35 | Predictor predictor = loadModel("breast_cancer_model", false); 36 | double[] result = predictor.predict(getFeature()); 37 | Assertions.assertEquals(0.03825810017556627D, result[0]); 38 | } catch (Throwable e) { 39 | Assertions.fail(e); 40 | } 41 | } 42 | 43 | @Test 44 | public void testNanFeature() { 45 | try { 46 | Predictor predictor = loadModel("breast_cancer_model", false); 47 | double[] features = getFeature(); 48 | features[0] = Double.NaN; 49 | features[27] = Double.NaN; 50 | double[] result = predictor.predict(features); 51 | Assertions.assertEquals(0.1716604128665585D, result[0]); 52 | } catch (Throwable e) { 53 | Assertions.fail(e); 54 | } 55 | } 56 | 57 | @Test 58 | public void testNanFeatureByGeneratedClass() { 59 | try { 60 | Predictor predictor = loadModel("breast_cancer_model", true); 61 | double[] features = getFeature(); 62 | features[0] = Double.NaN; 63 | features[27] = Double.NaN; 64 | double[] result = predictor.predict(features); 65 | Assertions.assertEquals(0.1716604128665585D, result[0]); 66 | } catch (Throwable e) { 67 | Assertions.fail(e); 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto execute 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto execute 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :execute 68 | @rem Setup the command line 69 | 70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 71 | 72 | 73 | @rem Execute Gradle 74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 75 | 76 | :end 77 | @rem End local scope for the variables with windows NT shell 78 | if "%ERRORLEVEL%"=="0" goto mainEnd 79 | 80 | :fail 81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 82 | rem the _cmd.exe /c_ return code! 83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 84 | exit /b 1 85 | 86 | :mainEnd 87 | if "%OS%"=="Windows_NT" endlocal 88 | 89 | :omega 90 | -------------------------------------------------------------------------------- /treetops-core/src/test/java/io/github/horoc/treetops/core/validation/WineModelTest.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.validation; 2 | 3 | import io.github.horoc.treetops.core.predictor.Predictor; 4 | import org.junit.jupiter.api.Assertions; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * @author chenzhou@apache.org 9 | * created on 2023/2/18 10 | */ 11 | public class WineModelTest extends LoadModelTemplate { 12 | 13 | @Override 14 | protected double[] getFeature() { 15 | return new double[] {0.913332708127694, -0.5981563241524619, -0.42590882302929406, -0.9293651811915726, 1.2819851519177938, 0.4885310849747506, 0.874184282556729, -1.2236095387553816, 16 | 0.050987616822829956, 0.34255654624083676, -0.16430337010617904, 0.830960739456274, 0.9970864580546609}; 17 | } 18 | 19 | @Test 20 | public void testPredictByGeneratedClass() { 21 | try { 22 | Predictor predictor = loadModel("wine_model", true); 23 | double[] result = predictor.predict(getFeature()); 24 | assertDoubleEquals(0.9849612333276241D, result[0]); 25 | assertDoubleEquals(0.008531186707393178D, result[1]); 26 | assertDoubleEquals(0.006507579964982725D, result[2]); 27 | } catch (Throwable e) { 28 | Assertions.fail(e); 29 | } 30 | } 31 | 32 | @Test 33 | public void testSimplePredictor() { 34 | try { 35 | Predictor predictor = loadModel("wine_model", false); 36 | double[] result = predictor.predict(getFeature()); 37 | assertDoubleEquals(0.9849612333276241D, result[0]); 38 | assertDoubleEquals(0.008531186707393178D, result[1]); 39 | assertDoubleEquals(0.006507579964982725D, result[2]); 40 | } catch (Throwable e) { 41 | Assertions.fail(e); 42 | } 43 | } 44 | 45 | @Test 46 | public void testNanFeature() { 47 | try { 48 | Predictor predictor = loadModel("wine_model", false); 49 | double[] features = getFeature(); 50 | features[2] = Double.NaN; 51 | features[7] = Double.NaN; 52 | double[] result = predictor.predict(features); 53 | assertDoubleEquals(0.9838801475136683D, result[0]); 54 | assertDoubleEquals(0.009129569672519997D, result[1]); 55 | assertDoubleEquals(0.006990282813811561D, result[2]); 56 | } catch (Throwable e) { 57 | Assertions.fail(e); 58 | } 59 | } 60 | 61 | @Test 62 | public void testNanFeatureByGeneratedClass() { 63 | try { 64 | Predictor predictor = loadModel("wine_model", true); 65 | double[] features = getFeature(); 66 | features[2] = Double.NaN; 67 | features[7] = Double.NaN; 68 | double[] result = predictor.predict(features); 69 | assertDoubleEquals(0.9838801475136683D, result[0]); 70 | assertDoubleEquals(0.009129569672519997D, result[1]); 71 | assertDoubleEquals(0.006990282813811561D, result[2]); 72 | } catch (Throwable e) { 73 | Assertions.fail(e); 74 | } 75 | } 76 | 77 | private void assertDoubleEquals(double expected, double actual) { 78 | Assertions.assertTrue(Math.abs(expected - actual) < 1e-35f); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/factory/ObjectiveDecoratorFactory.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.factory; 2 | 3 | import io.github.horoc.treetops.core.model.TreeModel; 4 | import io.github.horoc.treetops.core.predictor.Predictor; 5 | import io.github.horoc.treetops.core.predictor.objective.AbstractOutputConvertor; 6 | import io.github.horoc.treetops.core.predictor.objective.BinaryObjectiveDecorator; 7 | import io.github.horoc.treetops.core.predictor.objective.CrossEntropyLambdaObjectiveDecorator; 8 | import io.github.horoc.treetops.core.predictor.objective.CrossEntropyObjectiveConvertor; 9 | import io.github.horoc.treetops.core.predictor.objective.MultiClassObjectiveDecorator; 10 | import io.github.horoc.treetops.core.predictor.objective.RegressionL2ObjectiveDecorator; 11 | import io.github.horoc.treetops.core.predictor.objective.RegressionObjectiveDecorator; 12 | import java.util.HashMap; 13 | import java.util.Map; 14 | import javax.annotation.ParametersAreNonnullByDefault; 15 | import javax.annotation.concurrent.ThreadSafe; 16 | 17 | /** 18 | * Predictor objective function decorate factory. 19 | *
20 | * 21 | * @author chenzhou@apache.org 22 | * created on 2023/2/14 23 | */ 24 | @ThreadSafe 25 | @ParametersAreNonnullByDefault 26 | public class ObjectiveDecoratorFactory { 27 | 28 | private static final Map> CONVERTORS = new HashMap<>(); 29 | 30 | /** 31 | * Any new objective convertor should register into the container first 32 | */ 33 | static { 34 | CONVERTORS.put("binary", BinaryObjectiveDecorator.class); 35 | CONVERTORS.put("regression", RegressionL2ObjectiveDecorator.class); 36 | CONVERTORS.put("regression_l1", RegressionL2ObjectiveDecorator.class); 37 | CONVERTORS.put("quantile", RegressionL2ObjectiveDecorator.class); 38 | CONVERTORS.put("huber", RegressionL2ObjectiveDecorator.class); 39 | CONVERTORS.put("fair", RegressionL2ObjectiveDecorator.class); 40 | CONVERTORS.put("mape", RegressionL2ObjectiveDecorator.class); 41 | CONVERTORS.put("poisson", RegressionObjectiveDecorator.class); 42 | CONVERTORS.put("gamma", RegressionObjectiveDecorator.class); 43 | CONVERTORS.put("tweedie", RegressionObjectiveDecorator.class); 44 | CONVERTORS.put("multiclass", MultiClassObjectiveDecorator.class); 45 | CONVERTORS.put("cross_entropy", CrossEntropyObjectiveConvertor.class); 46 | CONVERTORS.put("cross_entropy_lambda", CrossEntropyLambdaObjectiveDecorator.class); 47 | } 48 | 49 | /** 50 | * Get decorated predictor instance based on tree model config. 51 | * 52 | * @param predictor raw predictor 53 | * @param treeModel tree model 54 | * @return decorated predictor 55 | * @throws Exception throw exception when getting convertor instance failed 56 | */ 57 | static Predictor decoratePredictorByObjectiveType(final Predictor predictor, final TreeModel treeModel) throws Exception { 58 | Class convertorClass = CONVERTORS.get(treeModel.getObjectiveType()); 59 | if (convertorClass == null) { 60 | throw new RuntimeException(String.format("unsupported objective type : %s", treeModel.getObjectiveType())); 61 | } 62 | 63 | AbstractOutputConvertor convertor = convertorClass.newInstance(); 64 | return convertor.decorate(predictor, treeModel); 65 | } 66 | 67 | /** 68 | * Entry point for custom convertor, 69 | * new convertor class should implement {@link AbstractOutputConvertor}. 70 | * 71 | * @param type type name, should be distinct from exist convertor 72 | * @param clazz class 73 | */ 74 | public static void registerNewConvertor(final String type, final Class clazz) { 75 | CONVERTORS.putIfAbsent(type, clazz); 76 | } 77 | 78 | /** 79 | * Check if objective type is supported. 80 | * 81 | * @param type type name 82 | * @return is supported 83 | */ 84 | public static boolean isValidObjectiveType(final String type) { 85 | return CONVERTORS.containsKey(type); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/model/TreeNode.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.model; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * @author chenzhou@apache.org 7 | * created on 2023/2/14 8 | */ 9 | public class TreeNode { 10 | 11 | private int treeIndex; 12 | 13 | private int nodeIndex; 14 | 15 | private boolean isCategoryNode; 16 | 17 | private boolean isDefaultLeftDecision; 18 | 19 | private int decisionType; 20 | 21 | private List splitFeatures; 22 | 23 | private double threshold; 24 | 25 | private int catBoundaryBegin; 26 | 27 | private int catBoundaryEnd; 28 | 29 | private List catThreshold; 30 | 31 | private TreeNode leftNode; 32 | 33 | private TreeNode rightNode; 34 | 35 | private boolean isLeaf; 36 | 37 | private double leafValue; 38 | 39 | private List allNodes; 40 | 41 | public TreeNode() { 42 | } 43 | 44 | public TreeNode(int treeIndex, int nodeIndex) { 45 | this.treeIndex = treeIndex; 46 | this.nodeIndex = nodeIndex; 47 | } 48 | 49 | public TreeNode(int nodeIndex) { 50 | this.nodeIndex = nodeIndex; 51 | } 52 | 53 | public int getTreeIndex() { 54 | return treeIndex; 55 | } 56 | 57 | public void setTreeIndex(int treeIndex) { 58 | this.treeIndex = treeIndex; 59 | } 60 | 61 | public int getNodeIndex() { 62 | return nodeIndex; 63 | } 64 | 65 | public void setNodeIndex(int nodeIndex) { 66 | this.nodeIndex = nodeIndex; 67 | } 68 | 69 | public boolean isCategoryNode() { 70 | return isCategoryNode; 71 | } 72 | 73 | public void setCategoryNode(boolean categoryNode) { 74 | isCategoryNode = categoryNode; 75 | } 76 | 77 | public boolean isDefaultLeftDecision() { 78 | return isDefaultLeftDecision; 79 | } 80 | 81 | public void setDefaultLeftDecision(boolean defaultLeftDecision) { 82 | isDefaultLeftDecision = defaultLeftDecision; 83 | } 84 | 85 | public int getDecisionType() { 86 | return decisionType; 87 | } 88 | 89 | public void setDecisionType(int decisionType) { 90 | this.decisionType = decisionType; 91 | } 92 | 93 | public List getSplitFeatures() { 94 | return splitFeatures; 95 | } 96 | 97 | public void setSplitFeatures(List splitFeatures) { 98 | this.splitFeatures = splitFeatures; 99 | } 100 | 101 | public double getThreshold() { 102 | return threshold; 103 | } 104 | 105 | public void setThreshold(double threshold) { 106 | this.threshold = threshold; 107 | } 108 | 109 | public int getCatBoundaryBegin() { 110 | return catBoundaryBegin; 111 | } 112 | 113 | public void setCatBoundaryBegin(int catBoundaryBegin) { 114 | this.catBoundaryBegin = catBoundaryBegin; 115 | } 116 | 117 | public int getCatBoundaryEnd() { 118 | return catBoundaryEnd; 119 | } 120 | 121 | public void setCatBoundaryEnd(int catBoundaryEnd) { 122 | this.catBoundaryEnd = catBoundaryEnd; 123 | } 124 | 125 | public List getCatThreshold() { 126 | return catThreshold; 127 | } 128 | 129 | public void setCatThreshold(List catThreshold) { 130 | this.catThreshold = catThreshold; 131 | } 132 | 133 | public TreeNode getLeftNode() { 134 | return leftNode; 135 | } 136 | 137 | public void setLeftNode(TreeNode leftNode) { 138 | this.leftNode = leftNode; 139 | } 140 | 141 | public TreeNode getRightNode() { 142 | return rightNode; 143 | } 144 | 145 | public void setRightNode(TreeNode rightNode) { 146 | this.rightNode = rightNode; 147 | } 148 | 149 | public boolean isLeaf() { 150 | return isLeaf; 151 | } 152 | 153 | public void setLeaf(boolean leaf) { 154 | isLeaf = leaf; 155 | } 156 | 157 | public double getLeafValue() { 158 | return leafValue; 159 | } 160 | 161 | public void setLeafValue(double leafValue) { 162 | this.leafValue = leafValue; 163 | } 164 | 165 | public List getAllNodes() { 166 | return allNodes; 167 | } 168 | 169 | public void setAllNodes(List allNodes) { 170 | this.allNodes = allNodes; 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/predictor/SimplePredictor.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.predictor; 2 | 3 | import io.github.horoc.treetops.core.model.MissingType; 4 | import io.github.horoc.treetops.core.model.TreeModel; 5 | import io.github.horoc.treetops.core.model.TreeNode; 6 | 7 | /** 8 | * Simple predictor implementation, follow the LightGBM cpp implementation. 9 | * 10 | * @author chenzhou@apache.org 11 | * created on 2023/2/14 12 | */ 13 | public class SimplePredictor extends MetaDataHolder implements Predictor { 14 | 15 | private static final double K_ZERO_THRESHOLD = 1e-35f; 16 | 17 | private final TreeModel treeModel; 18 | 19 | public SimplePredictor(TreeModel treeModel) { 20 | this.treeModel = treeModel; 21 | } 22 | 23 | @Override 24 | public double[] predictRaw(double[] features) { 25 | double[] ret = new double[treeModel.getNumClass()]; 26 | for (TreeNode root : treeModel.getTrees()) { 27 | ret[root.getTreeIndex() % treeModel.getNumClass()] += this.decision(root, features); 28 | } 29 | return ret; 30 | } 31 | 32 | /** 33 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#Tree::Decision. 34 | *

35 | * 36 | * @param treeNode tree node meta data 37 | * @param features input data 38 | * @return decision value of this node 39 | */ 40 | private double decision(TreeNode treeNode, double[] features) { 41 | if (treeNode.isLeaf()) { 42 | return treeNode.getLeafValue(); 43 | } 44 | 45 | if (treeNode.isCategoryNode()) { 46 | return categoricalDecision(treeNode, features); 47 | } else { 48 | return numericalDecision(treeNode, features); 49 | } 50 | } 51 | 52 | /** 53 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#Tree::NumericalDecision. 54 | *

55 | * 56 | * @param treeNode tree node meta data 57 | * @param features input data 58 | * @return decision value of this node 59 | */ 60 | private double numericalDecision(TreeNode treeNode, double[] features) { 61 | MissingType missingType = MissingType.ofMask((treeNode.getDecisionType() >> 2) & 3); 62 | double threshold = treeNode.getThreshold(); 63 | double feature = features[treeNode.getSplitFeatures().get(treeNode.getNodeIndex())]; 64 | if (Double.isNaN(feature) && missingType != MissingType.Nan) { 65 | feature = 0.0; 66 | } 67 | 68 | boolean isZeroMiss = missingType == MissingType.Zero && isZero(feature); 69 | boolean isNanMiss = missingType == MissingType.Nan && Double.isNaN(feature); 70 | if (isZeroMiss || isNanMiss) { 71 | if (treeNode.isDefaultLeftDecision()) { 72 | return decision(treeNode.getLeftNode(), features); 73 | } else { 74 | return decision(treeNode.getRightNode(), features); 75 | } 76 | } 77 | 78 | if (feature <= threshold) { 79 | return decision(treeNode.getLeftNode(), features); 80 | } else { 81 | return decision(treeNode.getRightNode(), features); 82 | } 83 | } 84 | 85 | /** 86 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#Tree::CategoricalDecision. 87 | *

88 | * 89 | * @param treeNode tree node meta data 90 | * @param features input data 91 | * @return decision value of this node 92 | */ 93 | private double categoricalDecision(TreeNode treeNode, double[] features) { 94 | double feature = features[treeNode.getSplitFeatures().get(treeNode.getNodeIndex())]; 95 | if (Double.isNaN(feature)) { 96 | return decision(treeNode.getRightNode(), features); 97 | } else { 98 | int val = (int) feature; 99 | if (val < 0) { 100 | return decision(treeNode.getRightNode(), features); 101 | } 102 | } 103 | if (findCatBitset(treeNode.getTreeIndex(), treeNode.getCatBoundaryBegin(), 104 | treeNode.getCatBoundaryEnd() - treeNode.getCatBoundaryBegin(), feature)) { 105 | return decision(treeNode.getLeftNode(), features); 106 | } 107 | return decision(treeNode.getRightNode(), features); 108 | } 109 | 110 | private boolean isZero(double val) { 111 | return val >= -K_ZERO_THRESHOLD && val <= K_ZERO_THRESHOLD; 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/model/RawTreeBlock.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.model; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * @author chenzhou@apache.org 7 | * created on 2023/2/14 8 | */ 9 | public class RawTreeBlock { 10 | /** 11 | * match with model file field: 12 | * - Tree=%d. 13 | */ 14 | private int tree; 15 | 16 | /** 17 | * match with model file field: 18 | * - num_leaves=%d. 19 | */ 20 | private int numLeaves; 21 | 22 | /** 23 | * match with model file field: 24 | * - num_cat=%d. 25 | */ 26 | private int numCat; 27 | 28 | /** 29 | * match with model file field: 30 | * - split_feature=%f %f ... 31 | */ 32 | private List splitFeature; 33 | 34 | /** 35 | * match with model file field: 36 | * - decision_type=%d %d ... 37 | */ 38 | private List decisionType; 39 | 40 | /** 41 | * match with model file field: 42 | * - left_child=%d %d ... 43 | */ 44 | private List leftChild; 45 | 46 | /** 47 | * match with model file field: 48 | * - left_child=%d %d ... 49 | */ 50 | private List rightChild; 51 | 52 | /** 53 | * match with model file field: 54 | * - leaf_value=%f %f ... 55 | */ 56 | private List leafValue; 57 | 58 | /** 59 | * match with model file field: 60 | * - internal_value=%f %f ... 61 | */ 62 | private List internalValue; 63 | 64 | /** 65 | * match with model file field: 66 | * - threshold=%f %f ... 67 | */ 68 | private List threshold; 69 | 70 | /** 71 | * match with model file field: 72 | * - cat_boundaries=%d %d ... 73 | */ 74 | private List catBoundaries; 75 | 76 | /** 77 | * match with model file field: 78 | * - cat_threshold=%f %f ... 79 | */ 80 | private List catThreshold; 81 | 82 | public int getTree() { 83 | return tree; 84 | } 85 | 86 | public void setTree(int tree) { 87 | this.tree = tree; 88 | } 89 | 90 | public int getNumLeaves() { 91 | return numLeaves; 92 | } 93 | 94 | public void setNumLeaves(int numLeaves) { 95 | this.numLeaves = numLeaves; 96 | } 97 | 98 | public int getNumCat() { 99 | return numCat; 100 | } 101 | 102 | public void setNumCat(int numCat) { 103 | this.numCat = numCat; 104 | } 105 | 106 | public List getSplitFeature() { 107 | return splitFeature; 108 | } 109 | 110 | public void setSplitFeature(List splitFeature) { 111 | this.splitFeature = splitFeature; 112 | } 113 | 114 | public List getDecisionType() { 115 | return decisionType; 116 | } 117 | 118 | public void setDecisionType(List decisionType) { 119 | this.decisionType = decisionType; 120 | } 121 | 122 | public List getLeftChild() { 123 | return leftChild; 124 | } 125 | 126 | public void setLeftChild(List leftChild) { 127 | this.leftChild = leftChild; 128 | } 129 | 130 | public List getRightChild() { 131 | return rightChild; 132 | } 133 | 134 | public void setRightChild(List rightChild) { 135 | this.rightChild = rightChild; 136 | } 137 | 138 | public List getLeafValue() { 139 | return leafValue; 140 | } 141 | 142 | public void setLeafValue(List leafValue) { 143 | this.leafValue = leafValue; 144 | } 145 | 146 | public List getInternalValue() { 147 | return internalValue; 148 | } 149 | 150 | public void setInternalValue(List internalValue) { 151 | this.internalValue = internalValue; 152 | } 153 | 154 | public List getThreshold() { 155 | return threshold; 156 | } 157 | 158 | public void setThreshold(List threshold) { 159 | this.threshold = threshold; 160 | } 161 | 162 | public List getCatBoundaries() { 163 | return catBoundaries; 164 | } 165 | 166 | public void setCatBoundaries(List catBoundaries) { 167 | this.catBoundaries = catBoundaries; 168 | } 169 | 170 | public List getCatThreshold() { 171 | return catThreshold; 172 | } 173 | 174 | public void setCatThreshold(List catThreshold) { 175 | this.catThreshold = catThreshold; 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍃 Treetops 2 | 3 | ![](https://img.shields.io/badge/License-MIT%20-green.svg?label=license) 4 | [![Continuous Integration](https://github.com/horoc/treetops/actions/workflows/gradle.yml/badge.svg?branch=master)](https://github.com/horoc/treetops/actions/workflows/gradle.yml) 5 | GitHub Stars 6 | Java support 7 | GitHub repo size 8 | 9 | > 🚀 Fast LightGBM tree model inference lib which is based on ASM dynamic code generation framework. 10 | > 11 | > 🐒 Easy to integrate with Java project, no need to install extra libs. 12 | > 13 | > 💯 Fully implemented by Java, easy to customize your own functionality. 14 | 15 | ---------------------------------------- 16 | 17 | ## How Fast 18 | 19 |
20 | 21 |
22 | 23 | ### Specification 24 | 25 | - Asm: Generated lightGBM predictor based on ASM framework. 26 | - Simple: Tree-based data structure predictor which follows the LightGBM cpp official implementation. 27 | 28 | ### Test Models 29 | 30 | - bc: model trained by breast cancer dataset (100 trees, binary) 31 | - ch: model trained by california housing dataset (100 trees, regression) 32 | - db: model trained by diabetes dataset (100 trees, regression) 33 | - wn: model trained by wine dataset (300 trees, classification) 34 | 35 | ### Environment 36 | 37 | - CPU: Intel Xeon Cooper Lake 3.4GHz, 4C 38 | - Memory: 8G 39 | - JDK: zulu JDK 8 40 | - Tools: Jmh, test config see `AverageTimeBenchmarkTemplate.java` 41 | 42 | ## Quick Start 43 | 44 | ### Dependencies 45 | 46 | Currently, we have published the `0.1.0` version. 47 | 48 | Add dependencies: 49 | 50 | ``` 51 | implementation 'io.github.horoc:treetops-core:0.1.0' 52 | ``` 53 | 54 | ### Model 55 | 56 | You can train your own model or just use the test model files for testing. 57 | 58 | > Test model path: `treetops-core/src/test/resources` 59 | 60 | ### API 61 | 62 | ```java 63 | // model name must only contain character: [a-zA-z0-9_] 64 | Predictor predictor = TreePredictorFactory.newInstance("your_model_name_v0", filePathOfYourModel); 65 | predictor.predict(features); 66 | ``` 67 | 68 | If you want to save the generated class file and have a look, you can specify the save path. 69 | 70 | ```java 71 | Predictor predictor = TreePredictorFactory.newInstance("your_model_name_v0", filePathOfYourModel, pathToSaveYourClass); 72 | predictor.predict(features); 73 | ``` 74 | 75 | If you want to compare with the simple predictor (disable genernation), you can create predictor by: 76 | ```java 77 | Predictor predictor = TreePredictorFactory.newInstance("your_model_name_v0", "", false); 78 | predictor.predict(features); 79 | ``` 80 | 81 | ## Core Idea 82 | 83 | **What treetops mainly do is translate the model file into a hardcode class instead of storing it in a tree-based data structure, and that's the core idea of treetops.** 84 | 85 | ### Example 86 | 87 | For example, the following configuration is one of the trees in a model. 88 | 89 | ``` 90 | Tree=0 91 | num_leaves=4 92 | num_cat=0 93 | split_feature=1 2 2 94 | split_gain=0.568011 0.483606 0.45669 95 | threshold=0.73144941452196321 0.90708366268745222 0.85551601478390116 96 | decision_type=2 2 2 97 | left_child=1 -1 -2 98 | right_child=2 -3 -4 99 | leaf_value=0.49510661266514339 0.50645382200299838 0.50688948369558862 0.49040602357823876 100 | leaf_weight=326 114 39 21 101 | leaf_count=326 114 39 21 102 | internal_value=0.498415 0.496366 0.503957 103 | internal_weight=0 365 135 104 | internal_count=500 365 135 105 | is_linear=0 106 | shrinkage=1 107 | ``` 108 | 109 | The output decision value of this tree is based on every internal and leaf node's split strategy and value. 110 | 111 | According to the config, we can see there are three internal nodes and four leave nodes. If we store this tree in a tree-based data structure, we would need to iterate from the root to the leaves to make a decision. However, this process can result in lots of memory accesses and function calls. This also affects the hit rate of the CPU instruction cache. 112 | 113 | If we hardcode the tree structure, we could optimize this overhead. 114 | 115 | A generated tree decision function is like this: 116 | 117 | ``` 118 | private tree_0([D)D 119 | ... ... 120 | IFEQ L0 121 | GOTO L1 122 | L0 123 | FRAME APPEND [D] 124 | DLOAD 2 125 | LDC 0.7314494145219632 126 | DCMPG 127 | IFGE L2 128 | GOTO L1 129 | L1 130 | 131 | ... ... 132 | 133 | L8 134 | FRAME SAME 135 | LDC 0.49040602357823876 136 | DRETURN 137 | MAXSTACK = 4 138 | MAXLOCALS = 4 139 | ``` 140 | 141 | the corresponding java code : 142 | 143 | ```java 144 | private double tree_0(double[] var1) { 145 | double var2 = var1[1]; 146 | if (var2 == var2 && !(var2 < 0.7314494145219632D)) { 147 | var2 = var1[2]; 148 | return var2 == var2 && !(var2 < 0.8555160147839012D) ? 0.49040602357823876D : 0.5064538220029984D; 149 | } else { 150 | var2 = var1[2]; 151 | return var2 == var2 && !(var2 < 0.9070836626874522D) ? 0.5068894836955886D : 0.4951066126651434D; 152 | } 153 | } 154 | ``` 155 | 156 | we precompute the decision rules and use conditional statements to evaluate input features based on the tree config. This can be particularly advantageous for large trees, as dynamic traversal can become a significant bottleneck. 157 | 158 | ## Author 159 | 160 | * **Chen Zhou** - *Initial work* - [ChenZhou](https://github.com/horoc) 161 | 162 | See also the list of [contributors](https://github.com/horoc/treetops/contributors) who participated in this project. 163 | 164 | ## License 165 | 166 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details 167 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # 4 | # Copyright 2015 the original author or authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | ## 21 | ## Gradle start up script for UN*X 22 | ## 23 | ############################################################################## 24 | 25 | # Attempt to set APP_HOME 26 | # Resolve links: $0 may be a link 27 | PRG="$0" 28 | # Need this for relative symlinks. 29 | while [ -h "$PRG" ] ; do 30 | ls=`ls -ld "$PRG"` 31 | link=`expr "$ls" : '.*-> \(.*\)$'` 32 | if expr "$link" : '/.*' > /dev/null; then 33 | PRG="$link" 34 | else 35 | PRG=`dirname "$PRG"`"/$link" 36 | fi 37 | done 38 | SAVED="`pwd`" 39 | cd "`dirname \"$PRG\"`/" >/dev/null 40 | APP_HOME="`pwd -P`" 41 | cd "$SAVED" >/dev/null 42 | 43 | APP_NAME="Gradle" 44 | APP_BASE_NAME=`basename "$0"` 45 | 46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 48 | 49 | # Use the maximum available, or set MAX_FD != -1 to use that value. 50 | MAX_FD="maximum" 51 | 52 | warn () { 53 | echo "$*" 54 | } 55 | 56 | die () { 57 | echo 58 | echo "$*" 59 | echo 60 | exit 1 61 | } 62 | 63 | # OS specific support (must be 'true' or 'false'). 64 | cygwin=false 65 | msys=false 66 | darwin=false 67 | nonstop=false 68 | case "`uname`" in 69 | CYGWIN* ) 70 | cygwin=true 71 | ;; 72 | Darwin* ) 73 | darwin=true 74 | ;; 75 | MSYS* | MINGW* ) 76 | msys=true 77 | ;; 78 | NONSTOP* ) 79 | nonstop=true 80 | ;; 81 | esac 82 | 83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 84 | 85 | 86 | # Determine the Java command to use to start the JVM. 87 | if [ -n "$JAVA_HOME" ] ; then 88 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 89 | # IBM's JDK on AIX uses strange locations for the executables 90 | JAVACMD="$JAVA_HOME/jre/sh/java" 91 | else 92 | JAVACMD="$JAVA_HOME/bin/java" 93 | fi 94 | if [ ! -x "$JAVACMD" ] ; then 95 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 96 | 97 | Please set the JAVA_HOME variable in your environment to match the 98 | location of your Java installation." 99 | fi 100 | else 101 | JAVACMD="java" 102 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 103 | 104 | Please set the JAVA_HOME variable in your environment to match the 105 | location of your Java installation." 106 | fi 107 | 108 | # Increase the maximum file descriptors if we can. 109 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 110 | MAX_FD_LIMIT=`ulimit -H -n` 111 | if [ $? -eq 0 ] ; then 112 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 113 | MAX_FD="$MAX_FD_LIMIT" 114 | fi 115 | ulimit -n $MAX_FD 116 | if [ $? -ne 0 ] ; then 117 | warn "Could not set maximum file descriptor limit: $MAX_FD" 118 | fi 119 | else 120 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 121 | fi 122 | fi 123 | 124 | # For Darwin, add options to specify how the application appears in the dock 125 | if $darwin; then 126 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 127 | fi 128 | 129 | # For Cygwin or MSYS, switch paths to Windows format before running java 130 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then 131 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 132 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 133 | 134 | JAVACMD=`cygpath --unix "$JAVACMD"` 135 | 136 | # We build the pattern for arguments to be converted via cygpath 137 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 138 | SEP="" 139 | for dir in $ROOTDIRSRAW ; do 140 | ROOTDIRS="$ROOTDIRS$SEP$dir" 141 | SEP="|" 142 | done 143 | OURCYGPATTERN="(^($ROOTDIRS))" 144 | # Add a user-defined pattern to the cygpath arguments 145 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 146 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 147 | fi 148 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 149 | i=0 150 | for arg in "$@" ; do 151 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 152 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 153 | 154 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 155 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 156 | else 157 | eval `echo args$i`="\"$arg\"" 158 | fi 159 | i=`expr $i + 1` 160 | done 161 | case $i in 162 | 0) set -- ;; 163 | 1) set -- "$args0" ;; 164 | 2) set -- "$args0" "$args1" ;; 165 | 3) set -- "$args0" "$args1" "$args2" ;; 166 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;; 167 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 168 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 169 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 170 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 171 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 172 | esac 173 | fi 174 | 175 | # Escape application args 176 | save () { 177 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 178 | echo " " 179 | } 180 | APP_ARGS=`save "$@"` 181 | 182 | # Collect all arguments for the java command, following the shell quoting and substitution rules 183 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 184 | 185 | exec "$JAVACMD" "$@" 186 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/factory/TreePredictorFactory.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.factory; 2 | 3 | import io.github.horoc.treetops.core.generator.PredictorClassGenerator; 4 | import io.github.horoc.treetops.core.loader.AbstractLoader; 5 | import io.github.horoc.treetops.core.loader.FileTreeModelLoader; 6 | import io.github.horoc.treetops.core.model.TreeModel; 7 | import io.github.horoc.treetops.core.predictor.MetaDataHolder; 8 | import io.github.horoc.treetops.core.predictor.Predictor; 9 | import io.github.horoc.treetops.core.predictor.PredictorWrapper; 10 | import io.github.horoc.treetops.core.predictor.SimplePredictor; 11 | import java.io.File; 12 | import java.io.FileOutputStream; 13 | import java.lang.ref.WeakReference; 14 | import java.util.HashMap; 15 | import java.util.Map; 16 | import java.util.regex.Pattern; 17 | import javax.annotation.ParametersAreNonnullByDefault; 18 | import javax.annotation.concurrent.ThreadSafe; 19 | import org.apache.commons.lang3.StringUtils; 20 | 21 | /** 22 | * Core factory of tree predictor, user should only get predictor instance through this factory. 23 | * 24 | * @author chenzhou@apache.org 25 | * created on 2023/2/14 26 | */ 27 | @ThreadSafe 28 | @ParametersAreNonnullByDefault 29 | public class TreePredictorFactory { 30 | 31 | /** 32 | * Predictor instance cache pool. 33 | */ 34 | private static final Map> PREDICTORS = new HashMap<>(); 35 | 36 | private static final Pattern MODEL_NAME_PATTERN = Pattern.compile("^[a-zA-Z0-9_]+$"); 37 | 38 | /** 39 | * Tree model loader. 40 | */ 41 | private static AbstractLoader treeModelLoader = FileTreeModelLoader.getInstance(); 42 | 43 | /** 44 | * Due to the limitation of asm framework, generated class can not be too large, 45 | * thus when the tree nums of model larger than threshold, 46 | * will use {@link SimplePredictor} as predictor implementation. 47 | *

48 | * Default threshold is 300, which can meet the requirement of most business scenarios. 49 | */ 50 | private static int asmGenerationTreeNumsThreshold = 300; 51 | 52 | /** 53 | * Entry point for user custom model loader, default model loader is {@link FileTreeModelLoader}
54 | * User custom model loader should implement {@link AbstractLoader}. 55 | * 56 | * @param loader custom loader 57 | */ 58 | public static synchronized void setTreeModelLoader(final AbstractLoader loader) { 59 | treeModelLoader = loader; 60 | } 61 | 62 | /** 63 | * Entry point for user to custom tree nums threshold.
64 | * To be careful, the value should not be too large, since asm has a limitation of class size. 65 | * Can see the detail {@link org.objectweb.asm.ClassWriter#toByteArray()}. 66 | * 67 | * @param threshold custom value 68 | */ 69 | public static synchronized void setGenerationTreeNumsThreshold(final int threshold) { 70 | asmGenerationTreeNumsThreshold = threshold; 71 | } 72 | 73 | /** 74 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}. 75 | * 76 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_] 77 | * @param resource resource path, if using default model loader, it means file path 78 | * @return Predictor instance 79 | */ 80 | public static synchronized Predictor newInstance(final String modelName, final String resource) { 81 | return newInstance(modelName, resource, null); 82 | } 83 | 84 | /** 85 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}. 86 | * 87 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_] 88 | * @param resource resource path, if using default model loader, it means file path 89 | * @param enableGeneration if enableGeneration is false, will always get {@link SimplePredictor} instance 90 | * @return Predictor instance 91 | */ 92 | public static synchronized Predictor newInstance(final String modelName, final String resource, boolean enableGeneration) { 93 | return newInstance(modelName, resource, null, enableGeneration); 94 | } 95 | 96 | /** 97 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}. 98 | * 99 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_] 100 | * @param resource resource path, if using default model loader, it means file path 101 | * @param saveClassFileDir generated Predictor class save path, can be null if it's not necessary 102 | * @return Predictor instance 103 | */ 104 | public static synchronized Predictor newInstance(final String modelName, final String resource, final String saveClassFileDir) { 105 | return newInstance(modelName, resource, saveClassFileDir, true); 106 | } 107 | 108 | /** 109 | * Create predictor from resource:
110 | * 1. load model data from resource path.
111 | * 2. parse to {@link TreeModel} instance.
112 | * 3. generate {@link Predictor} based on model detail. 113 | * 114 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_] 115 | * @param resource resource path, if using default model loader, it means file path 116 | * @param saveClassFileDir generated Predictor class save path, can be null if it's not necessary 117 | * @param enableGeneration if enableGeneration is false, will always get {@link SimplePredictor} instance 118 | * @return Predictor instance 119 | */ 120 | public static synchronized Predictor newInstance(final String modelName, final String resource, final String saveClassFileDir, 121 | boolean enableGeneration) { 122 | checkModelName(modelName); 123 | 124 | String className = toClassName(modelName); 125 | // predictor which is no longer used will be removed after gc 126 | if (PREDICTORS.containsKey(className)) { 127 | Predictor predictor = PREDICTORS.get(className).get(); 128 | if (predictor != null) { 129 | return predictor; 130 | } else { 131 | PREDICTORS.remove(className); 132 | } 133 | } 134 | try { 135 | TreeModel treeModel = treeModelLoader.loadModel(resource); 136 | Predictor predictor; 137 | // if model is too large, downgrade to simple predictor implementation 138 | if (!enableGeneration || treeModel.getTrees().size() > asmGenerationTreeNumsThreshold) { 139 | predictor = new SimplePredictor(treeModel); 140 | } else { 141 | // new class loader to do class generation 142 | PredictorClassGenerator generator = PredictorClassGenerator.getInstance(); 143 | byte[] bytes = generator.generateCode(className, treeModel); 144 | if (StringUtils.isNotBlank(saveClassFileDir)) { 145 | saveClass(bytes, className, saveClassFileDir); 146 | } 147 | Object targetObj = generator.defineClassFromCode(className, bytes).newInstance(); 148 | predictor = (Predictor) targetObj; 149 | } 150 | 151 | // init meta data if need 152 | if (predictor instanceof MetaDataHolder) { 153 | ((MetaDataHolder) predictor).initialize(treeModel); 154 | } 155 | 156 | // objective decorate 157 | Predictor objectivePredictor = ObjectiveDecoratorFactory.decoratePredictorByObjectiveType(predictor, treeModel); 158 | 159 | // wrapper 160 | Predictor predictorWrapper = new PredictorWrapper(objectivePredictor, treeModel); 161 | 162 | PREDICTORS.put(className, new WeakReference<>(predictorWrapper)); 163 | return predictorWrapper; 164 | } catch (Throwable e) { 165 | throw new RuntimeException(String.format("fail to generate predict instance, modelName: %s", modelName), e); 166 | } 167 | } 168 | 169 | /** 170 | * Clean predictor reference. 171 | * 172 | * @param predictor should be PredictorWrapper instance which created by factory 173 | */ 174 | private static synchronized void releasePredictor(Predictor predictor) { 175 | if (predictor instanceof PredictorWrapper) { 176 | ((PredictorWrapper) predictor).release(); 177 | } 178 | } 179 | 180 | /** 181 | * Save raw bytecode into file. 182 | * 183 | * @param bytes class bytecode 184 | * @param className class name 185 | * @param dir save path 186 | * @throws Exception 187 | */ 188 | private static void saveClass(final byte[] bytes, final String className, final String dir) throws Exception { 189 | String fileName = customClassFilePath(className, dir); 190 | 191 | File file = new File(fileName); 192 | file.getParentFile().mkdirs(); 193 | file.createNewFile(); 194 | 195 | try (FileOutputStream fos = new FileOutputStream(fileName)) { 196 | fos.write(bytes); 197 | } 198 | } 199 | 200 | private static String customClassFilePath(final String className, final String dir) { 201 | return StringUtils.join(dir, File.separator, StringUtils.join(className.split("\\."), File.separator), ".class"); 202 | } 203 | 204 | /** 205 | * Generated class will start with '_'. 206 | * 207 | * @param modelName model name 208 | * @return class name 209 | */ 210 | private static String toClassName(final String modelName) { 211 | return Predictor.PREDICTOR_CLASS_PREFIX + "._" + modelName; 212 | } 213 | 214 | private static void checkModelName(final String modelName) { 215 | if (!MODEL_NAME_PATTERN.matcher(modelName).matches()) { 216 | throw new IllegalArgumentException(String.format("illegal model name: %s, valid character: [a-zA-z0-9_]", modelName)); 217 | } 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /checkstyle/checkstyle.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 170 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/generator/PredictorClassGenerator.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.generator; 2 | 3 | import io.github.horoc.treetops.core.model.MissingType; 4 | import io.github.horoc.treetops.core.model.TreeModel; 5 | import io.github.horoc.treetops.core.model.TreeNode; 6 | import java.util.Map; 7 | import java.util.stream.Collectors; 8 | import org.apache.commons.lang3.StringUtils; 9 | import org.objectweb.asm.ClassVisitor; 10 | import org.objectweb.asm.ClassWriter; 11 | import org.objectweb.asm.Label; 12 | import org.objectweb.asm.MethodVisitor; 13 | import org.objectweb.asm.Opcodes; 14 | import org.objectweb.asm.util.CheckClassAdapter; 15 | 16 | /** 17 | * Predictor class generator based on asm framework. 18 | *
19 | * 20 | * @author chenzhou@apache.org 21 | * created on 2023/2/14 22 | */ 23 | public final class PredictorClassGenerator extends ClassLoader implements Generator, Opcodes { 24 | 25 | private static final int FEATURE_PARAMETER_INDEX = 1; 26 | 27 | private static final double K_ZERO_THRESHOLD = 1e-35f; 28 | 29 | private static final String INIT = ""; 30 | 31 | private static final String TREE_METHOD_PREFIX = "tree_"; 32 | 33 | private static final String OBJECT_INTERNAL_NAME = "java/lang/Object"; 34 | 35 | private static final String PREDICTOR_INTERNAL_NAME = "io/github/horoc/treetops/core/predictor/Predictor"; 36 | 37 | private static final String META_DATA_HOLDER_INTERNAL_NAME = "io/github/horoc/treetops/core/predictor/MetaDataHolder"; 38 | 39 | private static final String PREDICT_METHOD = "predictRaw"; 40 | 41 | private static final String FIND_CAT_BIT_SET_METHOD = "findCatBitset"; 42 | 43 | private PredictorClassGenerator() { 44 | } 45 | 46 | /** 47 | * We can not maintain a singleton instance of generator here,
48 | * since we want jvm to unload class which would be no longer used. 49 | * 50 | * @return predictor instance 51 | */ 52 | public static PredictorClassGenerator getInstance() { 53 | return new PredictorClassGenerator(); 54 | } 55 | 56 | @Override 57 | public Class defineClassFromCode(final String className, final byte[] code) { 58 | return this.defineClass(className, code, 0, code.length); 59 | } 60 | 61 | @Override 62 | public byte[] generateCode(final String className, final TreeModel model) { 63 | ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES); 64 | ClassVisitor cv = new CheckClassAdapter(cw); 65 | String internalClassName = toInternalName(className); 66 | 67 | // define class 68 | cv.visit(V1_8, ACC_PUBLIC | ACC_SUPER, toInternalName(internalClassName), null, getSuperName(model), 69 | new String[] {PREDICTOR_INTERNAL_NAME}); 70 | 71 | // define init method 72 | addInitMethod(cv, model); 73 | 74 | // tree decision method 75 | // description : private double tree_[%tree_index](double[] features); 76 | model.getTrees().forEach(t -> addTreeMethod(cv, internalClassName, t)); 77 | 78 | // prediction method 79 | addPredictionMethod(cv, internalClassName, model); 80 | 81 | cv.visitEnd(); 82 | return cw.toByteArray(); 83 | } 84 | 85 | /** 86 | * Define init method. 87 | * 88 | * @param cv class visitor 89 | * @param model model config 90 | */ 91 | private void addInitMethod(ClassVisitor cv, final TreeModel model) { 92 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PUBLIC, INIT, "()V"); 93 | methodVisitor.visitCode(); 94 | methodVisitor.visitVarInsn(ALOAD, 0); 95 | if (model.isContainsCatNode()) { 96 | methodVisitor.visitMethodInsn(INVOKESPECIAL, META_DATA_HOLDER_INTERNAL_NAME, INIT, "()V", false); 97 | } else { 98 | methodVisitor.visitMethodInsn(INVOKESPECIAL, OBJECT_INTERNAL_NAME, INIT, "()V", false); 99 | } 100 | methodVisitor.visitInsn(RETURN); 101 | methodVisitor.visitMaxs(1, 1); 102 | methodVisitor.visitEnd(); 103 | } 104 | 105 | private void addPredictionMethod(ClassVisitor cv, final String className, final TreeModel model) { 106 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PUBLIC, PREDICT_METHOD, "([D)[D"); 107 | methodVisitor.visitCode(); 108 | 109 | methodVisitor.visitLdcInsn(model.getNumClass()); 110 | methodVisitor.visitIntInsn(NEWARRAY, T_DOUBLE); 111 | methodVisitor.visitVarInsn(ASTORE, 2); 112 | 113 | for (int i = 0; i < model.getTrees().size(); i++) { 114 | TreeNode root = model.getTrees().get(i); 115 | methodVisitor.visitVarInsn(ALOAD, 2); 116 | methodVisitor.visitLdcInsn(root.getTreeIndex() % model.getNumClass()); 117 | methodVisitor.visitInsn(DUP2); 118 | methodVisitor.visitInsn(DALOAD); 119 | methodVisitor.visitVarInsn(ALOAD, 0); 120 | methodVisitor.visitVarInsn(ALOAD, 1); 121 | methodVisitor.visitMethodInsn(INVOKESPECIAL, className, TREE_METHOD_PREFIX + root.getTreeIndex(), "([D)D", false); 122 | methodVisitor.visitInsn(DADD); 123 | methodVisitor.visitInsn(DASTORE); 124 | } 125 | 126 | methodVisitor.visitVarInsn(ALOAD, 2); 127 | methodVisitor.visitInsn(ARETURN); 128 | methodVisitor.visitMaxs(1, 1); 129 | methodVisitor.visitEnd(); 130 | } 131 | 132 | private void addTreeMethod(ClassVisitor cv, final String className, final TreeNode root) { 133 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PRIVATE, TREE_METHOD_PREFIX + root.getTreeIndex(), "([D)D"); 134 | methodVisitor.visitCode(); 135 | 136 | Map labels = root.getAllNodes().stream().collect(Collectors.toMap(TreeNode::getNodeIndex, o -> new Label())); 137 | root.getAllNodes().forEach(node -> defineNodeBlock(methodVisitor, node, className, labels)); 138 | 139 | methodVisitor.visitMaxs(1, 1); 140 | methodVisitor.visitEnd(); 141 | } 142 | 143 | private void defineNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final String className, 144 | final Map labels) { 145 | if (node.isLeaf()) { 146 | defineLeafNodeBlock(methodVisitor, node, labels); 147 | return; 148 | } 149 | 150 | if (node.isCategoryNode()) { 151 | defineCategoryNodeBlock(methodVisitor, node, className, labels); 152 | } else { 153 | defineNumericalNodeBlock(methodVisitor, node, labels); 154 | } 155 | } 156 | 157 | private void defineLeafNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final Map labels) { 158 | int nodeIndex = node.getNodeIndex(); 159 | methodVisitor.visitLabel(labels.get(nodeIndex)); 160 | methodVisitor.visitLdcInsn(new Double(node.getLeafValue())); 161 | methodVisitor.visitInsn(DRETURN); 162 | } 163 | 164 | @SuppressWarnings("Duplicates") 165 | private void defineNumericalNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final Map labels) { 166 | int nodeIndex = node.getNodeIndex(); 167 | methodVisitor.visitLabel(labels.get(nodeIndex)); 168 | 169 | // load feature 170 | loadFeatureByIndex(methodVisitor, node.getSplitFeatures().get(nodeIndex)); 171 | methodVisitor.visitVarInsn(DSTORE, 2); 172 | 173 | // if missing_type != nan and feature is nan, set feature to zero 174 | MissingType missingType = MissingType.ofMask((node.getDecisionType() >> 2) & 3); 175 | if (missingType != MissingType.Nan) { 176 | methodVisitor.visitVarInsn(DLOAD, 2); 177 | methodVisitor.visitVarInsn(DLOAD, 2); 178 | methodVisitor.visitInsn(DCMPL); 179 | Label label = new Label(); 180 | // if feature is not nan, go to continue 181 | methodVisitor.visitJumpInsn(IFEQ, label); 182 | // set feature to zero 183 | methodVisitor.visitInsn(DCONST_0); 184 | methodVisitor.visitVarInsn(DSTORE, 2); 185 | // continue 186 | methodVisitor.visitLabel(label); 187 | } 188 | 189 | // if missingType == zero and feature is zero 190 | if (missingType == MissingType.Zero) { 191 | Label label = new Label(); 192 | // if feature < -1e-35, not zero, jump to continue 193 | methodVisitor.visitVarInsn(DLOAD, 2); 194 | methodVisitor.visitLdcInsn(new Double(-K_ZERO_THRESHOLD)); 195 | methodVisitor.visitInsn(DCMPL); 196 | methodVisitor.visitJumpInsn(IFLT, label); 197 | 198 | // if feature > 1e-35, not zero, jump to continue 199 | methodVisitor.visitVarInsn(DLOAD, 2); 200 | methodVisitor.visitLdcInsn(new Double(K_ZERO_THRESHOLD)); 201 | methodVisitor.visitInsn(DCMPG); 202 | methodVisitor.visitJumpInsn(IFGT, label); 203 | 204 | // if feature is zero, jump to next node 205 | if (node.isDefaultLeftDecision()) { 206 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex())); 207 | } else { 208 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex())); 209 | } 210 | 211 | // continue 212 | methodVisitor.visitLabel(label); 213 | } 214 | 215 | // if missingType == nan and feature is nan 216 | if (missingType == MissingType.Nan) { 217 | Label label = new Label(); 218 | // if feature is not nan, jump to continue 219 | methodVisitor.visitVarInsn(DLOAD, 2); 220 | methodVisitor.visitVarInsn(DLOAD, 2); 221 | methodVisitor.visitInsn(DCMPL); 222 | methodVisitor.visitJumpInsn(IFEQ, label); 223 | 224 | // if feature is nan, jump to next node 225 | if (node.isDefaultLeftDecision()) { 226 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex())); 227 | } else { 228 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex())); 229 | } 230 | 231 | // continue 232 | methodVisitor.visitLabel(label); 233 | } 234 | 235 | // compare to threshold 236 | methodVisitor.visitVarInsn(DLOAD, 2); 237 | methodVisitor.visitLdcInsn(new Double(node.getThreshold())); 238 | methodVisitor.visitInsn(DCMPG); 239 | // feature > threshold, jump to right 240 | methodVisitor.visitJumpInsn(IFGE, labels.get(node.getRightNode().getNodeIndex())); 241 | // feature <= threshold, jump to left 242 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex())); 243 | } 244 | 245 | private void defineCategoryNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final String className, 246 | final Map labels) { 247 | int nodeIndex = node.getNodeIndex(); 248 | methodVisitor.visitLabel(labels.get(nodeIndex)); 249 | 250 | // load feature 251 | loadFeatureByIndex(methodVisitor, node.getSplitFeatures().get(nodeIndex)); 252 | methodVisitor.visitVarInsn(DSTORE, 2); 253 | 254 | // if feature isNaN, jump to right child node 255 | methodVisitor.visitVarInsn(DLOAD, 2); 256 | methodVisitor.visitVarInsn(DLOAD, 2); 257 | methodVisitor.visitInsn(DCMPL); 258 | methodVisitor.visitJumpInsn(IFNE, labels.get(node.getRightNode().getNodeIndex())); 259 | 260 | // if feature < 0, jump to right child node 261 | methodVisitor.visitVarInsn(DLOAD, 2); 262 | methodVisitor.visitInsn(DCONST_0); 263 | methodVisitor.visitInsn(DCMPG); 264 | methodVisitor.visitJumpInsn(IFLT, labels.get(node.getRightNode().getNodeIndex())); 265 | 266 | // if findInBitset, jump to right child node 267 | methodVisitor.visitVarInsn(ALOAD, 0); 268 | methodVisitor.visitLdcInsn(node.getTreeIndex()); 269 | methodVisitor.visitLdcInsn(node.getCatBoundaryBegin()); 270 | methodVisitor.visitLdcInsn(node.getCatBoundaryEnd() - node.getCatBoundaryBegin()); 271 | methodVisitor.visitVarInsn(DLOAD, 2); 272 | methodVisitor.visitMethodInsn(INVOKEVIRTUAL, className, FIND_CAT_BIT_SET_METHOD, "(IIID)Z", false); 273 | methodVisitor.visitJumpInsn(IFNE, labels.get(node.getLeftNode().getNodeIndex())); 274 | 275 | // others, jump left child node 276 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex())); 277 | } 278 | 279 | private void loadFeatureByIndex(MethodVisitor methodVisitor, int index) { 280 | // load features[index] to stack 281 | methodVisitor.visitVarInsn(ALOAD, FEATURE_PARAMETER_INDEX); 282 | methodVisitor.visitLdcInsn(index); 283 | methodVisitor.visitInsn(DALOAD); 284 | } 285 | 286 | private MethodVisitor simpleVisitMethod(ClassVisitor cv, int access, final String name, 287 | final String descriptor) { 288 | return cv.visitMethod(access, name, descriptor, null, null); 289 | } 290 | 291 | private String toInternalName(final String name) { 292 | return StringUtils.join(name.split("\\."), "/"); 293 | } 294 | 295 | private String getSuperName(final TreeModel model) { 296 | if (model.isContainsCatNode()) { 297 | return META_DATA_HOLDER_INTERNAL_NAME; 298 | } else { 299 | return OBJECT_INTERNAL_NAME; 300 | } 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /treetops-core/src/main/java/io/github/horoc/treetops/core/parser/TreeModelParser.java: -------------------------------------------------------------------------------- 1 | package io.github.horoc.treetops.core.parser; 2 | 3 | import io.github.horoc.treetops.core.factory.ObjectiveDecoratorFactory; 4 | import io.github.horoc.treetops.core.model.RawTreeBlock; 5 | import io.github.horoc.treetops.core.model.TreeModel; 6 | import io.github.horoc.treetops.core.model.TreeNode; 7 | import java.util.ArrayList; 8 | import java.util.HashMap; 9 | import java.util.List; 10 | import java.util.Map; 11 | import java.util.Objects; 12 | import java.util.function.Consumer; 13 | import java.util.function.Function; 14 | import javax.annotation.Nonnull; 15 | import javax.annotation.ParametersAreNonnullByDefault; 16 | import javax.annotation.concurrent.Immutable; 17 | import javax.annotation.concurrent.ThreadSafe; 18 | import org.apache.commons.lang3.StringUtils; 19 | import org.apache.commons.lang3.Validate; 20 | 21 | /** 22 | * Parse String model info into TreeModel instance. 23 | * 24 | * @author chenzhou@apache.org 25 | * created on 2023/2/14 26 | */ 27 | @Immutable 28 | @ThreadSafe 29 | @ParametersAreNonnullByDefault 30 | public class TreeModelParser { 31 | 32 | /** 33 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#kCategoricalMask. 34 | */ 35 | private static final int CATEGORICAL_MASK = 1; 36 | 37 | /** 38 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#kDefaultLeftMask. 39 | */ 40 | private static final int DEFAULT_LEFT_MASK = 2; 41 | 42 | /** 43 | * Model file config separator. 44 | */ 45 | private static final String CONFIG_SEPARATOR = " "; 46 | 47 | /** 48 | * Parse workflow:
49 | * 1. iterator each line to find meta block header or tree block header
50 | * 2. parse different block into tree model instance
51 | * 52 | * @param rawLines raw data 53 | * @return tree model instance 54 | */ 55 | public static TreeModel parseTreeModel(@Nonnull final List rawLines) { 56 | Validate.notNull(rawLines); 57 | 58 | TreeModel treeModel = new TreeModel(); 59 | int curLineIndex = 0; 60 | while (curLineIndex < rawLines.size()) { 61 | String line = rawLines.get(curLineIndex); 62 | if (StringUtils.isBlank(line)) { 63 | curLineIndex++; 64 | continue; 65 | } 66 | // we only need tree info and meta block info 67 | if (isTreeBlockHeader(line)) { 68 | TreeNode tree = new TreeNode(); 69 | curLineIndex = initTreeBlock(tree, rawLines, curLineIndex); 70 | treeModel.getTrees().add(tree); 71 | // mark model contains category node, need to initialize threshold data later 72 | if (Objects.nonNull(tree.getCatThreshold()) && !tree.getCatThreshold().isEmpty()) { 73 | treeModel.setContainsCatNode(true); 74 | } 75 | } else if (isMetaInfoBlockHeader(line)) { 76 | // skip first line of meta info block 77 | curLineIndex = initMetaBlock(treeModel, rawLines, curLineIndex + 1); 78 | } 79 | curLineIndex++; 80 | } 81 | checkTreeBlock(treeModel); 82 | return treeModel; 83 | } 84 | 85 | /** 86 | * load meta block info into model. 87 | * 88 | * @param treeModel tree model 89 | * @param rawStingLines raw string lines 90 | * @param offset read offset 91 | * @return next offset 92 | */ 93 | private static int initMetaBlock(TreeModel treeModel, final List rawStingLines, int offset) { 94 | Map rawDataMap = new HashMap<>(); 95 | final int nextOffset = parseRawKeyValueMap(rawStingLines, rawDataMap, offset); 96 | convertAndSetField("objective", rawDataMap, TreeModelParser::parseObjectiveType, treeModel::setObjectiveType); 97 | convertAndSetField("objective", rawDataMap, TreeModelParser::parseObjectiveConfig, treeModel::setObjectiveConfig); 98 | convertAndSetField("num_class", rawDataMap, Integer::valueOf, treeModel::setNumClass); 99 | convertAndSetField("max_feature_idx", rawDataMap, Integer::valueOf, treeModel::setMaxFeatureIndex); 100 | convertAndSetField("num_tree_per_iteration", rawDataMap, Integer::valueOf, treeModel::setNumberTreePerIteration); 101 | return nextOffset; 102 | } 103 | 104 | /** 105 | * load tree block info into model. 106 | * 107 | * @param root tree node root 108 | * @param rawStingLines raw string lines 109 | * @param offset read offset 110 | * @return next offset 111 | */ 112 | private static int initTreeBlock(TreeNode root, final List rawStingLines, int offset) { 113 | Map rawDataMap = new HashMap<>(); 114 | final int nextOffset = parseRawKeyValueMap(rawStingLines, rawDataMap, offset); 115 | 116 | RawTreeBlock block = new RawTreeBlock(); 117 | convertAndSetField("Tree", rawDataMap, Integer::valueOf, block::setTree); 118 | convertAndSetField("num_leaves", rawDataMap, Integer::valueOf, block::setNumLeaves); 119 | convertAndSetField("num_cat", rawDataMap, Integer::valueOf, block::setNumCat); 120 | convertAndSetField("split_feature", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setSplitFeature); 121 | convertAndSetField("decision_type", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setDecisionType); 122 | convertAndSetField("left_child", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setLeftChild); 123 | convertAndSetField("right_child", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setRightChild); 124 | convertAndSetField("leaf_value", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setLeafValue); 125 | convertAndSetField("internal_value", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setInternalValue); 126 | convertAndSetField("threshold", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setThreshold); 127 | convertAndSetField("cat_boundaries", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setCatBoundaries); 128 | convertAndSetField("cat_threshold", rawDataMap, val -> fromStringToList(val, Long::valueOf), block::setCatThreshold); 129 | 130 | // init all nodes 131 | int treeSize = block.getLeftChild().size(); 132 | List treeNodes = new ArrayList<>(treeSize); 133 | treeNodes.add(0, root); 134 | for (int i = 1; i < treeSize; i++) { 135 | treeNodes.add(i, new TreeNode(i)); 136 | } 137 | treeNodes.forEach(o -> initTreeSingleNode(o, block)); 138 | 139 | // link all nodes 140 | for (int i = 0; i < treeSize; i++) { 141 | linkTreeNode(treeNodes.get(i), treeNodes, block); 142 | } 143 | 144 | // sort node array by index 145 | treeNodes.sort((a, b) -> { 146 | int i = a.getNodeIndex(); 147 | int j = b.getNodeIndex(); 148 | if (i >= 0 && j >= 0) { 149 | return i - j; 150 | } else { 151 | return j - i; 152 | } 153 | }); 154 | 155 | return nextOffset; 156 | } 157 | 158 | /** 159 | * load tree node block info into model. 160 | * 161 | * @param node tree node 162 | * @param block raw block data 163 | */ 164 | private static void initTreeSingleNode(TreeNode node, final RawTreeBlock block) { 165 | int nodeIndex = node.getNodeIndex(); 166 | node.setLeaf(false); 167 | node.setTreeIndex(block.getTree()); 168 | node.setDecisionType(block.getDecisionType().get(nodeIndex)); 169 | node.setCategoryNode(isCategoryNode(block.getDecisionType().get(nodeIndex))); 170 | node.setDefaultLeftDecision(isDefaultLeftDecisionNode(block.getDecisionType().get(nodeIndex))); 171 | node.setSplitFeatures(block.getSplitFeature()); 172 | node.setCatThreshold(block.getCatThreshold()); 173 | if (node.isCategoryNode()) { 174 | node.setCatBoundaryBegin(block.getThreshold().get(nodeIndex).intValue()); 175 | node.setCatBoundaryEnd(node.getCatBoundaryBegin() + 1); 176 | } else { 177 | node.setThreshold(block.getThreshold().get(nodeIndex)); 178 | } 179 | } 180 | 181 | private static void linkTreeNode(TreeNode node, List treeNodes, final RawTreeBlock block) { 182 | int leftIndex = block.getLeftChild().get(node.getNodeIndex()); 183 | if (leftIndex < 0) { 184 | TreeNode leftLeaf = new TreeNode(node.getTreeIndex(), leftIndex); 185 | leftLeaf.setLeaf(true); 186 | leftLeaf.setLeafValue(block.getLeafValue().get(-leftIndex - 1)); 187 | node.setLeftNode(leftLeaf); 188 | treeNodes.add(leftLeaf); 189 | } else { 190 | node.setLeftNode(treeNodes.get(leftIndex)); 191 | } 192 | 193 | int rightIndex = block.getRightChild().get(node.getNodeIndex()); 194 | if (rightIndex < 0) { 195 | TreeNode rightLeaf = new TreeNode(node.getTreeIndex(), rightIndex); 196 | rightLeaf.setLeaf(true); 197 | rightLeaf.setLeafValue(block.getLeafValue().get(-rightIndex - 1)); 198 | node.setRightNode(rightLeaf); 199 | treeNodes.add(rightLeaf); 200 | } else { 201 | node.setRightNode(treeNodes.get(rightIndex)); 202 | } 203 | 204 | node.setAllNodes(treeNodes); 205 | } 206 | 207 | private static int parseRawKeyValueMap(List metaInfos, Map rawDataMap, int offset) { 208 | int nextOffset = offset; 209 | for (; nextOffset < metaInfos.size(); nextOffset++) { 210 | String line = metaInfos.get(nextOffset); 211 | if (StringUtils.isBlank(line)) { 212 | return nextOffset; 213 | } 214 | String[] sp = line.split("="); 215 | if (sp.length != 2) { 216 | throw new RuntimeException(String.format("try to parse tree model failed, invalid key-value content %s", line)); 217 | } 218 | rawDataMap.put(sp[0], sp[1]); 219 | } 220 | return nextOffset; 221 | } 222 | 223 | private static boolean isCategoryNode(int decisionType) { 224 | return (decisionType & CATEGORICAL_MASK) > 0; 225 | } 226 | 227 | private static boolean isDefaultLeftDecisionNode(int decisionType) { 228 | return (decisionType & DEFAULT_LEFT_MASK) > 0; 229 | } 230 | 231 | private static boolean isTreeBlockHeader(final String line) { 232 | return StringUtils.isNoneBlank(line) && line.startsWith("Tree="); 233 | } 234 | 235 | private static boolean isMetaInfoBlockHeader(final String line) { 236 | return StringUtils.isNoneBlank(line) && "tree".equals(line); 237 | } 238 | 239 | private static String parseObjectiveType(final String objective) { 240 | if (StringUtils.isNotBlank(objective)) { 241 | return objective.split(CONFIG_SEPARATOR)[0]; 242 | } 243 | return StringUtils.EMPTY; 244 | } 245 | 246 | private static String parseObjectiveConfig(final String objective) { 247 | if (StringUtils.isNotBlank(objective)) { 248 | String[] sp = objective.split(CONFIG_SEPARATOR); 249 | if (sp.length > 1) { 250 | return sp[1]; 251 | } 252 | } 253 | return StringUtils.EMPTY; 254 | } 255 | 256 | private static List fromStringToList(final String str, Function converter) { 257 | String[] splits = str.split(CONFIG_SEPARATOR); 258 | List ret = new ArrayList<>(splits.length); 259 | for (String val : splits) { 260 | ret.add(converter.apply(val)); 261 | } 262 | return ret; 263 | } 264 | 265 | private static void convertAndSetField(final String key, Map rawDataMap, 266 | Function converter, Consumer setter) { 267 | String rawValue = rawDataMap.get(key); 268 | if (StringUtils.isNoneBlank(rawValue)) { 269 | try { 270 | setter.accept(converter.apply(rawValue)); 271 | } catch (Throwable e) { 272 | throw new RuntimeException(String.format("parsing tree model file error, " 273 | + "try to convert field key: , value: ", key, rawValue), e); 274 | } 275 | } 276 | } 277 | 278 | private static void checkTreeBlock(final TreeModel treeModel) { 279 | Validate.notNull(treeModel, 280 | "parsing tree model failed, can not find any valid block"); 281 | 282 | Validate.notEmpty(treeModel.getTrees(), 283 | "parsing tree model failed, can not find any valid block"); 284 | 285 | Validate.isTrue(treeModel.getNumClass() > 0, 286 | "parsing tree model failed, invalid meta info, num class must be positive"); 287 | 288 | Validate.isTrue(treeModel.getMaxFeatureIndex() >= 0, 289 | "parsing tree model failed, invalid meta info, max feature index can not be negative"); 290 | 291 | Validate.isTrue(treeModel.getNumberTreePerIteration() > 0, 292 | "parsing tree model failed, invalid meta info, number tree per iteration must be positive"); 293 | 294 | Validate.isTrue(ObjectiveDecoratorFactory.isValidObjectiveType(treeModel.getObjectiveType()), 295 | "parsing tree model failed, invalid meta info, objective type %s is not supported", treeModel.getObjectiveType()); 296 | } 297 | } 298 | --------------------------------------------------------------------------------