├── README.md ├── feat_drift ├── README.md ├── data │ ├── feature_imp1.csv │ └── feature_imp2.csv ├── feat_drift_test.ipynb ├── feature_drift_output │ └── feature_drift_201802_201803.html ├── images │ ├── feat_drift.PNG │ └── feat_drift.gif └── src │ ├── feature_drift_draw.py │ └── feature_drift_template.html ├── sankey ├── README.md ├── data │ └── titanic_train.csv ├── images │ ├── sankey_flow_col.PNG │ ├── sankey_flow_col_val.PNG │ ├── sankey_flow_same.PNG │ ├── sankey_flow_tab20.PNG │ └── sankey_flow_val.PNG ├── sankey_flow_output │ └── sankey_flow_Titanic.html ├── sankey_flow_test.ipynb └── src │ ├── generate_sankey_flow.py │ └── sankey_flow_template.html └── tree ├── README.md ├── data └── titanic_train.csv ├── generate_tree_test.ipynb ├── image ├── sankey_tree.gif └── simple_tree.gif ├── sankey_tree_output ├── sankey_tree_Iris_Tree.html └── sankey_tree_Titanic_Tree.html ├── simple_tree_output ├── simple_tree_Iris_Tree.html └── simple_tree_Titanic_Tree.html └── src ├── generate_tree.py ├── sankey_tree_template.html └── simple_tree_template.html /README.md: -------------------------------------------------------------------------------- 1 | # Nuance 2 | I use Nuance to curate varied visualization thoughts during my data scientist career. 3 | It is not yet a package but a list of small ideas. Welcome to test them out! 4 | 5 | ## Why Nuance? 6 | **nuance n.** 7 | a subtle difference in meaning or opinion or attitude 8 | 9 | ## How to use? 10 | Please check instructions in the corresponding folder 11 | 12 | ## List of ideas 13 | 1. **simple tree**: [visualize a sklearn Decision Tree](https://github.com/SauceCat/Nuance/blob/master/tree) 14 | 15 |
16 | 2. **sankey tree**: [visualize a sklearn Decision Tree](https://github.com/SauceCat/Nuance/blob/master/tree) 17 | 18 |
19 | 3. **sankey flow**: [visualize a sankey flow](https://github.com/SauceCat/Nuance/tree/master/sankey) 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 |
31 | 4. **feature drift**: [visualize feature drift](https://github.com/SauceCat/Nuance/tree/master/feat_drift) 32 |
33 | 34 |     35 | 36 | -------------------------------------------------------------------------------- /feat_drift/README.md: -------------------------------------------------------------------------------- 1 | ## Visualize feature drift 2 | 3 | 4 | ## What's feature drift 5 | **"Feature drifts occur whenever the relevance of a feature grows or shrinks for incoming instances."** 6 | Check this paper: [A survey on feature drift adaptation: Definition, benchmark, challenges and future directions](https://www.sciencedirect.com/science/article/pii/S0164121216301030) 7 | 8 | **Make it simple:** If your training dataset is relevant to time, the subset of important features selected by the same model might be quite different through time. 9 | The idea is to try to visualize the "feature drift" between two different training sets. Usually these two datasets are from different snapshots. So this visualization could help detect "feature drift" through time. The expected inputs are two dataframes, containing feature importance information. You can check [this notebook](https://github.com/SauceCat/Nuance/blob/master/feat_drift/feat_drift_test.ipynb) for more details. 10 | 11 | ## How to use? 12 | 1. Download the folder [**feat_drift**](https://github.com/SauceCat/Nuance/tree/master/feat_drift) 13 | 2. The folder structure: 14 | ``` 15 | feat_drift 16 | - src: folder for all codes regarding the visualization 17 | - data: folder for the test data 18 | - feature_drift_output: folder for outputs (will be re-generated if it was deleted) 19 | - feat_drift_test.ipynb: instructions and examples in jupyter notebook 20 | - ... 21 | ``` 22 | 3. Install `jinja2` 23 | ``` 24 | pip install jinja2 25 | ``` 26 | 4. Use feature drift visualization: (visualization depends on D3.js, so you need to connect to the network) 27 | ```python 28 | import sys 29 | sys.path.insert(0, 'src/') 30 | import feature_drift_draw 31 | 32 | feature_drift_draw.feature_drift_graph(feat_imp1=feat_imp1, feat_imp2=feat_imp2, feature_name='feat_name', imp_name='imp', 33 | ds_name1='training set', ds_name2='test set', graph_name='train_test', 34 | top_n=20, max_bar_width=300, bar_height=30, middle_gap=300, fontsize=12, color_dict=None) 35 | ``` 36 | 5. A html file would be generated in [feature_drift_output](https://github.com/SauceCat/Nuance/tree/master/feat_drift/feature_drift_output). Open it using any browser you like (I like Chrome anyway). 37 | 38 | ## Parameters 39 | ```python 40 | def feature_drift_graph(feat_imp1, feat_imp2, feature_name, imp_name, ds_name1, ds_name2, graph_name=None, 41 | top_n=None, max_bar_width=300, bar_height=30, middle_gap=300, fontsize=12, color_dict=None): 42 | """ 43 | Draw feature drift graph 44 | 45 | :param feat_imp1: feature importance dataframe #1 46 | :param feat_imp2: feature importance dataframe #2 47 | :param feature_name: column name of features 48 | :param imp_name: column name of importance value 49 | :param ds_name1: name of dataset #1 50 | :param ds_name2: name of dataset #2 51 | :param top_n: show top_n features 52 | :param max_bar_width: maximum bar width 53 | :param bar_height: bar height 54 | :param middle_gap: gap between bars 55 | :param fontsize: font size 56 | :param color_dict: color dictionary 57 | """ 58 | ``` 59 | -------------------------------------------------------------------------------- /feat_drift/data/feature_imp1.csv: -------------------------------------------------------------------------------- 1 | feat_name,imp 2 | feat_0,13.3436942857 3 | feat_1,15.800973571400002 4 | feat_2,21.0843816667 5 | feat_3,48.5161874074 6 | feat_4,113.18283999999998 7 | feat_5,19.0311508852 8 | feat_6,31.0144289552 9 | feat_7,44.0083055556 10 | feat_8,21.79879273 11 | feat_9,18.5955840707 12 | feat_10,40.70143425 13 | feat_11,44.3106395238 14 | feat_12,58.8463938636 15 | feat_13,26.846384736799997 16 | feat_14,17.89738 17 | feat_15,20.8532046377 18 | feat_16,109.927572857 19 | feat_17,45.3217483929 20 | feat_18,30.324107931 21 | feat_19,19.0750985714 22 | feat_20,15.9053894444 23 | feat_21,43.1775085 24 | feat_22,16.1404971724 25 | feat_23,30.933264415700002 26 | feat_24,25.120187 27 | feat_25,19.1116620577 28 | feat_26,23.112241081100002 29 | feat_27,175.000243176 30 | feat_28,39.8828694737 31 | feat_29,54.21781872729999 32 | feat_30,23.4141987356 33 | feat_31,30.3297 34 | feat_32,39.556396 35 | feat_33,18.1330640678 36 | feat_34,44.2011352174 37 | feat_35,20.492 38 | feat_36,14.839914088099999 39 | feat_37,24.425098869 40 | feat_38,23.087398800000006 41 | feat_39,15.864219090899999 42 | feat_40,36.5840980508 43 | feat_41,52.3627042029 44 | feat_42,17.157752380999998 45 | feat_43,159.114082 46 | feat_44,30.9611028058 47 | feat_45,46.2027222222 48 | feat_46,37.8220857851 49 | feat_47,45.8025334091 50 | feat_48,50.6727130435 51 | feat_49,55.133470512799995 52 | feat_50,33.0315573684 53 | feat_51,23.7468116667 54 | feat_52,11.451292186 55 | feat_53,31.55620125 56 | feat_54,16.6248694444 57 | feat_55,29.06870825 58 | feat_56,24.9980913953 59 | feat_57,15.538474 60 | feat_58,16.471437055 61 | feat_59,20.3773357143 62 | feat_60,23.528037551 63 | feat_61,65.3004364286 64 | feat_62,45.0990146667 65 | feat_63,17.89653 66 | feat_64,16.2706778333 67 | feat_65,18.4892417005 68 | feat_66,24.78189 69 | feat_67,162.19046175399998 70 | feat_68,34.09889375 71 | feat_69,24.2538968056 72 | feat_70,430.755799091 73 | feat_71,36.1628342105 74 | feat_72,96.30485301739999 75 | feat_73,26.820805714299997 76 | feat_74,13.4593211864 77 | feat_75,65.3470578947 78 | feat_76,26.874275 79 | feat_77,60.66391 80 | feat_78,18.2826344444 81 | feat_79,139.530246267 82 | -------------------------------------------------------------------------------- /feat_drift/data/feature_imp2.csv: -------------------------------------------------------------------------------- 1 | feat_name,imp 2 | feat_0,10.8216213571 3 | feat_1,23.4163501481 4 | feat_2,11.565690294100001 5 | feat_3,48.9171478125 6 | feat_4,113.584958947 7 | feat_5,11.0962521419 8 | feat_6,17.8663032911 9 | feat_7,26.624295454499997 10 | feat_8,8.02320777095 11 | feat_9,11.2232443967 12 | feat_10,24.5941423598 13 | feat_11,28.699749 14 | feat_12,25.775422679000002 15 | feat_13,16.1959695 16 | feat_14,11.397097619 17 | feat_15,10.6688073224 18 | feat_16,146.73557250000005 19 | feat_17,23.4428398347 20 | feat_18,11.718083461500001 21 | feat_19,13.567881818699998 22 | feat_20,8.789489411760002 23 | feat_21,30.6499793103 24 | feat_22,10.708858439 25 | feat_23,16.2798533926 26 | feat_24,8.307419928060003 27 | feat_25,11.1156423881 28 | feat_26,10.9661604717 29 | feat_27,28.9956640463 30 | feat_28,11.607423439200002 31 | feat_29,35.8433884211 32 | feat_30,11.9170976667 33 | feat_31,7.815233333330001 34 | feat_32,22.894614 35 | feat_33,12.3991074603 36 | feat_34,16.485753913 37 | feat_35,11.297780625 38 | feat_36,8.3932906449 39 | feat_37,15.230949728299999 40 | feat_38,8.202025200000001 41 | feat_39,10.4006885714 42 | feat_40,12.918274803800001 43 | feat_41,41.7556285075 44 | feat_42,8.27308114035 45 | feat_43,90.61743875 46 | feat_44,16.160808051900002 47 | feat_45,39.1513645833 48 | feat_46,28.195249158200003 49 | feat_47,11.925391071400002 50 | feat_48,39.5888431379 51 | feat_49,25.952234090900003 52 | feat_50,11.6053776503 53 | feat_51,10.2122409615 54 | feat_52,8.86319782 55 | feat_53,26.429423928600002 56 | feat_54,8.92956916667 57 | feat_55,13.319904821400002 58 | feat_56,22.599339901 59 | feat_57,19.431493224 60 | feat_58,9.70069634855 61 | feat_59,19.0049746667 62 | feat_60,35.4501718966 63 | feat_61,72.6334658974 64 | feat_62,38.310038 65 | feat_63,7.72638458065 66 | feat_64,9.84875153025 67 | feat_65,15.948526475 68 | feat_66,18.7450463636 69 | feat_67,151.57445307700002 70 | feat_68,10.638199758999999 71 | feat_69,21.6299969863 72 | feat_70,180.07161175 73 | feat_71,39.1086955 74 | feat_72,211.432106861 75 | feat_73,26.0100392308 76 | feat_74,9.14663557073 77 | feat_75,19.9266190476 78 | feat_76,10.540188666699999 79 | feat_77,15.1881593548 80 | feat_78,12.2875825 81 | feat_79,128.742952879 82 | -------------------------------------------------------------------------------- /feat_drift/feat_drift_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## read fake feature importance \n", 17 | "The expected inputs are two feature importance dataframes to compare. \n", 18 | "It is assumed that the feature set between these two dataframes is exactly same." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "feat_imp1 = pd.read_csv('data/feature_imp1.csv')\n", 28 | "feat_imp2 = pd.read_csv('data/feature_imp2.csv')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/html": [ 39 | "
\n", 40 | "\n", 53 | "\n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | "
feat_nameimp
0feat_013.343694
1feat_115.800974
2feat_221.084382
3feat_348.516187
4feat_4113.182840
\n", 89 | "
" 90 | ], 91 | "text/plain": [ 92 | " feat_name imp\n", 93 | "0 feat_0 13.343694\n", 94 | "1 feat_1 15.800974\n", 95 | "2 feat_2 21.084382\n", 96 | "3 feat_3 48.516187\n", 97 | "4 feat_4 113.182840" 98 | ] 99 | }, 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "feat_imp1.head()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/html": [ 117 | "
\n", 118 | "\n", 131 | "\n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | "
feat_nameimp
0feat_010.821621
1feat_123.416350
2feat_211.565690
3feat_348.917148
4feat_4113.584959
\n", 167 | "
" 168 | ], 169 | "text/plain": [ 170 | " feat_name imp\n", 171 | "0 feat_0 10.821621\n", 172 | "1 feat_1 23.416350\n", 173 | "2 feat_2 11.565690\n", 174 | "3 feat_3 48.917148\n", 175 | "4 feat_4 113.584959" 176 | ] 177 | }, 178 | "execution_count": 4, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "feat_imp2.head()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "## test feature drift graph\n", 192 | "Here we assume that these two different feature importance results are from model trained on datasets from different snapshot months. \n", 193 | "- feat_imp1: feature importance from dataset 201802\n", 194 | "- feat_imp2: feature importance from dataset 201803 " 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 5, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "import sys\n", 204 | "sys.path.insert(0, 'src/')\n", 205 | "import feature_drift_draw" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 6, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "feature_drift_draw.feature_drift_graph(feat_imp1=feat_imp1, feat_imp2=feat_imp2, feature_name='feat_name', imp_name='imp',\n", 215 | " ds_name1='201802', ds_name2='201803', graph_name='201802_201803',\n", 216 | " top_n=20, max_bar_width=300, bar_height=30, middle_gap=300, fontsize=12, color_dict=None)" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 2", 223 | "language": "python", 224 | "name": "python2" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 2 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython2", 236 | "version": "2.7.14" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /feat_drift/feature_drift_output/feature_drift_201802_201803.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 15 | 16 |
17 |

Feature Drift

18 |

Visualize how important features drift through two datasets.

19 |
20 |
21 |
22 | 23 | 279 | 280 | -------------------------------------------------------------------------------- /feat_drift/images/feat_drift.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/feat_drift/images/feat_drift.PNG -------------------------------------------------------------------------------- /feat_drift/images/feat_drift.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/feat_drift/images/feat_drift.gif -------------------------------------------------------------------------------- /feat_drift/src/feature_drift_draw.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import jinja2 4 | import os 5 | 6 | 7 | def _process_imp(imp_df, imp_name): 8 | """ 9 | Preprocessing on the input feature importance dataframe 10 | 11 | :param imp_df: feature importance pandas dataframe 12 | :param imp_name: column name of the importance value 13 | :return: 14 | dataframe with relative_imp and feat_rank 15 | """ 16 | 17 | imp_df = imp_df.sort_values(by=imp_name, ascending=False).reset_index(drop=True) 18 | imp_df['relative_imp'] = imp_df[imp_name] * 1.0 / imp_df[imp_name].max() 19 | imp_df['relative_imp'] = imp_df['relative_imp'].apply(lambda x : round(x, 3)) 20 | imp_df['feat_rank'] = imp_df.index.values + 1 21 | return imp_df 22 | 23 | 24 | def _rank2color(x, color_dict): 25 | """ 26 | Map change of rank to color 27 | 28 | :param x: row of dataframe 29 | :param color_dict: color dictionary 30 | """ 31 | 32 | if x['feat_rank_x'] < x['feat_rank_y']: 33 | return color_dict['drop'] 34 | if x['feat_rank_x'] >= x['feat_rank_y']: 35 | return color_dict['up_or_stable'] 36 | if pd.isnull(x['feat_rank_y']): 37 | return color_dict['disappear'] 38 | if pd.isnull(x['feat_rank_x']): 39 | return color_dict['appear'] 40 | 41 | 42 | def _get_mark(x): 43 | """ 44 | '1' for feature appears on both feature importance dataframes 45 | '0' for feature disappears on either dataframe 46 | """ 47 | if pd.isnull(x['feat_rank_y']) or pd.isnull(x['feat_rank_x']): 48 | return "0" 49 | else: 50 | return "1" 51 | 52 | 53 | def _merge_feat_imp(imp_df1, imp_df2, feature_name, top_n, color_dict): 54 | """ 55 | Merge and compare two feature importance dataframes 56 | 57 | :param imp_df1: feature importance dataframe #1 58 | :param imp_df2: feature importance dataframe #2 59 | :param feature_name: column name of features 60 | :param top_n: show top_n features 61 | :param color_dict: color dictionary 62 | :return: 63 | The merged dataframe 64 | """ 65 | 66 | imp_df1['pos'] = 'left' 67 | imp_df2['pos'] = 'right' 68 | if top_n: 69 | both_imp = imp_df1.head(top_n).merge(imp_df2.head(top_n), on=feature_name, how='outer') 70 | else: 71 | both_imp = imp_df1.merge(imp_df2, on=feature_name, how='outer') 72 | 73 | both_imp['bar_color'] = both_imp.apply(lambda x : _rank2color(x, color_dict), axis=1) 74 | both_imp['bar_mark'] = both_imp.apply(lambda x : _get_mark(x), axis=1) 75 | 76 | return both_imp 77 | 78 | 79 | def feature_drift_graph(feat_imp1, feat_imp2, feature_name, imp_name, ds_name1, ds_name2, graph_name=None, 80 | top_n=None, max_bar_width=300, bar_height=30, middle_gap=300, fontsize=12, color_dict=None): 81 | """ 82 | Draw feature drift graph 83 | 84 | :param feat_imp1: feature importance dataframe #1 85 | :param feat_imp2: feature importance dataframe #2 86 | :param feature_name: column name of features 87 | :param imp_name: column name of importance value 88 | :param ds_name1: name of dataset #1 89 | :param ds_name2: name of dataset #2 90 | :param top_n: show top_n features 91 | :param max_bar_width: maximum bar width 92 | :param bar_height: bar height 93 | :param middle_gap: gap between bars 94 | :param fontsize: font size 95 | :param color_dict: color dictionary 96 | """ 97 | 98 | feat_imp1 = _process_imp(feat_imp1, imp_name) 99 | feat_imp2 = _process_imp(feat_imp2, imp_name) 100 | 101 | if color_dict is None: 102 | color_dict = { 103 | 'drop': '#f17182', 104 | 'up_or_stable': '#abdda4', 105 | 'disappear': '#bababa', 106 | 'appear': '#9ac6df' 107 | } 108 | 109 | both_imp = _merge_feat_imp(feat_imp1, feat_imp2, feature_name, top_n, color_dict) 110 | 111 | bar_left_data = both_imp[['feat_name', 'relative_imp_x', 'pos_x', 'bar_color', 'bar_mark'] 112 | ].dropna().sort_values('relative_imp_x', ascending=False) 113 | bar_left_data.columns = [col.replace('_x', '') for col in bar_left_data.columns.values] 114 | 115 | bar_right_data = both_imp[['feat_name', 'relative_imp_y', 'pos_y', 'bar_color', 'bar_mark'] 116 | ].dropna().sort_values('relative_imp_y', ascending=False) 117 | bar_right_data.columns = [col.replace('_y', '') for col in bar_right_data.columns.values] 118 | 119 | line_data = both_imp[['feat_name', 'bar_color', 'feat_rank_x', 'feat_rank_y']].dropna()[['feat_name', 'bar_color']] 120 | 121 | legend_data = [ 122 | {'name': 'Drop', 'color': color_dict['drop']}, 123 | {'name': 'Up & Stable', 'color': color_dict['up_or_stable']}, 124 | {'name': 'Disappear', 'color': color_dict['disappear']}, 125 | {'name': 'Appear', 'color': color_dict['appear']} 126 | ] 127 | 128 | # render the output 129 | temp = open('src/feature_drift_template.html').read() 130 | template = jinja2.Template(temp) 131 | 132 | # create the output root if it is not exits 133 | if not os.path.exists('feature_drift_output'): 134 | os.mkdir('feature_drift_output') 135 | 136 | # generate output html 137 | if graph_name is None: 138 | output_path = 'feature_drift_output/feature_drift_output.html' 139 | else: 140 | output_path = 'feature_drift_output/feature_drift_%s.html' %graph_name 141 | 142 | with open(output_path, 'wb') as fh: 143 | fh.write(template.render({'bar_left_data': bar_left_data.to_dict('records'), 144 | 'bar_right_data': bar_right_data.to_dict('records'), 145 | 'line_data': line_data.to_dict('records'), 146 | 'legend_data': legend_data, 147 | 'max_bar_width': max_bar_width, 'bar_height': bar_height, 148 | 'middle_gap': middle_gap, 'fontsize': fontsize, 149 | 'ds_name1': ds_name1, 'ds_name2': ds_name2})) 150 | -------------------------------------------------------------------------------- /feat_drift/src/feature_drift_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 15 | 16 |
17 |

Feature Drift

18 |

Visualize how important features drift through two datasets.

19 |
20 |
21 |
22 | 23 | 279 | 280 | -------------------------------------------------------------------------------- /sankey/README.md: -------------------------------------------------------------------------------- 1 | ## Visualize a sankey flow 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | ## How to use? 14 | 1. Download the folder [**sankey**](https://github.com/SauceCat/Nuance/tree/master/sankey) 15 | 2. The folder structure: 16 | ``` 17 | sankey 18 | - src: folder for all codes regarding the visualization 19 | - data: folder for the test data 20 | - sankey_flow_output: folder for outputs (will be re-generated if it was deleted) 21 | - sankey_flow_test.ipynb: instructions and examples in jupyter notebook 22 | - ... 23 | ``` 24 | 3. Install `jinja2` 25 | ``` 26 | pip install jinja2 27 | ``` 28 | 4. Use sankey_flow visualization: (visualization depends on D3.js, so you need to connect to the network) 29 | ```python 30 | import sys 31 | sys.path.insert(0, 'src/') 32 | import generate_sankey_flow 33 | 34 | generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='col', link_color_type='source', 35 | width=1600, height=900, graph_name='Titanic', 36 | node_color_mapping=None, color_map=None, link_color=None) 37 | ``` 38 | 5. A html file would be generated in [sankey_flow_output](https://github.com/SauceCat/Nuance/tree/master/sankey/sankey_flow_output). Open it using any browser you like (I like Chrome anyway). 39 | 40 | ## Parameters 41 | ```python 42 | def draw_sankey_flow(df, node_color_type, link_color_type, width, height, 43 | graph_name=None, node_color_mapping=None, color_map=None, link_color=None): 44 | ''' 45 | :param df: 46 | pandas DataFrame, each column represents a state 47 | :param node_color_type: 48 | node coloring strategy, can be one of ['col', 'val', 'col_val', 'cus'] 49 | - 'col': each column has different color 50 | - 'val': each unique value has different color (unique values through all columns) 51 | - 'col_val': each unique value in each column has different color 52 | - 'cus': customer provide node color mapping 53 | :param link_color_type: 54 | link coloring strategy, default='source' 55 | Can be one of ['source', 'target', 'both', 'same'] 56 | - 'source': same color as the source node 57 | - 'target': same color as the target node 58 | - 'both': color from both target and source 59 | - 'same': all links have same color 60 | :param width: wdith 61 | :param height: height 62 | :param graph_name: name of the graph 63 | :param node_color_mapping: 64 | if node_color_type == 'cus', color_mapping should be provided 65 | example: 66 | node_color_mapping = { 67 | 'type': 'col', 68 | 'mapping': { 69 | column1: color1, column2: color2, ... 70 | } 71 | } 72 | :param color_map: matplotlib color map 73 | :param link_color: 74 | if link_color_type == 'same', link color should be provided 75 | ''' 76 | ``` 77 | -------------------------------------------------------------------------------- /sankey/images/sankey_flow_col.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/sankey/images/sankey_flow_col.PNG -------------------------------------------------------------------------------- /sankey/images/sankey_flow_col_val.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/sankey/images/sankey_flow_col_val.PNG -------------------------------------------------------------------------------- /sankey/images/sankey_flow_same.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/sankey/images/sankey_flow_same.PNG -------------------------------------------------------------------------------- /sankey/images/sankey_flow_tab20.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/sankey/images/sankey_flow_tab20.PNG -------------------------------------------------------------------------------- /sankey/images/sankey_flow_val.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/sankey/images/sankey_flow_val.PNG -------------------------------------------------------------------------------- /sankey/sankey_flow_output/sankey_flow_Titanic.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 27 | 28 | 29 | 30 |
31 |
32 |
33 | 34 | 35 | 36 | 331 | 493 | 494 | 495 | -------------------------------------------------------------------------------- /sankey/sankey_flow_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## test sankey_flow visualization with Titanic dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## read data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "raw = pd.read_csv('data/titanic_train.csv')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": { 44 | "collapsed": false 45 | }, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/html": [ 50 | "
\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", 147 | "
" 148 | ], 149 | "text/plain": [ 150 | " PassengerId Survived Pclass \\\n", 151 | "0 1 0 3 \n", 152 | "1 2 1 1 \n", 153 | "2 3 1 3 \n", 154 | "3 4 1 1 \n", 155 | "4 5 0 3 \n", 156 | "\n", 157 | " Name Sex Age SibSp \\\n", 158 | "0 Braund, Mr. Owen Harris male 22.0 1 \n", 159 | "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", 160 | "2 Heikkinen, Miss. Laina female 26.0 0 \n", 161 | "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", 162 | "4 Allen, Mr. William Henry male 35.0 0 \n", 163 | "\n", 164 | " Parch Ticket Fare Cabin Embarked \n", 165 | "0 0 A/5 21171 7.2500 NaN S \n", 166 | "1 0 PC 17599 71.2833 C85 C \n", 167 | "2 0 STON/O2. 3101282 7.9250 NaN S \n", 168 | "3 0 113803 53.1000 C123 S \n", 169 | "4 0 373450 8.0500 NaN S " 170 | ] 171 | }, 172 | "execution_count": 3, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "raw.head()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 4, 184 | "metadata": { 185 | "collapsed": false 186 | }, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "array(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',\n", 192 | " 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'], dtype=object)" 193 | ] 194 | }, 195 | "execution_count": 4, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "raw.columns.values" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 5, 207 | "metadata": { 208 | "collapsed": false 209 | }, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "Survived: [0 1]\n", 216 | "Pclass: [3 1 2]\n", 217 | "Sex: ['male' 'female']\n", 218 | "SibSp: [1 0 3 4 2 5 8]\n", 219 | "Parch: [0 1 2 5 3 4 6]\n", 220 | "Embarked: ['S' 'C' 'Q' nan]\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "# use these columns as layers of nodes\n", 226 | "use_cols = ['Survived', 'Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']\n", 227 | "for col in use_cols:\n", 228 | " print '%s: %s' %(col, str(raw[col].unique()))" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 6, 234 | "metadata": { 235 | "collapsed": true 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "raw['Embarked'] = raw['Embarked'].fillna('unknown')" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## input for generate_sankey_flow\n", 247 | "The function is expecting a pandas DataFrame. In this DataFrame, each colum represents a state. \n", 248 | "Take Titanic for example, columns **['Survived', 'Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']** are 6 different states. Unique values in each column are all possible state value in that state. And each row represents a unit and its flowing path. \n", 249 | "Well, to draw a proper sankey path, you need to calculate how many units is flowing from (status1, value1) to (status2, value1), how many units is flowing from (status2, value1) to (status3, value3), ..., which was tedious. Fortunately, it is handled automatically in **generate_sankey_flow** function." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 7, 255 | "metadata": { 256 | "collapsed": false 257 | }, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "text/html": [ 262 | "
\n", 263 | "\n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | "
SurvivedPclassSexSibSpParchEmbarked
003male10S
111female10C
213female00S
311female10S
403male00S
\n", 323 | "
" 324 | ], 325 | "text/plain": [ 326 | " Survived Pclass Sex SibSp Parch Embarked\n", 327 | "0 0 3 male 1 0 S\n", 328 | "1 1 1 female 1 0 C\n", 329 | "2 1 3 female 0 0 S\n", 330 | "3 1 1 female 1 0 S\n", 331 | "4 0 3 male 0 0 S" 332 | ] 333 | }, 334 | "execution_count": 7, 335 | "metadata": {}, 336 | "output_type": "execute_result" 337 | } 338 | ], 339 | "source": [ 340 | "raw[use_cols].head()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "## color strategy\n", 348 | "**node_color_type**, default=\"col_val\" \n", 349 | "Can be ['col', 'val', 'col_val', 'cus']\n", 350 | "- 'col': each column has different color\n", 351 | "- 'val': each unique value has different color (unique values through all columns)\n", 352 | "- 'col_val': each unique value in each column has different color\n", 353 | "- 'cus': customer provide node color mapping\n", 354 | "\n", 355 | "**node_color_mapping**, default=None \n", 356 | "Customized color mapping.\n", 357 | "- node_color_mapping = '#fff' \n", 358 | "All nodes have the same color\n", 359 | "- node_color_mapping = dict() \n", 360 | "\n", 361 | "```python\n", 362 | "node_color_mapping = {\n", 363 | " 'type': 'col',\n", 364 | " 'mapping': {\n", 365 | " column1: color1, column2: color2, ...\n", 366 | " }\n", 367 | "}\n", 368 | "\n", 369 | "node_color_mapping = {\n", 370 | " 'type': 'val',\n", 371 | " 'mapping': {\n", 372 | " value1: color1, value2: color2, ...\n", 373 | " }\n", 374 | "}\n", 375 | "\n", 376 | "node_color_mapping = {\n", 377 | " 'type': 'col_val',\n", 378 | " 'mapping': {\n", 379 | " column1: {value1: color1, value2: color2, ...},\n", 380 | " column2: {value1: color3, value2: color4, ...}\n", 381 | " }\n", 382 | "}\n", 383 | "```\n", 384 | "\n", 385 | "**link_color_type**, default='source' \n", 386 | "Can be ['source', 'target', 'both', 'same']\n", 387 | "- 'source': same color as the source node\n", 388 | "- 'target': same color as the target node\n", 389 | "- 'both': color from both target and source\n", 390 | "- 'same': all links have same color\n", 391 | "\n", 392 | "**link_color**, default=None \n", 393 | "Only required when `link_color_type=\"same\"`" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "## start generating sankey flow!" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 8, 406 | "metadata": { 407 | "collapsed": true 408 | }, 409 | "outputs": [], 410 | "source": [ 411 | "import sys\n", 412 | "sys.path.insert(0, 'src/')\n", 413 | "import generate_sankey_flow" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "### use node_color_type" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 9, 426 | "metadata": { 427 | "collapsed": true 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='col', link_color_type='source', \n", 432 | " width=1600, height=900, graph_name='Titanic', \n", 433 | " node_color_mapping=None, color_map=None, link_color=None)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 18, 439 | "metadata": { 440 | "collapsed": true 441 | }, 442 | "outputs": [], 443 | "source": [ 444 | "# change color_map\n", 445 | "generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='col', link_color_type='same', \n", 446 | " width=1600, height=900, graph_name='Titanic_same', \n", 447 | " node_color_mapping=None, color_map='tab20', link_color='#ccc')" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "### use customized node_color_type and provide node_color_mapping" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 15, 460 | "metadata": { 461 | "collapsed": false 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "node_color_mapping = {\n", 466 | " 'type': 'col',\n", 467 | " 'mapping': {\n", 468 | " 'Survived': '#9e0142', \n", 469 | " 'Pclass': '#d53e4f', \n", 470 | " 'Sex': '#f46d43', \n", 471 | " 'SibSp': '#fdae61', \n", 472 | " 'Parch': '#fee08b', \n", 473 | " 'Embarked': '#ffffbf'\n", 474 | " }\n", 475 | "}\n", 476 | "generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='cus', link_color_type='source', \n", 477 | " width=1600, height=900, graph_name='Titanic_col', \n", 478 | " node_color_mapping=node_color_mapping, color_map=None, link_color=None)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 16, 484 | "metadata": { 485 | "collapsed": true 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "node_color_mapping = {\n", 490 | " 'type': 'val',\n", 491 | " 'mapping': {\n", 492 | " 0: '#9e0142',\n", 493 | " 1: '#d53e4f',\n", 494 | " 2: '#f46d43',\n", 495 | " 3: '#fdae61',\n", 496 | " 4: '#fee08b',\n", 497 | " 5: '#ffffbf',\n", 498 | " 6: '#e6f598',\n", 499 | " 8: '#abdda4',\n", 500 | " 'male': '#66c2a5',\n", 501 | " 'female': '#3288bd',\n", 502 | " 'S': '#5e4fa2',\n", 503 | " 'C': '#9e0142',\n", 504 | " 'Q': '#d53e4f',\n", 505 | " 'unknown': '#f46d43'\n", 506 | " }\n", 507 | "}\n", 508 | "generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='cus', link_color_type='source', \n", 509 | " width=1600, height=900, graph_name='Titanic_val', \n", 510 | " node_color_mapping=node_color_mapping, color_map=None, link_color=None)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": 17, 516 | "metadata": { 517 | "collapsed": true 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "node_color_mapping = {\n", 522 | " 'type': 'col_val',\n", 523 | " 'mapping': {\n", 524 | " 'Survived': {0: '#9e0142', 1: '#5e4fa2'},\n", 525 | " 'Pclass': {1: '#d53e4f', 2: '#3288bd', 3: '#f46d43'},\n", 526 | " 'Sex': {'male': '#fdae61', 'female': '#66c2a5'},\n", 527 | " 'SibSp': {0: '#fee08b', 1: '#ffffbf', 2: '#e6f598', 3: '#abdda4', 4: '#9e0142', 5: '#5e4fa2', 8: '#d53e4f'},\n", 528 | " 'Parch': {0: '#f46d43', 1: '#fdae61', 2: '#fee08b', 3: '#ffffbf', 4: '#e6f598', 5: '#abdda4', 6: '#66c2a5'},\n", 529 | " 'Embarked': {'S': '#3288bd', 'C': '#9e0142', 'Q': '#d53e4f', 'unknown': '#f46d43'}\n", 530 | " }\n", 531 | "}\n", 532 | "generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='cus', link_color_type='source', \n", 533 | " width=1600, height=900, graph_name='Titanic_col_val', \n", 534 | " node_color_mapping=node_color_mapping, color_map=None, link_color=None)" 535 | ] 536 | } 537 | ], 538 | "metadata": { 539 | "anaconda-cloud": {}, 540 | "kernelspec": { 541 | "display_name": "Python [default]", 542 | "language": "python", 543 | "name": "python2" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": { 547 | "name": "ipython", 548 | "version": 2.0 549 | }, 550 | "file_extension": ".py", 551 | "mimetype": "text/x-python", 552 | "name": "python", 553 | "nbconvert_exporter": "python", 554 | "pygments_lexer": "ipython2", 555 | "version": "2.7.12" 556 | } 557 | }, 558 | "nbformat": 4, 559 | "nbformat_minor": 0 560 | } -------------------------------------------------------------------------------- /sankey/src/generate_sankey_flow.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | import json 6 | import jinja2 7 | import os 8 | 9 | import sys 10 | reload(sys) 11 | sys.setdefaultencoding('utf-8') 12 | 13 | 14 | def _get_node_names(df): 15 | ''' 16 | :param df: pandas DataFrame, each column represents a state 17 | :return: 18 | dictionary of (column name, column unique value list) 19 | which means (state name, unique values in this state) 20 | ''' 21 | 22 | node_infos = {} 23 | 24 | for col in df.columns.values: 25 | col_values = sorted(df[col].unique()) 26 | col_names = [] 27 | 28 | for col_value in col_values: 29 | node_name = '%s: %s' %(col, str(col_value)) 30 | col_names.append(node_name) 31 | 32 | node_infos[col] = {'col_names': col_names, 'col_values': col_values} 33 | 34 | return node_infos 35 | 36 | 37 | def _get_node_colors(node_infos, node_color_type, node_color_mapping, cm): 38 | ''' 39 | :param node_infos: 40 | dictionary of (state name, unique values in this state) 41 | :param node_color_type: 42 | node coloring strategy, can be one of ['col', 'val', 'col_val', 'cus'] 43 | - 'col': each column has different color 44 | - 'val': each unique value has different color (unique values through all columns) 45 | - 'col_val': each unique value in each column has different color 46 | - 'cus': customer provide node color mapping 47 | :param node_color_mapping: 48 | if node_color_type == 'cus', node_color_mapping should't be None 49 | :param cm: matplotlib color map 50 | :return: 51 | dictionary of (node name, node color) 52 | ''' 53 | 54 | node_colors = {} 55 | 56 | if (node_color_type == 'col') or (node_color_type == 'cus' and node_color_mapping['type'] == 'col'): 57 | for col_idx, col in enumerate(node_infos.keys()): 58 | if node_color_type == 'cus': 59 | col_color = node_color_mapping['mapping'][col] 60 | else: 61 | col_color = cm(col_idx % 20) 62 | for col_name in node_infos[col]['col_names']: 63 | node_colors[col_name] = matplotlib.colors.rgb2hex(col_color) 64 | 65 | if (node_color_type == 'val') or (node_color_type == 'cus' and node_color_mapping['type'] == 'val'): 66 | unique_values = [] 67 | for col in node_infos.keys(): 68 | unique_values += list(node_infos[col]['col_values']) 69 | unique_values = list(set(unique_values)) 70 | 71 | for col in node_infos.keys(): 72 | for col_value in node_infos[col]['col_values']: 73 | if node_color_type == 'cus': 74 | val_color = node_color_mapping['mapping'][col_value] 75 | else: 76 | val_idx = unique_values.index(col_value) 77 | val_color = matplotlib.colors.rgb2hex(cm(val_idx % 20)) 78 | node_colors['%s: %s' %(col, str(col_value))] = val_color 79 | 80 | if (node_color_type == 'col_val') or (node_color_type == 'cus' and node_color_mapping['type'] == 'col_val'): 81 | for col_idx, col in enumerate(node_infos.keys()): 82 | if node_color_type == 'cus': 83 | for col_value in node_color_mapping['mapping'][col]: 84 | node_colors['%s: %s' % (col, str(col_value))] = node_color_mapping['mapping'][col][col_value] 85 | else: 86 | for col_name_idx, col_name in enumerate(node_infos[col]['col_names']): 87 | col_name_idx_true = col_idx * len(node_infos[col]['col_names']) + col_name_idx 88 | node_colors[col_name] = matplotlib.colors.rgb2hex(cm(col_name_idx_true % 20)) 89 | 90 | return node_colors 91 | 92 | 93 | def _prepare_sankey_data(df, node_colors): 94 | ''' 95 | :param df: pandas DataFrame, each column represents a state 96 | :param node_colors: dictionary of (node name, node color) 97 | :return: 98 | dictionary of links and nodes 99 | ''' 100 | 101 | links = [] 102 | for i in range(len(df.columns.values)-1): 103 | source_col, target_col = df.columns.values[i], df.columns.values[i+1] 104 | temp_df = df[[source_col, target_col]] 105 | temp_df['count'] = 1 106 | temp_df = temp_df.rename(columns={source_col: 'source', target_col: 'target'}) 107 | temp_df_gp = temp_df.groupby(['source', 'target'], as_index=False).count() 108 | 109 | temp_df_gp['source'] = temp_df_gp['source'].apply(lambda x : '%s: %s' %(source_col, str(x))) 110 | temp_df_gp['target'] = temp_df_gp['target'].apply(lambda x : '%s: %s' %(target_col, str(x))) 111 | 112 | temp_df_gp['color_source'] = temp_df_gp['source'].apply(lambda x : node_colors[x] if x in node_colors.keys() else '#000') 113 | temp_df_gp['color_target'] = temp_df_gp['target'].apply(lambda x : node_colors[x] if x in node_colors.keys() else '#000') 114 | temp_df_gp['value'] = temp_df_gp['count'].map(str) 115 | 116 | links+= temp_df_gp[['source', 'target', 'value']].to_dict('records') 117 | 118 | nodes = [{'name': n, 'color': c} for (n, c) in node_colors.items()] 119 | data = { 120 | 'links': links, 121 | 'nodes': nodes 122 | } 123 | return data 124 | 125 | 126 | def draw_sankey_flow(df, node_color_type, link_color_type, width, height, graph_name=None, 127 | node_color_mapping=None, color_map=None, link_color=None): 128 | ''' 129 | :param df: 130 | pandas DataFrame, each column represents a state 131 | :param node_color_type: 132 | node coloring strategy, can be one of ['col', 'val', 'col_val', 'cus'] 133 | - 'col': each column has different color 134 | - 'val': each unique value has different color (unique values through all columns) 135 | - 'col_val': each unique value in each column has different color 136 | - 'cus': customer provide node color mapping 137 | :param link_color_type: 138 | link coloring strategy, default='source' 139 | Can be one of ['source', 'target', 'both', 'same'] 140 | - 'source': same color as the source node 141 | - 'target': same color as the target node 142 | - 'both': color from both target and source 143 | - 'same': all links have same color 144 | :param width: wdith 145 | :param height: height 146 | :param graph_name: name of the graph 147 | :param node_color_mapping: 148 | if node_color_type == 'cus', color_mapping should be provided 149 | example: 150 | node_color_mapping = { 151 | 'type': 'col', 152 | 'mapping': { 153 | column1: color1, column2: color2, ... 154 | } 155 | } 156 | :param color_map: matplotlib color map 157 | :param link_color: 158 | if link_color_type == 'same', link color should be provided 159 | ''' 160 | 161 | # get node infos 162 | node_infos = _get_node_names(df) 163 | 164 | # get node colors 165 | if color_map is None: 166 | cm = plt.cm.get_cmap('Vega20') 167 | else: 168 | try: 169 | cm = plt.cm.get_cmap(color_map) 170 | except: 171 | cm = plt.cm.get_cmap('Vega20') 172 | 173 | node_colors = _get_node_colors(node_infos, node_color_type, node_color_mapping, cm) 174 | 175 | # prepare data for sankey 176 | sankey_data = _prepare_sankey_data(df, node_colors) 177 | 178 | # render the output 179 | temp = open('src/sankey_flow_template.html').read() 180 | template = jinja2.Template(temp) 181 | 182 | # create the output root if it is not exits 183 | if not os.path.exists('sankey_flow_output'): 184 | os.mkdir('sankey_flow_output') 185 | 186 | # generate output html 187 | if graph_name is None: 188 | output_path = 'sankey_flow_output/sankey_flow_output.html' 189 | else: 190 | output_path = 'sankey_flow_output/sankey_flow_%s.html' %(graph_name) 191 | with open(output_path, 'wb') as fh: 192 | fh.write(template.render({'data': json.dumps(sankey_data), 'link_color_type': link_color_type, 193 | 'link_color': link_color, 'width': width, 'height': height})) 194 | print('The output is in %s. Enjoy!' %(output_path)) 195 | -------------------------------------------------------------------------------- /sankey/src/sankey_flow_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 27 | 28 | 29 | 30 |
31 |
32 |
33 | 34 | 35 | 36 | 331 | 493 | 494 | 495 | -------------------------------------------------------------------------------- /tree/README.md: -------------------------------------------------------------------------------- 1 | ## Visualize a sklearn Decision Tree Classifier 2 | - simple tree 3 | 4 | - sankey tree 5 | 6 | 7 | ## How to use? 8 | Take **simple_tree** as an example: 9 | 1. Download the folder [**tree**](https://github.com/SauceCat/Nuance/tree/master/tree) 10 | 2. The folder structure: 11 | ``` 12 | tree 13 | - src: folder for all codes regarding the visualization 14 | - data: folder for the test data 15 | - simple_tree_output: folder for simple_tree outputs (will be re-generated if it was deleted) 16 | - sankey_tree_output: folder for sankey_tree outputs (will be re-generated if it was deleted) 17 | - generate_tree_test.ipynb: instructions and examples in jupyter notebook 18 | - ... 19 | ``` 20 | 3. Install `jinja2` 21 | ``` 22 | pip install jinja2 23 | ``` 24 | 4. Use simple_tree or sankey_tree visualization: (visualization depends on D3.js, so you need to connect to the network) 25 | ```python 26 | import sys 27 | sys.path.insert(0, 'src/') 28 | import generate_tree 29 | 30 | # simple tree 31 | generate_tree.generate_simple_tree(tree_title='Titanic_Tree', tree_model=dt, X=titanic[features], 32 | target_names=['Not Survived', 'Survived'], target_colors = None, 33 | color_map=None, width=1500, height=1000) 34 | 35 | # sankey tree 36 | generate_tree.generate_sankey_tree(tree_title='Titanic_Tree', tree_model=dt, X=titanic[features], 37 | target_names=['Not Survived', 'Survived'], target_colors = None, 38 | color_map=None, width=1500, height=1200) 39 | ``` 40 | 5. A html file would be generated in [simple_tree_output](https://github.com/SauceCat/Nuance/tree/master/tree/simple_tree_output) or [sankey_tree_output](https://github.com/SauceCat/Nuance/tree/master/tree/sankey_tree_output) folder. Open it using any browser you like (I like Chrome anyway). 41 | 42 | ## Parameters 43 | ```python 44 | def generate_simple_tree(tree_title, tree_model, X, target_names, 45 | target_colors=None, color_map=None, width=None, height=None): 46 | ''' 47 | visualize a sklearn Decision Tree Classifier 48 | 49 | :param tree_title: string 50 | name of the tree 51 | :param tree_model: a fitted sklearn Decision Tree Classifier 52 | :param X: pandas DataFrame 53 | dataset model was fitted on 54 | :param target_names: list 55 | list of names for targets 56 | :param target_colors: list, default=None 57 | list of colors for targets 58 | :param color_map: string, default=None 59 | matplotlib color map name, like 'Vega20' 60 | :param width: int 61 | width of the html page 62 | :param height: int 63 | height of the html page 64 | ''' 65 | 66 | def generate_sankey_tree(tree_title, tree_model, X, target_names, 67 | target_colors=None, color_map=None, width=None, height=None): 68 | ''' 69 | visualize a sklearn Decision Tree Classifier 70 | 71 | :param tree_title: string 72 | name of the tree 73 | :param tree_model: a fitted sklearn Decision Tree Classifier 74 | :param X: pandas DataFrame 75 | dataset model was fitted on 76 | :param target_names: list 77 | list of names for targets 78 | :param target_colors: list, default=None 79 | list of colors for targets 80 | :param color_map: string, default=None 81 | matplotlib color map name, like 'Vega20' 82 | :param width: int 83 | width of the html page 84 | :param height: int 85 | height of the html page 86 | ''' 87 | ``` 88 | -------------------------------------------------------------------------------- /tree/generate_tree_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## test generate_tree visualization with \n", 8 | "1. Titanic dataset: Binary classification Decision Tree\n", 9 | "2. Iris dataset: Multiclass classification Decision Tree" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Titanic: Binary classifier" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "import pandas as pd\n", 28 | "import numpy as np\n", 29 | "\n", 30 | "from sklearn.tree import DecisionTreeClassifier" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": { 37 | "collapsed": true 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "# read dataset\n", 42 | "titanic = pd.read_csv('data/titanic_train.csv')\n", 43 | "\n", 44 | "# impute null values\n", 45 | "titanic[\"Age\"] = titanic[\"Age\"].fillna(titanic[\"Age\"].dropna().median())\n", 46 | "titanic[\"Embarked\"] = titanic[\"Embarked\"].fillna(\"S\")\n", 47 | "\n", 48 | "# handle categrical features\n", 49 | "titanic['Sex'] = titanic['Sex'].apply(lambda x : 1 if x == 'male' else 0)\n", 50 | "titanic = pd.get_dummies(titanic, columns=['Embarked'])\n", 51 | "\n", 52 | "# features to used\n", 53 | "features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked_C', 'Embarked_Q', 'Embarked_S']" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "CPU times: user 4.45 ms, sys: 411 µs, total: 4.86 ms\n", 66 | "Wall time: 3.83 ms\n" 67 | ] 68 | }, 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", 73 | " max_features=None, max_leaf_nodes=20,\n", 74 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 75 | " min_samples_leaf=1, min_samples_split=2,\n", 76 | " min_weight_fraction_leaf=0.0, presort=False, random_state=24,\n", 77 | " splitter='best')" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "dt = DecisionTreeClassifier(random_state=24, max_leaf_nodes=20)\n", 87 | "%time dt.fit(titanic[features], titanic['Survived'])" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### visualize the tree model" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": { 101 | "collapsed": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "import sys\n", 106 | "sys.path.insert(0, 'src/')\n", 107 | "import generate_tree" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "#### generate simple tree" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "The output is in simple_tree_output/simple_tree_Titanic_Tree.html. Enjoy!\n", 127 | "CPU times: user 11.8 ms, sys: 0 ns, total: 11.8 ms\n", 128 | "Wall time: 11.4 ms\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "%%time\n", 134 | "generate_tree.generate_simple_tree(tree_title='Titanic_Tree', tree_model=dt, X=titanic[features], \n", 135 | " target_names=['Not Survived', 'Survived'], target_colors = None,\n", 136 | " color_map=None, width=1500, height=1000)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "#### generate sankey tree" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 7, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "The output is in sankey_tree_output/sankey_tree_Titanic_Tree.html. Enjoy!\n", 156 | "CPU times: user 28.6 ms, sys: 5.19 ms, total: 33.8 ms\n", 157 | "Wall time: 31.1 ms\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "%%time\n", 163 | "generate_tree.generate_sankey_tree(tree_title='Titanic_Tree', tree_model=dt, X=titanic[features], \n", 164 | " target_names=['Not Survived', 'Survived'], target_colors = None,\n", 165 | " color_map=None, width=1500, height=1200)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Iris: multiclass classifier" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "from sklearn.datasets import load_iris\n", 182 | "import pandas as pd\n", 183 | "\n", 184 | "iris = load_iris()\n", 185 | "clf = DecisionTreeClassifier()\n", 186 | "clf = clf.fit(iris.data, iris.target)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "#### generate simple tree" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 10, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "The output is in simple_tree_output/simple_tree_Iris_Tree.html. Enjoy!\n", 206 | "CPU times: user 13.8 ms, sys: 6.28 ms, total: 20.1 ms\n", 207 | "Wall time: 17.5 ms\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "%%time\n", 213 | "generate_tree.generate_simple_tree(tree_title='Iris_Tree', tree_model=clf, \n", 214 | " X=pd.DataFrame(iris.data, columns=iris.feature_names), \n", 215 | " target_names=list(iris.target_names), target_colors = None,\n", 216 | " color_map='Vega10', width=1200, height=1000)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "#### generate sankey tree" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 11, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "The output is in sankey_tree_output/sankey_tree_Iris_Tree.html. Enjoy!\n", 236 | "CPU times: user 22.1 ms, sys: 2.52 ms, total: 24.7 ms\n", 237 | "Wall time: 21.8 ms\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "%%time\n", 243 | "generate_tree.generate_sankey_tree(tree_title='Iris_Tree', tree_model=clf, \n", 244 | " X=pd.DataFrame(iris.data, columns=iris.feature_names), \n", 245 | " target_names=list(iris.target_names), target_colors = None,\n", 246 | " color_map='Vega10', width=1200, height=1000)" 247 | ] 248 | } 249 | ], 250 | "metadata": { 251 | "anaconda-cloud": {}, 252 | "kernelspec": { 253 | "display_name": "Python 2", 254 | "language": "python", 255 | "name": "python2" 256 | }, 257 | "language_info": { 258 | "codemirror_mode": { 259 | "name": "ipython", 260 | "version": 2 261 | }, 262 | "file_extension": ".py", 263 | "mimetype": "text/x-python", 264 | "name": "python", 265 | "nbconvert_exporter": "python", 266 | "pygments_lexer": "ipython2", 267 | "version": "2.7.14" 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 1 272 | } 273 | -------------------------------------------------------------------------------- /tree/image/sankey_tree.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/tree/image/sankey_tree.gif -------------------------------------------------------------------------------- /tree/image/simple_tree.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sosuneko/Nuance/e8b486ae7459850a00f1e8bbd756e7a57aed4417/tree/image/simple_tree.gif -------------------------------------------------------------------------------- /tree/sankey_tree_output/sankey_tree_Iris_Tree.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 20 | 21 | 22 | 23 | 24 | 354 | 355 | -------------------------------------------------------------------------------- /tree/sankey_tree_output/sankey_tree_Titanic_Tree.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 20 | 21 | 22 | 23 | 24 | 354 | 355 | -------------------------------------------------------------------------------- /tree/simple_tree_output/simple_tree_Iris_Tree.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 308 | 309 | -------------------------------------------------------------------------------- /tree/simple_tree_output/simple_tree_Titanic_Tree.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 308 | 309 | -------------------------------------------------------------------------------- /tree/src/generate_tree.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | 7 | import copy 8 | import sys 9 | import os 10 | 11 | import json 12 | import jinja2 13 | 14 | 15 | def _get_tree_info(X, tree_model, target_names, target_colors, tree_title, color_map): 16 | ''' 17 | get useful information of the tree 18 | 19 | :param X: pandas DataFrame 20 | dataset model was fitted on 21 | :param tree_model: a fitted sklearn Decision Tree Classifier 22 | :param target_names: list 23 | list of names for targets 24 | :param target_colors: list, default=None 25 | list of colors for targets 26 | :param tree_title: string 27 | name of the tree 28 | :param color_map: string, default=None 29 | matplotlib color map name, like 'Vega20' 30 | :return: 31 | dictionary of useful information 32 | ''' 33 | # classify features into 3 types: binary, float and int 34 | binary_features = [] 35 | for col in X.columns.values: 36 | if list(sorted(np.unique(X[col].values))) == [0, 1]: 37 | binary_features.append(col) 38 | 39 | int_features = [] 40 | for col in list(set(X.columns.values) - set(binary_features)): 41 | if list(X[col].map(int).values) == list(X[col].values): 42 | int_features.append(col) 43 | 44 | # get feature names 45 | feature_names = X.columns.values 46 | 47 | # check target names 48 | if type(target_names) != list or len(target_names) != tree_model.tree_.n_classes: 49 | raise ValueError("target_names should be a list of length %d." % (tree_model.tree_.n_classes)) 50 | 51 | # color mapping for targets 52 | if target_colors is None: 53 | if color_map is not None: 54 | cm = plt.get_cmap(color_map) 55 | else: 56 | cm = plt.get_cmap('Vega20') 57 | target_colors = [] 58 | for n in range(tree_model.tree_.n_classes): 59 | target_colors.append(str(matplotlib.colors.rgb2hex(cm(n + 1)))) 60 | 61 | tree_info = { 62 | 'tree_model': tree_model, 63 | 'features': [feature_names[i] for i in tree_model.tree_.feature], 64 | 'tree_title': tree_title, 65 | 'binary_features': binary_features, 66 | 'int_features': int_features, 67 | 'target_names': target_names, 68 | 'target_colors': target_colors 69 | } 70 | return tree_info 71 | 72 | 73 | def _parse_tree(node_id, parent, pos, tree_info): 74 | ''' 75 | parse the tree structure 76 | 77 | :param node_id: int 78 | node id 79 | :param parent: int 80 | parent node id 81 | :param pos: string 82 | position of the node 83 | :param tree_info: dict 84 | information of the tree model 85 | :return: 86 | complete tree structure 87 | ''' 88 | tree_model = tree_info['tree_model'] 89 | features = tree_info['features'] 90 | tree_title = tree_info['tree_title'] 91 | binary_features = tree_info['binary_features'] 92 | int_features = tree_info['int_features'] 93 | target_names = tree_info['target_names'] 94 | 95 | node = {} 96 | if parent == 'null': 97 | node['name'] = tree_title 98 | else: 99 | feature = features[parent] 100 | if pos == 'left': 101 | if feature in binary_features: 102 | node['name'] = feature + ': 0' 103 | elif feature in int_features: 104 | node['name'] = feature + " <= " + str(int(tree_model.tree_.threshold[parent])) 105 | else: 106 | node['name'] = feature + " <= " + str(round(tree_model.tree_.threshold[parent], 3)) 107 | else: 108 | if feature in binary_features: 109 | node['name'] = feature + ': 1' 110 | elif feature in int_features: 111 | node['name'] = feature + " > " + str(int(tree_model.tree_.threshold[parent])) 112 | else: 113 | node['name'] = feature + " > " + str(round(tree_model.tree_.threshold[parent], 3)) 114 | try: 115 | node['parent'] = int(parent) 116 | except: 117 | node['parent'] = parent 118 | 119 | node['self'] = int(node_id) 120 | node['sample'] = int(tree_model.tree_.n_node_samples[node_id]) 121 | node['impurity'] = round(tree_model.tree_.impurity[node_id], 3) 122 | node['value'] = [int(v) for v in tree_model.tree_.value[node_id][0]] 123 | node['predict'] = target_names[np.argmax(node['value'])] 124 | node['color'] = tree_info['target_colors'][np.argmax(node['value'])] 125 | node['pos'] = pos 126 | 127 | if tree_model.tree_.children_left[node_id] != -1 or tree_model.tree_.children_right[node_id] != -1: 128 | node['children'] = [] 129 | if tree_model.tree_.children_left[node_id] != -1: 130 | child = tree_model.tree_.children_left[node_id] 131 | node['children'].append(_parse_tree(child, node_id, 'left', tree_info)) 132 | if tree_model.tree_.children_right[node_id] != -1: 133 | child = tree_model.tree_.children_right[node_id] 134 | node['children'].append(_parse_tree(child, node_id, 'right', tree_info)) 135 | return node 136 | 137 | 138 | def _extract_rules(node_id, parent, pos, tree_rules, tree_info): 139 | ''' 140 | extract rules for each tree node 141 | 142 | :param node_id: int 143 | tree node id 144 | :param parent: int 145 | parent node id 146 | :param pos: string 147 | position of the node 148 | :param tree_rules: dict 149 | key: node_id, value: rule 150 | :param tree_info: dict 151 | information of the tree model 152 | :return: 153 | complete tree_rules 154 | ''' 155 | features = tree_info['features'] 156 | tree_model = tree_info['tree_model'] 157 | 158 | tree_rules[node_id] = {} 159 | tree_rules[node_id]['features'] = {} 160 | 161 | if parent != "null": 162 | previous = copy.deepcopy(tree_rules[parent]['features']) 163 | tree_rules[node_id]['features'] = previous 164 | feat = features[parent] 165 | thre = tree_model.tree_.threshold[parent] 166 | if feat not in previous.keys(): 167 | tree_rules[node_id]['features'][feat] = [-sys.maxint, sys.maxint] 168 | if pos == "left": 169 | origin = tree_rules[node_id]['features'][feat][1] 170 | tree_rules[node_id]['features'][feat][1] = np.min([thre, origin]) 171 | if pos == "right": 172 | origin = tree_rules[node_id]['features'][feat][0] 173 | tree_rules[node_id]['features'][feat][0] = np.max([thre, origin]) 174 | 175 | if tree_model.tree_.children_left[node_id] != -1: 176 | child = tree_model.tree_.children_left[node_id] 177 | _extract_rules(child, node_id, "left", tree_rules, tree_info) 178 | 179 | if tree_model.tree_.children_right[node_id] != -1: 180 | child = tree_model.tree_.children_right[node_id] 181 | _extract_rules(child, node_id, "right", tree_rules, tree_info) 182 | 183 | return tree_rules 184 | 185 | 186 | def _clean_rules(tree_rules, tree_info): 187 | ''' 188 | clean up the rules for each branch 189 | 190 | :param tree_rules: dict 191 | key: node_id, value: rule 192 | :param tree_info: dict 193 | information of the tree model 194 | :return: 195 | cleaned rules with the sample structure as tree_rules 196 | ''' 197 | tree_rules_clean = {} 198 | for key in tree_rules.keys(): 199 | key = int(key) 200 | node = copy.deepcopy(tree_rules[key]) 201 | rules = [] 202 | if node['features'].keys(): 203 | for k in node['features'].keys(): 204 | feat = node['features'][k] 205 | if k in tree_info['binary_features']: 206 | if feat[0] == -sys.maxint: 207 | rule = k + ': 0' 208 | else: 209 | rule = k + ': 1' 210 | elif k in tree_info['int_features']: 211 | if feat[0] == -sys.maxint: 212 | rule = k + ' <= ' + str(int(feat[1])) 213 | elif feat[1] == sys.maxint: 214 | rule = k + ' > ' + str(int(feat[0])) 215 | else: 216 | rule = str(int(feat[0])) + ' < ' + k + ' <= ' + str(int(feat[1])) 217 | else: 218 | if feat[0] == -sys.maxint: 219 | rule = k + ' <= ' + str(round(feat[1], 3)) 220 | elif feat[1] == sys.maxint: 221 | rule = k + ' > ' + str(round(feat[0], 3)) 222 | else: 223 | rule = str(round(feat[0], 3)) + ' < ' + k + ' <= ' + str(round(feat[1], 3)) 224 | rules.append(rule) 225 | rules = sorted(rules, key= lambda x : len(x)) 226 | tree_rules_clean[key] = rules 227 | return tree_rules_clean 228 | 229 | 230 | def generate_simple_tree(tree_title, tree_model, X, target_names, 231 | target_colors=None, color_map=None, width=None, height=None): 232 | ''' 233 | visualize a sklearn Decision Tree Classifier 234 | 235 | :param tree_title: string 236 | name of the tree 237 | :param tree_model: a fitted sklearn Decision Tree Classifier 238 | :param X: pandas DataFrame 239 | dataset model was fitted on 240 | :param target_names: list 241 | list of names for targets 242 | :param target_colors: list, default=None 243 | list of colors for targets 244 | :param color_map: string, default=None 245 | matplotlib color map name, like 'Vega20' 246 | :param width: int 247 | width of the html page 248 | :param height: int 249 | height of the html page 250 | ''' 251 | 252 | # get tree information 253 | tree_info = _get_tree_info(X, tree_model, target_names, target_colors, tree_title, color_map) 254 | 255 | # get the tree structure 256 | final_tree = _parse_tree(0, "null", "null", tree_info) 257 | 258 | # extract tree rules 259 | tree_rules = {} 260 | tree_rules = _extract_rules(0, "null", "null", tree_rules, tree_info) 261 | 262 | # clean up rules 263 | tree_rules_clean = _clean_rules(tree_rules, tree_info) 264 | 265 | # get template 266 | temp = open('src/simple_tree_template.html').read() 267 | template = jinja2.Template(temp) 268 | 269 | # create the output root if it is not exits 270 | if not os.path.exists('simple_tree_output'): 271 | os.mkdir('simple_tree_output') 272 | 273 | # generate output html 274 | with open('simple_tree_output/simple_tree_%s.html' %(tree_title), 'wb') as fh: 275 | render_result = { 276 | 'tree': json.dumps(final_tree), 'rule': json.dumps(tree_rules_clean), 277 | 'num_node': tree_info['tree_model'].tree_.capacity, 278 | 'tree_depth': tree_info['tree_model'].tree_.max_depth, 279 | 'width': width, 'height': height, 'n_classes': tree_info['tree_model'].n_classes_ 280 | } 281 | fh.write(template.render(render_result)) 282 | print('The output is in simple_tree_output/simple_tree_%s.html. Enjoy!' %(tree_title)) 283 | 284 | 285 | def generate_sankey_tree(tree_title, tree_model, X, target_names, 286 | target_colors=None, color_map=None, width=None, height=None): 287 | ''' 288 | visualize a sklearn Decision Tree Classifier 289 | 290 | :param tree_title: string 291 | name of the tree 292 | :param tree_model: a fitted sklearn Decision Tree Classifier 293 | :param X: pandas DataFrame 294 | dataset model was fitted on 295 | :param target_names: list 296 | list of names for targets 297 | :param target_colors: list, default=None 298 | list of colors for targets 299 | :param color_map: string, default=None 300 | matplotlib color map name, like 'Vega20' 301 | :param width: int 302 | width of the html page 303 | :param height: int 304 | height of the html page 305 | ''' 306 | 307 | # get tree information 308 | tree_info = _get_tree_info(X, tree_model, target_names, target_colors, tree_title, color_map) 309 | 310 | # get the tree structure 311 | final_tree = _parse_tree(0, "null", "null", tree_info) 312 | 313 | # extract tree rules 314 | tree_rules = {} 315 | tree_rules = _extract_rules(0, "null", "null", tree_rules, tree_info) 316 | 317 | # clean up rules 318 | tree_rules_clean = _clean_rules(tree_rules, tree_info) 319 | 320 | # get template 321 | temp = open('src/sankey_tree_template.html').read() 322 | template = jinja2.Template(temp) 323 | 324 | # create the output root if it is not exits 325 | if not os.path.exists('sankey_tree_output'): 326 | os.mkdir('sankey_tree_output') 327 | 328 | # generate output html 329 | with open('sankey_tree_output/sankey_tree_%s.html' %(tree_title), 'wb') as fh: 330 | render_result = { 331 | 'tree': json.dumps(final_tree), 'rule': json.dumps(tree_rules_clean), 332 | 'num_node': tree_info['tree_model'].tree_.capacity, 333 | 'tree_depth': tree_info['tree_model'].tree_.max_depth, 334 | 'width': width, 'height': height, 'target_colors': tree_info['target_colors'], 335 | 'max_samples': np.max(tree_info['tree_model'].tree_.n_node_samples), 336 | 'min_samples': np.min(tree_info['tree_model'].tree_.n_node_samples), 337 | } 338 | fh.write(template.render(render_result)) 339 | print('The output is in sankey_tree_output/sankey_tree_%s.html. Enjoy!' %(tree_title)) 340 | -------------------------------------------------------------------------------- /tree/src/sankey_tree_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 20 | 21 | 22 | 23 | 24 | 354 | 355 | -------------------------------------------------------------------------------- /tree/src/simple_tree_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 25 | 26 | 27 | 28 | 29 | 308 | 309 | --------------------------------------------------------------------------------