├── AdditiveDecisionTree └── AdditiveDecisionTree.py ├── README.md ├── Results ├── results_18_08_2021_14_26_34_heatmap.png ├── results_18_08_2021_14_26_34_summarized.csv └── results_24_05_2024_10_17_47_plot.png └── examples ├── Accuracy_Test_Additive_Decision_Tree.py └── Simple_Example_Additive_Decision_Tree.ipynb /AdditiveDecisionTree/AdditiveDecisionTree.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn import tree 4 | from sklearn.base import BaseEstimator 5 | from sklearn.metrics import f1_score, mean_squared_error 6 | from info_gain import info_gain 7 | import matplotlib.pyplot as plt 8 | 9 | # Node types 10 | NOT_SPLIT = -1 11 | CAN_NOT_SPLIT = -2 12 | ADDITIVE_NODE = -100 13 | 14 | 15 | # Colours used in visualizations, with each class represented by a consistent colour. 16 | tableau_palette_list = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown", "tab:pink", 17 | "tab:gray", "tab:olive", "tab:cyan"] 18 | 19 | 20 | def clean_data(df): 21 | df = df.fillna(0.0) # todo: fill as in datasets evaluator. including one-hot encoding cat cols 22 | df = df.replace([np.inf, -np.inf], 0.0) 23 | return df 24 | 25 | 26 | class InternalTree: 27 | """ 28 | Represents the collection of nodes within an Additive Decision Tree and information about these. 29 | """ 30 | def __init__(self): 31 | self.children_left = None 32 | self.children_right = None 33 | self.feature = None 34 | self.threshold = None 35 | self.indexes = None 36 | self.depths = None 37 | self.can_split = None 38 | self.leaf_explanation = None 39 | self.node_best_threshold_arr = None 40 | self.node_used_feature_arr = None 41 | self.node_used_threshold_arr = None 42 | self.node_measurement = None 43 | self.node_best_measurement_arr = None 44 | 45 | 46 | class AdditiveDecisionTree(BaseEstimator): 47 | def __init__( self, 48 | min_samples_split=8, 49 | min_samples_leaf=6, 50 | max_depth=np.inf, 51 | verbose_level=0, 52 | allow_additive_nodes=True, 53 | max_added_splits_per_node=5): 54 | """ 55 | :param min_samples_split: As with standard decision trees. The minimum number of samples in a node to allow 56 | further splitting. 57 | :param min_samples_leaf: As with standard decision trees. The minimum number of samples in that would result 58 | in a leaf node to allow further splitting. 59 | :param max_depth: As with standard decision trees. The maximum path length from the root to any leaf node. 60 | :param verbose_level: Controls the display of output during fitting and predicting. 61 | :param allow_additive_nodes: If False, behaves similar to standard decision tree. 62 | :param max_added_splits_per_node: The maximum number of splits that may be included in any additive node. 63 | """ 64 | 65 | # Variables related to the fitting process 66 | self.min_samples_split = min_samples_split 67 | self.min_samples_leaf = min_samples_leaf 68 | self.max_depth = max_depth 69 | self.allow_additive_nodes = allow_additive_nodes 70 | self.max_added_splits_per_node = max_added_splits_per_node 71 | 72 | # Dataframe holding the original and generated features and arrays describing the features 73 | self.X = None 74 | self.y = None 75 | 76 | # Parallel arrays related to the tree that is generated 77 | self.tree_ = InternalTree() 78 | self.tree_.children_left = [CAN_NOT_SPLIT] # Index of the left child. -2 for leaves. 79 | self.tree_.children_right = [CAN_NOT_SPLIT] # Index of the right child. -2 for leaves. 80 | self.tree_.feature = [NOT_SPLIT] # The feature used for this node. -1 for nodes not yet split. -2 for leaves. 81 | self.tree_.threshold = [CAN_NOT_SPLIT] # The threshold used for this node. -2 for leaves. 82 | self.tree_.indexes = [[]] # The row indexes out of the full self.X covered by each node 83 | self.tree_.depths = [0] # The depth of each node 84 | self.tree_.can_split = [True] # Indication if the igr can be calculated for this node. Generally True. 85 | self.tree_.leaf_explanation = [""] # Indicates the reason the leaf was not split further. Blank for internal nodes and leaves with full purity. 86 | self.tree_.node_best_threshold_arr = [[]] 87 | self.tree_.node_used_feature_arr = [[]] 88 | self.tree_.node_used_threshold_arr = [[]] 89 | self.tree_.node_measurement = [CAN_NOT_SPLIT] # The information gain ratio (IGR) of the split at each node. -2 for leaves. 90 | self.tree_.node_best_measurement_arr = [[]] 91 | 92 | # Logging information 93 | self.verbose_level = verbose_level 94 | 95 | def __str__(self): 96 | return (f"min_samples_split: {self.min_samples_split}, min_samples_leaf: {self.min_samples_leaf}, " 97 | f"max_depth: {self.max_depth}, allow_additive_nodes: {self.allow_additive_nodes}, " 98 | f"max_added_splits_per_node: {self.max_added_splits_per_node}") 99 | 100 | def check_nodes_arrays(self): 101 | """ 102 | Test for internal consistency. Check all parallel arrays are the same length. 103 | """ 104 | assert len(self.tree_.children_left) == len(self.tree_.children_right) 105 | assert len(self.tree_.children_left) == len(self.tree_.feature) 106 | assert len(self.tree_.children_left) == len(self.tree_.threshold) 107 | assert len(self.tree_.children_left) == len(self.tree_.indexes) 108 | assert len(self.tree_.children_left) == len(self.tree_.depths) 109 | assert len(self.tree_.children_left) == len(self.tree_.can_split) 110 | assert len(self.tree_.children_left) == len(self.tree_.node_measurement) 111 | assert len(self.tree_.children_left) == len(self.tree_.node_best_measurement_arr) 112 | assert len(self.tree_.children_left) == len(self.tree_.node_best_threshold_arr) 113 | assert len(self.tree_.children_left) == len(self.tree_.node_used_feature_arr) 114 | assert len(self.tree_.children_left) == len(self.tree_.node_used_threshold_arr) 115 | 116 | def fit(self, X, y): 117 | """ 118 | Fit a model. 119 | :param X: matrix of values. 120 | :param y: target column. May be categorical or numeric. 121 | :return: self 122 | """ 123 | # If possible, split a single node, creating two child nodes. 124 | def split_node(node_idx): 125 | if self.tree_.feature[node_idx] != NOT_SPLIT: 126 | self.log(3, "Node already split. (Features[node_idx] is not -1.). Feature: " + \ 127 | str(self.tree_.feature[node_idx])) 128 | return False 129 | 130 | self.log(2, "\n\n####################################") 131 | self.log(1, "Calling split_node for:", node_idx) 132 | 133 | if len(self.tree_.indexes[node_idx]) <= self.min_samples_split: 134 | self.log(2, "Too few rows to split further. #rows: ", len(self.tree_.indexes[node_idx])) 135 | self.tree_.feature[node_idx] = CAN_NOT_SPLIT 136 | self.tree_.leaf_explanation[node_idx] = "Too few rows to split further." 137 | return False 138 | 139 | if self.tree_.depths[node_idx] >= self.max_depth: 140 | self.log(2, "Maximum depth reached. Depth: ", self.tree_.depths[node_idx]) 141 | self.tree_.feature[node_idx] = CAN_NOT_SPLIT 142 | self.tree_.leaf_explanation[node_idx] = "Maximum depth reached." 143 | return False 144 | 145 | if not self.tree_.can_split[node_idx]: 146 | self.log(2, "Cannot split this node. (Cannot calculate the igr.)") 147 | self.tree_.feature[node_idx] = CAN_NOT_SPLIT 148 | self.tree_.leaf_explanation[node_idx] = "Cannot calculate splitting criteria." 149 | return False 150 | 151 | if self.check_node_complete(node_idx): 152 | return True 153 | 154 | # Get the set of rows at this node 155 | X_local = self.X.loc[self.tree_.indexes[node_idx]] 156 | y_local = self.y.loc[self.tree_.indexes[node_idx]] 157 | self.log(5, self.tree_.indexes[node_idx]) 158 | self.log(3, "# rows this node: ", len(X_local)) 159 | assert len(X_local) == len(y_local), "Lengths wrong: " + str(len(X_local)) + ", " + str(len(y_local)) 160 | 161 | # todo: Increase the efficiency of this. Where sampling is use, specify as a hyperparameter. 162 | sample_size = 1000 163 | if len(X_local) > sample_size: 164 | X_local = X_local.sample(n=sample_size, random_state=0) 165 | y_local = y_local.loc[X_local.index] 166 | assert len(X_local) == len(y_local), "Lengths incorrect after taking sample: " + str(len(X_local)) + ", " + str(len(y_local)) 167 | 168 | self.log(4, "X_local:") 169 | if self.verbose_level >= 4: 170 | print("X:") 171 | print(X_local.head()) 172 | print("y:") 173 | print(y.head()) 174 | 175 | # If the verbose_level is sufficiently high, render plots of each pair of features at this node. 176 | # todo: move to subclasses 177 | if self.verbose_level >= 6: 178 | # todo: use the general method to produce scatter plots 179 | for c1_idx in range(len(X.columns)-1): 180 | for c2_idx in range(c1_idx+1, len(X.columns)): 181 | for class_idx in range(len(self.classes_)): 182 | class_name = self.classes_[class_idx] 183 | idx_arr = y_local.loc[y_local == class_name].index 184 | X_curr_class = X_local.loc[idx_arr] 185 | plt.scatter(X_curr_class[X_curr_class.columns[c1_idx]], 186 | X_curr_class[X_curr_class.columns[c2_idx]], 187 | alpha=0.1, 188 | c=tableau_palette_list[class_idx], 189 | label=self.classes_[class_idx]) 190 | plt.title("Columns "+str(c1_idx)+"-"+str(c2_idx)) 191 | plt.legend() 192 | plt.show() 193 | 194 | # Loop through each column and determine the information gain ratio using that column. 195 | best_col_idx, best_col_threshold, best_col_measurement, measurement_arr, threshold_arr, good_split_found = \ 196 | self.find_best_col(X_local, y_local, node_idx) 197 | 198 | if not good_split_found: 199 | return False 200 | 201 | # Check if this split would result in too few rows in either child 202 | X_local = self.X.loc[self.tree_.indexes[node_idx]] 203 | y_local = self.y.loc[self.tree_.indexes[node_idx]] 204 | attribute_arr = np.where(X_local[self.X.columns[best_col_idx]] <= best_col_threshold, 0, 1) 205 | if (attribute_arr.tolist().count(0) < self.min_samples_leaf) or (attribute_arr.tolist().count(1) < self.min_samples_leaf): 206 | self.tree_.feature[node_idx] = CAN_NOT_SPLIT 207 | self.log(2, "Split would result in too small child nodes.") 208 | self.tree_.leaf_explanation[node_idx] = "Split would result in too small child nodes." 209 | return True 210 | 211 | # Update this node 212 | new_left_idx = len(self.tree_.feature) 213 | new_right_idx = len(self.tree_.feature) + 1 214 | self.tree_.children_left[node_idx] = new_left_idx 215 | self.tree_.children_right[node_idx] = new_right_idx 216 | self.tree_.feature[node_idx] = best_col_idx 217 | self.tree_.threshold[node_idx] = best_col_threshold 218 | self.tree_.node_measurement[node_idx] = best_col_measurement 219 | self.tree_.node_best_measurement_arr[node_idx] = measurement_arr 220 | self.tree_.node_best_threshold_arr[node_idx] = threshold_arr 221 | 222 | # Create nodes for the 2 children 223 | self.tree_.children_left.extend([-2, -2]) 224 | self.tree_.children_right.extend([-2, -2]) 225 | self.tree_.feature.extend([-1, -1]) 226 | self.tree_.threshold.extend([-2, -2]) 227 | self.tree_.indexes.extend([[], []]) 228 | new_depth = self.tree_.depths[node_idx]+1 229 | self.tree_.depths.extend([new_depth, new_depth]) 230 | self.tree_.can_split.extend([True, True]) 231 | self.tree_.leaf_explanation.extend(["", ""]) 232 | self.tree_.node_measurement.extend([-2, -2]) 233 | self.tree_.node_best_measurement_arr.extend([[], []]) 234 | self.tree_.node_best_threshold_arr.extend([[], []]) 235 | self.tree_.node_used_feature_arr.extend([[], []]) 236 | self.tree_.node_used_threshold_arr.extend([[], []]) 237 | 238 | # Set the indexes of the two child nodes 239 | self.tree_.indexes[new_left_idx] = X_local.iloc[np.where(attribute_arr <= 0)[0]].index 240 | self.tree_.indexes[new_right_idx] = X_local.iloc[np.where(attribute_arr > 0)[0]].index 241 | 242 | # Update the subclass-specific stats about the new nodes 243 | self.update_node_stats(new_left_idx, new_right_idx) 244 | 245 | # Log messages 246 | self.log(5, "where arr 1:", np.where(attribute_arr <= 0)[0]) 247 | self.log(5, "where arr 2:", np.where(attribute_arr > 0)[0]) 248 | self.log(3, "# rows in left child: ", len(np.where(attribute_arr <= 0)[0])) 249 | self.log(3, "# rows in right child: ", len(np.where(attribute_arr > 0)[0])) 250 | self.log(5, "new_left_idx indexes", self.tree_.indexes[new_left_idx]) 251 | self.log(5, "new_right_idx indexes", self.tree_.indexes[new_right_idx]) 252 | 253 | return True 254 | 255 | # Initialize the variables related to the data 256 | self.X = X 257 | self.y = y 258 | self.init_variables(y) 259 | 260 | # Initialize the variables related to the root node of the tree 261 | self.tree_.indexes[0] = self.X.index # The first node contains all rows 262 | self.init_root(y) 263 | 264 | # Build the tree. Loop through each node until no more nodes can be split. 265 | num_nodes_split = 1 266 | while num_nodes_split > 0: 267 | num_nodes_split = 0 268 | for node_idx in range(len(self.tree_.feature)): 269 | num_nodes_split += split_node(node_idx) 270 | self.check_nodes_arrays() 271 | 272 | # Create additive nodes 273 | if self.allow_additive_nodes: 274 | self.create_additive_nodes() 275 | self.remove_disconnected_nodes() 276 | 277 | return self 278 | 279 | def output_tree(self): 280 | print("\n********************************************************") 281 | print("Generated Tree") 282 | print("********************************************************") 283 | print(f"\n# Nodes: {self.get_num_nodes()}") 284 | print(f"\nLeft Chidren:\n{self.tree_.children_left}") 285 | print(f"\nRight Chidren:\n{self.tree_.children_right}") 286 | print(f"\n# Rows: \n{[len(x) for x in self.tree_.indexes]}") 287 | print(f"\nFeatures:\n{self.tree_.feature}") 288 | print(f"\nFeatures in additive nodes:\n{self.tree_.node_used_feature_arr}") 289 | print(f"\nThresholds:\n{self.tree_.threshold}") 290 | print(f"\nDepths:\n{self.tree_.depths}") 291 | print(f"\nCan split: \n{self.tree_.can_split}") 292 | self.subclass_output_tree() # todo: probably have subclasses call super.outputtree() 293 | print("********************************************************\n") 294 | 295 | def remove_disconnected_nodes(self): 296 | """ 297 | Remove any nodes that were disconnected from the tree during pruning. 298 | :return: None 299 | """ 300 | node_reachable_arr = [0]*len(self.tree_.feature) 301 | node_reachable_arr[0] = 1 302 | for i in range(len(self.tree_.children_left)): 303 | node_idx = self.tree_.children_left[i] 304 | if (node_idx >= 0) and (node_reachable_arr[i] == 1): 305 | node_reachable_arr[node_idx]=1 306 | node_idx = self.tree_.children_right[i] 307 | if (node_idx >= 0) and (node_reachable_arr[i] == 1): 308 | node_reachable_arr[node_idx]=1 309 | 310 | for e in range(len(node_reachable_arr)-1, -1, -1): 311 | if node_reachable_arr[e] == 0: 312 | self.tree_.children_left.pop(e) 313 | self.tree_.children_right.pop(e) 314 | self.tree_.feature.pop(e) 315 | self.tree_.threshold.pop(e) 316 | self.tree_.indexes.pop(e) 317 | self.tree_.depths.pop(e) 318 | self.tree_.can_split.pop(e) 319 | self.tree_.node_measurement.pop(e) 320 | self.tree_.node_best_measurement_arr.pop(e) 321 | self.tree_.node_best_threshold_arr.pop(e) 322 | self.tree_.node_used_feature_arr.pop(e) 323 | self.tree_.node_used_threshold_arr.pop(e) 324 | self.pop_unused_array_elements(e) 325 | 326 | self.check_nodes_arrays() 327 | 328 | # Adjust the indexes of the child nodes 329 | num_popped_prev = [0]*len(node_reachable_arr) 330 | count_false = 0 331 | for i in range(len(node_reachable_arr)): 332 | num_popped_prev[i] = count_false 333 | if node_reachable_arr[i] == 0: 334 | count_false += 1 335 | for i in range(len(self.tree_.children_left)): 336 | left_child = self.tree_.children_left[i] 337 | if left_child >= 0: 338 | self.tree_.children_left[i] -= num_popped_prev[left_child] 339 | right_child = self.tree_.children_right[i] 340 | if right_child >= 0: 341 | self.tree_.children_right[i] -= num_popped_prev[right_child] 342 | 343 | def get_num_nodes(self): 344 | return len(self.tree_.feature) 345 | 346 | def get_model_complexity(self): 347 | """ 348 | Returns global complexity. Count each node as 1, except additive nodes, which count based on the number 349 | of aggregations. 350 | todo: also measure avg. local complexity 351 | """ 352 | complexity = self.get_num_nodes() 353 | for i, feature_idx in enumerate(self.tree_.feature): 354 | if feature_idx == ADDITIVE_NODE: 355 | complexity += len(self.tree_.node_used_feature_arr[i]) 356 | return complexity 357 | 358 | def _predict(self, predict_X, testing_additive=True): 359 | """ 360 | Find the prediction and decision paths for each record provided 361 | :param predict_X: matrix of values. Must match the matrix used for fitting. 362 | :param testing_additive: Used to set log level 363 | :return: array of predictions and array of decision paths. Both match the length of predict_X. 364 | """ 365 | 366 | header_log_level = 2 if (not testing_additive) else 5 367 | self.log(header_log_level, "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") 368 | self.log(2, "PREDICT") 369 | self.log(header_log_level, "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") 370 | 371 | def find_leaf(row): 372 | """ 373 | Find the leaf node for the specified row 374 | :param row: pandas series 375 | :return: node id of the leaf node and path as an array of node ids 376 | """ 377 | self.log(5, "row:", row) 378 | curr_node = 0 379 | path = [0] 380 | while self.tree_.feature[curr_node] >= 0: 381 | curr_feat = self.tree_.feature[curr_node] 382 | curr_feat_name = self.X.columns[curr_feat] 383 | curr_threshold = self.tree_.threshold[curr_node] 384 | self.log(5, "curr_node: ", curr_node, ", curr_feat: ", curr_feat) 385 | if row[curr_feat_name] >= curr_threshold: 386 | curr_node = self.tree_.children_right[curr_node] 387 | else: 388 | curr_node = self.tree_.children_left[curr_node] 389 | path.append(curr_node) 390 | return curr_node, path 391 | 392 | pred_arr = [] 393 | decision_path_arr = [] 394 | self.additive_votes = [""]*len(predict_X) 395 | row_num = 0 396 | for row_idx in range(len(predict_X)): 397 | row = predict_X.iloc[row_idx] 398 | leaf_idx, path = find_leaf(row) 399 | pred_arr.append(self.get_prediction(leaf_idx, row, row_idx)) # Call method in subclass 400 | decision_path_arr.append(path) 401 | row_num += 1 402 | return pred_arr, decision_path_arr 403 | 404 | def predict(self, predict_X, testing_additive=False): 405 | """ 406 | Predict for all rows in predict_X 407 | :param predict_X: matrix of values. Must match the format used for fitting. 408 | :param testing_additive: Used to set log level. 409 | :return: array of predictions. The length matches the length of predict_X. 410 | """ 411 | pred, _ = self._predict(predict_X, testing_additive) 412 | return pred 413 | 414 | # todo: do for regression 415 | def get_explanations(self, predict_X, y): 416 | y = pd.Series(y) 417 | pred, paths = self._predict(predict_X) 418 | 419 | preamble_str = f"Initial distribution of classes: {self.classes_}: {self.class_counts[0]}" 420 | explanations = [preamble_str] 421 | 422 | for path_idx, path in enumerate(paths): 423 | if y.iloc[path_idx] == pred[path_idx]: 424 | correct_indicator = "Correct" 425 | else: 426 | correct_indicator = f"Wrong. Correct target value: {y.iloc[path_idx]}" 427 | expl = "..............................................................." 428 | expl += f"\nPrediction for row {path_idx}: {pred[path_idx]} -- {correct_indicator}\n" 429 | expl += "..............................................................." 430 | expl += f"\nPath: {path}" 431 | for path_element_idx, node_idx in enumerate(path): 432 | col_idx = self.tree_.feature[node_idx] 433 | if col_idx == ADDITIVE_NODE: 434 | and_indicator = "\n" 435 | if node_idx > 0: 436 | and_indicator = "\n\nAND " 437 | expl += f"{and_indicator}vote based on: {self.additive_votes[path_idx]}" 438 | elif col_idx == -2: 439 | expl += f"\nwhere the majority class is: {self.get_majority_class(node_idx)}" 440 | else: 441 | col_name = self.X.columns[col_idx] 442 | and_indicator = "\n" 443 | if node_idx > 0: 444 | and_indicator = "\nAND " 445 | sign_indicator = "greater than" 446 | if path[path_element_idx+1] == self.tree_.children_left[node_idx]: 447 | sign_indicator = "less than" 448 | expl += f"\n{and_indicator}{self.X.columns[col_idx]} is {sign_indicator} " 449 | expl += f"{self.tree_.threshold[node_idx]} \n (has value: {predict_X.iloc[path_idx][col_name]})" 450 | next_node_idx = path[path_element_idx+1] 451 | expl += f" --> (Class distribution: {self.class_counts[next_node_idx]}" 452 | explanations.append(expl.strip()) 453 | 454 | return explanations 455 | 456 | def log(self, min_verbose_level, *log_str): 457 | if self.verbose_level >= min_verbose_level: 458 | m = "" 459 | for s in log_str: 460 | m += str(s) 461 | print(m) 462 | 463 | 464 | class AdditiveDecisionTreeClasssifier(AdditiveDecisionTree): 465 | def __init__( self, 466 | min_samples_split=8, 467 | min_samples_leaf=6, 468 | max_depth=np.inf, 469 | verbose_level=0, 470 | allow_additive_nodes=True, 471 | max_added_splits_per_node=5): 472 | super().__init__(min_samples_split, min_samples_leaf, max_depth, verbose_level, allow_additive_nodes, 473 | max_added_splits_per_node) 474 | 475 | # Summary information about the class distributions at each node 476 | self.classes_ = [] # The set of unique target classes in y. Used to maintain a consistent order. 477 | self.class_counts = [] # Each node contains an array giving the count of each class, in the order defined 478 | # by self.classes_ 479 | self.class_counts_arr = [[]] 480 | 481 | def init_variables(self, y): 482 | # Get the unique set of classes in y. Used to maintain a consistent order. 483 | self.classes_ = list(set(y)) 484 | 485 | def init_root(self, y): 486 | self.class_counts.append(self.get_class_counts_for_node(y)) 487 | 488 | def get_class_counts_for_node(self, local_y): 489 | counts_arr = pd.Series(local_y).value_counts() 490 | sorted_counts_arr = [] 491 | for c in self.classes_: 492 | if c in counts_arr: 493 | sorted_counts_arr.append(counts_arr[c]) 494 | else: 495 | sorted_counts_arr.append(0) 496 | return sorted_counts_arr 497 | 498 | def get_majority_class(self, leaf_idx): 499 | class_arr = self.class_counts[leaf_idx] 500 | class_idx = class_arr.index(max(class_arr)) 501 | return self.classes_[class_idx] 502 | 503 | def get_multiple_spits_majority_class(self, node_idx, row, row_idx): 504 | display_row = -1 505 | if row_idx == display_row: print("node_idx: ", node_idx) 506 | if row_idx == display_row: print("row: ", row) 507 | if row_idx == display_row: print("Features used: here: ", self.tree_.node_used_feature_arr[node_idx]) 508 | if row_idx == display_row: print("Thresholds used: here: ", self.tree_.node_used_threshold_arr[node_idx]) 509 | if row_idx == display_row: print("self.class_counts_arr: ", self.class_counts_arr) 510 | class_votes = [0]*len(self.classes_) 511 | additive_votes_str = "" 512 | for i in range(len(self.tree_.node_used_feature_arr[node_idx])): 513 | col_idx = self.tree_.node_used_feature_arr[node_idx][i] 514 | threshold = self.tree_.node_used_threshold_arr[node_idx][i] 515 | curr_feat_name = self.X.columns[col_idx] 516 | if row_idx == display_row: print("i:", i, ", curr_feat_name: ", curr_feat_name) 517 | if row_idx == display_row: print("self.class_counts_arr[node_idx]: ", self.class_counts_arr[node_idx]) 518 | if row[curr_feat_name] >= threshold: 519 | class_arr = self.class_counts_arr[node_idx][i][1] 520 | class_idx = class_arr.index(max(class_arr)) 521 | class_votes[class_idx] += 1 522 | if row_idx == display_row: print(curr_feat_name, " Over. class_idx: ", class_idx, " of: ", class_arr) 523 | additive_votes_str += f"\n {i+1}: {curr_feat_name} is greater than {threshold} \n (has value {row[curr_feat_name]}) " 524 | additive_votes_str += f" --> (class distribution: {self.class_counts_arr[node_idx][i][1]})" 525 | else: 526 | class_arr = self.class_counts_arr[node_idx][i][0] 527 | class_idx = class_arr.index(max(class_arr)) 528 | class_votes[class_idx] += 1 529 | if row_idx == display_row: print(curr_feat_name, " Under. class_idx: ", class_idx, " of: ", class_arr) 530 | additive_votes_str += f"\n {i+1}: {curr_feat_name} is less than {threshold}\n (has value {row[curr_feat_name]}) " 531 | additive_votes_str += f" --> (class distribution: {self.class_counts_arr[node_idx][i][0]})" 532 | if row_idx == display_row: print("class_votes: ", class_votes) 533 | highest_vote = np.argmax(class_votes) 534 | additive_votes_str += f"\nThe class with the most votes is {self.classes_[highest_vote]}" 535 | if row_idx == display_row: print("highest_vote: ", highest_vote) 536 | self.additive_votes[row_idx] = additive_votes_str 537 | return self.classes_[highest_vote] 538 | 539 | def check_node_complete(self, node_idx): 540 | if (0 in self.class_counts[node_idx]) and (self.class_counts[node_idx].count(0) == (len(self.classes_)-1)): 541 | self.tree_.feature[node_idx] = -2 542 | self.log(2, "Node has full purity. Making a leaf node.") 543 | return True 544 | 545 | def update_node_stats(self, new_left_idx, new_right_idx): 546 | # Set the class counts in the two child nodes # todo put in subclass 547 | y_left = self.y.loc[self.tree_.indexes[new_left_idx]] 548 | self.class_counts.append(self.get_class_counts_for_node(y_left)) 549 | y_right = self.y.loc[self.tree_.indexes[new_right_idx]] 550 | self.class_counts.append(self.get_class_counts_for_node(y_right)) 551 | 552 | def find_best_col(self, X_local, y_local, node_idx): 553 | best_col_idx = -1 554 | best_col_threshold = -1 555 | best_col_igr = -1 556 | igr_arr = [] 557 | threshold_arr = [] 558 | 559 | for col_idx in range(len(self.X.columns)): 560 | X_local = clean_data(X_local) 561 | stump = tree.DecisionTreeClassifier(random_state=0, max_depth=1) # Used to get thresholds 562 | stump.fit(X_local[[X_local.columns[col_idx]]].values.reshape(-1,1), y_local) 563 | threshold = stump.tree_.threshold[0] 564 | attribute_arr = np.where(X_local[X_local.columns[col_idx]] <= threshold, 0, 1) 565 | igr = info_gain.info_gain_ratio(attribute_arr, y_local) 566 | self.log(4, "node_idx:", node_idx, ", col_idx:", col_idx, ", igr: ", round(igr, 2), 567 | ", threshold:", round(threshold, 2)) 568 | igr_arr.append(igr) 569 | threshold_arr.append(threshold) 570 | if igr > best_col_igr: 571 | best_col_idx = col_idx 572 | best_col_threshold = threshold 573 | best_col_igr = igr 574 | 575 | good_split_found = True 576 | if (len(X_local) < 50) and (best_col_igr < 0.1): 577 | self.log(2, "Cannot split this node. (igr is too low given the number of rows) ") 578 | self.tree_.feature[node_idx] = -2 579 | self.tree_.leaf_explanation[node_idx] = "igr is too low given the number of rows." 580 | good_split_found = False 581 | 582 | if best_col_idx == -1: 583 | self.tree_.can_split[node_idx] = False 584 | good_split_found = False 585 | 586 | return best_col_idx, best_col_threshold, best_col_igr, igr_arr, threshold_arr, good_split_found 587 | 588 | def create_additive_nodes(self): 589 | 590 | # Potentially replace any non-leaf nodes with additive nodes. Doing this, we do not change the size 591 | # of the tree or the parallel arrays, though may leave some nodes unreachable. 592 | def check_node(node_index): 593 | used_col = self.tree_.feature[node_index] 594 | used_igr = self.tree_.node_measurement[node_index] 595 | good_cols = [] 596 | good_thresholds = [] 597 | for col_idx, igr_val in enumerate(self.tree_.node_best_measurement_arr[node_index]): 598 | if igr_val > 0.9 * used_igr or igr_val > 0.4: # todo: make hyperparameters 599 | good_cols.append(col_idx) 600 | good_thresholds.append(self.tree_.node_best_threshold_arr[node_index][col_idx]) 601 | if len(good_cols) > 1: 602 | # Get the training score given the current tree 603 | y_pred = self.predict(self.X, testing_additive=True) 604 | curr_train_score = f1_score(self.y, y_pred, average='macro') 605 | 606 | # Temporarily replace this node with an additive node 607 | self.tree_.feature[node_index] = ADDITIVE_NODE 608 | self.tree_.node_used_feature_arr[node_index] = good_cols 609 | self.tree_.node_used_threshold_arr[node_index] = good_thresholds 610 | 611 | # Set the class counts in the multiple splits of the data here 612 | X_local = self.X.loc[self.tree_.indexes[node_index]] 613 | y_local = self.y.loc[self.tree_.indexes[node_index]] 614 | assert len(X_local) == len(y_local) 615 | class_counts_arr = [] 616 | for i in range(len(good_cols)): 617 | col_idx = good_cols[i] 618 | threshold = good_thresholds[i] 619 | 620 | attribute_arr = np.where(X_local[self.X.columns[col_idx]]>=threshold, 1, 0) 621 | left_indexes = X_local.iloc[np.where(attribute_arr<=0)[0]].index 622 | right_indexes = X_local.iloc[np.where(attribute_arr>0)[0]].index 623 | 624 | y_left = self.y.loc[left_indexes] 625 | y_right = self.y.loc[right_indexes] 626 | class_counts_arr.append([self.get_class_counts_for_node(y_left), self.get_class_counts_for_node(y_right)]) 627 | 628 | self.class_counts_arr[node_index] = class_counts_arr 629 | 630 | # Determine the training score given an additive node here 631 | y_pred = self.predict(self.X, testing_additive=True) 632 | updated_train_score = f1_score(self.y, y_pred, average='macro') 633 | 634 | # If the additive node did not improve the accuracy, return the tree 635 | if updated_train_score < curr_train_score: 636 | self.tree_.feature[node_index] = used_col 637 | self.tree_.node_used_feature_arr[node_index] = [] 638 | self.tree_.node_used_threshold_arr[node_index] = [] 639 | else: 640 | self.tree_.children_left[node_index] = -2 641 | self.tree_.children_right[node_index] = -2 642 | 643 | # Remove once working. check put back right. 644 | #y_pred = self.predict(self.X) 645 | #curr_train_score = f1_score(self.y, y_pred, average='macro') 646 | #print(" curr_train_score restored: ", curr_train_score) 647 | 648 | self.class_counts_arr = [[]]*len(self.tree_.feature) 649 | 650 | # todo: recode to just go through backwards 651 | checked_nodes = [False]*len(self.tree_.children_left) 652 | for i in range(len(self.tree_.feature)): 653 | if self.tree_.feature[i] < 0: 654 | checked_nodes[i]=True 655 | count_unchecked = checked_nodes.count(False) 656 | while count_unchecked > 0: 657 | for i in range(len(checked_nodes)): 658 | left_child_idx = self.tree_.children_left[i] 659 | right_child_idx = self.tree_.children_right[i] 660 | if (not checked_nodes[i]) and checked_nodes[left_child_idx] and checked_nodes[right_child_idx]: 661 | check_node(i) 662 | checked_nodes[i]=True 663 | count_unchecked = checked_nodes.count(False) 664 | 665 | def get_prediction(self, leaf_idx, row, row_idx): 666 | if self.tree_.feature[leaf_idx] == ADDITIVE_NODE: 667 | return self.get_multiple_spits_majority_class(leaf_idx, row, row_idx) 668 | else: 669 | return self.get_majority_class(leaf_idx) 670 | 671 | def subclass_output_tree(self): 672 | print(f"\nClass counts:\n{self.class_counts}") 673 | print(f"\nLeaf Class Counts:\n{[bx for ax,bx in zip(self.tree_.feature,self.class_counts) if ax < 0]}") 674 | print(f"\nNode igr: \n{self.tree_.node_measurement}") 675 | 676 | def pop_unused_array_elements(self, e): 677 | self.class_counts.pop(e) 678 | self.class_counts_arr.pop(e) 679 | 680 | 681 | class AdditiveDecisionTreeRegressor(AdditiveDecisionTree): 682 | def __init__( self, 683 | min_samples_split=8, 684 | min_samples_leaf=6, 685 | max_depth=np.inf, 686 | verbose_level=0, 687 | allow_additive_nodes=True, 688 | max_added_splits_per_node=5): 689 | super().__init__(min_samples_split, min_samples_leaf, max_depth, verbose_level, allow_additive_nodes, max_added_splits_per_node) 690 | self.average_y_value = [] # The average y value at each node 691 | 692 | def init_variables(self, y): 693 | pass 694 | 695 | def init_root(self, y): 696 | self.average_y_value.append(self.y.mean()) 697 | 698 | # Applies only to classification. 699 | def check_node_complete(self, node_idx): 700 | return False # todo: in principle all rows may have identical y value, should check 701 | 702 | def update_node_stats(self, new_left_idx, new_right_idx): 703 | y_left = self.y.loc[self.tree_.indexes[new_left_idx]] 704 | self.average_y_value.append(y_left.mean()) 705 | y_right = self.y.loc[self.tree_.indexes[new_right_idx]] 706 | self.average_y_value.append(y_right.mean()) 707 | 708 | def find_best_col(self, X_local, y_local, node_idx): 709 | best_col_idx = -1 710 | best_col_threshold = -1 711 | best_col_mse_gain = 0.0 712 | mse_gain_arr = [] 713 | threshold_arr = [] 714 | 715 | mse_before = mean_squared_error(y_local, [y_local.mean()]*len(y_local)) 716 | for col_idx in range(len(self.X.columns)): 717 | X_local = clean_data(X_local) 718 | stump = tree.DecisionTreeRegressor(random_state=0, max_depth=1) # Used to get thresholds 719 | stump.fit(X_local[[X_local.columns[col_idx]]].values.reshape(-1,1), y_local) 720 | threshold = stump.tree_.threshold[0] 721 | over_threshold_bool_arr = X_local[X_local.columns[col_idx]]>threshold 722 | 723 | left_arr = [y_val for y_val,over_bool in zip(y_local, over_threshold_bool_arr) if over_bool == False] 724 | right_arr = [y_val for y_val,over_bool in zip(y_local, over_threshold_bool_arr) if over_bool == True] 725 | if len(left_arr)==0 or len(right_arr)==0: 726 | continue 727 | 728 | mean_left = np.array(left_arr).mean() 729 | mean_right = np.array(right_arr).mean() 730 | new_pred_y = [mean_left if over_bool==False else mean_right for over_bool in over_threshold_bool_arr] 731 | mse_after = mean_squared_error(y_local, new_pred_y) 732 | mse_gain = mse_before - mse_after 733 | mse_gain_arr.append(mse_gain) 734 | threshold_arr.append(threshold) 735 | if (mse_gain > best_col_mse_gain) and (len(left_arr) >= self.min_samples_leaf) and (len(right_arr) >= self.min_samples_leaf): 736 | best_col_idx = col_idx 737 | best_col_threshold = threshold 738 | best_col_mse_gain = mse_gain 739 | 740 | good_split_found = True 741 | if (len(X_local) < 50) and ((best_col_mse_gain/mse_before) < 0.01): 742 | # print("case 1") 743 | self.log(2, "Cannot split this node. (Gain in MSE is too low given the number of rows) ") 744 | self.tree_.feature[node_idx] = -2 745 | self.tree_.leaf_explanation[node_idx] = "Gain in MSE is too low given the number of rows." 746 | good_split_found = False 747 | 748 | if best_col_idx == -1: 749 | self.tree_.can_split[node_idx] = False 750 | good_split_found = False 751 | 752 | return best_col_idx, best_col_threshold, best_col_mse_gain, mse_gain_arr, threshold_arr, good_split_found 753 | 754 | def create_additive_nodes(self): 755 | # Potentially replace any non-leaf nodes with additive nodes. Doing this, we do not change the size 756 | # of the tree or the parallel arrays, though may leave some nodes unreachable. 757 | def check_node(node_index): 758 | used_col = self.tree_.feature[node_index] 759 | used_mse_gain = self.tree_.node_measurement[node_index] 760 | good_cols = [] 761 | good_thresholds = [] 762 | good_mse_gains = [] 763 | for col_idx, mse_gain_val in enumerate(self.tree_.node_best_measurement_arr[node_index]): 764 | if mse_gain_val > 0.0 and mse_gain_val > (0.9 * used_mse_gain): 765 | good_cols.append(col_idx) 766 | good_thresholds.append(self.tree_.node_best_threshold_arr[node_index][col_idx]) 767 | good_mse_gains.append(mse_gain_val) 768 | if len(good_cols) > self.max_added_splits_per_node: 769 | ind = np.argpartition(good_mse_gains, -self.max_added_splits_per_node)[-self.max_added_splits_per_node:] 770 | good_cols = np.array(good_cols)[ind].tolist() 771 | good_thresholds = np.array(good_thresholds)[ind].tolist() 772 | if not used_col in good_cols: 773 | good_cols.append(used_col) 774 | good_thresholds.append(self.tree_.threshold[node_index]) 775 | 776 | if (len(good_cols)>1): 777 | # Get the training score given the current tree. todo: just get for this part of the tree -- the rest of the tree is the same 778 | y_pred = self.predict(self.X, testing_additive=True) 779 | curr_train_mse = mean_squared_error(self.y, y_pred) 780 | 781 | # Set the average y value in the multiple splits of the data here 782 | X_local = self.X.loc[self.tree_.indexes[node_index]] 783 | y_local = self.y.loc[self.tree_.indexes[node_index]] 784 | assert len(X_local) == len(y_local) 785 | new_pred_y_arr = [] 786 | new_average_y_arr = [] 787 | for i in range(len(good_cols)): 788 | col_idx = good_cols[i] 789 | threshold = good_thresholds[i] 790 | over_threshold_bool_arr = X_local[X_local.columns[col_idx]]>threshold 791 | left_arr = [y_val for y_val,over_bool in zip(y_local, over_threshold_bool_arr) if over_bool == False] 792 | right_arr = [y_val for y_val,over_bool in zip(y_local, over_threshold_bool_arr) if over_bool == True] 793 | if len(left_arr) == 0 or len(right_arr)==0: 794 | new_average_y_arr.append((np.NaN, np.NaN)) 795 | else: 796 | mean_left = np.array(left_arr).mean() 797 | mean_right = np.array(right_arr).mean() 798 | new_pred_y = [mean_left if over_bool == False else mean_right for over_bool in over_threshold_bool_arr] 799 | new_pred_y_arr.append(new_pred_y) 800 | new_average_y_arr.append((mean_left, mean_right)) 801 | 802 | # Remove any elements with nan's 803 | for i in range(len(good_cols)-1, -1, -1): 804 | if new_average_y_arr[i][0] is np.NaN: 805 | good_cols.pop(i) 806 | good_thresholds.pop(i) 807 | new_average_y_arr.pop(i) 808 | 809 | if len(new_average_y_arr) == 0: 810 | return 811 | 812 | # Temporarily replace this node with an additive node 813 | self.tree_.feature[node_index] = ADDITIVE_NODE 814 | self.tree_.node_used_feature_arr[node_index] = good_cols 815 | self.tree_.node_used_threshold_arr[node_index] = good_thresholds 816 | self.average_y_value_arr[node_index] = new_average_y_arr 817 | 818 | final_pred = [] 819 | new_pred_y_arr_np = np.array(new_pred_y_arr) 820 | new_pred_y_arr_means = new_pred_y_arr_np.mean(axis=1) 821 | 822 | # Determine the training score given an additive node here 823 | y_pred = self.predict(self.X, testing_additive=True) 824 | updated_train_mse = mean_squared_error(self.y, y_pred) 825 | 826 | # If the additive node did not improve the accuracy, return the tree 827 | if updated_train_mse > curr_train_mse: 828 | self.tree_.feature[node_index] = used_col 829 | self.tree_.node_used_feature_arr[node_index] = [] 830 | self.tree_.node_used_threshold_arr[node_index] = [] 831 | self.average_y_value_arr[node_index] = [] 832 | else: 833 | self.tree_.children_left[node_index] = -2 834 | self.tree_.children_right[node_index] = -2 835 | 836 | self.average_y_value_arr = [[]]*len(self.tree_.feature) 837 | 838 | # todo: recode to just go through backwards 839 | checked_nodes = [False]*len(self.tree_.children_left) 840 | for i in range(len(self.tree_.feature)): 841 | if self.tree_.feature[i] < 0: 842 | checked_nodes[i]=True 843 | count_unchecked = checked_nodes.count(False) 844 | while count_unchecked > 0: 845 | for i in range(len(checked_nodes)): 846 | left_child_idx = self.tree_.children_left[i] 847 | right_child_idx = self.tree_.children_right[i] 848 | if (not checked_nodes[i]) and (checked_nodes[left_child_idx]) and (checked_nodes[right_child_idx]): 849 | check_node(i) 850 | checked_nodes[i] = True 851 | count_unchecked = checked_nodes.count(False) 852 | 853 | def get_multiple_spits_average_prediction(self, leaf_idx, row): 854 | num_features_in_node = len(self.tree_.node_used_feature_arr[leaf_idx]) 855 | average_prediction = 0.0 856 | num_features_used = 0 857 | for i in range(num_features_in_node): 858 | if row[self.tree_.node_used_feature_arr[leaf_idx][i]] <= self.tree_.node_used_threshold_arr[leaf_idx][i]: 859 | split_prediction = self.average_y_value_arr[leaf_idx][i][0] 860 | else: 861 | split_prediction = self.average_y_value_arr[leaf_idx][i][1] 862 | if not (split_prediction is np.NaN): 863 | average_prediction += split_prediction 864 | num_features_used += 1 865 | 866 | return average_prediction / num_features_used 867 | 868 | def get_prediction(self, leaf_idx, row, row_idx): 869 | if self.tree_.feature[leaf_idx] == ADDITIVE_NODE: 870 | return self.get_multiple_spits_average_prediction(leaf_idx, row) 871 | else: 872 | return self.average_y_value[leaf_idx] 873 | 874 | def subclass_output_tree(self): 875 | print(f"\nAverage target values:\n{self.average_y_value}") 876 | if hasattr(self, "average_y_value_arr"): 877 | print(f"\nAverage target values in additive nodes:{self.average_y_value_arr}") 878 | 879 | def pop_unused_array_elements(self, e): 880 | self.average_y_value.pop(e) 881 | self.average_y_value_arr.pop(e) 882 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdditiveDecisionTree 2 | 3 | ## Summary 4 | This tool provides an implementation of a decision tree, similar to a standard decision tree such as in sklearn, but utilizing an additive approach to fitting the data. Both AdditiveDecitionTreeClassifier and AdditiveDecisionTreeRegressor classes are provided. 5 | 6 | This tool provides, on the whole, comparable accuracy to standard decision trees, but in many cases provides greater accuracy and/or improved interpretability. As such, it can be a useful tool for generating interpretable models and may be considered a useful XAI tool. It is not intended to be competitive with approaches such as boosting or neural networks in terms of accuracy, but is simply a tool to generate interpretable models. It can often produce models comparable in accuracy to deeper standard decision trees, while having a lower overall complexity compared to these. 7 | 8 | For an overview, see Medium article: https://medium.com/towards-data-science/additive-decision-trees-85f2feda2223 9 | 10 | #### Limitations of Decision Trees 11 | This tool addresses some well-known limitations of decision trees, in particular their limited stability, their necessity to split based on fewer and fewer samples lower in the trees, repeated sub-trees, and their tendency to overfit if not restricted or pruned. These limitations are typically addressed by ensembling decision trees, either through bagging or boosting, which results in highly uninterpretable, though generally more accurate, models. Constructing oblivious trees (this is done, for example, within CatBoost) and oblique decision trees (see: [RotationFeatures](https://github.com/Brett-Kennedy/RotationFeatures)) aare other approaches to mitigate some of these limitations, and also often produces more stable trees. 12 | 13 | Decision trees are considered to be among the more interpretable models, but only where it is possible to construct them to a shallow depth, perhaps to 4 or 5 levels at most. However, decision trees often need to be fairly deep to acheive higher levels of accuracy, which can greatly undermine their interpretability. As decision trees are likely the most, or among the most, commonly used models where interpretability is required, our comparisions, both in terms of accuracy and interpretability, are made with respect to standard decision trees. 14 | 15 | #### Intuition Behind Additive Decision Trees 16 | The intuition behind AdditiveDecisionTrees in that often the true function *(f(x))*, mapping the input x to the target y, is based on logical conditions; and in other cases it is simply a probabalistic function where each input feature may be considered somewhat independently (as with the Naive Bayes assumption). For example, the true f(x) may include something to the effect: 17 | ``` 18 | If A > 10 Then: y = class Y 19 | Else if B < 19 Then: y = class X 20 | Else if C * D > 44 Then: y = class Y 21 | Else y = class Z. 22 | ``` 23 | In this case, the true f(x) is composed of logical conditions and may be accurately (and in a simple manner) represented as a series of rules, such as in a 24 | Decision Tree, Rule List, or Rule Set. Note that conditions may be viewed as interactions, where how one feature predicts the target is depedent on the value of another columns. Here, one rule is based explicitely on the interaction C * D, but all rules entail interactions, as they may fire only if previous rules do not, and therefore the relationships between the features used in these rules is effected by other features. 25 | 26 | ON the other hand, the true f(x) may be a set of patterns related to probabilities, more of the form: 27 | ``` 28 | The higher A is, the more likely y is to be class X, regardless of B, C and D 29 | The higher B is, up to 100.0, the more likely y is class Y, regardless of A, C and D 30 | The higher B is, where B is 100.0 or more, the more likely y is to be class Z, regardless of A, C and D 31 | The higher C * D is, the more likely y is class X, regardless of A and B. 32 | ``` 33 | Some of these patterns may involve two or more features, and some a single feature. In this form of function, for each instance, the feature values (or combinations of feature values), each contribute some probability to the target value (to each class in the case of classification), and these probabilities are summed to determine the overall probability distribution. Here feature interactions may exist within the probabalistic patterns, as in the case of C * D. 34 | 35 | While there are other means to taxonify functions, this system is quite useful, and many true functions may be viewed as some combination of these two broad classes of function. 36 | 37 | Standard decision trees do not explicitely assume the true function is conditional, and can accurately (often through the use of very large trees) capture non-conditional relationships such as those based on probabilities. They do, however, model the functions as conditions, which can limit their expressive power and lower their interpretability. 38 | 39 | Though f(x) is ultimately a set of rules for any classifiction problem, the rules may be largely independent of each other, each simply a probability distribution based on one or a small number of features. 40 | 41 | AdditiveDecisionTrees remove the assumption in standard decision trees that f(x) may be best modeled as a set of conditions, but does support conditions where the data suggests they exist. The central idea is that the true f(x) may be based on logical conditions, probabilities, or some combination of these. 42 | 43 | The case where f(x) is based on the features independently (each feature's relationship to the target is not based on any other feature) may be modelled better by linear or logistic regression, Naive Bayes models, or GAM (Generalized Additive Model), among other models, which simply predict based on a weighted sum of each independent feature. That is, each relevant feature contributes to the final prediction without consideration of the other features (though interaction features may be created). f(x), in this case, is simply a probability distribution associated with each input feature. In these cases, linear regression, logistic regression, Naive Bayes, and GAMs can be quite interpretable and may be suitable choices for XAI. 44 | 45 | Conversely, linear and logistic regressions do not capture well where there are strong conditions in the function f(x), while decision trees can model these, at least potentially, quite closely. It is usually not know a priori if the true f(x) contains strong conditions and, as such, if it is desirable model the function as a decsion tree does: to repeatedly split the data into subsets and develop a different prediction for each leaf node based entirely on the datapoints within it. 46 | 47 | #### Splitting Policy 48 | We describe here how Additive Decision Trees are constructed, and particularly their splitting policy. Note, the process is simpler to present for classification problems, and so most examples relate to classification, but the ideas apply equally to regression. 49 | 50 | The approach taken by AdditiveDecisionTrees is to split the dataspace where appropriate and to make an aggregate decision based on numerous potential splits (all standard axis-parallel splits over different input parameters) where this appears most appropriate. This is done such that the splits appear higher in the tree, where there are larger numbers of samples to base the splits on and they may be found in a more reliable manner, while lower in the tree, where there are less samples to rely on, the decisions are based on a collection of splits, each using the full set of samples in that subset. 51 | 52 | This provides for straight-forward explanations for each row (known as *local* explanations) and for the models as a whole (known as *global* explanations). Though the final trees may be somewhat more complex than an standard decision tree of equal depth, as some nodes may be based on multiple splits, Additive Decision Trees are more accurate than standard decision trees of equal depth, and simpler than standard decision trees of equal accuracy. The explanations for individual rows may be presented simply through the corresponding decision paths, as with standard decision trees, but the final nodes may be based on averaging over multiple splits. The maximum number of splits aggregated together is configurable, but 4 or 5 is typically sufficient. In most cases, as well, all splits agree, and only one needs to be presented to the user. And in fact, even where the splits disagree, the majority prediction may be presented as a single split. Therefore, the explanations are usually similar as those for standard decision trees, but with shorter decision paths. 53 | 54 | This, then, produces a model where there are a small number of splits, ideally representing the true conditions, if any, in the model, followed by *additive nodes*, which are leaf nodes that average the predictions of multiple splits, providing more robust predictions. This reduces the need to split the data into progressively smaller subsets, each with less statistical significance. 55 | 56 | AdditiveDecisionTrees, therefore, provide a simple form of ensembling, but one that still allows a single, interpretable model, easily supporting both global and local explanations. As it still follows a simple tree structure, contrapositive explanations may be easily generated as well. 57 | 58 | ## Intallation 59 | 60 | The source code is provided in a single .py file which may be included in any project. It uses no non-standard libraries. 61 | 62 | ## Local Interpretability 63 | 64 | In standard decision trees, local explanations (explanations of the prediction for a single instance) are presented as the path from the root to the leaf node where the instance ends, with each split point on the path leading to this final decision. However, this can be misleading and confusing, as very often multiple rules would be just as or more appropriate. As standard decision trees split each node based on a single feature, they select the split that has the greatest information gain. The fact that other splits using other features are nearly as useful and lead to similar, though less, information gain is lost, creating the impression that only the selected feature is relevant. So, as well as less stable trees, this process can lead to lower interpretability, which is removed with Additive Decision Trees. 65 | 66 | ## Pruning Algorithm 67 | The pruning algorihm executes after a tree, similar to a standard decision tree is constructed. The prunnig algorithm seeks to reduce significant sub-trees within the tree into single additive nodes, based on a small set of simple rules (comparable to the rule used in standard decsision trees, but such that the addititive nodes use multiple such rules). 68 | 69 | The algorithm behaves similarly to most pruning algorithms, starting at the bottom, at the leaves, and working towards the root node. At each node, a decision is made to either leave the node as is, or convert it to an additive node, that is, a node combining multiple data splits. 70 | 71 | At each node, the accuracy of the tree is evaluated on the training data given the current split, then again treating this node as an additive node. If the accuracy is higher with this node set as an additive node, it is set as such, and all nodes below it removed. This node itself may be later removed, if node above it is converted to an additive node. Testing indicates a very significant proportion of sub-trees benefit from being aggregated in this way. 72 | 73 | ## Inference at Additive Nodes 74 | An additive node is a terminal node where the predictions are based on multiple splits. To make a prediction, when reaching an additive node, a prediction based on each split is made, then these are aggregated to create a single prediction. Multiple aggregation schemes are available for classification and regression. 75 | 76 | Standard, non-additive leaf nodes behave as in any other decision tree, producing an classification estimate based on the majority class in the node's subspace, or a regression estimate based on the the average value in the node's subspace. 77 | 78 | ## Evaluation Metrics 79 | To evaluate the effectiveness of the tool we consdidered both accuracy (macro f1-score for classification and normalized root mean squared error (NRMSE) for regression) and interpretability, measured by the size of the tree. Details regarding the complexity metric are included below. 80 | 81 | To evaluate, we compared to standard decision trees, both comparing where both models used default hyperparameters, and where both models used a grid search to estimate the best parameters. [DatasetsEvaluator](https://github.com/Brett-Kennedy/DatasetsEvaluator) was used to collect the datasests and run the tests. This allowed evaluating on a large number of datasets (100 were used) without bias, as the datasets were randomly selected from OpenML. 82 | 83 | Results for classification on 100 datasets: 84 | 85 | | Model | Avg f1_macro | Avg. Train-Test Gap | Avg. Fit Time Avg. | Complexity | 86 | | ---- | ---- | ---- | ---- | ---- | 87 | | DT | 0.634 | 0.359 | 0.0172 | 251.893 | 88 | | ADT | 0.617 | 0.156 | 3.991 | 39.907 | 89 | 90 | Here 'ADT' refers to Additive Decision Trees. The Train-Test Gap was found subtracting the F1 macro score on test set from that on the train set, and is used to estimate overfitting. ADT models suffered considerably less from over-fitting. 91 | 92 | AdditiveTrees did very similar to standard decision trees with respect to accuracy, though often do better. The complextity is, however, considerably lower. This allows users to understand the models considering fewer overall rules. 93 | 94 | Results over 100 Classification sets: 95 | ![Result Plot](https://github.com/Brett-Kennedy/AdditiveDecisionTree/blob/main/Results/results_24_05_2024_10_17_47_plot.png) 96 | 97 | The first plot tracks the 100 datasets on the x-axis, with F1 score (macro) on y-axis. The second tracks the same 100 datasets on the x-axis, and model complexity on the y-axis. 98 | 99 | This shows in, the first plot, model accuracy (higher is better) and, in the second plot, model complexity (lower is better). It can be seen here that, compared to standard decision trees, at least for the 100 random files tested, AdditiveDecisionTrees are competitive in terms of accuracy, and consistently better in terms of complexity (and thus interpretability), though altenative measures of model complexity could be used. 100 | 101 | ## Examples 102 | 103 | AdditiveDecisionTrees follow the standard sklearn fit-predict API framework. 104 | 105 | ```python 106 | from sklearn.datasets import load_iris 107 | from sklearn.model_selection import train_test_split 108 | from AdditiveDecisionTree import AdditiveDecisionTreeClasssifier 109 | 110 | iris = load_iris() 111 | X, y = iris.data, iris.target 112 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 113 | 114 | adt = AdditiveDecisionTreeClasssifier() 115 | adt.fit(X_train, y_train) 116 | y_pred_test = adt.predict(X_test) 117 | ``` 118 | 119 | ## Example Files 120 | Two example files are provided. 121 | 122 | [**Simple_Example_Additive_Decision_Tree**](https://github.com/Brett-Kennedy/AdditiveDecisionTree/blob/main/examples/Simple_Example_Additive_Decision_Tree.ipynb) 123 | 124 | This is a notebook providing some simple examples using the model. 125 | 126 | [**Accuracy_Test_Additive_Tree.py**](https://github.com/Brett-Kennedy/AdditiveDecisionTree/blob/main/examples/Accuracy_Test_Additive_Decision_Tree.py) 127 | 128 | This is a python file indended to test the accuracy and model complexity of the AdditiveDecisionTrees compared to sklearn Decision Trees, evaluated over 100 datasets, for both classification and regression problems. To provide a fair comparison, tests are performed where both models use default parameters and where both use CV grid search to estimate the optinal parameters. Results for an execution of this file are included in the Results folder. 129 | 130 | ## Methods 131 | 132 | ### AdditiveDecsionTree 133 | 134 | ``` 135 | adt = AdditiveDecisionTreeClasssifier(min_samples_split=8, 136 | min_samples_leaf=6, 137 | max_depth=np.inf, 138 | verbose_level=0, 139 | allow_additive_nodes=True, 140 | max_added_splits_per_node=5) 141 | ``` 142 | 143 | #### Parameters 144 | 145 | **min_samples_split**: int 146 | 147 | As with standard decision trees. The minimum number of samples in a node to allow further splitting. 148 | 149 | **min_samples_leaf**: int 150 | 151 | As with standard decision trees. The minimum number of samples in that would result in a leaf node to allow further splitting. 152 | 153 | **max_depth**: int 154 | 155 | As with standard decision trees. The maximum path length from the root to any leaf node. 156 | 157 | **verbose_level**: int 158 | 159 | Controls the display of output during fitting and predicting. 160 | 161 | **allow_additive_nodes**: bool 162 | 163 | If False, behaves similar to standard decision tree. 164 | 165 | **max_added_splits_per_node**: int 166 | 167 | The maximimum number of splits that may be included in any additive node. 168 | 169 | ## 170 | 171 | ### fit 172 | ``` 173 | adt.fit(X, y) 174 | ``` 175 | Fit an Additive Decision Tree to the training data provided. 176 | 177 | #### Parameters 178 | **X**: 2d array-like of shape (n_samples, n_features) 179 | 180 | **y**: array-like of shape (n_samples) 181 | 182 | ## 183 | 184 | ### predict 185 | ``` 186 | y_pred = adt.predict(X) 187 | ``` 188 | 189 | Predict the class labels for the provided data. 190 | 191 | ## 192 | 193 | ### output_tree 194 | ``` 195 | adt.output_tree() 196 | ``` 197 | 198 | Outputs a summary of the full tree including all nodes 199 | 200 | ## 201 | 202 | ### get_model_complexity 203 | ``` 204 | adt.get_model_complexity() 205 | ``` 206 | 207 | Outputs a score indicating the number of nodes, counting additive nodes based on the number of splits they aggregate. 208 | 209 | 210 | ## 211 | 212 | ### get_explanations 213 | ``` 214 | adt.get_explanations(X, y) 215 | ``` 216 | 217 | Outputs the decision path related to all records included in X. 218 | 219 | 220 | ## Interpretability Metric 221 | The evaluation uses a straightforward approach to measuring the global complexity of models, that is the overall-description of the model (as opposed to local complexity which measures the complexity of explanations for individual rows). For standard decision trees, it simply uses the number of nodes (a common metric, though others are commonly used, for example number of leaf nodes). For additive trees, we do this as well, but for each additive node, count it as many times as there are splits aggregated together at this node. We, therefore, measure the total number of comparisons of feature values to thresholds (the number of splits) regardless if the results are aggregated or not. Future work will consider additional metrics. 222 | -------------------------------------------------------------------------------- /Results/results_18_08_2021_14_26_34_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brett-Kennedy/AdditiveDecisionTree/e637d7e195ea45c4f9a18329c4de3b3870f709e9/Results/results_18_08_2021_14_26_34_heatmap.png -------------------------------------------------------------------------------- /Results/results_18_08_2021_14_26_34_summarized.csv: -------------------------------------------------------------------------------- 1 | Model,Feature Engineering Description,Avg f1_macro,Avg. Train-Test Gap,Avg. Fit Time,Avg. Complexity 2 | ADT,,0.617875525955198,0.15660414872125247,3.9917161989212038,39.906666666666666 3 | DT,,0.6344495042480763,0.3597553756012678,0.008185651302337647,251.89333333333332 4 | -------------------------------------------------------------------------------- /Results/results_24_05_2024_10_17_47_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brett-Kennedy/AdditiveDecisionTree/e637d7e195ea45c4f9a18329c4de3b3870f709e9/Results/results_24_05_2024_10_17_47_plot.png -------------------------------------------------------------------------------- /examples/Accuracy_Test_Additive_Decision_Tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import tree 3 | from warnings import filterwarnings 4 | 5 | import sys 6 | sys.path.insert(0, 'C:\python_projects\AdditiveDecisionTree_project\AdditiveDecisionTree') 7 | from AdditiveDecisionTree import AdditiveDecisionTreeClasssifier, AdditiveDecisionTreeRegressor 8 | 9 | sys.path.insert(0, 'C:\python_projects\DatasetsEvaluator_project\DatasetsEvaluator') 10 | import DatasetsEvaluator as de 11 | 12 | filterwarnings('ignore') 13 | np.random.seed(0) 14 | 15 | 16 | # These specify how many datasets are used in the tests below. Ideally about 50 to 100 datasets would be used, 17 | # but these may be set lower. Set to 0 to skip tests. 18 | NUM_DATASETS_CLASSIFICATION_DEFAULT = 100 19 | NUM_DATASETS_CLASSIFICATION_GRID_SEARCH = 0 20 | NUM_DATASETS_REGRESSION_DEFAULT = 0 21 | NUM_DATASETS_REGRESSION_GRID_SEARCH = 0 22 | 23 | 24 | def print_header(test_name): 25 | stars = "*****************************************************" 26 | print(f"\n\n{stars}\n{test_name}\n{stars}") 27 | 28 | 29 | def test_classification_default_parameters(datasets_tester, partial_result_folder, results_folder): 30 | print_header("Classification with default parameters") 31 | 32 | dt = tree.DecisionTreeClassifier(random_state=0) 33 | adt = AdditiveDecisionTreeClasssifier(allow_additive_nodes=True) 34 | 35 | summary_df, saved_file_name = datasets_tester.run_tests( 36 | estimators_arr=[ 37 | ("DT", "", "Default", dt), 38 | ("ADT", "", "with additive", adt), 39 | ], 40 | num_cv_folds=3, 41 | show_warnings=False, 42 | partial_result_folder=partial_result_folder, 43 | results_folder=results_folder, 44 | run_parallel=True) 45 | 46 | datasets_tester.summarize_results(summary_df, 'Avg f1_macro', saved_file_name, results_folder) 47 | datasets_tester.plot_results(summary_df, 'Avg f1_macro', saved_file_name, results_folder) 48 | 49 | 50 | def test_classification_grid_search(datasets_tester, partial_result_folder, results_folder): 51 | # As this takes much longer than testing with the default parameters, we test with fewer datasets. Note though, 52 | # run_tests_grid_search() uses CV to evaluate the grid search for the best hyperparameters, it does a train-test 53 | # split on the data for evaluation, so evaluates the predictions quickly, though with more variability than if 54 | # using CV to evaluate as well. 55 | 56 | print_header("Classification with grid search for best parameters") 57 | 58 | dt = tree.DecisionTreeClassifier(random_state=0) 59 | adt = AdditiveDecisionTreeClasssifier(allow_additive_nodes=True) 60 | 61 | orig_parameters = { 62 | 'max_depth': (3, 4, 5, 6) 63 | } 64 | adt_parameters = { 65 | 'max_depth': (3, 4, 5, 6) 66 | } 67 | 68 | # This provides an example using some non-default parameters. 69 | summary_df, saved_file_name = datasets_tester.run_tests_parameter_search( 70 | estimators_arr = [ 71 | ("DT", "Original Features", "", dt), 72 | ("ADT", "Rotation-based Features", "", adt) 73 | ], 74 | parameters_arr=[orig_parameters, adt_parameters], 75 | num_cv_folds=5, 76 | show_warnings=False, 77 | results_folder=results_folder, 78 | partial_result_folder=partial_result_folder, 79 | run_parallel=True) 80 | 81 | datasets_tester.summarize_results(summary_df, 'f1_macro', saved_file_name, results_folder) 82 | datasets_tester.plot_results(summary_df, 'f1_macro', saved_file_name, results_folder) 83 | 84 | 85 | def test_regression_default_parameters(datasets_tester, partial_result_folder, results_folder): 86 | print_header("Regression with default parameters") 87 | 88 | dt = tree.DecisionTreeRegressor(random_state=0) 89 | adt = AdditiveDecisionTreeRegressor(allow_additive_nodes=True) 90 | 91 | summary_df, saved_file_name = datasets_tester.run_tests( 92 | estimators_arr=[ 93 | ("DT", "Original Features", "Default", dt), 94 | ("DT", "Rotation-based Features", "Default", adt)], 95 | num_cv_folds=3, 96 | show_warnings=True, 97 | results_folder=results_folder, 98 | partial_result_folder=partial_result_folder, 99 | run_parallel=True) 100 | 101 | datasets_tester.summarize_results(summary_df, 'Avg NRMSE', saved_file_name, results_folder) 102 | datasets_tester.plot_results(summary_df, 'Avg NRMSE', saved_file_name, results_folder) 103 | 104 | 105 | def test_regression_grid_search(datasets_tester, partial_result_folder, results_folder): 106 | # As this takes much longer than testing with the default parameters, we test with fewer datasets. Note though, 107 | # run_tests_grid_search() uses CV to evaluate the grid search for the best hyperparameters, it does a train-test 108 | # split on the data for evaluation, so evaluates the predictions quickly, though with more variability than if 109 | # using CV to evaluate as well. 110 | 111 | print_header("Regression with grid search for best parameters") 112 | 113 | dt = tree.DecisionTreeRegressor(random_state=0) 114 | adt = AdditiveDecisionTreeRegressor(allow_additive_nodes=True) 115 | 116 | orig_parameters = { 117 | 'max_depth': (3, 4, 5, 6, 100) 118 | } 119 | 120 | adt_parameters = { 121 | 'max_depth': (3, 4, 5, 6, 100) 122 | } 123 | 124 | # This provides an example using some non-default parameters. 125 | summary_df, saved_file_name = datasets_tester.run_tests_parameter_search( 126 | estimators_arr = [ 127 | ("DT", "Original Features", "", dt), 128 | ("ADT", "Rotation-based Features", "", adt)], 129 | parameters_arr=[orig_parameters, adt_parameters], 130 | num_cv_folds=5, 131 | show_warnings=False, 132 | partial_result_folder=partial_result_folder, 133 | results_folder=results_folder, 134 | run_parallel=True) 135 | 136 | datasets_tester.summarize_results(summary_df, 'NRMSE', saved_file_name, results_folder) 137 | datasets_tester.plot_results(summary_df, 'NRMSE', saved_file_name, results_folder) 138 | 139 | 140 | def main(): 141 | cache_folder = "c:\\dataset_cache" 142 | partial_result_folder = "c:\\intermediate_results" 143 | results_folder = "c:\\results" 144 | 145 | # These are a bit slower, so excluded from some tests 146 | exclude_list = ["oil_spill", "fri_c4_1000_50", "fri_c3_1000_50", "fri_c1_1000_50", "fri_c2_1000_50", "waveform-5000", 147 | "mfeat-zernikemfeat-zernike", "auml_eml_1_b"] 148 | 149 | ######################################################################### 150 | # Get datasets for classification tests 151 | ######################################################################### 152 | datasets_tester = de.DatasetsTester( 153 | problem_type="classification", 154 | path_local_cache=cache_folder 155 | ) 156 | matching_datasets = datasets_tester.find_datasets( 157 | min_num_classes=2, 158 | max_num_classes=20, 159 | min_num_minority_class=5, 160 | max_num_minority_class=np.inf, 161 | min_num_features=0, 162 | max_num_features=np.inf, 163 | min_num_instances=500, 164 | max_num_instances=5_000, 165 | min_num_numeric_features=2, 166 | max_num_numeric_features=50, 167 | min_num_categorical_features=0, 168 | max_num_categorical_features=50) 169 | print("Number matching datasets found: ", len(matching_datasets)) 170 | # Note: some datasets may have errors loading or testing. 171 | datasets_tester.collect_data( 172 | max_num_datasets_used=NUM_DATASETS_CLASSIFICATION_DEFAULT, 173 | use_automatic_exclude_list=True, 174 | exclude_list=exclude_list, 175 | save_local_cache=True, 176 | check_local_cache=True) 177 | 178 | test_classification_default_parameters(datasets_tester, partial_result_folder, results_folder) 179 | 180 | datasets_tester.collect_data( 181 | max_num_datasets_used=NUM_DATASETS_CLASSIFICATION_GRID_SEARCH, 182 | use_automatic_exclude_list=True, 183 | exclude_list=exclude_list, 184 | save_local_cache=True, 185 | check_local_cache=True) 186 | 187 | test_classification_grid_search(datasets_tester, partial_result_folder, results_folder) 188 | 189 | ######################################################################### 190 | # Collect & test with the regression datasets 191 | ######################################################################### 192 | datasets_tester = de.DatasetsTester( 193 | problem_type="regression", 194 | path_local_cache=cache_folder 195 | ) 196 | matching_datasets = datasets_tester.find_datasets( 197 | min_num_features=0, 198 | max_num_features=np.inf, 199 | min_num_instances=500, 200 | max_num_instances=5_000, 201 | min_num_numeric_features=2, 202 | max_num_numeric_features=50, 203 | min_num_categorical_features=0, 204 | max_num_categorical_features=50) 205 | datasets_tester.collect_data( 206 | max_num_datasets_used=NUM_DATASETS_REGRESSION_DEFAULT, 207 | exclude_list=exclude_list, 208 | use_automatic_exclude_list=True, 209 | preview_data=False, 210 | save_local_cache=True, 211 | check_local_cache=True) 212 | 213 | test_regression_default_parameters(datasets_tester, partial_result_folder, results_folder) 214 | 215 | datasets_tester.collect_data( 216 | max_num_datasets_used=NUM_DATASETS_REGRESSION_GRID_SEARCH, 217 | exclude_list=exclude_list, 218 | use_automatic_exclude_list=True, 219 | preview_data=False, 220 | save_local_cache=True, 221 | check_local_cache=True) 222 | 223 | test_regression_grid_search(datasets_tester, partial_result_folder, results_folder) 224 | 225 | 226 | if __name__ == "__main__": 227 | main() 228 | -------------------------------------------------------------------------------- /examples/Simple_Example_Additive_Decision_Tree.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "7283a3a8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "from sklearn.datasets import load_iris, load_breast_cancer, load_wine, load_diabetes, make_regression \n", 13 | "from sklearn import tree\n", 14 | "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n", 15 | "from sklearn.metrics import f1_score, mean_squared_error\n", 16 | "\n", 17 | "# If AdditiveDecisionTree.py is not in the current folder, specify the path \n", 18 | "import sys \n", 19 | "sys.path.insert(0, 'C:\\python_projects\\AdditiveDecisionTree_project\\AdditiveDecisionTree') \n", 20 | "from AdditiveDecisionTree import AdditiveDecisionTreeClasssifier, AdditiveDecisionTreeRegressor\n", 21 | "\n", 22 | "np.random.seed(0)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "74761305", 28 | "metadata": {}, 29 | "source": [ 30 | "## Methods used to load the toy datasets" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "80859ea5", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# Classification datasets \n", 41 | "\n", 42 | "def get_iris():\n", 43 | " iris = load_iris()\n", 44 | " X, y = iris.data, iris.target\n", 45 | " X = pd.DataFrame(X, columns=iris['feature_names'])\n", 46 | " y = pd.Series(y)\n", 47 | " return X, y\n", 48 | "\n", 49 | "def get_breast_cancer():\n", 50 | " X, y = load_breast_cancer(return_X_y=True, as_frame=True)\n", 51 | " return X, y\n", 52 | "\n", 53 | "def get_wine():\n", 54 | " X, y = load_wine(return_X_y=True, as_frame=True)\n", 55 | " return X, y\n", 56 | "\n", 57 | "# Regression datasets\n", 58 | "\n", 59 | "def get_diabetes():\n", 60 | " data = load_diabetes()\n", 61 | " X = pd.DataFrame(data.data, columns=data.feature_names)\n", 62 | " y = pd.Series(data.target)\n", 63 | " return X, y\n", 64 | "\n", 65 | "def get_make_regression():\n", 66 | " np.random.seed(0)\n", 67 | " X, y = make_regression(noise=0.0)\n", 68 | " X = pd.DataFrame(X)\n", 69 | " y = pd.Series(y)\n", 70 | " return X, y" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "f84fac78", 76 | "metadata": {}, 77 | "source": [ 78 | "## Example using sklearn's Decision Tree and AddtiveDecisionTree on toy datasets" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "898c2ca6", 85 | "metadata": { 86 | "scrolled": false 87 | }, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\n", 94 | "Iris\n", 95 | "Standard DT: Training score: 1.0, Testing score: 0.97, Complexity: 13\n", 96 | "Additive DT: Training score: 0.96, Testing score: 0.88, Complexity: 5\n", 97 | "\n", 98 | "Wine\n", 99 | "Standard DT: Training score: 1.0, Testing score: 0.92, Complexity: 13\n", 100 | "Additive DT: Training score: 0.97, Testing score: 0.95, Complexity: 7\n", 101 | "\n", 102 | "Breast Cancer\n", 103 | "Standard DT: Training score: 0.99, Testing score: 0.92, Complexity: 23\n", 104 | "Additive DT: Training score: 0.97, Testing score: 0.91, Complexity: 11\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# Note: this provides only an example of using AdditiveDecisionTree and does not \n", 110 | "# properly test its accuracy. We can, though, see that in terms of test scores,\n", 111 | "# ADT (Additive Decision Trees) often do about the same as DT (standard Decsion\n", 112 | "# Trees), but sometimes one or the other does better. \n", 113 | "# Training scores are also show to give a sense of overfitting.\n", 114 | "\n", 115 | "# To estimate complexity for DTs, we use the number of nodes\n", 116 | "# To estimate complexity for ADTs, we call get_model_complexity(),\n", 117 | "# which is similar, but considers that additive nodes are more complex.\n", 118 | "\n", 119 | "def evaluate_model(clf, clf_desc, X_train, X_test, y_train, y_test):\n", 120 | " clf.fit(X_train, y_train)\n", 121 | " y_pred_train = clf.predict(X_train)\n", 122 | " score_train = f1_score(y_train, y_pred_train, average='macro')\n", 123 | " y_pred_test = clf.predict(X_test)\n", 124 | " score_test = f1_score(y_test, y_pred_test, average='macro')\n", 125 | " complexity = 0\n", 126 | " if hasattr(clf, \"get_model_complexity\"):\n", 127 | " complexity = clf.get_model_complexity()\n", 128 | " elif hasattr(clf, \"tree_\"):\n", 129 | " complexity = len(clf.tree_.feature)\n", 130 | " print(f\"{clf_desc}: Training score: {round(score_train,2)}, Testing score: {round(score_test,2)}, Complexity: {complexity}\")\n", 131 | "\n", 132 | " \n", 133 | "def evaluate_dataset(dataset_name, X,y):\n", 134 | " print(f\"\\n{dataset_name}\")\n", 135 | " X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n", 136 | "\n", 137 | " dt_1 = tree.DecisionTreeClassifier(max_depth=4, random_state=42)\n", 138 | " evaluate_model(dt_1, \"Standard DT\", X_train, X_test, y_train, y_test)\n", 139 | "\n", 140 | " adt = AdditiveDecisionTreeClasssifier(max_depth=4, allow_additive_nodes=True, verbose_level=0)\n", 141 | " evaluate_model(adt, \"Additive DT\", X_train, X_test, y_train, y_test)\n", 142 | " return adt\n", 143 | " \n", 144 | " \n", 145 | "X,y = get_iris()\n", 146 | "evaluate_dataset(\"Iris\", X,y)\n", 147 | "\n", 148 | "X,y = get_wine()\n", 149 | "evaluate_dataset(\"Wine\", X,y)\n", 150 | "\n", 151 | "X,y = get_breast_cancer()\n", 152 | "adt = evaluate_dataset(\"Breast Cancer\", X,y)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "06e883f6", 158 | "metadata": {}, 159 | "source": [ 160 | "## Summary Output of the AdditiveDecisionTree" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 4, 166 | "id": "575195ff", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "\n", 174 | "********************************************************\n", 175 | "Generated Tree\n", 176 | "********************************************************\n", 177 | "\n", 178 | "# Nodes: 9\n", 179 | "\n", 180 | "Left Chidren:\n", 181 | "[1, 3, 5, -2, -2, 7, -2, -2, -2]\n", 182 | "\n", 183 | "Right Chidren:\n", 184 | "[2, 4, 6, -2, -2, 8, -2, -2, -2]\n", 185 | "\n", 186 | "# Rows: \n", 187 | "[426, 260, 166, 252, 8, 30, 136, 14, 16]\n", 188 | "\n", 189 | "Features:\n", 190 | "[7, 20, 23, -100, -2, 21, -2, -2, -2]\n", 191 | "\n", 192 | "Features in additive nodes:\n", 193 | "[[], [], [], [1, 13], [], [], [], [], []]\n", 194 | "\n", 195 | "Thresholds:\n", 196 | "[0.04891999997198582, 17.589999198913574, 785.7999877929688, 21.574999809265137, -2, 23.739999771118164, -2, -2, -2]\n", 197 | "\n", 198 | "Depths:\n", 199 | "[0, 1, 1, 2, 2, 2, 2, 3, 3]\n", 200 | "\n", 201 | "Can split: \n", 202 | "[True, True, True, True, True, True, True, True, True]\n", 203 | "\n", 204 | "Class counts:\n", 205 | "[[159, 267], [13, 247], [146, 20], [7, 245], [6, 2], [13, 17], [133, 3], [0, 14], [13, 3]]\n", 206 | "\n", 207 | "Leaf Class Counts:\n", 208 | "[[7, 245], [6, 2], [133, 3], [0, 14], [13, 3]]\n", 209 | "\n", 210 | "Node igr: \n", 211 | "[0.4156254639152989, 0.2031712696855239, 0.29661687709662865, 0.18239393289682015, -2, 0.43241893359216155, -2, -2, -2]\n", 212 | "********************************************************\n", 213 | "\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "# This continues the example with the Breast Cancer dataset.\n", 219 | "\n", 220 | "# The output to explain an Additive Decsion Tree is similar as for\n", 221 | "# scikit-learn decision trees, though has slighly more information.\n", 222 | "# For example, it provides the depth of each node and the class counts \n", 223 | "# in each node. \n", 224 | "\n", 225 | "# Here node 3 is an additive node. In the features list, it is specified\n", 226 | "# as feature -100. In the Features in addtivie nodes list, we see it\n", 227 | "# uses both feature 1 and feature 13. \n", 228 | "\n", 229 | "adt.output_tree()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "23d9add9", 235 | "metadata": {}, 236 | "source": [ 237 | "## Explanations of Predictions" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 5, 243 | "id": "369d85db", 244 | "metadata": { 245 | "scrolled": false 246 | }, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "\n", 253 | "\n", 254 | "Initial distribution of classes: [0, 1]: [159, 267]\n", 255 | "\n", 256 | "\n", 257 | "...............................................................\n", 258 | "Prediction for row 0: 0 -- Correct\n", 259 | "...............................................................\n", 260 | "Path: [0, 2, 6]\n", 261 | "\n", 262 | "mean concave points is greater than 0.04891999997198582 \n", 263 | " (has value: 0.1471) --> (Class distribution: [146, 20]\n", 264 | "\n", 265 | "AND worst area is greater than 785.7999877929688 \n", 266 | " (has value: 2019.0) --> (Class distribution: [133, 3]\n", 267 | "where the majority class is: 0\n", 268 | "\n", 269 | "\n", 270 | "...............................................................\n", 271 | "Prediction for row 1: 0 -- Correct\n", 272 | "...............................................................\n", 273 | "Path: [0, 2, 6]\n", 274 | "\n", 275 | "mean concave points is greater than 0.04891999997198582 \n", 276 | " (has value: 0.07017) --> (Class distribution: [146, 20]\n", 277 | "\n", 278 | "AND worst area is greater than 785.7999877929688 \n", 279 | " (has value: 1956.0) --> (Class distribution: [133, 3]\n", 280 | "where the majority class is: 0\n", 281 | "\n", 282 | "\n", 283 | "...............................................................\n", 284 | "Prediction for row 2: 0 -- Correct\n", 285 | "...............................................................\n", 286 | "Path: [0, 2, 6]\n", 287 | "\n", 288 | "mean concave points is greater than 0.04891999997198582 \n", 289 | " (has value: 0.1279) --> (Class distribution: [146, 20]\n", 290 | "\n", 291 | "AND worst area is greater than 785.7999877929688 \n", 292 | " (has value: 1709.0) --> (Class distribution: [133, 3]\n", 293 | "where the majority class is: 0\n", 294 | "\n", 295 | "\n", 296 | "...............................................................\n", 297 | "Prediction for row 3: 0 -- Correct\n", 298 | "...............................................................\n", 299 | "Path: [0, 2, 5, 8]\n", 300 | "\n", 301 | "mean concave points is greater than 0.04891999997198582 \n", 302 | " (has value: 0.1052) --> (Class distribution: [146, 20]\n", 303 | "\n", 304 | "AND worst area is less than 785.7999877929688 \n", 305 | " (has value: 567.7) --> (Class distribution: [13, 17]\n", 306 | "\n", 307 | "AND worst texture is greater than 23.739999771118164 \n", 308 | " (has value: 26.5) --> (Class distribution: [13, 3]\n", 309 | "where the majority class is: 0\n", 310 | "\n", 311 | "\n", 312 | "...............................................................\n", 313 | "Prediction for row 4: 0 -- Correct\n", 314 | "...............................................................\n", 315 | "Path: [0, 2, 6]\n", 316 | "\n", 317 | "mean concave points is greater than 0.04891999997198582 \n", 318 | " (has value: 0.1043) --> (Class distribution: [146, 20]\n", 319 | "\n", 320 | "AND worst area is greater than 785.7999877929688 \n", 321 | " (has value: 1575.0) --> (Class distribution: [133, 3]\n", 322 | "where the majority class is: 0\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "# This provides explanations (in the form of the decision path)\n", 328 | "# for the first five rows. \n", 329 | "\n", 330 | "exp_arr = adt.get_explanations(X[:5], y[:5])\n", 331 | "for exp in exp_arr: \n", 332 | " print(\"\\n\")\n", 333 | " print(exp)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 6, 339 | "id": "8b61b253", 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "\n", 347 | "\n", 348 | "Initial distribution of classes: [0, 1]: [159, 267]\n", 349 | "\n", 350 | "\n", 351 | "...............................................................\n", 352 | "Prediction for row 0: 1 -- Correct\n", 353 | "...............................................................\n", 354 | "Path: [0, 1, 3]\n", 355 | "\n", 356 | "mean concave points is less than 0.04891999997198582 \n", 357 | " (has value: 0.04781) --> (Class distribution: [13, 247]\n", 358 | "\n", 359 | "AND worst radius is less than 17.589999198913574 \n", 360 | " (has value: 15.11) --> (Class distribution: [7, 245]\n", 361 | "\n", 362 | "AND vote based on: \n", 363 | " 1: mean texture is less than 21.574999809265137\n", 364 | " (has value 14.36) --> (class distribution: [1, 209])\n", 365 | " 2: area error is less than 42.19000053405762\n", 366 | " (has value 23.56) --> (class distribution: [4, 243])\n", 367 | "The class with the most votes is 1\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "# This gives an example (Row 19) where the decision path includes \n", 373 | "# node 3, which is an additive node. \n", 374 | "\n", 375 | "exp_arr = adt.get_explanations(X.loc[19:19], y.loc[19:19])\n", 376 | "for exp in exp_arr: \n", 377 | " print(\"\\n\")\n", 378 | " print(exp)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "56aa13f9", 384 | "metadata": {}, 385 | "source": [ 386 | "## Example wtih Regression" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 7, 392 | "id": "58a08f62", 393 | "metadata": { 394 | "scrolled": false 395 | }, 396 | "outputs": [ 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "\n", 402 | "Diabetes\n", 403 | "Standard DT: Training MSE: 2281.54, Testing MSE: 4373.97, Complexity: 29\n", 404 | "Additive DT: Training MSE: 2159.58, Testing MSE: 4291.76, Complexity: 33\n", 405 | "\n", 406 | "Make Regression\n", 407 | "Standard DT: Training MSE: 3487.28, Testing MSE: 23856.35, Complexity: 17\n", 408 | "Additive DT: Training MSE: 3302.9, Testing MSE: 21077.32, Complexity: 20\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "# Note: this provides only an example of using AdditiveDecisionTree and does \n", 414 | "# not properly test its accuracy\n", 415 | "\n", 416 | "# In these examples, the additive decision trees provide slightly lower errors\n", 417 | "# but slightly higher complexity.\n", 418 | "\n", 419 | "# In general, Additive Decision Trees tend to work better for classification \n", 420 | "# than regression at least with default hyperparameters.\n", 421 | "\n", 422 | "\n", 423 | "def evaluate_model(clf, clf_desc, X_train, X_test, y_train, y_test):\n", 424 | " clf.fit(X_train, y_train)\n", 425 | " y_pred_train = clf.predict(X_train)\n", 426 | " score_train = mean_squared_error(y_train, y_pred_train)\n", 427 | " y_pred_test = clf.predict(X_test)\n", 428 | " score_test = mean_squared_error(y_test, y_pred_test)\n", 429 | " complexity = 0\n", 430 | " if hasattr(clf, \"get_model_complexity\"):\n", 431 | " complexity = clf.get_model_complexity()\n", 432 | " elif hasattr(clf, \"tree_\"):\n", 433 | " complexity = len(clf.tree_.feature)\n", 434 | " print(f\"{clf_desc}: Training MSE: {round(score_train,2)}, Testing MSE: {round(score_test,2)}, Complexity: {complexity}\")\n", 435 | "\n", 436 | " \n", 437 | "def evaluate_dataset(dataset_name, X,y):\n", 438 | " print(f\"\\n{dataset_name}\")\n", 439 | " X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n", 440 | "\n", 441 | " dt_1 = tree.DecisionTreeRegressor(max_depth=4, min_samples_leaf=5, random_state=42)\n", 442 | " evaluate_model(dt_1, \"Standard DT\", X_train, X_test, y_train, y_test)\n", 443 | "\n", 444 | " adt = AdditiveDecisionTreeRegressor(max_depth=4, min_samples_leaf=5, allow_additive_nodes=True, verbose_level=0)\n", 445 | " evaluate_model(adt, \"Additive DT\", X_train, X_test, y_train, y_test)\n", 446 | " return adt\n", 447 | " \n", 448 | " \n", 449 | "X,y = get_diabetes()\n", 450 | "adt = evaluate_dataset(\"Diabetes\", X, y)\n", 451 | "\n", 452 | "X,y = get_make_regression()\n", 453 | "adt = evaluate_dataset(\"Make Regression\", X, y)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 8, 459 | "id": "a5066d1f", 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "name": "stdout", 464 | "output_type": "stream", 465 | "text": [ 466 | "\n", 467 | "********************************************************\n", 468 | "Generated Tree\n", 469 | "********************************************************\n", 470 | "\n", 471 | "# Nodes: 13\n", 472 | "\n", 473 | "Left Chidren:\n", 474 | "[1, 3, 5, -2, 7, 9, -2, 11, -2, -2, -2, -2, -2]\n", 475 | "\n", 476 | "Right Chidren:\n", 477 | "[2, 4, 6, -2, 8, 10, -2, 12, -2, -2, -2, -2, -2]\n", 478 | "\n", 479 | "# Rows: \n", 480 | "[75, 53, 22, 7, 46, 16, 6, 31, 15, 11, 5, 25, 6]\n", 481 | "\n", 482 | "Features:\n", 483 | "[57, 53, 46, -2, 43, 41, -2, 14, -100, -100, -2, -2, -2]\n", 484 | "\n", 485 | "Features in additive nodes:\n", 486 | "[[], [], [], [], [], [], [], [], [32, 96], [72, 15, 79, 85, 96], [], [], []]\n", 487 | "\n", 488 | "Thresholds:\n", 489 | "[0.2633100152015686, -0.9771790504455566, 0.5912367105484009, -2, 0.3434883654117584, 1.1032692193984985, -2, 0.5415648818016052, -0.8923328518867493, -0.13171404972672462, -2, -2, -2]\n", 490 | "\n", 491 | "Depths:\n", 492 | "[0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4]\n", 493 | "\n", 494 | "Can split: \n", 495 | "[True, True, True, True, True, True, True, True, True, True, True, True, True]\n", 496 | "\n", 497 | "Average target values:\n", 498 | "[-28.94558781986425, -72.30996714891963, 75.5231442001328, 66.46907443748522, -93.42851695554644, 35.158511395759106, 183.16216501179596, -125.21029964481193, -27.746166064397826, -6.244026611424963, 126.24409501156406, -154.15960673535238, -4.588186767560032]\n", 499 | "\n", 500 | "Average target values in additive nodes:[[], [], [], [], [], [], [], [], [(-132.02410519598638, 10.173084528907104), (-80.16582925228708, 32.16202043604704)], [(21.68185255974749, -80.7130377345515), (-75.58467168872085, 33.37919914702983), (36.54293366212939, -57.58837893969019), (33.37919914702983, -75.58467168872085), (-68.33470394239224, 29.23636043484205)], [], [], []]\n", 501 | "********************************************************\n", 502 | "\n" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "adt.output_tree()" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "id": "a9859947", 513 | "metadata": {}, 514 | "source": [ 515 | "## Example Tuning Hyperparameters with a Cross Validated Grid Search" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 9, 521 | "id": "d6555065", 522 | "metadata": { 523 | "scrolled": false 524 | }, 525 | "outputs": [ 526 | { 527 | "name": "stdout", 528 | "output_type": "stream", 529 | "text": [ 530 | "test_score: 4277.794998844322\n", 531 | "best estimator: min_samples_split: 25, min_samples_leaf: 15, max_depth: 5, allow_additive_nodes: True, max_added_splits_per_node: 5\n" 532 | ] 533 | } 534 | ], 535 | "source": [ 536 | "# Note: this can be several minutes to execute.\n", 537 | "\n", 538 | "X,y = get_diabetes()\n", 539 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n", 540 | "\n", 541 | "parameters = {\n", 542 | " 'min_samples_split': (5,10,25,50), \n", 543 | " 'min_samples_leaf': (5,10,15),\n", 544 | " 'max_depth': (4,5,6,7),\n", 545 | " 'allow_additive_nodes': (True, False),\n", 546 | " 'max_added_splits_per_node': (2,3,4,5,10)\n", 547 | "}\n", 548 | "\n", 549 | "estimator = AdditiveDecisionTreeRegressor(max_depth=4, min_samples_leaf=5)\n", 550 | "gs_estimator = RandomizedSearchCV(estimator, parameters, scoring='neg_mean_squared_error',n_iter=100)\n", 551 | "gs_estimator.fit(X_train, y_train)\n", 552 | "y_pred = gs_estimator.predict(X_test)\n", 553 | "test_score = mean_squared_error(list(y_pred), list(y_test)) \n", 554 | "\n", 555 | "print(\"test_score: \", test_score)\n", 556 | "print(\"best estimator: \", gs_estimator.best_estimator_)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 10, 562 | "id": "c8678ad5", 563 | "metadata": {}, 564 | "outputs": [ 565 | { 566 | "data": { 567 | "text/plain": [ 568 | "41" 569 | ] 570 | }, 571 | "execution_count": 10, 572 | "metadata": {}, 573 | "output_type": "execute_result" 574 | } 575 | ], 576 | "source": [ 577 | "# Create an instance of the best model found during tuning\n", 578 | "\n", 579 | "adt = AdditiveDecisionTreeRegressor(\n", 580 | " min_samples_split=25, \n", 581 | " min_samples_leaf=15, \n", 582 | " max_depth=5, \n", 583 | " allow_additive_nodes=True, \n", 584 | " max_added_splits_per_node=5)\n", 585 | "adt.fit(X_train, y_train)\n", 586 | "\n", 587 | "adt.get_model_complexity()" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "id": "0e0b7a91", 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [] 597 | } 598 | ], 599 | "metadata": { 600 | "kernelspec": { 601 | "display_name": "Python 3 (ipykernel)", 602 | "language": "python", 603 | "name": "python3" 604 | }, 605 | "language_info": { 606 | "codemirror_mode": { 607 | "name": "ipython", 608 | "version": 3 609 | }, 610 | "file_extension": ".py", 611 | "mimetype": "text/x-python", 612 | "name": "python", 613 | "nbconvert_exporter": "python", 614 | "pygments_lexer": "ipython3", 615 | "version": "3.9.6" 616 | } 617 | }, 618 | "nbformat": 4, 619 | "nbformat_minor": 5 620 | } 621 | --------------------------------------------------------------------------------