> PREDICTORS = new HashMap<>();
35 |
36 | private static final Pattern MODEL_NAME_PATTERN = Pattern.compile("^[a-zA-Z0-9_]+$");
37 |
38 | /**
39 | * Tree model loader.
40 | */
41 | private static AbstractLoader treeModelLoader = FileTreeModelLoader.getInstance();
42 |
43 | /**
44 | * Due to the limitation of asm framework, generated class can not be too large,
45 | * thus when the tree nums of model larger than threshold,
46 | * will use {@link SimplePredictor} as predictor implementation.
47 | *
48 | * Default threshold is 300, which can meet the requirement of most business scenarios.
49 | */
50 | private static int asmGenerationTreeNumsThreshold = 300;
51 |
52 | /**
53 | * Entry point for user custom model loader, default model loader is {@link FileTreeModelLoader}
54 | * User custom model loader should implement {@link AbstractLoader}.
55 | *
56 | * @param loader custom loader
57 | */
58 | public static synchronized void setTreeModelLoader(final AbstractLoader loader) {
59 | treeModelLoader = loader;
60 | }
61 |
62 | /**
63 | * Entry point for user to custom tree nums threshold.
64 | * To be careful, the value should not be too large, since asm has a limitation of class size.
65 | * Can see the detail {@link org.objectweb.asm.ClassWriter#toByteArray()}.
66 | *
67 | * @param threshold custom value
68 | */
69 | public static synchronized void setGenerationTreeNumsThreshold(final int threshold) {
70 | asmGenerationTreeNumsThreshold = threshold;
71 | }
72 |
73 | /**
74 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}.
75 | *
76 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_]
77 | * @param resource resource path, if using default model loader, it means file path
78 | * @return Predictor instance
79 | */
80 | public static synchronized Predictor newInstance(final String modelName, final String resource) {
81 | return newInstance(modelName, resource, null);
82 | }
83 |
84 | /**
85 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}.
86 | *
87 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_]
88 | * @param resource resource path, if using default model loader, it means file path
89 | * @param enableGeneration if enableGeneration is false, will always get {@link SimplePredictor} instance
90 | * @return Predictor instance
91 | */
92 | public static synchronized Predictor newInstance(final String modelName, final String resource, boolean enableGeneration) {
93 | return newInstance(modelName, resource, null, enableGeneration);
94 | }
95 |
96 | /**
97 | * Refer to {@link TreePredictorFactory#newInstance(java.lang.String, java.lang.String, java.lang.String, boolean)}.
98 | *
99 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_]
100 | * @param resource resource path, if using default model loader, it means file path
101 | * @param saveClassFileDir generated Predictor class save path, can be null if it's not necessary
102 | * @return Predictor instance
103 | */
104 | public static synchronized Predictor newInstance(final String modelName, final String resource, final String saveClassFileDir) {
105 | return newInstance(modelName, resource, saveClassFileDir, true);
106 | }
107 |
108 | /**
109 | * Create predictor from resource:
110 | * 1. load model data from resource path.
111 | * 2. parse to {@link TreeModel} instance.
112 | * 3. generate {@link Predictor} based on model detail.
113 | *
114 | * @param modelName model name, should be distinct from exist Predictor, and must only contain character: [a-zA-z0-9_]
115 | * @param resource resource path, if using default model loader, it means file path
116 | * @param saveClassFileDir generated Predictor class save path, can be null if it's not necessary
117 | * @param enableGeneration if enableGeneration is false, will always get {@link SimplePredictor} instance
118 | * @return Predictor instance
119 | */
120 | public static synchronized Predictor newInstance(final String modelName, final String resource, final String saveClassFileDir,
121 | boolean enableGeneration) {
122 | checkModelName(modelName);
123 |
124 | String className = toClassName(modelName);
125 | // predictor which is no longer used will be removed after gc
126 | if (PREDICTORS.containsKey(className)) {
127 | Predictor predictor = PREDICTORS.get(className).get();
128 | if (predictor != null) {
129 | return predictor;
130 | } else {
131 | PREDICTORS.remove(className);
132 | }
133 | }
134 | try {
135 | TreeModel treeModel = treeModelLoader.loadModel(resource);
136 | Predictor predictor;
137 | // if model is too large, downgrade to simple predictor implementation
138 | if (!enableGeneration || treeModel.getTrees().size() > asmGenerationTreeNumsThreshold) {
139 | predictor = new SimplePredictor(treeModel);
140 | } else {
141 | // new class loader to do class generation
142 | PredictorClassGenerator generator = PredictorClassGenerator.getInstance();
143 | byte[] bytes = generator.generateCode(className, treeModel);
144 | if (StringUtils.isNotBlank(saveClassFileDir)) {
145 | saveClass(bytes, className, saveClassFileDir);
146 | }
147 | Object targetObj = generator.defineClassFromCode(className, bytes).newInstance();
148 | predictor = (Predictor) targetObj;
149 | }
150 |
151 | // init meta data if need
152 | if (predictor instanceof MetaDataHolder) {
153 | ((MetaDataHolder) predictor).initialize(treeModel);
154 | }
155 |
156 | // objective decorate
157 | Predictor objectivePredictor = ObjectiveDecoratorFactory.decoratePredictorByObjectiveType(predictor, treeModel);
158 |
159 | // wrapper
160 | Predictor predictorWrapper = new PredictorWrapper(objectivePredictor, treeModel);
161 |
162 | PREDICTORS.put(className, new WeakReference<>(predictorWrapper));
163 | return predictorWrapper;
164 | } catch (Throwable e) {
165 | throw new RuntimeException(String.format("fail to generate predict instance, modelName: %s", modelName), e);
166 | }
167 | }
168 |
169 | /**
170 | * Clean predictor reference.
171 | *
172 | * @param predictor should be PredictorWrapper instance which created by factory
173 | */
174 | private static synchronized void releasePredictor(Predictor predictor) {
175 | if (predictor instanceof PredictorWrapper) {
176 | ((PredictorWrapper) predictor).release();
177 | }
178 | }
179 |
180 | /**
181 | * Save raw bytecode into file.
182 | *
183 | * @param bytes class bytecode
184 | * @param className class name
185 | * @param dir save path
186 | * @throws Exception
187 | */
188 | private static void saveClass(final byte[] bytes, final String className, final String dir) throws Exception {
189 | String fileName = customClassFilePath(className, dir);
190 |
191 | File file = new File(fileName);
192 | file.getParentFile().mkdirs();
193 | file.createNewFile();
194 |
195 | try (FileOutputStream fos = new FileOutputStream(fileName)) {
196 | fos.write(bytes);
197 | }
198 | }
199 |
200 | private static String customClassFilePath(final String className, final String dir) {
201 | return StringUtils.join(dir, File.separator, StringUtils.join(className.split("\\."), File.separator), ".class");
202 | }
203 |
204 | /**
205 | * Generated class will start with '_'.
206 | *
207 | * @param modelName model name
208 | * @return class name
209 | */
210 | private static String toClassName(final String modelName) {
211 | return Predictor.PREDICTOR_CLASS_PREFIX + "._" + modelName;
212 | }
213 |
214 | private static void checkModelName(final String modelName) {
215 | if (!MODEL_NAME_PATTERN.matcher(modelName).matches()) {
216 | throw new IllegalArgumentException(String.format("illegal model name: %s, valid character: [a-zA-z0-9_]", modelName));
217 | }
218 | }
219 | }
220 |
--------------------------------------------------------------------------------
/checkstyle/checkstyle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
170 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
--------------------------------------------------------------------------------
/treetops-core/src/main/java/io/github/horoc/treetops/core/generator/PredictorClassGenerator.java:
--------------------------------------------------------------------------------
1 | package io.github.horoc.treetops.core.generator;
2 |
3 | import io.github.horoc.treetops.core.model.MissingType;
4 | import io.github.horoc.treetops.core.model.TreeModel;
5 | import io.github.horoc.treetops.core.model.TreeNode;
6 | import java.util.Map;
7 | import java.util.stream.Collectors;
8 | import org.apache.commons.lang3.StringUtils;
9 | import org.objectweb.asm.ClassVisitor;
10 | import org.objectweb.asm.ClassWriter;
11 | import org.objectweb.asm.Label;
12 | import org.objectweb.asm.MethodVisitor;
13 | import org.objectweb.asm.Opcodes;
14 | import org.objectweb.asm.util.CheckClassAdapter;
15 |
16 | /**
17 | * Predictor class generator based on asm framework.
18 | *
19 | *
20 | * @author chenzhou@apache.org
21 | * created on 2023/2/14
22 | */
23 | public final class PredictorClassGenerator extends ClassLoader implements Generator, Opcodes {
24 |
25 | private static final int FEATURE_PARAMETER_INDEX = 1;
26 |
27 | private static final double K_ZERO_THRESHOLD = 1e-35f;
28 |
29 | private static final String INIT = "";
30 |
31 | private static final String TREE_METHOD_PREFIX = "tree_";
32 |
33 | private static final String OBJECT_INTERNAL_NAME = "java/lang/Object";
34 |
35 | private static final String PREDICTOR_INTERNAL_NAME = "io/github/horoc/treetops/core/predictor/Predictor";
36 |
37 | private static final String META_DATA_HOLDER_INTERNAL_NAME = "io/github/horoc/treetops/core/predictor/MetaDataHolder";
38 |
39 | private static final String PREDICT_METHOD = "predictRaw";
40 |
41 | private static final String FIND_CAT_BIT_SET_METHOD = "findCatBitset";
42 |
43 | private PredictorClassGenerator() {
44 | }
45 |
46 | /**
47 | * We can not maintain a singleton instance of generator here,
48 | * since we want jvm to unload class which would be no longer used.
49 | *
50 | * @return predictor instance
51 | */
52 | public static PredictorClassGenerator getInstance() {
53 | return new PredictorClassGenerator();
54 | }
55 |
56 | @Override
57 | public Class> defineClassFromCode(final String className, final byte[] code) {
58 | return this.defineClass(className, code, 0, code.length);
59 | }
60 |
61 | @Override
62 | public byte[] generateCode(final String className, final TreeModel model) {
63 | ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
64 | ClassVisitor cv = new CheckClassAdapter(cw);
65 | String internalClassName = toInternalName(className);
66 |
67 | // define class
68 | cv.visit(V1_8, ACC_PUBLIC | ACC_SUPER, toInternalName(internalClassName), null, getSuperName(model),
69 | new String[] {PREDICTOR_INTERNAL_NAME});
70 |
71 | // define init method
72 | addInitMethod(cv, model);
73 |
74 | // tree decision method
75 | // description : private double tree_[%tree_index](double[] features);
76 | model.getTrees().forEach(t -> addTreeMethod(cv, internalClassName, t));
77 |
78 | // prediction method
79 | addPredictionMethod(cv, internalClassName, model);
80 |
81 | cv.visitEnd();
82 | return cw.toByteArray();
83 | }
84 |
85 | /**
86 | * Define init method.
87 | *
88 | * @param cv class visitor
89 | * @param model model config
90 | */
91 | private void addInitMethod(ClassVisitor cv, final TreeModel model) {
92 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PUBLIC, INIT, "()V");
93 | methodVisitor.visitCode();
94 | methodVisitor.visitVarInsn(ALOAD, 0);
95 | if (model.isContainsCatNode()) {
96 | methodVisitor.visitMethodInsn(INVOKESPECIAL, META_DATA_HOLDER_INTERNAL_NAME, INIT, "()V", false);
97 | } else {
98 | methodVisitor.visitMethodInsn(INVOKESPECIAL, OBJECT_INTERNAL_NAME, INIT, "()V", false);
99 | }
100 | methodVisitor.visitInsn(RETURN);
101 | methodVisitor.visitMaxs(1, 1);
102 | methodVisitor.visitEnd();
103 | }
104 |
105 | private void addPredictionMethod(ClassVisitor cv, final String className, final TreeModel model) {
106 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PUBLIC, PREDICT_METHOD, "([D)[D");
107 | methodVisitor.visitCode();
108 |
109 | methodVisitor.visitLdcInsn(model.getNumClass());
110 | methodVisitor.visitIntInsn(NEWARRAY, T_DOUBLE);
111 | methodVisitor.visitVarInsn(ASTORE, 2);
112 |
113 | for (int i = 0; i < model.getTrees().size(); i++) {
114 | TreeNode root = model.getTrees().get(i);
115 | methodVisitor.visitVarInsn(ALOAD, 2);
116 | methodVisitor.visitLdcInsn(root.getTreeIndex() % model.getNumClass());
117 | methodVisitor.visitInsn(DUP2);
118 | methodVisitor.visitInsn(DALOAD);
119 | methodVisitor.visitVarInsn(ALOAD, 0);
120 | methodVisitor.visitVarInsn(ALOAD, 1);
121 | methodVisitor.visitMethodInsn(INVOKESPECIAL, className, TREE_METHOD_PREFIX + root.getTreeIndex(), "([D)D", false);
122 | methodVisitor.visitInsn(DADD);
123 | methodVisitor.visitInsn(DASTORE);
124 | }
125 |
126 | methodVisitor.visitVarInsn(ALOAD, 2);
127 | methodVisitor.visitInsn(ARETURN);
128 | methodVisitor.visitMaxs(1, 1);
129 | methodVisitor.visitEnd();
130 | }
131 |
132 | private void addTreeMethod(ClassVisitor cv, final String className, final TreeNode root) {
133 | MethodVisitor methodVisitor = simpleVisitMethod(cv, ACC_PRIVATE, TREE_METHOD_PREFIX + root.getTreeIndex(), "([D)D");
134 | methodVisitor.visitCode();
135 |
136 | Map labels = root.getAllNodes().stream().collect(Collectors.toMap(TreeNode::getNodeIndex, o -> new Label()));
137 | root.getAllNodes().forEach(node -> defineNodeBlock(methodVisitor, node, className, labels));
138 |
139 | methodVisitor.visitMaxs(1, 1);
140 | methodVisitor.visitEnd();
141 | }
142 |
143 | private void defineNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final String className,
144 | final Map labels) {
145 | if (node.isLeaf()) {
146 | defineLeafNodeBlock(methodVisitor, node, labels);
147 | return;
148 | }
149 |
150 | if (node.isCategoryNode()) {
151 | defineCategoryNodeBlock(methodVisitor, node, className, labels);
152 | } else {
153 | defineNumericalNodeBlock(methodVisitor, node, labels);
154 | }
155 | }
156 |
157 | private void defineLeafNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final Map labels) {
158 | int nodeIndex = node.getNodeIndex();
159 | methodVisitor.visitLabel(labels.get(nodeIndex));
160 | methodVisitor.visitLdcInsn(new Double(node.getLeafValue()));
161 | methodVisitor.visitInsn(DRETURN);
162 | }
163 |
164 | @SuppressWarnings("Duplicates")
165 | private void defineNumericalNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final Map labels) {
166 | int nodeIndex = node.getNodeIndex();
167 | methodVisitor.visitLabel(labels.get(nodeIndex));
168 |
169 | // load feature
170 | loadFeatureByIndex(methodVisitor, node.getSplitFeatures().get(nodeIndex));
171 | methodVisitor.visitVarInsn(DSTORE, 2);
172 |
173 | // if missing_type != nan and feature is nan, set feature to zero
174 | MissingType missingType = MissingType.ofMask((node.getDecisionType() >> 2) & 3);
175 | if (missingType != MissingType.Nan) {
176 | methodVisitor.visitVarInsn(DLOAD, 2);
177 | methodVisitor.visitVarInsn(DLOAD, 2);
178 | methodVisitor.visitInsn(DCMPL);
179 | Label label = new Label();
180 | // if feature is not nan, go to continue
181 | methodVisitor.visitJumpInsn(IFEQ, label);
182 | // set feature to zero
183 | methodVisitor.visitInsn(DCONST_0);
184 | methodVisitor.visitVarInsn(DSTORE, 2);
185 | // continue
186 | methodVisitor.visitLabel(label);
187 | }
188 |
189 | // if missingType == zero and feature is zero
190 | if (missingType == MissingType.Zero) {
191 | Label label = new Label();
192 | // if feature < -1e-35, not zero, jump to continue
193 | methodVisitor.visitVarInsn(DLOAD, 2);
194 | methodVisitor.visitLdcInsn(new Double(-K_ZERO_THRESHOLD));
195 | methodVisitor.visitInsn(DCMPL);
196 | methodVisitor.visitJumpInsn(IFLT, label);
197 |
198 | // if feature > 1e-35, not zero, jump to continue
199 | methodVisitor.visitVarInsn(DLOAD, 2);
200 | methodVisitor.visitLdcInsn(new Double(K_ZERO_THRESHOLD));
201 | methodVisitor.visitInsn(DCMPG);
202 | methodVisitor.visitJumpInsn(IFGT, label);
203 |
204 | // if feature is zero, jump to next node
205 | if (node.isDefaultLeftDecision()) {
206 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex()));
207 | } else {
208 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex()));
209 | }
210 |
211 | // continue
212 | methodVisitor.visitLabel(label);
213 | }
214 |
215 | // if missingType == nan and feature is nan
216 | if (missingType == MissingType.Nan) {
217 | Label label = new Label();
218 | // if feature is not nan, jump to continue
219 | methodVisitor.visitVarInsn(DLOAD, 2);
220 | methodVisitor.visitVarInsn(DLOAD, 2);
221 | methodVisitor.visitInsn(DCMPL);
222 | methodVisitor.visitJumpInsn(IFEQ, label);
223 |
224 | // if feature is nan, jump to next node
225 | if (node.isDefaultLeftDecision()) {
226 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex()));
227 | } else {
228 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex()));
229 | }
230 |
231 | // continue
232 | methodVisitor.visitLabel(label);
233 | }
234 |
235 | // compare to threshold
236 | methodVisitor.visitVarInsn(DLOAD, 2);
237 | methodVisitor.visitLdcInsn(new Double(node.getThreshold()));
238 | methodVisitor.visitInsn(DCMPG);
239 | // feature > threshold, jump to right
240 | methodVisitor.visitJumpInsn(IFGE, labels.get(node.getRightNode().getNodeIndex()));
241 | // feature <= threshold, jump to left
242 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getLeftNode().getNodeIndex()));
243 | }
244 |
245 | private void defineCategoryNodeBlock(MethodVisitor methodVisitor, final TreeNode node, final String className,
246 | final Map labels) {
247 | int nodeIndex = node.getNodeIndex();
248 | methodVisitor.visitLabel(labels.get(nodeIndex));
249 |
250 | // load feature
251 | loadFeatureByIndex(methodVisitor, node.getSplitFeatures().get(nodeIndex));
252 | methodVisitor.visitVarInsn(DSTORE, 2);
253 |
254 | // if feature isNaN, jump to right child node
255 | methodVisitor.visitVarInsn(DLOAD, 2);
256 | methodVisitor.visitVarInsn(DLOAD, 2);
257 | methodVisitor.visitInsn(DCMPL);
258 | methodVisitor.visitJumpInsn(IFNE, labels.get(node.getRightNode().getNodeIndex()));
259 |
260 | // if feature < 0, jump to right child node
261 | methodVisitor.visitVarInsn(DLOAD, 2);
262 | methodVisitor.visitInsn(DCONST_0);
263 | methodVisitor.visitInsn(DCMPG);
264 | methodVisitor.visitJumpInsn(IFLT, labels.get(node.getRightNode().getNodeIndex()));
265 |
266 | // if findInBitset, jump to right child node
267 | methodVisitor.visitVarInsn(ALOAD, 0);
268 | methodVisitor.visitLdcInsn(node.getTreeIndex());
269 | methodVisitor.visitLdcInsn(node.getCatBoundaryBegin());
270 | methodVisitor.visitLdcInsn(node.getCatBoundaryEnd() - node.getCatBoundaryBegin());
271 | methodVisitor.visitVarInsn(DLOAD, 2);
272 | methodVisitor.visitMethodInsn(INVOKEVIRTUAL, className, FIND_CAT_BIT_SET_METHOD, "(IIID)Z", false);
273 | methodVisitor.visitJumpInsn(IFNE, labels.get(node.getLeftNode().getNodeIndex()));
274 |
275 | // others, jump left child node
276 | methodVisitor.visitJumpInsn(GOTO, labels.get(node.getRightNode().getNodeIndex()));
277 | }
278 |
279 | private void loadFeatureByIndex(MethodVisitor methodVisitor, int index) {
280 | // load features[index] to stack
281 | methodVisitor.visitVarInsn(ALOAD, FEATURE_PARAMETER_INDEX);
282 | methodVisitor.visitLdcInsn(index);
283 | methodVisitor.visitInsn(DALOAD);
284 | }
285 |
286 | private MethodVisitor simpleVisitMethod(ClassVisitor cv, int access, final String name,
287 | final String descriptor) {
288 | return cv.visitMethod(access, name, descriptor, null, null);
289 | }
290 |
291 | private String toInternalName(final String name) {
292 | return StringUtils.join(name.split("\\."), "/");
293 | }
294 |
295 | private String getSuperName(final TreeModel model) {
296 | if (model.isContainsCatNode()) {
297 | return META_DATA_HOLDER_INTERNAL_NAME;
298 | } else {
299 | return OBJECT_INTERNAL_NAME;
300 | }
301 | }
302 | }
303 |
--------------------------------------------------------------------------------
/treetops-core/src/main/java/io/github/horoc/treetops/core/parser/TreeModelParser.java:
--------------------------------------------------------------------------------
1 | package io.github.horoc.treetops.core.parser;
2 |
3 | import io.github.horoc.treetops.core.factory.ObjectiveDecoratorFactory;
4 | import io.github.horoc.treetops.core.model.RawTreeBlock;
5 | import io.github.horoc.treetops.core.model.TreeModel;
6 | import io.github.horoc.treetops.core.model.TreeNode;
7 | import java.util.ArrayList;
8 | import java.util.HashMap;
9 | import java.util.List;
10 | import java.util.Map;
11 | import java.util.Objects;
12 | import java.util.function.Consumer;
13 | import java.util.function.Function;
14 | import javax.annotation.Nonnull;
15 | import javax.annotation.ParametersAreNonnullByDefault;
16 | import javax.annotation.concurrent.Immutable;
17 | import javax.annotation.concurrent.ThreadSafe;
18 | import org.apache.commons.lang3.StringUtils;
19 | import org.apache.commons.lang3.Validate;
20 |
21 | /**
22 | * Parse String model info into TreeModel instance.
23 | *
24 | * @author chenzhou@apache.org
25 | * created on 2023/2/14
26 | */
27 | @Immutable
28 | @ThreadSafe
29 | @ParametersAreNonnullByDefault
30 | public class TreeModelParser {
31 |
32 | /**
33 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#kCategoricalMask.
34 | */
35 | private static final int CATEGORICAL_MASK = 1;
36 |
37 | /**
38 | * Refer to official library: microsoft/LightGBM/include/LightGBM/tree.h#kDefaultLeftMask.
39 | */
40 | private static final int DEFAULT_LEFT_MASK = 2;
41 |
42 | /**
43 | * Model file config separator.
44 | */
45 | private static final String CONFIG_SEPARATOR = " ";
46 |
47 | /**
48 | * Parse workflow:
49 | * 1. iterator each line to find meta block header or tree block header
50 | * 2. parse different block into tree model instance
51 | *
52 | * @param rawLines raw data
53 | * @return tree model instance
54 | */
55 | public static TreeModel parseTreeModel(@Nonnull final List rawLines) {
56 | Validate.notNull(rawLines);
57 |
58 | TreeModel treeModel = new TreeModel();
59 | int curLineIndex = 0;
60 | while (curLineIndex < rawLines.size()) {
61 | String line = rawLines.get(curLineIndex);
62 | if (StringUtils.isBlank(line)) {
63 | curLineIndex++;
64 | continue;
65 | }
66 | // we only need tree info and meta block info
67 | if (isTreeBlockHeader(line)) {
68 | TreeNode tree = new TreeNode();
69 | curLineIndex = initTreeBlock(tree, rawLines, curLineIndex);
70 | treeModel.getTrees().add(tree);
71 | // mark model contains category node, need to initialize threshold data later
72 | if (Objects.nonNull(tree.getCatThreshold()) && !tree.getCatThreshold().isEmpty()) {
73 | treeModel.setContainsCatNode(true);
74 | }
75 | } else if (isMetaInfoBlockHeader(line)) {
76 | // skip first line of meta info block
77 | curLineIndex = initMetaBlock(treeModel, rawLines, curLineIndex + 1);
78 | }
79 | curLineIndex++;
80 | }
81 | checkTreeBlock(treeModel);
82 | return treeModel;
83 | }
84 |
85 | /**
86 | * load meta block info into model.
87 | *
88 | * @param treeModel tree model
89 | * @param rawStingLines raw string lines
90 | * @param offset read offset
91 | * @return next offset
92 | */
93 | private static int initMetaBlock(TreeModel treeModel, final List rawStingLines, int offset) {
94 | Map rawDataMap = new HashMap<>();
95 | final int nextOffset = parseRawKeyValueMap(rawStingLines, rawDataMap, offset);
96 | convertAndSetField("objective", rawDataMap, TreeModelParser::parseObjectiveType, treeModel::setObjectiveType);
97 | convertAndSetField("objective", rawDataMap, TreeModelParser::parseObjectiveConfig, treeModel::setObjectiveConfig);
98 | convertAndSetField("num_class", rawDataMap, Integer::valueOf, treeModel::setNumClass);
99 | convertAndSetField("max_feature_idx", rawDataMap, Integer::valueOf, treeModel::setMaxFeatureIndex);
100 | convertAndSetField("num_tree_per_iteration", rawDataMap, Integer::valueOf, treeModel::setNumberTreePerIteration);
101 | return nextOffset;
102 | }
103 |
104 | /**
105 | * load tree block info into model.
106 | *
107 | * @param root tree node root
108 | * @param rawStingLines raw string lines
109 | * @param offset read offset
110 | * @return next offset
111 | */
112 | private static int initTreeBlock(TreeNode root, final List rawStingLines, int offset) {
113 | Map rawDataMap = new HashMap<>();
114 | final int nextOffset = parseRawKeyValueMap(rawStingLines, rawDataMap, offset);
115 |
116 | RawTreeBlock block = new RawTreeBlock();
117 | convertAndSetField("Tree", rawDataMap, Integer::valueOf, block::setTree);
118 | convertAndSetField("num_leaves", rawDataMap, Integer::valueOf, block::setNumLeaves);
119 | convertAndSetField("num_cat", rawDataMap, Integer::valueOf, block::setNumCat);
120 | convertAndSetField("split_feature", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setSplitFeature);
121 | convertAndSetField("decision_type", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setDecisionType);
122 | convertAndSetField("left_child", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setLeftChild);
123 | convertAndSetField("right_child", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setRightChild);
124 | convertAndSetField("leaf_value", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setLeafValue);
125 | convertAndSetField("internal_value", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setInternalValue);
126 | convertAndSetField("threshold", rawDataMap, val -> fromStringToList(val, Double::valueOf), block::setThreshold);
127 | convertAndSetField("cat_boundaries", rawDataMap, val -> fromStringToList(val, Integer::valueOf), block::setCatBoundaries);
128 | convertAndSetField("cat_threshold", rawDataMap, val -> fromStringToList(val, Long::valueOf), block::setCatThreshold);
129 |
130 | // init all nodes
131 | int treeSize = block.getLeftChild().size();
132 | List treeNodes = new ArrayList<>(treeSize);
133 | treeNodes.add(0, root);
134 | for (int i = 1; i < treeSize; i++) {
135 | treeNodes.add(i, new TreeNode(i));
136 | }
137 | treeNodes.forEach(o -> initTreeSingleNode(o, block));
138 |
139 | // link all nodes
140 | for (int i = 0; i < treeSize; i++) {
141 | linkTreeNode(treeNodes.get(i), treeNodes, block);
142 | }
143 |
144 | // sort node array by index
145 | treeNodes.sort((a, b) -> {
146 | int i = a.getNodeIndex();
147 | int j = b.getNodeIndex();
148 | if (i >= 0 && j >= 0) {
149 | return i - j;
150 | } else {
151 | return j - i;
152 | }
153 | });
154 |
155 | return nextOffset;
156 | }
157 |
158 | /**
159 | * load tree node block info into model.
160 | *
161 | * @param node tree node
162 | * @param block raw block data
163 | */
164 | private static void initTreeSingleNode(TreeNode node, final RawTreeBlock block) {
165 | int nodeIndex = node.getNodeIndex();
166 | node.setLeaf(false);
167 | node.setTreeIndex(block.getTree());
168 | node.setDecisionType(block.getDecisionType().get(nodeIndex));
169 | node.setCategoryNode(isCategoryNode(block.getDecisionType().get(nodeIndex)));
170 | node.setDefaultLeftDecision(isDefaultLeftDecisionNode(block.getDecisionType().get(nodeIndex)));
171 | node.setSplitFeatures(block.getSplitFeature());
172 | node.setCatThreshold(block.getCatThreshold());
173 | if (node.isCategoryNode()) {
174 | node.setCatBoundaryBegin(block.getThreshold().get(nodeIndex).intValue());
175 | node.setCatBoundaryEnd(node.getCatBoundaryBegin() + 1);
176 | } else {
177 | node.setThreshold(block.getThreshold().get(nodeIndex));
178 | }
179 | }
180 |
181 | private static void linkTreeNode(TreeNode node, List treeNodes, final RawTreeBlock block) {
182 | int leftIndex = block.getLeftChild().get(node.getNodeIndex());
183 | if (leftIndex < 0) {
184 | TreeNode leftLeaf = new TreeNode(node.getTreeIndex(), leftIndex);
185 | leftLeaf.setLeaf(true);
186 | leftLeaf.setLeafValue(block.getLeafValue().get(-leftIndex - 1));
187 | node.setLeftNode(leftLeaf);
188 | treeNodes.add(leftLeaf);
189 | } else {
190 | node.setLeftNode(treeNodes.get(leftIndex));
191 | }
192 |
193 | int rightIndex = block.getRightChild().get(node.getNodeIndex());
194 | if (rightIndex < 0) {
195 | TreeNode rightLeaf = new TreeNode(node.getTreeIndex(), rightIndex);
196 | rightLeaf.setLeaf(true);
197 | rightLeaf.setLeafValue(block.getLeafValue().get(-rightIndex - 1));
198 | node.setRightNode(rightLeaf);
199 | treeNodes.add(rightLeaf);
200 | } else {
201 | node.setRightNode(treeNodes.get(rightIndex));
202 | }
203 |
204 | node.setAllNodes(treeNodes);
205 | }
206 |
207 | private static int parseRawKeyValueMap(List metaInfos, Map rawDataMap, int offset) {
208 | int nextOffset = offset;
209 | for (; nextOffset < metaInfos.size(); nextOffset++) {
210 | String line = metaInfos.get(nextOffset);
211 | if (StringUtils.isBlank(line)) {
212 | return nextOffset;
213 | }
214 | String[] sp = line.split("=");
215 | if (sp.length != 2) {
216 | throw new RuntimeException(String.format("try to parse tree model failed, invalid key-value content %s", line));
217 | }
218 | rawDataMap.put(sp[0], sp[1]);
219 | }
220 | return nextOffset;
221 | }
222 |
223 | private static boolean isCategoryNode(int decisionType) {
224 | return (decisionType & CATEGORICAL_MASK) > 0;
225 | }
226 |
227 | private static boolean isDefaultLeftDecisionNode(int decisionType) {
228 | return (decisionType & DEFAULT_LEFT_MASK) > 0;
229 | }
230 |
231 | private static boolean isTreeBlockHeader(final String line) {
232 | return StringUtils.isNoneBlank(line) && line.startsWith("Tree=");
233 | }
234 |
235 | private static boolean isMetaInfoBlockHeader(final String line) {
236 | return StringUtils.isNoneBlank(line) && "tree".equals(line);
237 | }
238 |
239 | private static String parseObjectiveType(final String objective) {
240 | if (StringUtils.isNotBlank(objective)) {
241 | return objective.split(CONFIG_SEPARATOR)[0];
242 | }
243 | return StringUtils.EMPTY;
244 | }
245 |
246 | private static String parseObjectiveConfig(final String objective) {
247 | if (StringUtils.isNotBlank(objective)) {
248 | String[] sp = objective.split(CONFIG_SEPARATOR);
249 | if (sp.length > 1) {
250 | return sp[1];
251 | }
252 | }
253 | return StringUtils.EMPTY;
254 | }
255 |
256 | private static List fromStringToList(final String str, Function converter) {
257 | String[] splits = str.split(CONFIG_SEPARATOR);
258 | List ret = new ArrayList<>(splits.length);
259 | for (String val : splits) {
260 | ret.add(converter.apply(val));
261 | }
262 | return ret;
263 | }
264 |
265 | private static void convertAndSetField(final String key, Map rawDataMap,
266 | Function converter, Consumer setter) {
267 | String rawValue = rawDataMap.get(key);
268 | if (StringUtils.isNoneBlank(rawValue)) {
269 | try {
270 | setter.accept(converter.apply(rawValue));
271 | } catch (Throwable e) {
272 | throw new RuntimeException(String.format("parsing tree model file error, "
273 | + "try to convert field key: , value: ", key, rawValue), e);
274 | }
275 | }
276 | }
277 |
278 | private static void checkTreeBlock(final TreeModel treeModel) {
279 | Validate.notNull(treeModel,
280 | "parsing tree model failed, can not find any valid block");
281 |
282 | Validate.notEmpty(treeModel.getTrees(),
283 | "parsing tree model failed, can not find any valid block");
284 |
285 | Validate.isTrue(treeModel.getNumClass() > 0,
286 | "parsing tree model failed, invalid meta info, num class must be positive");
287 |
288 | Validate.isTrue(treeModel.getMaxFeatureIndex() >= 0,
289 | "parsing tree model failed, invalid meta info, max feature index can not be negative");
290 |
291 | Validate.isTrue(treeModel.getNumberTreePerIteration() > 0,
292 | "parsing tree model failed, invalid meta info, number tree per iteration must be positive");
293 |
294 | Validate.isTrue(ObjectiveDecoratorFactory.isValidObjectiveType(treeModel.getObjectiveType()),
295 | "parsing tree model failed, invalid meta info, objective type %s is not supported", treeModel.getObjectiveType());
296 | }
297 | }
298 |
--------------------------------------------------------------------------------