├── InterpretableDecisionTreeClassifier.py ├── InterpretableDecisionTreeRegression.py ├── README.md ├── demo ├── __init__.py ├── run_demo_classification.py ├── run_demo_classifier_comparison.py ├── run_demo_regression.py ├── uci_comparison.py └── uci_loader.py ├── example_dt.png └── treeutils.py /InterpretableDecisionTreeClassifier.py: -------------------------------------------------------------------------------- 1 | import re 2 | from sklearn.metrics import f1_score, make_scorer 3 | from sklearn.tree import DecisionTreeClassifier 4 | from treeutils import simplify_tree, tree_to_code 5 | 6 | class IDecisionTreeClassifier(DecisionTreeClassifier): 7 | """A decision tree classifier. 8 | 9 | Read more in the :ref:`User Guide `. 10 | 11 | Parameters 12 | ---------- 13 | criterion : string, optional (default="gini") 14 | The function to measure the quality of a split. Supported criteria are 15 | "gini" for the Gini impurity and "entropy" for the information gain. 16 | 17 | splitter : string, optional (default="best") 18 | The strategy used to choose the split at each node. Supported 19 | strategies are "best" to choose the best split and "random" to choose 20 | the best random split. 21 | 22 | max_features : int, float, string or None, optional (default=None) 23 | The number of features to consider when looking for the best split: 24 | 25 | - If int, then consider `max_features` features at each split. 26 | - If float, then `max_features` is a percentage and 27 | `int(max_features * n_features)` features are considered at each 28 | split. 29 | - If "auto", then `max_features=sqrt(n_features)`. 30 | - If "sqrt", then `max_features=sqrt(n_features)`. 31 | - If "log2", then `max_features=log2(n_features)`. 32 | - If None, then `max_features=n_features`. 33 | 34 | Note: the search for a split does not stop until at least one 35 | valid partition of the node samples is found, even if it requires to 36 | effectively inspect more than ``max_features`` features. 37 | 38 | max_depth : int or None, optional (default=None) 39 | The maximum depth of the tree. If None, then nodes are expanded until 40 | all leaves are pure or until all leaves contain less than 41 | min_samples_split samples. 42 | 43 | min_samples_split : int, float, optional (default=2) 44 | The minimum number of samples required to split an internal node: 45 | 46 | - If int, then consider `min_samples_split` as the minimum number. 47 | - If float, then `min_samples_split` is a percentage and 48 | `ceil(min_samples_split * n_samples)` are the minimum 49 | number of samples for each split. 50 | 51 | .. versionchanged:: 0.18 52 | Added float values for percentages. 53 | 54 | min_samples_leaf : int, float, optional (default=1) 55 | The minimum number of samples required to be at a leaf node: 56 | 57 | - If int, then consider `min_samples_leaf` as the minimum number. 58 | - If float, then `min_samples_leaf` is a percentage and 59 | `ceil(min_samples_leaf * n_samples)` are the minimum 60 | number of samples for each node. 61 | 62 | .. versionchanged:: 0.18 63 | Added float values for percentages. 64 | 65 | min_weight_fraction_leaf : float, optional (default=0.) 66 | The minimum weighted fraction of the sum total of weights (of all 67 | the input samples) required to be at a leaf node. Samples have 68 | equal weight when sample_weight is not provided. 69 | 70 | max_leaf_nodes : int or None, optional (default=None) 71 | Grow a tree with ``max_leaf_nodes`` in best-first fashion. 72 | Best nodes are defined as relative reduction in impurity. 73 | If None then unlimited number of leaf nodes. 74 | 75 | class_weight : dict, list of dicts, "balanced" or None, optional (default=None) 76 | Weights associated with classes in the form ``{class_label: weight}``. 77 | If not given, all classes are supposed to have weight one. For 78 | multi-output problems, a list of dicts can be provided in the same 79 | order as the columns of y. 80 | 81 | The "balanced" mode uses the values of y to automatically adjust 82 | weights inversely proportional to class frequencies in the input data 83 | as ``n_samples / (n_classes * np.bincount(y))`` 84 | 85 | For multi-output, the weights of each column of y will be multiplied. 86 | 87 | Note that these weights will be multiplied with sample_weight (passed 88 | through the fit method) if sample_weight is specified. 89 | 90 | random_state : int, RandomState instance or None, optional (default=None) 91 | If int, random_state is the seed used by the random number generator; 92 | If RandomState instance, random_state is the random number generator; 93 | If None, the random number generator is the RandomState instance used 94 | by `np.random`. 95 | 96 | min_impurity_split : float, optional (default=1e-7) 97 | Threshold for early stopping in tree growth. A node will split 98 | if its impurity is above the threshold, otherwise it is a leaf. 99 | 100 | .. versionadded:: 0.18 101 | 102 | presort : bool, optional (default=False) 103 | Whether to presort the data to speed up the finding of best splits in 104 | fitting. For the default settings of a decision tree on large 105 | datasets, setting this to true may slow down the training process. 106 | When using either a smaller dataset or a restricted depth, this may 107 | speed up the training. 108 | 109 | Attributes 110 | ---------- 111 | classes_ : array of shape = [n_classes] or a list of such arrays 112 | The classes labels (single output problem), 113 | or a list of arrays of class labels (multi-output problem). 114 | 115 | feature_importances_ : array of shape = [n_features] 116 | The feature importances. The higher, the more important the 117 | feature. The importance of a feature is computed as the (normalized) 118 | total reduction of the criterion brought by that feature. It is also 119 | known as the Gini importance [4]_. 120 | 121 | max_features_ : int, 122 | The inferred value of max_features. 123 | 124 | n_classes_ : int or list 125 | The number of classes (for single output problems), 126 | or a list containing the number of classes for each 127 | output (for multi-output problems). 128 | 129 | n_features_ : int 130 | The number of features when ``fit`` is performed. 131 | 132 | n_outputs_ : int 133 | The number of outputs when ``fit`` is performed. 134 | 135 | tree_ : Tree object 136 | The underlying Tree object. 137 | 138 | See also 139 | -------- 140 | DecisionTreeRegressor 141 | 142 | References 143 | ---------- 144 | 145 | .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning 146 | 147 | .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification 148 | and Regression Trees", Wadsworth, Belmont, CA, 1984. 149 | 150 | .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical 151 | Learning", Springer, 2009. 152 | 153 | .. [4] L. Breiman, and A. Cutler, "Random Forests", 154 | http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm 155 | 156 | Examples 157 | -------- 158 | >>> from sklearn.datasets import load_iris 159 | >>> from sklearn.model_selection import cross_val_score 160 | >>> from sklearn.tree import DecisionTreeClassifier 161 | >>> clf = DecisionTreeClassifier(random_state=0) 162 | >>> iris = load_iris() 163 | >>> cross_val_score(clf, iris.data, iris.target, cv=10) 164 | ... # doctest: +SKIP 165 | ... 166 | array([ 1. , 0.93..., 0.86..., 0.93..., 0.93..., 167 | 0.93..., 0.93..., 1. , 0.93..., 1. ]) 168 | """ 169 | def __init__(self, 170 | criterion="gini", 171 | splitter="best", 172 | max_depth=None, 173 | min_samples_split=2, 174 | min_samples_leaf=1, 175 | min_weight_fraction_leaf=0., 176 | max_features=None, 177 | random_state=None, 178 | max_leaf_nodes=None, 179 | min_impurity_split=1e-7, 180 | class_weight=None, 181 | presort=False, 182 | verbose=False, 183 | acceptable_score_drop=0.01): 184 | 185 | super(IDecisionTreeClassifier, self).__init__( 186 | criterion=criterion, 187 | splitter=splitter, 188 | max_depth=max_depth, 189 | min_samples_split=min_samples_split, 190 | min_samples_leaf=min_samples_leaf, 191 | min_weight_fraction_leaf=min_weight_fraction_leaf, 192 | max_features=max_features, 193 | max_leaf_nodes=max_leaf_nodes, 194 | class_weight=class_weight, 195 | random_state=random_state, 196 | min_impurity_split=min_impurity_split, 197 | presort=presort) 198 | 199 | self.acceptable_score_drop = acceptable_score_drop 200 | self.verbose = verbose 201 | 202 | 203 | def fit(self, X, y, sample_weight=None, check_input=True, 204 | X_idx_sorted=None, scorer=make_scorer(f1_score, greater_is_better=True)): 205 | """Build a decision tree classifier from the training set (X, y). 206 | 207 | Parameters 208 | ---------- 209 | X : array-like or sparse matrix, shape = [n_samples, n_features] 210 | The training input samples. Internally, it will be converted to 211 | ``dtype=np.float32`` and if a sparse matrix is provided 212 | to a sparse ``csc_matrix``. 213 | 214 | y : array-like, shape = [n_samples] or [n_samples, n_outputs] 215 | The target values (class labels) as integers or strings. 216 | 217 | sample_weight : array-like, shape = [n_samples] or None 218 | Sample weights. If None, then samples are equally weighted. Splits 219 | that would create child nodes with net zero or negative weight are 220 | ignored while searching for a split in each node. Splits are also 221 | ignored if they would result in any single class carrying a 222 | negative weight in either child node. 223 | 224 | check_input : boolean, (default=True) 225 | Allow to bypass several input checking. 226 | Don't use this parameter unless you know what you do. 227 | 228 | X_idx_sorted : array-like, shape = [n_samples, n_features], optional 229 | The indexes of the sorted training input samples. If many tree 230 | are grown on the same dataset, this allows the ordering to be 231 | cached between trees. If None, the data will be sorted here. 232 | Don't use this parameter unless you know what to do. 233 | 234 | Returns 235 | ------- 236 | self : object 237 | Returns self. 238 | """ 239 | 240 | super(IDecisionTreeClassifier, self).fit( 241 | X, y, 242 | sample_weight=sample_weight, 243 | check_input=check_input, 244 | X_idx_sorted=X_idx_sorted) 245 | 246 | simplify_tree(self, X, y, scorer, self.acceptable_score_drop, verbose=self.verbose) 247 | 248 | return self 249 | 250 | def __str__(self): 251 | feature_names = ["ft"+str(i) for i in range(len(self.feature_importances_))] 252 | return self.tostring(feature_names) 253 | 254 | def tostring(self, feature_names, decimals=4): 255 | return re.sub('\n\s+return', ' return', tree_to_code(self, feature_names, decimals=decimals)) -------------------------------------------------------------------------------- /InterpretableDecisionTreeRegression.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import make_scorer, r2_score 2 | from sklearn.tree import DecisionTreeRegressor 3 | from treeutils import simplify_tree, tree_to_code 4 | import re 5 | 6 | class IDecisionTreeRegressor(DecisionTreeRegressor): 7 | """A decision tree regressor. 8 | 9 | Read more in the :ref:`User Guide `. 10 | 11 | Parameters 12 | ---------- 13 | criterion : string, optional (default="mse") 14 | The function to measure the quality of a split. Supported criteria 15 | are "mse" for the mean squared error, which is equal to variance 16 | reduction as feature selection criterion, and "mae" for the mean 17 | absolute error. 18 | 19 | .. versionadded:: 0.18 20 | Mean Absolute Error (MAE) criterion. 21 | 22 | splitter : string, optional (default="best") 23 | The strategy used to choose the split at each node. Supported 24 | strategies are "best" to choose the best split and "random" to choose 25 | the best random split. 26 | 27 | max_features : int, float, string or None, optional (default=None) 28 | The number of features to consider when looking for the best split: 29 | 30 | - If int, then consider `max_features` features at each split. 31 | - If float, then `max_features` is a percentage and 32 | `int(max_features * n_features)` features are considered at each 33 | split. 34 | - If "auto", then `max_features=n_features`. 35 | - If "sqrt", then `max_features=sqrt(n_features)`. 36 | - If "log2", then `max_features=log2(n_features)`. 37 | - If None, then `max_features=n_features`. 38 | 39 | Note: the search for a split does not stop until at least one 40 | valid partition of the node samples is found, even if it requires to 41 | effectively inspect more than ``max_features`` features. 42 | 43 | max_depth : int or None, optional (default=None) 44 | The maximum depth of the tree. If None, then nodes are expanded until 45 | all leaves are pure or until all leaves contain less than 46 | min_samples_split samples. 47 | 48 | min_samples_split : int, float, optional (default=2) 49 | The minimum number of samples required to split an internal node: 50 | 51 | - If int, then consider `min_samples_split` as the minimum number. 52 | - If float, then `min_samples_split` is a percentage and 53 | `ceil(min_samples_split * n_samples)` are the minimum 54 | number of samples for each split. 55 | 56 | .. versionchanged:: 0.18 57 | Added float values for percentages. 58 | 59 | min_samples_leaf : int, float, optional (default=1) 60 | The minimum number of samples required to be at a leaf node: 61 | 62 | - If int, then consider `min_samples_leaf` as the minimum number. 63 | - If float, then `min_samples_leaf` is a percentage and 64 | `ceil(min_samples_leaf * n_samples)` are the minimum 65 | number of samples for each node. 66 | 67 | .. versionchanged:: 0.18 68 | Added float values for percentages. 69 | 70 | min_weight_fraction_leaf : float, optional (default=0.) 71 | The minimum weighted fraction of the sum total of weights (of all 72 | the input samples) required to be at a leaf node. Samples have 73 | equal weight when sample_weight is not provided. 74 | 75 | max_leaf_nodes : int or None, optional (default=None) 76 | Grow a tree with ``max_leaf_nodes`` in best-first fashion. 77 | Best nodes are defined as relative reduction in impurity. 78 | If None then unlimited number of leaf nodes. 79 | 80 | random_state : int, RandomState instance or None, optional (default=None) 81 | If int, random_state is the seed used by the random number generator; 82 | If RandomState instance, random_state is the random number generator; 83 | If None, the random number generator is the RandomState instance used 84 | by `np.random`. 85 | 86 | min_impurity_split : float, optional (default=1e-7) 87 | Threshold for early stopping in tree growth. If the impurity 88 | of a node is below the threshold, the node is a leaf. 89 | 90 | .. versionadded:: 0.18 91 | 92 | presort : bool, optional (default=False) 93 | Whether to presort the data to speed up the finding of best splits in 94 | fitting. For the default settings of a decision tree on large 95 | datasets, setting this to true may slow down the training process. 96 | When using either a smaller dataset or a restricted depth, this may 97 | speed up the training. 98 | 99 | Attributes 100 | ---------- 101 | feature_importances_ : array of shape = [n_features] 102 | The feature importances. 103 | The higher, the more important the feature. 104 | The importance of a feature is computed as the 105 | (normalized) total reduction of the criterion brought 106 | by that feature. It is also known as the Gini importance [4]_. 107 | 108 | max_features_ : int, 109 | The inferred value of max_features. 110 | 111 | n_features_ : int 112 | The number of features when ``fit`` is performed. 113 | 114 | n_outputs_ : int 115 | The number of outputs when ``fit`` is performed. 116 | 117 | tree_ : Tree object 118 | The underlying Tree object. 119 | 120 | See also 121 | -------- 122 | DecisionTreeClassifier 123 | 124 | References 125 | ---------- 126 | 127 | .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning 128 | 129 | .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification 130 | and Regression Trees", Wadsworth, Belmont, CA, 1984. 131 | 132 | .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical 133 | Learning", Springer, 2009. 134 | 135 | .. [4] L. Breiman, and A. Cutler, "Random Forests", 136 | http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm 137 | 138 | Examples 139 | -------- 140 | >>> from sklearn.datasets import load_boston 141 | >>> from sklearn.model_selection import cross_val_score 142 | >>> from sklearn.tree import DecisionTreeRegressor 143 | >>> boston = load_boston() 144 | >>> regressor = DecisionTreeRegressor(random_state=0) 145 | >>> cross_val_score(regressor, boston.data, boston.target, cv=10) 146 | ... # doctest: +SKIP 147 | ... 148 | array([ 0.61..., 0.57..., -0.34..., 0.41..., 0.75..., 149 | 0.07..., 0.29..., 0.33..., -1.42..., -1.77...]) 150 | """ 151 | def __init__(self, 152 | criterion="mse", 153 | splitter="best", 154 | max_depth=None, 155 | min_samples_split=2, 156 | min_samples_leaf=1, 157 | min_weight_fraction_leaf=0., 158 | max_features=None, 159 | random_state=None, 160 | max_leaf_nodes=None, 161 | min_impurity_split=1e-7, 162 | presort=False, 163 | verbose=False, 164 | acceptable_score_drop=0.01): 165 | 166 | super(IDecisionTreeRegressor, self).__init__( 167 | criterion=criterion, 168 | splitter=splitter, 169 | max_depth=max_depth, 170 | min_samples_split=min_samples_split, 171 | min_samples_leaf=min_samples_leaf, 172 | min_weight_fraction_leaf=min_weight_fraction_leaf, 173 | max_features=max_features, 174 | max_leaf_nodes=max_leaf_nodes, 175 | random_state=random_state, 176 | min_impurity_split=min_impurity_split, 177 | presort=presort) 178 | 179 | self.acceptable_score_drop = acceptable_score_drop 180 | self.verbose = verbose 181 | 182 | def fit(self, X, y, sample_weight=None, check_input=True, 183 | X_idx_sorted=None, scorer=make_scorer(r2_score, greater_is_better=True)): 184 | """Build a decision tree regressor from the training set (X, y). 185 | 186 | Parameters 187 | ---------- 188 | X : array-like or sparse matrix, shape = [n_samples, n_features] 189 | The training input samples. Internally, it will be converted to 190 | ``dtype=np.float32`` and if a sparse matrix is provided 191 | to a sparse ``csc_matrix``. 192 | 193 | y : array-like, shape = [n_samples] or [n_samples, n_outputs] 194 | The target values (real numbers). Use ``dtype=np.float64`` and 195 | ``order='C'`` for maximum efficiency. 196 | 197 | sample_weight : array-like, shape = [n_samples] or None 198 | Sample weights. If None, then samples are equally weighted. Splits 199 | that would create child nodes with net zero or negative weight are 200 | ignored while searching for a split in each node. 201 | 202 | check_input : boolean, (default=True) 203 | Allow to bypass several input checking. 204 | Don't use this parameter unless you know what you do. 205 | 206 | X_idx_sorted : array-like, shape = [n_samples, n_features], optional 207 | The indexes of the sorted training input samples. If many tree 208 | are grown on the same dataset, this allows the ordering to be 209 | cached between trees. If None, the data will be sorted here. 210 | Don't use this parameter unless you know what to do. 211 | 212 | Returns 213 | ------- 214 | self : object 215 | Returns self. 216 | """ 217 | 218 | super(IDecisionTreeRegressor, self).fit( 219 | X, y, 220 | sample_weight=sample_weight, 221 | check_input=check_input, 222 | X_idx_sorted=X_idx_sorted) 223 | 224 | simplify_tree(self, X, y, scorer, self.acceptable_score_drop, verbose=self.verbose) 225 | 226 | return self 227 | 228 | def __str__(self): 229 | feature_names = ["ft"+str(i) for i in range(len(self.feature_importances_))] 230 | return self.tostring(feature_names) 231 | 232 | def tostring(self, feature_names, decimals=4): 233 | return re.sub('\n\s+return', ' return', tree_to_code(self, feature_names, transform_to_probabilities=False, decimals=decimals)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Highly interpretable, sklearn-compatible classifier and regressor based on simplified decision trees 2 | =============== 3 | 4 | Implementation of a simple, greedy optimization approach to simplifying decision trees for better interpretability and readability. 5 | 6 | It produces small decision trees, which makes trained classifiers **easily interpretable to human experts**, and is competitive with state of the art classifiers such as random forests or SVMs. 7 | 8 | Turns out to frequently outperform [Bayesian Rule Lists](https://github.com/tmadl/sklearn-expertsys) in terms of accuracy and computational complexity, and Logistic Regression in terms of interpretability. 9 | Note that a feature selection method is highly advisable on large datasets, as the runtime directly depends on the number of features. 10 | 11 | Usage 12 | =============== 13 | 14 | The project requires [scikit-learn](http://scikit-learn.org/stable/install.html). 15 | 16 | The included `InterpretableDecisionTreeClassifier` and `InterpretableDecisionTreeRegressor` both work as scikit-learn estimators, with a `model.fit(X,y)` method which takes training data `X` (numpy array or pandas DataFrame) and labels `y`. 17 | 18 | The learned rules of a trained model can be displayed simply by casting the object as a string, e.g. `print model`, or by using the `model.tostring(feature_names=['feature1', 'feature2', ], decimals=1)` method and specifying names for the features and, optionally, the rounding precision. 19 | 20 | Example output on `breast cancer` dataset: 21 | 22 | ```python 23 | # Data from https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic) 24 | def breast_cancer_probability(radius, texture, perimeter, area, smoothness, compactness, concavity, concave_points, symmetry, fractal_dimension): 25 | if perimeter <= 2.5: 26 | if concavity <= 5.5: return 0.012 27 | else: return 0.875 28 | else: 29 | if area <= 2.5: return 0.217 30 | else: return 0.917 31 | ``` 32 | 33 | Tree size and complexity can be reduced by two parameters: 34 | * the classical [`max_depth` parameter](http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier), and 35 | * the `acceptable_score_drop` parameter, which specifies the maximum acceptable reduction in classifier performance (higher means more branches can be pruned). By default, the F1-score is used for this purpose. A `scorer` parameter can be passed to the `fit` method if optimization based on other scores is preferred. 36 | 37 | Self-contained usage example: 38 | 39 | ```python 40 | import numpy as np 41 | from sklearn.datasets.samples_generator import make_moons 42 | from sklearn.model_selection._validation import cross_val_score 43 | from InterpretableDecisionTreeClassifier import * 44 | 45 | X, y = make_moons(300, noise=0.4) 46 | print("Decision Tree F1 score:",np.mean(cross_val_score(DecisionTreeClassifier(), X, y, scoring="f1"))) 47 | print("Interpretable Decision Tree F1 score:",np.mean(cross_val_score(IDecisionTreeClassifier(), X, y, scoring="f1"))) 48 | 49 | """ 50 | **Output:** 51 | Decision Tree F1 score: 0.81119342213567125 52 | Interpretable Decision Tree F1 score: 0.8416950113378685 53 | """ 54 | ``` 55 | 56 | ![Simplified decision tree on moons dataset](example_dt.png) 57 | 58 | Comparison with other sklearn classifiers (can be reproduced with `run_demo_classifier_comparison.py'. Rule List Classifier: see [here](https://github.com/tmadl/sklearn-expertsys)) 59 | 60 | ```python 61 | D.Tree3 F1 D.Tree5 F1 Interpr.D.Tree3 F1 Interpr.D.Tree5 F1 RuleListClassifier F1 Random Forest F1 62 | ========================================================================================================================================================== 63 | diabetes_scale 0.814 (SE=0.006) 0.808 (SE=0.007) 0.826 (SE=0.005) *0.833 (SE=0.005) 0.765 (SE=0.007) 0.793 (SE=0.006) 64 | breast-cancer 0.899 (SE=0.005) 0.912 (SE=0.005) 0.920 (SE=0.004) 0.917 (SE=0.004) 0.938 (SE=0.004) *0.946 (SE=0.004) 65 | uci-20070111 haberman 0.380 (SE=0.020) 0.305 (SE=0.019) 0.380 (SE=0.020) *0.404 (SE=0.015) 0.321 (SE=0.019) 0.268 (SE=0.017) 66 | heart 0.827 (SE=0.005) 0.800 (SE=0.005) 0.824 (SE=0.005) *0.828 (SE=0.006) 0.792 (SE=0.006) 0.808 (SE=0.008) 67 | liver-disorders 0.684 (SE=0.013) 0.610 (SE=0.017) *0.702 (SE=0.014) 0.670 (SE=0.016) 0.663 (SE=0.019) 0.635 (SE=0.016) 68 | 69 | ==== Interpretable DT for dataset `diabetes_scale' (lines of full DT: 24, lines of interpretable DT: 6, simplification factor: 0.25) ==== 70 | def probability_of_class_one(ft0, ft1, ft2, ft3, ft4, ft5, ft6, ft7): 71 | if ft1 <= 0.2814: return 0.8062 72 | else: 73 | if ft5 <= -0.1073: return 0.6842 74 | else: return 0.2754 75 | 76 | ==== end of DT for dataset `diabetes_scale'. F1 score: 0.835061262959 ==== 77 | 78 | ==== Interpretable DT for dataset `breast-cancer' (lines of full DT: 24, lines of interpretable DT: 8, simplification factor: 0.333333333333) ==== 79 | def probability_of_class_one(ft0, ft1, ft2, ft3, ft4, ft5, ft6, ft7, ft8, ft9): 80 | if ft2 <= 2.5: 81 | if ft6 <= 5.5: return 0.0122 82 | else: return 0.875 83 | else: 84 | if ft3 <= 2.5: return 0.2174 85 | else: return 0.9174 86 | 87 | ==== end of DT for dataset `breast-cancer'. F1 score: 0.936605316973 ==== 88 | 89 | WARNING: No target found. Taking last column of data matrix as target 90 | ==== Interpretable DT for dataset `uci-20070111 haberman' (lines of full DT: 21, lines of interpretable DT: 10, simplification factor: 0.47619047619) ==== 91 | def probability_of_class_one(ft0, ft1, ft2): 92 | if ft2 <= 4.5: 93 | if ft0 <= 77.5: return 0.1754 94 | else: return 1.0 95 | else: 96 | if ft0 <= 42.5: 97 | if ft2 <= 20.5: return 0.0833 98 | else: return 0.6667 99 | else: return 0.5902 100 | 101 | ==== end of DT for dataset `uci-20070111 haberman'. F1 score: 0.544217687075 ==== 102 | 103 | ==== Interpretable DT for dataset `heart' (lines of full DT: 24, lines of interpretable DT: 12, simplification factor: 0.5) ==== 104 | def probability_of_class_one(ft0, ft1, ft2, ft3, ft4, ft5, ft6, ft7, ft8, ft9, ft10, ft11, ft12): 105 | if ft12 <= 4.5: 106 | if ft2 <= 3.5: return 0.901 107 | else: 108 | if ft11 <= 0.5: return 0.8065 109 | else: return 0.15 110 | else: 111 | if ft11 <= 0.5: 112 | if ft8 <= 0.5: return 0.6897 113 | else: return 0.2083 114 | else: return 0.0923 115 | 116 | ==== end of DT for dataset `heart'. F1 score: 0.87459807074 ==== 117 | 118 | ==== Interpretable DT for dataset `liver-disorders' (lines of full DT: 24, lines of interpretable DT: 6, simplification factor: 0.25) ==== 119 | def probability_of_class_one(ft0, ft1, ft2, ft3, ft4, ft5): 120 | if ft4 <= 20.5: 121 | if ft2 <= 19.5: return 0.6833 122 | else: return 0.25 123 | else: return 0.678 124 | 125 | ==== end of DT for dataset `liver-disorders'. F1 score: 0.774193548387 ==== 126 | ``` -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmadl/sklearn-interpretable-tree/4a5229018b6ededf8ae64996fc5dbf3a06338b9b/demo/__init__.py -------------------------------------------------------------------------------- /demo/run_demo_classification.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt, numpy as np 2 | from sklearn.datasets import make_moons 3 | from sklearn.tree import DecisionTreeClassifier 4 | from InterpretableDecisionTreeClassifier import IDecisionTreeClassifier 5 | from treeutils import tree_to_code 6 | from sklearn.model_selection._split import train_test_split 7 | from sklearn.metrics import f1_score 8 | 9 | X, y = make_moons(300, noise=0.4) 10 | Xtrain, Xtest, ytrain, ytest = train_test_split(X, y) 11 | 12 | clf1 = DecisionTreeClassifier(max_depth=4).fit(Xtrain,ytrain) 13 | clf2 = IDecisionTreeClassifier(max_depth=4).fit(Xtrain,ytrain) 14 | 15 | print("=== original decision tree ===") 16 | features = ["ft"+str(i) for i in range(X.shape[1])] 17 | print(tree_to_code(clf1, features)) # output large tree 18 | print("=== simplified (interpretable) decision tree ===") 19 | print(tree_to_code(clf2, features)) 20 | 21 | h = 0.02 22 | x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 23 | y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 24 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 25 | np.arange(y_min, y_max, h)) 26 | 27 | 28 | plt.subplot(1,2,1) 29 | plt.title("original decision tree. F1: "+str(f1_score(ytest, clf1.predict(Xtest)))) 30 | Z = clf1.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] 31 | Z = Z.reshape(xx.shape) 32 | plt.contourf(xx, yy, Z, alpha=.8) 33 | plt.scatter(X[:,0], X[:,1], c=y) 34 | 35 | plt.subplot(1,2,2) 36 | plt.title("simplified (interpretable) decision tree. F1: "+str(f1_score(ytest, clf2.predict(Xtest)))) 37 | Z = clf2.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] 38 | Z = Z.reshape(xx.shape) 39 | plt.contourf(xx, yy, Z, alpha=.8) 40 | plt.scatter(X[:,0], X[:,1], c=y) 41 | plt.show() -------------------------------------------------------------------------------- /demo/run_demo_classifier_comparison.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from uci_comparison import * 4 | from sklearn.ensemble.forest import RandomForestClassifier 5 | from InterpretableDecisionTreeClassifier import * 6 | 7 | estimators = { 8 | 'Random Forest': RandomForestClassifier(), 9 | 'D.Tree3': DecisionTreeClassifier(max_depth=3), 10 | 'Interpr.D.Tree3': IDecisionTreeClassifier(max_depth=3), 11 | 'D.Tree5': DecisionTreeClassifier(max_depth=5), 12 | 'Interpr.D.Tree5': IDecisionTreeClassifier(max_depth=5), 13 | } 14 | 15 | # optionally, pass a list of UCI dataset identifiers as the datasets parameter, e.g. datasets=['iris', 'diabetes'] 16 | # optionally, pass a dict of scoring functions as the metric parameter, e.g. metrics={'F1-score': f1_score} 17 | compare_estimators(estimators) 18 | 19 | print("") 20 | for d in comparison_datasets: 21 | # load 22 | try: 23 | X, y = getdataset(d) 24 | except: 25 | print("FAILED TO LOAD",d," - SKIPPING") 26 | continue 27 | # train 28 | clf = IDecisionTreeClassifier(max_depth=3).fit(X,y) 29 | # tostring 30 | itree = str(clf) 31 | # compare with full DT 32 | clf = DecisionTreeClassifier(max_depth=3).fit(X,y) 33 | fulltree = tree_to_code(clf, ["ft"+str(i) for i in range(X.shape[1])]) 34 | # statistics 35 | linestats = "lines of full DT: {}, lines of interpretable DT: {}, simplification factor: {}" 36 | linestats = linestats.format(len(fulltree.split("\n")), len(itree.split("\n")), 1.0/len(fulltree.split("\n"))*len(itree.split("\n"))) 37 | print("==== Interpretable DT for dataset `{}' ({}) ====".format(d,linestats)) 38 | # print 39 | print(itree) 40 | # add evaluation code 41 | ppred = [] 42 | ftlist = ",".join(["ft"+str(i) for i in range(X.shape[1])]) 43 | itree += "\nfor x in X:\n " 44 | itree += ftlist + " = x\n" 45 | itree += " ppred.append(probability_of_class_one(" + ftlist + "))\n" 46 | exec(itree) 47 | # evaluate 48 | ypred = 1*(np.array(ppred)>0.5) 49 | print("==== end of DT for dataset `{}'. F1 score: {} ====".format(d,f1_score(y,ypred))) 50 | print("") -------------------------------------------------------------------------------- /demo/run_demo_regression.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn.model_selection import cross_val_predict 3 | from sklearn import linear_model 4 | import matplotlib.pyplot as plt 5 | from InterpretableDecisionTreeRegression import IDecisionTreeRegressor 6 | 7 | lr = linear_model.LinearRegression() 8 | boston = datasets.load_boston() 9 | y = boston.target 10 | 11 | predicted1 = cross_val_predict(lr, boston.data, y, cv=10) 12 | predicted2 = cross_val_predict(IDecisionTreeRegressor(max_depth=5), boston.data, y, cv=10) 13 | 14 | print(IDecisionTreeRegressor(max_depth=5).fit(boston.data, y)) 15 | 16 | fig, ax = plt.subplots() 17 | ax.scatter(y, predicted1, color='b') 18 | ax.scatter(y, predicted2, color='r') 19 | ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4) 20 | ax.set_xlabel('Measured') 21 | ax.set_ylabel('Predicted') 22 | plt.show() -------------------------------------------------------------------------------- /demo/uci_comparison.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.classification import accuracy_score, f1_score 3 | import re, string 4 | from uci_loader import getdataset, tonumeric 5 | from sklearn.cross_validation import KFold 6 | from scipy.stats.stats import mannwhitneyu, ttest_ind 7 | 8 | comparison_datasets = [ 9 | "diabetes_scale", 10 | "breast-cancer", 11 | "uci-20070111 haberman", 12 | "heart", 13 | "liver-disorders", 14 | ] 15 | 16 | metrics = { 17 | #'Acc.': accuracy_score, 18 | 'F1score ': f1_score 19 | } 20 | 21 | def shorten(d): 22 | return "".join(re.findall("[^\W\d_]", d.lower().replace('datasets-', '').replace('uci', ''))) 23 | 24 | def print_results_table(results, rows, cols, cellsize=20): 25 | row_format =("{:>"+str(cellsize)+"}") * (len(cols) + 1) 26 | print row_format.format("", *cols) 27 | print "".join(["="]*cellsize*(len(cols)+1)) 28 | for rh, row in zip(rows, results): 29 | print row_format.format(rh, *row) 30 | 31 | def compare_estimators(estimators, datasets = comparison_datasets, metrics = metrics, n_cv_folds = 10, decimals = 3, cellsize = 22): 32 | if type(estimators) != dict: 33 | raise Exception("First argument needs to be a dict containing 'name': Estimator pairs") 34 | if type(metrics) != dict: 35 | raise Exception("Argument metrics needs to be a dict containing 'name': scoring function pairs") 36 | cols = [] 37 | for e in range(len(estimators)): 38 | for mname in metrics.keys(): 39 | cols.append(sorted(estimators.keys())[e]+" "+mname) 40 | 41 | rows = {} 42 | mean_results = {} 43 | std_results = {} 44 | found_datasets = [] 45 | for i in range(len(datasets)): 46 | d = datasets[i] 47 | print "comparing on dataset",i,d 48 | mean_result = [] 49 | std_result = [] 50 | try: 51 | X, y = getdataset(d) 52 | except: 53 | print "FAILED TO LOAD",d," - SKIPPING" 54 | continue 55 | found_datasets.append(i) 56 | rows[i] = shorten(d)+" (n="+str(len(y))+")" 57 | for e in range(len(estimators.keys())): 58 | est = estimators[sorted(estimators.keys())[e]] 59 | mresults = [[] for j in range(len(metrics))] 60 | for train_idx, test_idx in KFold(len(y), n_folds=n_cv_folds): 61 | est.fit(X[train_idx, :], y[train_idx]) 62 | y_pred = est.predict(X[test_idx, :]) 63 | for j in range(len(metrics)): 64 | try: 65 | mresults[j].append(metrics.values()[j](y[test_idx], y_pred)) 66 | except: 67 | mresults[j].append(metrics.values()[j](tonumeric(y[test_idx]), tonumeric(y_pred))) 68 | 69 | for j in range(len(metrics)): 70 | mean_result.append(np.mean(mresults[j])) 71 | std_result.append(np.std(mresults[j])/n_cv_folds) 72 | mean_results[i] = mean_result 73 | std_results[i] = std_result 74 | 75 | results = [] 76 | for i in found_datasets: 77 | result = [] 78 | 79 | sigstars = ["*"]*(len(estimators)*len(metrics)) 80 | for j in range(len(estimators)): 81 | for k in range(len(metrics)): 82 | for l in range(len(estimators)): 83 | #if j != l and mean_results[i][j*len(metrics)+k] < mean_results[i][l*len(metrics)+k] + 2*(std_results[i][j*len(metrics)+k] + std_results[i][l*len(metrics)+k]): 84 | if j != l and mean_results[i][j*len(metrics)+k] < mean_results[i][l*len(metrics)+k]: 85 | sigstars[j*len(metrics)+k] = "" 86 | 87 | for j in range(len(estimators)): 88 | for k in range(len(metrics)): 89 | result.append((sigstars[j*len(metrics)+k]+"%."+str(decimals)+"f (SE=%."+str(decimals)+"f)") % (mean_results[i][j*len(metrics)+k], std_results[i][j*len(metrics)+k])) 90 | results.append(result) 91 | 92 | print_results_table(results, rows, cols, cellsize) 93 | 94 | return mean_results, std_results, results 95 | -------------------------------------------------------------------------------- /demo/uci_loader.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import fetch_mldata 2 | from sklearn.preprocessing import OneHotEncoder 3 | import numpy as np 4 | 5 | def dshape(X): 6 | if len(X.shape) == 1: 7 | return X.reshape(-1,1) 8 | else: 9 | return X if X.shape[0]>X.shape[1] else X.T 10 | 11 | def unpack(t): 12 | while type(t) == list or type(t) == np.ndarray: 13 | t = t[0] 14 | return t 15 | 16 | def tonumeric(lst): 17 | lbls = {} 18 | for t in lst.flatten(): 19 | if unpack(t) not in lbls: 20 | lbls[unpack(t)] = len(lbls.keys()) 21 | return np.array([lbls[unpack(t)] for t in lst.flatten()]) 22 | 23 | def getdataset(datasetname, onehot_encode_strings=True): 24 | # load 25 | dataset = fetch_mldata(datasetname) 26 | # get X and y 27 | X = dshape(dataset.data) 28 | try: 29 | target = dshape(dataset.target) 30 | except: 31 | print "WARNING: No target found. Taking last column of data matrix as target" 32 | target = X[:, -1] 33 | X = X[:, :-1] 34 | if len(target.shape)>1 and target.shape[1]>X.shape[1]: # some mldata sets are mixed up... 35 | X = target 36 | target = dshape(dataset.data) 37 | if len(X.shape) == 1 or X.shape[1] <= 1: 38 | for k in dataset.keys(): 39 | if k != 'data' and k != 'target' and len(dataset[k]) == X.shape[1]: 40 | X = np.hstack((X, dshape(dataset[k]))) 41 | # one-hot for categorical values 42 | if onehot_encode_strings: 43 | cat_ft=[i for i in range(X.shape[1]) if 'str' in str(type(unpack(X[0,i]))) or 'unicode' in str(type(unpack(X[0,i])))] 44 | if len(cat_ft): 45 | for i in cat_ft: 46 | X[:,i] = tonumeric(X[:,i]) 47 | X = OneHotEncoder(categorical_features=cat_ft).fit_transform(X) 48 | # if sparse, make dense 49 | try: 50 | X = X.toarray() 51 | except: 52 | pass 53 | # convert y to monotonically increasing ints 54 | y = tonumeric(target).astype(int) 55 | return np.nan_to_num(X.astype(float)),y -------------------------------------------------------------------------------- /example_dt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmadl/sklearn-interpretable-tree/4a5229018b6ededf8ae64996fc5dbf3a06338b9b/example_dt.png -------------------------------------------------------------------------------- /treeutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from sklearn.metrics import make_scorer, f1_score 4 | from sklearn.tree import _tree 5 | import numpy as np, random as rnd 6 | 7 | import sys 8 | try: 9 | from cStringIO import StringIO 10 | except: 11 | from io import StringIO 12 | 13 | def simplify_tree(decision_tree, X, y, scorer=make_scorer(f1_score, greater_is_better=True), acceptable_score_drop=0.0, verbose=1): 14 | current_score, original_score = 0, 1 15 | 16 | while current_score != original_score: 17 | current_score = scorer(decision_tree, X, y) 18 | original_score = current_score 19 | tree = decision_tree.tree_ 20 | 21 | removed_branches = [] 22 | nodes = np.random.permutation(np.arange(tree.node_count)) 23 | for i in nodes: 24 | current_left, current_right = tree.children_left[i], tree.children_right[i] 25 | 26 | if tree.children_left[i] >= 0 or tree.children_right[i] >= 0: 27 | tree.children_left[i], tree.children_right[i] = -1, -1 28 | auc = scorer(decision_tree, X, y) 29 | if auc >= current_score - acceptable_score_drop: 30 | current_score = auc 31 | removed_branches.append(i) 32 | else: 33 | tree.children_left[i], tree.children_right[i] = current_left, current_right 34 | 35 | if verbose: 36 | print("Removed",len(removed_branches)," branches. current score: ", current_score) 37 | 38 | return decision_tree 39 | 40 | def tree_to_code(tree, feature_names, decimals=4, transform_to_probabilities=True): 41 | tree_ = tree.tree_ 42 | tree_feature_name = [ 43 | feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" 44 | for i in tree_.feature 45 | ] 46 | rounding_multiplier = np.power(10, decimals) 47 | round = lambda x: np.round(x*rounding_multiplier)/rounding_multiplier 48 | def leaf_value(value, samples=1): 49 | if transform_to_probabilities: 50 | return round(value / samples)[0][1] 51 | else: 52 | return value[0] 53 | 54 | stdout_ = sys.stdout 55 | sys.stdout = StringIO() 56 | 57 | #print("def probability_of_class_one({}):".format(", ".join(feature_names))+"") 58 | 59 | def recurse(node, depth): 60 | indent = " " * depth 61 | if tree_.feature[node] != _tree.TREE_UNDEFINED: 62 | name = tree_feature_name[node] 63 | threshold = tree_.threshold[node] 64 | 65 | if tree_.feature[tree_.children_left[node]] == _tree.TREE_UNDEFINED and \ 66 | tree_.feature[tree_.children_right[node]] == _tree.TREE_UNDEFINED and \ 67 | np.all(np.equal(tree_.value[tree_.children_left[node]], tree_.value[tree_.children_right[node]])): 68 | print("{}return {}".format(indent, leaf_value(tree_.value[node], tree_.weighted_n_node_samples[node]))) 69 | 70 | else: 71 | print("{}if {} <= {}:".format(indent, name, round(threshold))) 72 | recurse(tree_.children_left[node], depth + 1) 73 | print("{}else:".format(indent)) # # if {} > {}".format(indent, name, threshold)) 74 | recurse(tree_.children_right[node], depth + 1) 75 | else: 76 | if transform_to_probabilities: 77 | p = round(tree_.value[node] / tree_.weighted_n_node_samples[node])[0] 78 | else: 79 | p = tree_.value[node] 80 | 81 | print("{}return {}".format(indent, leaf_value(tree_.value[node], tree_.weighted_n_node_samples[node]))) 82 | 83 | recurse(0, 1) 84 | 85 | string = sys.stdout.getvalue() 86 | sys.stdout = stdout_ 87 | 88 | string = "def probability_of_class_one({}):".format(", ".join([f for f in feature_names if f in string]))+"\n"+string 89 | 90 | return string 91 | --------------------------------------------------------------------------------