├── .gitignore ├── plots ├── html │ ├── .DS_Store │ ├── PCA_analysis.html │ ├── roc.html │ ├── roc_pr.html │ ├── roc_pr_curve.html │ ├── roc_pr_plot.html │ └── roc_cvfold │ │ └── bokeh-1.0.4.min.css └── png │ ├── Normalized_corr_matrix.png │ ├── Z_normalized_corr_matrix.png │ └── Z_normalized_corr_matrix_Abs.png ├── pip_requirements.txt ├── code ├── __pycache__ │ └── accstats.cpython-36.pyc ├── accstats.py ├── pca_feature_correlation.py └── roc_pr_curve.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | __pycache__/* 3 | .ipynb_checkpoints 4 | .ipynb_checkpoints/* 5 | .DS_Store 6 | -------------------------------------------------------------------------------- /plots/html/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/html/.DS_Store -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | seaborn>=0.9.0 2 | bokeh>=0.12.0 3 | numpy>=1.14.0 4 | pandas>=0.22.0 5 | matplotlib>=2.1.0 6 | scikit-learn>=0.19.0 7 | -------------------------------------------------------------------------------- /plots/png/Normalized_corr_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Normalized_corr_matrix.png -------------------------------------------------------------------------------- /plots/png/Z_normalized_corr_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Z_normalized_corr_matrix.png -------------------------------------------------------------------------------- /code/__pycache__/accstats.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/code/__pycache__/accstats.cpython-36.pyc -------------------------------------------------------------------------------- /plots/png/Z_normalized_corr_matrix_Abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Z_normalized_corr_matrix_Abs.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Understanding Data and Machine Learning Models with Visualizations 2 | 3 | ### Part 1 - PCA and Feature Correlation 4 | 5 | * code/Interactive PCA and Feature Correlation.ipynb 6 | * code/pca\_feature\_correlation.py 7 | * Post on [Cascade.Bio blog](https://medium.com/cascade-bio-blog/creating-visualizations-to-better-understand-your-data-and-models-part-1-a51e7e5af9c0) 8 | 9 | ### Part 2 - Machine Learning Decision Boundary Visualization 10 | 11 | * code/Interactive\_Model\_Predictions\_and\_Decision\_Boundaries.ipynb 12 | * Post on [Cascade.Bio blog](https://medium.com/cascade-bio-blog/creating-visualizations-to-better-understand-your-data-and-models-part-2-28d5c46e956) 13 | 14 | ### Part 3 - ROC Curves 15 | 16 | * code/Interactive\_ROC\_analysis.ipynb 17 | * code/accstats.py 18 | * code/roc\_pr\_curve.py 19 | * Post on [HackerNoon](https://medium.com/hackernoon/making-sense-of-real-world-data-roc-curves-and-when-to-use-them-90a17e6d1db) 20 | 21 | 22 | --- 23 | 24 | Note: code was developed in Python 3.6 and likely not backwards compatible because of liberal use of f-strings (sorry not sorry) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Gaurav Kaushik 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /code/accstats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | def confusion_matrix(acc:float=0.995, subpop:float=1e-4, population:float=8.5e6) -> dict: 8 | """ 9 | Generates confusion matrix and derived variables 10 | based on accuracy of detecting a fraction (subpop) of a population 11 | 12 | Args: 13 | acc (float): accuracy of detecting subpop (0-1) 14 | subpop (float): fraction of population in subpopulation (0-1) 15 | population (float): Total population size (absolute number) 16 | 17 | Returns: 18 | dict: derived variables as a dictionary 19 | """ 20 | 21 | # Inputs 22 | population = int(population) 23 | print(f"\nInputs\n------------") 24 | print(f"Accuracy (%): {100*acc}") 25 | print(f"Subpopulation (%): {100*subpop}") 26 | print(f"Population size: {population}") 27 | print(f"Predicted subpopulation size: {int(subpop*population)}") 28 | 29 | # Check variables 30 | if acc > 1 or acc < 0: 31 | print("\nERROR: give valid accuracy (0 to 1).") 32 | return 33 | if subpop > 1 or subpop < 0: 34 | print("\nERROR: give valid subpop percent (0 to 1).") 35 | return 36 | if population < 1: 37 | print("\nERROR: cannot have zero or negative populations.") 38 | return 39 | 40 | # confusion matrix 41 | tp = np.rint(population*subpop*acc).astype(int) 42 | fp = np.rint(population*(1-acc)).astype(int) 43 | tn = np.rint(population*(1-subpop)*acc).astype(int) 44 | fn = np.rint(population*subpop).astype(int) - tp 45 | print(f"\nResults\n------------") 46 | print(f"True Positives (Power): {tp}") 47 | print(f"False Positives (Type I): {fp}") 48 | print(f"True Negatives: {tn}") 49 | print(f"False Negatives (Type II): {fn}") 50 | 51 | # derivations 52 | round_var = 4 # round vars to this place 53 | tpr = np.round((tp)/(tp+fn), round_var) 54 | fpr = np.round((fp)/(fp+tn), round_var) 55 | precision = np.round((tp)/(tp+fp), round_var) 56 | specificity = np.round((tn)/(tn+fp), round_var) 57 | fdr = np.round((fp)/(fp+tp), round_var) 58 | fscore = np.round(2*tp/(2*tp+fp+fn), round_var) 59 | print(f"\nDerivations\n------------") 60 | print(f"True Positive Rate (Recall): {tpr}") 61 | print(f"False Positive Rate: {fpr}") 62 | print(f"Precision: {precision}") 63 | print(f"Specificity: {specificity}") 64 | print(f"False Discovery Rate: {fdr}") 65 | print(f"F-Score: {fscore}") 66 | 67 | # output a dictionary of derived variables 68 | output = { 69 | 'True_Positives':tp, 70 | 'False_Positives':fp, 71 | 'True_Negatives':tn, 72 | 'False_Negatives':fn, 73 | 'TPR':tpr, 74 | 'FPR':fpr, 75 | 'Precision':precision, 76 | 'Specificity':specificity, 77 | 'FDR':fdr, 78 | 'FScore':fscore 79 | } 80 | 81 | return output 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("-a", "--acc", type=float, default=0.995, help="Accuracy (0 -> 1)") 87 | parser.add_argument("-s", "--sub", type=float, default=1e-4, help="Subpop fraction (0 -> 1)") 88 | parser.add_argument("-p", "--pop", type=float, default=8.5e6, help="Pop size") 89 | args = parser.parse_args() 90 | accuracy = args.acc 91 | subpopulation = args.sub 92 | population = args.pop 93 | 94 | # main 95 | confusion_matrix(accuracy, subpopulation, population) 96 | -------------------------------------------------------------------------------- /code/pca_feature_correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | from sklearn import datasets 9 | from sklearn.decomposition import PCA 10 | from sklearn.preprocessing import StandardScaler 11 | import matplotlib.pyplot as plt 12 | from matplotlib.colors import cnames 13 | from itertools import cycle 14 | from bokeh.plotting import output_file, figure, show, ColumnDataSource 15 | from bokeh.models import HoverTool 16 | import warnings 17 | warnings.filterwarnings(action='ignore') 18 | 19 | 20 | def get_float_list(range_max:int, div:int=100) -> list: 21 | """ To get 0 -> 1, range_max must be same order of mag as div """ 22 | return [float(x)/div for x in range(int(range_max))] 23 | 24 | 25 | def get_colorcycle(colordict:dict): 26 | """ Subset cnames with a string match and get a color cycle for plotting """ 27 | return cycle(list(colordict.keys())) 28 | 29 | 30 | def get_colordict(filter_:str='dark') -> dict: 31 | """ return dictionary of colornames by filter """ 32 | return dict((k, v) for k, v in cnames.items() if filter_ in k) 33 | 34 | 35 | def pca_report_interactive(X, scale_X:bool=True, save_plot:bool=False): 36 | """ 37 | X: input data matrix 38 | scale_X: determine whether to rescale X (StandardScaler) [default: True, X is not prescaled 39 | save_plot: save plot to file (html) and not show 40 | """ 41 | 42 | # calculate mean and var 43 | X_mean, X_var = X.mean(), X.var() 44 | print('\n*--- PCA Report ---*\n') 45 | print(f'X mean:\t\t{X_mean:.3f}\nX variance:\t{X_var:.3f}') 46 | 47 | if scale_X: 48 | # rescale and run PCA 49 | print("\n...Rescaling data...\n") 50 | scaler = StandardScaler() 51 | X_scaled = scaler.fit_transform(X) 52 | X_s_mean, X_s_var = X_scaled.mean(), X_scaled.var() 53 | print(f'X_scaled mean:\t\t{np.round(X_s_mean):.3f}') 54 | print(f'X_scaled variance:\t{np.round(X_s_var):.3f}\n') 55 | pca_ = PCA().fit(X_scaled) 56 | X_pca = PCA().fit_transform(X) 57 | else: 58 | # run PCA directly 59 | print("...Assuming data is properly scaled...") 60 | pca_ = PCA().fit(X) 61 | X_pca = PCA().fit_transform(X) 62 | 63 | # Get cumulative explained variance for each dimension 64 | pca_evr = pca_.explained_variance_ratio_ 65 | cumsum_ = np.cumsum(pca_evr) 66 | 67 | # Get dimensions where var >= 95% and values for variance at 2D, 3D 68 | dim_95 = np.argmax(cumsum_ >= 0.95) + 1 69 | twoD = np.round(cumsum_[1], decimals=3)*100 70 | threeD = np.round(cumsum_[2], decimals=3)*100 71 | instances_, dims_ = X.shape 72 | 73 | # check shape of X 74 | if dims_ > instances_: 75 | print("WARNING: number of features greater than number of instances.") 76 | dimensions = list(range(1, instances_+1)) 77 | else: 78 | dimensions = list(range(1, dims_+1)) 79 | 80 | # Print report 81 | print("\n -- Summary --") 82 | print(f"You can reduce from {dims_} to {dim_95} dimensions while retaining 95% of variance.") 83 | print(f"2 principal components explain {twoD:.2f}% of variance.") 84 | print(f"3 principal components explain {threeD:.2f}% of variance.") 85 | 86 | """ - Plotting - """ 87 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there 88 | hover_ = HoverTool(names=['PCA'], tooltips=[("dimensions", "@x_dim"), 89 | ("cumulative variance", "@y_cumvar"), 90 | ("explained variance", "@y_var")]) 91 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 92 | 93 | # insert 0 at beginning for cleaner plotting 94 | cumsum_plot = np.insert(cumsum_, 0, 0) 95 | pca_evr_plot = np.insert(pca_evr, 0, 0) 96 | dimensions_plot = np.insert(dimensions, 0, 0) 97 | 98 | """ 99 | ColumnDataSource 100 | - a special type in Bokeh that allows you to store data for plotting 101 | - store data as dict (key:list) 102 | - to plot two keys against one another, make sure they're the same length! 103 | - below: 104 | x_dim # of dimensions (length = # of dimensions) 105 | y_cumvar # cumulative variance (length = # of dimensions) 106 | var_95 # y = 0.95 (length = # of dimensions) 107 | zero_one # list of 0 to 1 108 | twoD # x = 2 109 | threeD # x = 3 110 | """ 111 | 112 | # get sources 113 | source_PCA = ColumnDataSource(data=dict(x_dim = dimensions_plot,y_cumvar = cumsum_plot, y_var = pca_evr_plot)) 114 | source_var95 = ColumnDataSource(data=dict(var95_x = [dim_95]*96, var95_y = get_float_list(96))) 115 | source_twoD = ColumnDataSource(data=dict(twoD_x = [2]*(int(twoD)+1), twoD_y = get_float_list(twoD+1))) 116 | source_threeD = ColumnDataSource(data=dict(threeD_x = [3]*(int(threeD)+1), threeD_y = get_float_list(threeD+1))) 117 | 118 | """ PLOT """ 119 | # set up figure and add axis labels 120 | p = figure(title='PCA Analysis', tools=p_tools) 121 | p.xaxis.axis_label = f'N of {dims_} Principal Components' 122 | p.yaxis.axis_label = 'Variance Explained (per PC & Cumulative)' 123 | 124 | # add reference lines: y=0.95, x=2, x=3 125 | p.line('twoD_x', 'twoD_y', line_width=0.5, line_dash='dotted', color='#435363', source=source_twoD) # x=2 126 | p.line('threeD_x', 'threeD_y', line_width=0.5, line_dash='dotted', color='#435363', source=source_threeD) # x=3 127 | p.line('var95_x', 'var95_y', line_width=2, line_dash='dotted', color='#435363', source=source_var95) # var = 0.95 128 | 129 | # add bar plot for variance per dimension 130 | p.vbar(x='x_dim', top='y_var', width=.5, bottom=0, color='#D9F2EF', source=source_PCA, name='PCA') 131 | 132 | # add cumulative variance (scatter + line) 133 | p.line('x_dim', 'y_cumvar', line_width=1, color='#F79737', source=source_PCA) 134 | p.circle('x_dim', 'y_cumvar', size=7, color='#FF4C00', source=source_PCA, name='PCA') 135 | 136 | # change gridlines 137 | p.ygrid.grid_line_alpha = 0.25 138 | p.xgrid.grid_line_alpha = 0.25 139 | 140 | # change axis bounds and grid 141 | p.xaxis.bounds = (0, dims_) 142 | p.yaxis.bounds = (0, 1) 143 | p.grid.bounds = (0, dims_) 144 | 145 | # save and show p 146 | if save_plot: 147 | output_file('PCA_analysis.html') 148 | show(p) 149 | 150 | # output PCA info as a dataframe 151 | df_PCA = pd.DataFrame({'dimension': dimensions, 'variance_cumulative': cumsum_, 'variance': pca_evr}).set_index(['dimension']) 152 | 153 | return df_PCA, X_pca, pca_evr 154 | 155 | 156 | def pca_feature_correlation(X, X_pca, explained_var, features:list=None, fig_dpi:int=150, save_plot:bool=False): 157 | """ 158 | 1. Get dot product of X and X_pca 159 | 2. Run normalizations of X*X_pca 160 | 3. Retrieve df/matrices 161 | 162 | X: data (numpy matrix) 163 | X_pca: PCA 164 | explained_var: explained variance matrix 165 | features: list of feature names 166 | fig_dpi: dpi to use for heatmaps 167 | save_plot: save plot to file (html) and not show 168 | """ 169 | 170 | # Add zeroes for data where features > instances 171 | outer_diff = X.T.shape[0] - X_pca.shape[1] 172 | if outer_diff > 0: # outer dims must match to get sq matrix 173 | Z = np.zeros([X_pca.shape[0], outer_diff]) 174 | X_pca = np.c_[X_pca, Z] 175 | explained_var = np.append(explained_var, np.zeros(outer_diff)) 176 | 177 | # Get correlation between original features (X) and PCs (X_pca) 178 | dot_matrix = np.dot(X.T, X_pca) 179 | print(f"X*X_pca: {X.T.shape} * {X_pca.shape} = {dot_matrix.shape}") 180 | 181 | # Correlation matrix -> df 182 | df_dotproduct = pd.DataFrame(dot_matrix) 183 | df_dotproduct.columns = [''.join(['PC', f'{i+1}']) for i in range(dot_matrix.shape[0])] 184 | if any(features): df_dotproduct.index = features 185 | 186 | # Normalize & Sort 187 | df_n, df_na, df_nabv = normalize_dataframe(df_dotproduct, explained_var, plot_opt=True, save_plot=save_plot) 188 | 189 | return df_dotproduct, df_n, df_na, df_nabv 190 | 191 | 192 | def normalize_dataframe(df, explained_var=None, fig_dpi:int=150, plot_opt:bool=True, save_plot:bool=False): 193 | """ 194 | 1. Get z-normalized df (normalized to µ=0, σ=1) 195 | 2. Get absolute value of z-normalized df 196 | 3. If explained_variance matrix provided, dot it w/ (2) 197 | """ 198 | # Normalize, Reindex, & Sort 199 | df_norm = (df.copy()-df.mean())/df.std() 200 | df_norm = df_norm.sort_values(list(df_norm.columns), ascending=False) 201 | 202 | # Absolute value of normalized (& sort) 203 | df_abs = df_norm.copy().abs().set_index(df_norm.index) 204 | df_abs = df_abs.sort_values(by=list(df_abs.columns), ascending=False) 205 | 206 | # Plot 207 | if plot_opt: 208 | # Z-normalized corr matrix 209 | plt.figure(dpi=fig_dpi) 210 | ax_normal = sns.heatmap(df_norm, cmap="RdBu") 211 | ax_normal.set_title("Z-Normalized Data") 212 | if save_plot: 213 | plt.savefig('Z_normalized_corr_matrix.png') 214 | else: 215 | plt.show() 216 | 217 | # |Z-normalized corr matrix| 218 | plt.figure(dpi=fig_dpi) 219 | ax_abs = sns.heatmap(df_abs, cmap="Purples") 220 | ax_abs.set_title("|Z-Normalized|") 221 | if save_plot: 222 | plt.savefig('Z_normalized_corr_matrix_Abs.png') 223 | else: 224 | plt.show() 225 | 226 | # Re-normalize by explained var (& sort) 227 | if explained_var.any(): 228 | df_byvar = df_abs.copy()*explained_var 229 | df_byvar = df_byvar.sort_values(by=list(df_norm.columns), ascending=False) 230 | if plot_opt: 231 | plt.figure(dpi=fig_dpi) 232 | ax_relative = sns.heatmap(df_byvar, cmap="Purples") 233 | ax_relative.set_title("|Z-Normalized|*Explained_Variance") 234 | if save_plot: 235 | plt.savefig('Normalized_corr_matrix.png') 236 | else: 237 | plt.show() 238 | else: 239 | df_byvar = None 240 | return df_norm, df_abs, df_byvar 241 | 242 | 243 | def pca_rank_features(df_nabv, verbose:bool=True): 244 | """ 245 | Given a dataframe df_nabv with dimensions [f, p], where: 246 | f = features (sorted) 247 | p = principal components 248 | df_nabv.values are |Z-normalized X|*pca_.explained_variance_ratio_ 249 | 250 | 1. Create column of sum of each row, sort by it 'score_' 251 | 3. Set index as 'rank' 252 | """ 253 | df_rank = df_nabv.copy().assign(score_ = df_nabv.sum(axis=1)).sort_values('score_', ascending=False) 254 | df_rank['feature_'] = df_rank.index 255 | df_rank.index = range(1, len(df_rank)+1) 256 | df_rank.drop(df_nabv.columns, axis=1, inplace=True) 257 | df_rank.index.rename('rank', inplace=True) 258 | if verbose: print(df_rank) 259 | return df_rank 260 | 261 | 262 | def pca_full_report(X, features_:list=None, fig_dpi:int=150, save_plot:bool=False): 263 | """ 264 | Run complete PCA workflow: 265 | 1. pca_report_interactive() 266 | 2. pca_feature_correlation() 267 | 3. pca_rank_features() 268 | 269 | X: data (numpy array) 270 | features_: list of feature names 271 | fig_dpi: image resolution 272 | 273 | """ 274 | # Retrieve the interactive report 275 | df_pca, X_pca, pca_evr = pca_report_interactive(X, save_plot=save_plot) 276 | # Get feature-PC correlation matrices 277 | df_corr, df_n, df_na, df_nabv = pca_feature_correlation(X, X_pca, pca_evr, features_, fig_dpi, save_plot) 278 | # Get rank for each feature 279 | df_rank = pca_rank_features(df_nabv) 280 | return (df_pca, X_pca, pca_evr, df_corr, df_n, df_na, df_nabv, df_rank) 281 | 282 | 283 | if __name__ == '__main__': 284 | """ IRIS """ 285 | data = datasets.load_iris() 286 | outputs = pca_full_report(X=data.data, features_=data.feature_names, save_plot=True) 287 | 288 | -------------------------------------------------------------------------------- /plots/html/PCA_analysis.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Bokeh Plot 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 |
18 |
19 |
20 | 21 | 24 | 59 | 60 | -------------------------------------------------------------------------------- /plots/html/roc.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Bokeh Plot 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 |
34 | 35 | 36 | 37 | 38 | 39 | 42 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /plots/html/roc_pr.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Bokeh Plot 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |
39 | 40 | 41 | 42 | 43 | 44 | 47 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /plots/html/roc_pr_curve.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Bokeh Plot 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |
39 | 40 | 41 | 42 | 43 | 44 | 47 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /plots/html/roc_pr_plot.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Bokeh Plot 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |
39 | 40 | 41 | 42 | 43 | 44 | 47 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /code/roc_pr_curve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from itertools import cycle 6 | from sklearn import datasets 7 | from sklearn.datasets import make_classification 8 | from sklearn.decomposition import PCA 9 | from sklearn.preprocessing import StandardScaler 10 | from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, accuracy_score 11 | from sklearn.model_selection import train_test_split, StratifiedKFold 12 | from sklearn.ensemble import RandomForestClassifier 13 | from sklearn.linear_model import LogisticRegression 14 | from sklearn.naive_bayes import GaussianNB 15 | from bokeh.plotting import output_file, figure, show, ColumnDataSource 16 | from bokeh.models import HoverTool 17 | from bokeh.layouts import row 18 | import matplotlib.pyplot as plt 19 | from matplotlib.colors import cnames 20 | cnames = dict((k, v) for k, v in cnames.items() if 'dark' in k) 21 | 22 | 23 | def PCA_2D_labeled(X, y, cnames:list, target_names:list): 24 | """ 25 | Get a quick 2D rescaled PCA of a labeled dataset 26 | 27 | Args: 28 | X (numpy.ndarray): data 29 | y (numpy.ndarray): labels 30 | cnames: a list of color names (str) 31 | target_names: a list of target names (str) 32 | 33 | Returns: 34 | matplotlib plot object 35 | """ 36 | 37 | # rescaled, 2D PCA 38 | X_2D = PCA(2).fit_transform(StandardScaler().fit_transform(X)) 39 | 40 | # plot 41 | plt.figure(dpi=150) 42 | for c, i, t in zip(['red', 'green'], set(y), target_names): 43 | # plot each column with a color pertaining to the labels 44 | plt.scatter(X_2D[y==i, 0], X_2D[y==i, 1], color=c, alpha=.2, lw=1, label=t) 45 | plt.legend(loc='best') 46 | plt.xticks([]) 47 | plt.yticks([]) 48 | plt.xlabel('PC1') 49 | plt.ylabel('PC2') 50 | plt.tight_layout() 51 | plt.show() 52 | return plt 53 | 54 | 55 | """ Function to return all plots for a classifer """ 56 | 57 | def classifier_plots(clf_trained, X_test, y_test, target_names:list, minority_idx:int=0, ylog:bool=False): 58 | """ 59 | Get summary plots for a trained classifier 60 | 61 | Args: 62 | clf_trained: trained sklearn clf 63 | X_test (np.ndarray): withheld test data 64 | y_test (np.ndarray): withheld test data labels 65 | target_names (list): list of target labels/names 66 | minority_idx: (int): index for the minority class (e.g. 0, 1) 67 | ylog (str): toggle log-scaling on yaxis 68 | 69 | Returns: 70 | None 71 | """ 72 | 73 | """ Probabilty Dist """ 74 | # get the probability distribution 75 | probas = clf_trained.predict_proba(X_test) 76 | 77 | # PLOT - count 78 | plt.figure(dpi=150) 79 | plt.hist(probas, bins=20) 80 | plt.title('Classification Probabilities') 81 | plt.xlabel('Probability') 82 | plt.ylabel('# of Instances') 83 | plt.xlim([0.5, 1.0]) 84 | if ylog: plt.yscale('log') 85 | plt.legend(target_names) 86 | plt.show() 87 | 88 | 89 | # PLOT - density 90 | plt.figure(dpi=150) 91 | plt.hist(probas[:, minority_idx], bins=20, density=True) 92 | plt.title('Classification Density (Minority)') 93 | plt.xlabel('Probability') 94 | plt.ylabel('% of Total') 95 | if ylog: plt.yscale('log') 96 | plt.xlim([0, 1.0]) 97 | plt.legend(target_names) 98 | plt.show() 99 | 100 | """ ROC curve """ 101 | 102 | # get false and true positive rates 103 | fpr, tpr, _ = roc_curve(y_test, probas[:,0], pos_label=0) 104 | 105 | # get area under the curve 106 | clf_auc = auc(fpr, tpr) 107 | 108 | # PLOT ROC curve 109 | plt.figure(dpi=150) 110 | plt.plot(fpr, tpr, lw=1, color='green', label=f'AUC = {clf_auc:.3f}') 111 | plt.plot([0,1], [0,1], '--k', lw=0.5, label='Random') 112 | plt.title('ROC') 113 | plt.xlabel('False Positive Rate') 114 | plt.ylabel('True Positive Rate (Recall)') 115 | plt.xlim([-0.05, 1.05]) 116 | plt.ylim([-0.05, 1.05]) 117 | plt.legend() 118 | plt.show() 119 | 120 | """ Precision Recall Curve """ 121 | 122 | # get precision and recall values 123 | precision, recall, _ = precision_recall_curve(y_test, probas[:,0], pos_label=0) 124 | 125 | # average precision score 126 | avg_precision = average_precision_score(y_test, probas[:,0]) 127 | 128 | # precision auc 129 | pr_auc = auc(recall, precision) 130 | 131 | # plot 132 | plt.figure(dpi=150) 133 | plt.plot(recall, precision, lw=1, color='blue', label=f'AP={avg_precision:.3f}; AUC={pr_auc:.3f}') 134 | plt.fill_between(recall, precision, -1, facecolor='lightblue', alpha=0.5) 135 | 136 | plt.title('PR Curve') 137 | plt.xlabel('Recall (TPR)') 138 | plt.ylabel('Precision') 139 | plt.xlim([-0.05, 1.05]) 140 | plt.ylim([-0.05, 1.05]) 141 | plt.legend() 142 | plt.show() 143 | 144 | 145 | """ 146 | Interactive Plots and Utility Functions 147 | """ 148 | 149 | def get_clf_name(clf): 150 | """ 151 | get_clf_name takes a classifer (trained or untrained) and returns its name as a string 152 | clf.__str__() will return a string of the classifiers name and params 153 | e.g. "LogisticRegresion(...)" or "(("Logistic Regression(...)")) 154 | We then split on "(", use filter to drop empty strings, convert to list, and return first item 155 | clf: sklearn classifier (e.g. rf = RandomForestClassifier()) 156 | """ 157 | return list(filter(None, clf.__str__().split("(")))[0] 158 | 159 | 160 | def get_ROC_data(data, clf, pos_label_=None, verbose=False): 161 | 162 | """ 163 | source_ROC, df_ROC, clf = get_ROC_data(data, clf, verbose=False) 164 | 165 | get_ROC_data will return ColumnDataSource and dataframes with TPR and FPR 166 | for a particular dataset and an untrained classifier. The CSD can be used 167 | to plot a Bokeh plot while the dataframe can be used for additional 168 | exploration and plotting with other libs. Note that the dataframes 169 | are returned with metadata (e.g. AUC and the clf used). 170 | 171 | data: tuple of our data (X_train, X_test, y_train, y_test) 172 | where each item in the tuple is a numpy ndarray 173 | clf: an untrained classifier (e.g. rf = RandomForestClassifier()) 174 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 175 | verbose: print warnings [default: False] 176 | """ 177 | 178 | # split data into training, testing 179 | (X_train, X_test, y_train, y_test) = data 180 | 181 | # train and retrieve probabilities of class per feature for the test data 182 | probas_ = clf.fit(X_train, y_train).predict_proba(X_test) 183 | 184 | # get false and true positive rates for positive labels 185 | # (and thresholds, which is not used but shown here for fyi) 186 | if not pos_label_: 187 | pos_label_ = np.max(y_train) 188 | if verbose: 189 | print(f"Warning: Maximum target value of '{pos_label_}' used as positive.") 190 | print("You can use 'pos_label_' to indicate your own.") 191 | 192 | # get values for roc curve 193 | fpr, tpr, thresholds = roc_curve(y_test, probas_[:,1], pos_label=pos_label_) 194 | thresholds[0] = np.nan 195 | 196 | # get area under the curve (AUC) 197 | roc_auc = auc(fpr, tpr) 198 | 199 | # create legend variables - we'll create an array with len(tpr) 200 | auc_ = [f"AUC: {roc_auc:.3f}"]*len(tpr) 201 | clf_name = get_clf_name(clf) 202 | clf_ = [f"{clf_name}, AUC: {roc_auc:.3f}"]*len(tpr) 203 | 204 | # create bokeh column source for plotting new ROC 205 | source_ROC = ColumnDataSource(data=dict(x_fpr=fpr, 206 | y_tpr=tpr, 207 | thresh=thresholds, 208 | auc_legend=auc_, 209 | clf_legend=clf_)) 210 | 211 | # create output dataframe with TPR and FPR, and metadata 212 | df_ROC = pd.DataFrame({'TPR':tpr, 'FPR':fpr, 'Thresholds':thresholds}) 213 | df_ROC.auc = roc_auc 214 | df_ROC.clf = get_clf_name(clf) 215 | df_ROC.score = clf.score(X_test, y_test) 216 | 217 | return source_ROC, df_ROC, clf 218 | 219 | 220 | def interpolate_mean_tpr(FPRs=None, TPRs=None, df_list=None): 221 | """ 222 | mean_fpr, mean_tpr = interpolate_mean_tpr(FPRs=None, TPRs=None, df_list=None) 223 | 224 | FPRs: False positive rates (list of n arrays) 225 | TPRs: True positive rates (list of n arrays) 226 | df_list: DataFrames with TPR, FPR columns (list of n DataFrames) 227 | """ 228 | 229 | # seed empty linspace 230 | mean_tpr, mean_fpr = 0, np.linspace(0, 1, 101) 231 | 232 | if TPRs and FPRs: 233 | for idx, PRs in enumerate(zip(FPRs, TPRs)): 234 | mean_tpr += np.interp(mean_fpr, PRs[0], PRs[1]) 235 | 236 | elif df_list: 237 | for idx, df_ in enumerate(df_list): 238 | mean_tpr += np.interp(mean_fpr, df_.FPR, df_.TPR) 239 | 240 | else: 241 | print("Please give valid inputs.") 242 | return None, None 243 | 244 | # normalize by length of inputs (# indices looped over) 245 | mean_tpr /= (idx+1) 246 | 247 | # add origin point 248 | mean_fpr = np.insert(mean_fpr, 0, 0) 249 | mean_tpr = np.insert(mean_tpr, 0, 0) 250 | 251 | return mean_fpr, mean_tpr 252 | 253 | 254 | def plot_ROC(clf, X, y, test_size_:float=0.5, pos_label_:str=None, filename:str=None, verbose:bool=False): 255 | """ 256 | clf, classifiers, df_ROCs = plot_ROC(clf, X, y, pos_label_=None, verbose=False) 257 | 258 | Plot an interactive ROC curve for a binary classifier. 259 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv. 260 | 261 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer()) 262 | X: training + testing data 263 | y: targets (numeric/integers) 264 | test_size: fraction of data to be reserved for testing [default: 0.5] 265 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 266 | filename: if provided, save to html [default: None] 267 | verbose: print warnings [default: False] 268 | """ 269 | 270 | """ Split and get ROC curve data """ 271 | data = train_test_split(X, y, test_size=test_size_) 272 | source_ROC, df_ROC, clf = get_ROC_data(data, clf, pos_label_, verbose) 273 | 274 | """ Set up initial PLOT """ 275 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there 276 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Thresh", "@thresh")]) 277 | 278 | # Create your toolbox 279 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 280 | 281 | # Create figure and labels 282 | clf_name = get_clf_name(clf) 283 | p = figure(title=f'{clf_name} ROC curve', tools=p_tools) 284 | p.xaxis.axis_label = 'False Positive Rate' 285 | p.yaxis.axis_label = 'True Positive Rate' 286 | 287 | """ PLOT ROC """ 288 | p.line('x_fpr', 'y_tpr', line_width=1, color="blue", source=source_ROC) 289 | p.circle('x_fpr', 'y_tpr', size=3, color="orange", legend='auc_legend', source=source_ROC, name='ROC') 290 | 291 | """ Plot Threshold==0.5 """ 292 | # get value closest to threshold == 0.5 293 | df_half = df_ROC.dropna().iloc[(df_ROC['Thresholds'].dropna()-0.5).abs().argsort()[:2]] 294 | df_half['Legend'] = 'Thresh~0.5' 295 | source_half = ColumnDataSource(data=dict(x_fpr=df_half.FPR, 296 | y_tpr=df_half.TPR, 297 | thresh=df_half.Thresholds, 298 | legend_=df_half.Legend)) 299 | p.circle('x_fpr', 'y_tpr', size=5, color="blue", source=source_half, legend="legend_", name='ROC') 300 | 301 | """ PLOT chance line """ 302 | # Plot chance (tpr = fpr) 303 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance') 304 | 305 | # Finishing touches 306 | p.legend.location = "bottom_right" 307 | 308 | """ save and show """ 309 | if filename: 310 | output_file(filename) 311 | show(p) 312 | 313 | return clf, df_ROC 314 | 315 | 316 | def plot_ROC_CV(clf, X, y, cv_fold=3, pos_label_=None, verbose=False): 317 | 318 | """ 319 | clf, classifiers, df_ROCs = plot_ROC_CV(clf, X, y, cv_fold=3, pos_label_=None, verbose=False) 320 | 321 | Plot an interactive ROC curve for a binary classifier with n='cv-fold' cross-validations. 322 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv, 323 | and precision_info, which is a tuple of (precision, recall, avg_precision) of types (array, array, float). 324 | 325 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer()) 326 | X: training + testing data 327 | y: targets (numeric/integers) 328 | cv_fold: cross-validations to run [default: 3] 329 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 330 | verbose: print warnings [default: False] 331 | """ 332 | 333 | """ Check cross-validations to run and get stratification """ 334 | # Check cross-validation > 1 and get stratified data 335 | if cv_fold > 1: 336 | skf = StratifiedKFold(cv_fold) 337 | else: 338 | print(f"cv_fold must be greater than 1. You have input {cv_fold}") 339 | return clf 340 | 341 | """ Get source data for each ROC curve """ 342 | # Loop over each split in the data and get source data, df, and clf 343 | source_ROCs, df_ROCs, classifiers = [], [], [] 344 | for idx, val in enumerate(skf.split(X, y)): 345 | (train, test) = val 346 | data = (X[train], X[test], y[train], y[test]) # not that skf returns indices, not values 347 | source_, df_, clf_ = get_ROC_data(data, clf, pos_label_, verbose) 348 | source_ROCs.append(source_) 349 | df_ROCs.append(df_) 350 | classifiers.append(clf) 351 | 352 | """ Set up initial PLOT """ 353 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there 354 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Threshold", "@thresh")]) 355 | 356 | # Create your toolbox 357 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 358 | 359 | # Create figure and labels 360 | clf_name = get_clf_name(clf) 361 | p = figure(title=f'{clf_name} ROC curve with {cv_fold}-fold cross-validation', tools=p_tools) 362 | p.xaxis.axis_label = 'False Positive Rate' 363 | p.yaxis.axis_label = 'True Positive Rate' 364 | 365 | """ Get ROC CURVE for each iteration """ 366 | # Set the matplotlib colorwheel as a cycle 367 | colors_ = cycle(list(cnames.keys())) 368 | 369 | # plot each ROC curve - loop over source_ROCs, colors_ 370 | for _, val in enumerate(zip(source_ROCs, colors_)): 371 | (ROC, color_) = val 372 | p.line('x_fpr', 'y_tpr', line_width=1, color=color_, source=ROC) 373 | p.circle('x_fpr', 'y_tpr', size=10, color=color_, legend='auc_legend', source=ROC, name='ROC') 374 | 375 | """ Mean ROC and AUC for all curves and plot """ 376 | # process inputs 377 | mean_fpr, mean_tpr = interpolate_mean_tpr(df_list=df_ROCs) 378 | mean_auc = auc(mean_fpr, mean_tpr) 379 | mean_legend = [f'Mean, AUC: {mean_auc:.3f}']*len(mean_tpr) 380 | 381 | # Create ColumnDataSource 382 | source_ROC_mean = ColumnDataSource(data=dict(x_fpr=mean_fpr, 383 | y_tpr=mean_tpr, 384 | auc_legend=mean_legend)) 385 | 386 | # Plot mean ROC 387 | p.line('x_fpr', 'y_tpr', legend='auc_legend', color='black', 388 | line_width=3.33, line_alpha=0.33, line_dash='dashdot', source=source_ROC_mean, name='ROC') 389 | 390 | # Plot chance (tpr = fpr) 391 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance') 392 | 393 | # Finishing touches 394 | p.legend.location = "bottom_right" 395 | show(p) 396 | 397 | return clf, classifiers, df_ROCs 398 | 399 | 400 | def plot_ROC_clfs(classifiers, X, y, test_size=0.33, pos_label_=None, verbose=False): 401 | 402 | """ 403 | clf, classifiers, df_ROCs, precision_info = plot_ROC_clfs(clf, X, y,pos_label_=None, verbose=False) 404 | 405 | Plot an interactive ROC curve for a binary classifier with n='cv-fold' cross-validations. 406 | It returns the original clf, a classifier for each cv, and a list of dataframes for each cv. 407 | precision_info is a tuple of (precision, recall, avg_precision) of types (array, array, float). 408 | 409 | classifiers: list of untrained classifiers 410 | X: training + testing data 411 | y: targets (numeric/integers) 412 | test_size: test size for train_test_split (0 < x < 1) 413 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 414 | verbose: print warnings [default: False] 415 | """ 416 | 417 | """ Get source data for each ROC curve """ 418 | 419 | # Get training and test data 420 | (data_) = train_test_split(X, y, test_size=test_size) 421 | 422 | # Loop over each CLASSIFIER now -- note that we don't redefine our classifiers 423 | source_ROCs, df_ROCs = [], [] 424 | for _, clf_ in enumerate(classifiers): 425 | source_, df_, clf_ = get_ROC_data(data_, clf_, pos_label_, verbose) 426 | source_ROCs.append(source_) 427 | df_ROCs.append(df_) 428 | 429 | """ Set up initial PLOT """ 430 | 431 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there 432 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Threshold", "@thresh")]) 433 | 434 | # Create your toolbox 435 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 436 | 437 | # Create figure and labels 438 | p = figure(title=f'Benchmarking {len(classifiers)} classifiers', tools=p_tools) 439 | p.xaxis.axis_label = 'False Positive Rate' 440 | p.yaxis.axis_label = 'True Positive Rate' 441 | 442 | """ Get ROC CURVE for each iteration """ 443 | 444 | # Set the matplotlib colorwheel as a cycle 445 | colors_ = cycle(list(cnames.keys())) 446 | 447 | # loop over source, color and plot each ROC curve 448 | for _, val in enumerate(zip(source_ROCs, colors_)): 449 | (ROC, color_) = val 450 | p.line('x_fpr', 'y_tpr', line_width=1, color=color_, source=ROC) 451 | p.circle('x_fpr', 'y_tpr', size=5, color=color_, legend='clf_legend', source=ROC, name='ROC') 452 | 453 | """ Mean ROC and AUC for all curves and plot """ 454 | 455 | # process mean values, legend, ColumnDataSource 456 | mean_fpr, mean_tpr = interpolate_mean_tpr(df_list=df_ROCs) 457 | mean_auc = auc(mean_fpr, mean_tpr) 458 | mean_legend = [f'Mean, AUC: {mean_auc:.3f}']*len(mean_tpr) 459 | source_ROC_mean = ColumnDataSource(data=dict(x_fpr = mean_fpr, y_tpr = mean_tpr, roc_legend=mean_legend)) 460 | 461 | # PLOT mean ROC 462 | p.line('x_fpr', 'y_tpr', legend='roc_legend', color='black', 463 | line_width=5, line_alpha=0.3, line_dash='dashed', source=source_ROC_mean, name='ROC') 464 | 465 | # PLOT chance (tpr = fpr) 466 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.2, color='black', name='Chance') 467 | 468 | # Finishing touches 469 | p.legend.location = "bottom_right" 470 | show(p) 471 | 472 | # Print scores 473 | print("Scores:") 474 | # Get scores for each classifier: 475 | for i, df_ in enumerate(df_ROCs): 476 | print(df_.clf, np.round(df_.score, decimals=3)) 477 | 478 | return classifiers, df_ROCs 479 | 480 | 481 | def get_ROC_PR_data(data, clf, pos_label_=None, verbose=False): 482 | 483 | """ 484 | source, df, clf = get_ROC_PR_data(data, clf, verbose=False) 485 | 486 | get_ROC_data will return ColumnDataSource and dataframes with TPR and FPR 487 | for a particular dataset and an untrained classifier. The CSD can be used 488 | to plot a Bokeh plot while the dataframe can be used for additional 489 | exploration and plotting with other libs. Note that the dataframes 490 | are returned with metadata (e.g. AUC and the clf used). 491 | 492 | data: tuple of our data (X_train, X_test, y_train, y_test) 493 | where each item in the tuple is a numpy ndarray 494 | clf: an untrained classifier (e.g. rf = RandomForestClassifier()) 495 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 496 | verbose: print warnings [default: False] 497 | """ 498 | 499 | # split data into training, testing 500 | (X_train, X_test, y_train, y_test) = data 501 | 502 | # train and retrieve probabilities of class per feature for the test data 503 | probas = clf.fit(X_train, y_train).predict_proba(X_test) 504 | 505 | # get false and true positive rates for positive labels 506 | # (and thresholds, which is not used but shown here for fyi) 507 | if not pos_label_: 508 | pos_label_ = np.max(y_train) 509 | if verbose: 510 | print(f"Warning: Maximum target value of '{pos_label_}' used as positive.") 511 | print("You can use 'pos_label_' to indicate your own.") 512 | 513 | """ ROC """ 514 | fpr, tpr, roc_thresholds = roc_curve(y_test, probas[:,1], pos_label=pos_label_) 515 | roc_thresholds[0] = np.nan 516 | 517 | # get area under the curve (AUC) 518 | roc_auc = auc(fpr, tpr) 519 | 520 | """ PR """ 521 | # get precision and recall values 522 | precision, recall, pr_thresholds = precision_recall_curve(y_test, probas[:,1], pos_label=pos_label_) 523 | pr_thresholds = np.insert(pr_thresholds, 0, 0) # do this to correct lengths 524 | 525 | # average precision score 526 | avg_precision = average_precision_score(y_test, probas[:,1]) 527 | 528 | # precision auc 529 | pr_auc = auc(recall, precision) 530 | 531 | 532 | """ Create Sources """ 533 | # create legend variables - we'll create an array with len(tpr) 534 | roc_auc_ = [f"AUC: {roc_auc:.3f}"]*len(tpr) 535 | pr_auc_ = [f"AUC: {pr_auc:.3f}"]*len(precision) 536 | clf_name = get_clf_name(clf) 537 | clf_roc = [f"{clf_name}, AUC: {roc_auc:.3f}"]*len(tpr) 538 | clf_pr = [f"{clf_name}, AUC: {pr_auc:.3f}"]*len(precision) 539 | 540 | # create bokeh column source for plotting new ROC 541 | source_ROC = ColumnDataSource(data=dict(x_fpr=fpr, 542 | y_tpr=tpr, 543 | thresh_roc=roc_thresholds, 544 | auc_legend=roc_auc_, 545 | clf_legend=clf_roc)) 546 | 547 | source_PR = ColumnDataSource(data=dict(x_rec=recall, 548 | y_prec=precision, 549 | thresh_pr=pr_thresholds, 550 | auc_legend=pr_auc_, 551 | clf_legend=clf_pr)) 552 | 553 | """ Dataframes """ 554 | # create output dataframe with TPR and FPR, and metadata 555 | df_ROC = pd.DataFrame({'TPR':tpr, 'FPR':fpr, 'Thresholds':roc_thresholds}) 556 | df_ROC.auc = roc_auc 557 | df_ROC.clf = get_clf_name(clf) 558 | df_ROC.score = clf.score(X_test, y_test) 559 | 560 | # create output dataframe with TPR and FPR, and metadata 561 | df_PR = pd.DataFrame({'Recall':recall, 'Precision':precision, 'Thresholds':pr_thresholds}) 562 | df_PR.auc = pr_auc 563 | df_PR.clf = get_clf_name(clf) 564 | df_PR.score = clf.score(X_test, y_test) 565 | 566 | return source_ROC, source_PR, df_ROC, df_PR, clf 567 | 568 | 569 | def plot_ROC_PR(clf, X, y, test_size_:float=0.5, pos_label_:str=None, filename:str=None, verbose:bool=False): 570 | """ 571 | clf, classifiers, df_ROCs = plot_ROC(clf, X, y, pos_label_=None, verbose=False) 572 | 573 | Plot an interactive ROC curve for a binary classifier. 574 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv. 575 | 576 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer()) 577 | X: training + testing data 578 | y: targets (numeric/integers) 579 | test_size: fraction of data to be reserved for testing [default: 0.5] 580 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None] 581 | filename: if provided, save to html [default: None] 582 | verbose: print warnings [default: False] 583 | """ 584 | 585 | """ Split and get ROC curve data """ 586 | data = train_test_split(X, y, test_size=test_size_) 587 | source_ROC, source_PR, df_ROC, df_PR, clf = get_ROC_PR_data(data, clf, pos_label_, verbose) 588 | 589 | 590 | """ PLOT ROC """ 591 | 592 | # Create custom HoverTool -- we'll make one for each curve 593 | hover_ROC = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), 594 | ("FPR", "@x_fpr"), 595 | ("Thresh", "@thresh_roc"), 596 | ]) 597 | 598 | # Create your toolbox 599 | p_tools_ROC = [hover_ROC, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 600 | 601 | clf_name = get_clf_name(clf) 602 | p1 = figure(title=f'{clf_name} ROC curve', tools=p_tools_ROC) 603 | p1.xaxis.axis_label = 'False Positive Rate' 604 | p1.yaxis.axis_label = 'True Positive Rate' 605 | 606 | # plot curve and datapts 607 | p1.line('x_fpr', 'y_tpr', line_width=1, color="blue", source=source_ROC) 608 | p1.circle('x_fpr', 'y_tpr', size=3, color="orange", legend='auc_legend', source=source_ROC, name='ROC') 609 | 610 | # highlight values closest to threshold == 0.5 611 | df_half = df_ROC.dropna().iloc[(df_ROC['Thresholds'].dropna()-0.5).abs().argsort()[:2]] 612 | df_half['Legend'] = 'Thresh~0.5' 613 | source_half = ColumnDataSource(data=dict(x_fpr=df_half.FPR, 614 | y_tpr=df_half.TPR, 615 | thresh_roc=df_half.Thresholds, 616 | legend_=df_half.Legend)) 617 | p1.circle('x_fpr', 'y_tpr', size=5, color="blue", source=source_half, legend="legend_", name='ROC') 618 | 619 | # Plot chance (tpr = fpr) 620 | p1.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance') 621 | 622 | # Finishing touches 623 | p1.legend.location = "bottom_right" 624 | 625 | """ PLOT PR """ 626 | 627 | # Create custom HoverTool -- we'll make one for each curve 628 | hover_PR = HoverTool(names=['PR'], tooltips=[("Precision", "@y_prec"), 629 | ("Recall", "@x_rec"), 630 | ("Thresh", "@thresh_pr") 631 | ]) 632 | 633 | # Create your toolbox 634 | p_tools_PR = [hover_PR, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom'] 635 | 636 | p2 = figure(title=f'{clf_name} PR curve', tools=p_tools_PR) 637 | p2.xaxis.axis_label = 'Recall' 638 | p2.yaxis.axis_label = 'Precision' 639 | 640 | p2.line('x_rec', 'y_prec', line_width=1, color="blue", source=source_PR) 641 | p2.circle('x_rec', 'y_prec', size=3, color="orange", legend='auc_legend', source=source_PR, name='PR') 642 | 643 | # highlight values closest to threshold == 0.5 644 | df_half = df_PR.dropna().iloc[(df_PR['Thresholds'].dropna()-0.5).abs().argsort()[:2]] 645 | df_half['Legend'] = 'Thresh~0.5' 646 | source_half = ColumnDataSource(data=dict(x_rec=df_half.Recall, 647 | y_prec=df_half.Precision, 648 | thresh_pr=df_half.Thresholds, 649 | legend_=df_half.Legend)) 650 | p2.circle('x_rec', 'y_prec', size=5, color="blue", source=source_half, legend="legend_", name='PR') 651 | 652 | # Plot chance (prec = rec) 653 | p2.line([0, 1], [1, 0], line_dash='dashed', line_width=0.5, color='black', name='Chance') 654 | 655 | # Finishing touches 656 | p2.legend.location = "bottom_left" 657 | 658 | """ save and show """ 659 | if filename: 660 | output_file(filename) 661 | show(row(p1, p2)) 662 | 663 | return clf, df_ROC 664 | 665 | 666 | if __name__ == '__main__': 667 | 668 | # get UCI Breast Cancer Data 669 | data = datasets.load_breast_cancer() 670 | X = data.data 671 | y = data.target 672 | target_names = list(data.target_names) 673 | 674 | # Classifier plots 675 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5) 676 | rf_clf = RandomForestClassifier(n_estimators=100) 677 | rf_clf.fit(X_train, y_train) 678 | classifier_plots(rf_clf, X_test, y_test, target_names, ylog=False) 679 | 680 | # ROC curve for synthetic data 681 | rf_clf = RandomForestClassifier(n_estimators=100) 682 | rf_clf, df_rf_ROC = plot_ROC(rf_clf, X, y, test_size_=0.5) 683 | 684 | # Interactive ROC curve with cross-validation 685 | rf_clf = RandomForestClassifier(n_estimators=100) 686 | rf_clf, rf_cv_clfs, df_rf_ROCs = plot_ROC_CV(rf_clf, X, y) 687 | 688 | # Benchmark classifiers 689 | rf_bench = RandomForestClassifier(random_state=42, n_estimators=100) 690 | lr_bench = LogisticRegression(random_state=42, solver='saga') 691 | gnb_bench = GaussianNB(priors=None, var_smoothing=1e-06) 692 | clfs_benchmark, dfs_bench = plot_ROC_clfs(classifiers=[rf_bench, lr_bench, gnb_bench], X=X, y=y) 693 | 694 | # Interactive ROC and PR curves 695 | rf_clf = RandomForestClassifier(n_estimators=100) 696 | rf_clf, df_rf_ROC_PR = plot_ROC_PR(rf_clf, X, y, test_size_=0.33) 697 | -------------------------------------------------------------------------------- /plots/html/roc_cvfold/bokeh-1.0.4.min.css: -------------------------------------------------------------------------------- 1 | .bk-root{font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:10pt;position:relative;width:auto;height:auto}.bk-root .bk-shading{position:absolute;display:block;border:1px dashed green}.bk-root .bk-tile-attribution a{color:black}.bk-root .bk-tool-icon-box-select{background-image:url("")}.bk-root .bk-tool-icon-box-zoom{background-image:url("")}.bk-root .bk-tool-icon-zoom-in{background-image:url("")}.bk-root .bk-tool-icon-zoom-out{background-image:url("")}.bk-root .bk-tool-icon-help{background-image:url("")}.bk-root .bk-tool-icon-hover{background-image:url("")}.bk-root .bk-tool-icon-crosshair{background-image:url("")}.bk-root .bk-tool-icon-lasso-select{background-image:url("")}.bk-root .bk-tool-icon-pan{background-image:url("")}.bk-root .bk-tool-icon-xpan{background-image:url("")}.bk-root .bk-tool-icon-ypan{background-image:url("")}.bk-root .bk-tool-icon-range{background-image:url("")}.bk-root .bk-tool-icon-polygon-select{background-image:url("")}.bk-root .bk-tool-icon-redo{background-image:url("")}.bk-root .bk-tool-icon-reset{background-image:url("")}.bk-root .bk-tool-icon-save{background-image:url("")}.bk-root .bk-tool-icon-tap-select{background-image:url("")}.bk-root .bk-tool-icon-undo{background-image:url("")}.bk-root .bk-tool-icon-wheel-pan{background-image:url("")}.bk-root .bk-tool-icon-wheel-zoom{background-image:url("")}.bk-root .bk-tool-icon-box-edit{background-image:url("")}.bk-root .bk-tool-icon-freehand-draw{background-image:url("")}.bk-root .bk-tool-icon-poly-draw{background-image:url("")}.bk-root .bk-tool-icon-point-draw{background-image:url("")}.bk-root .bk-tool-icon-poly-edit{background-image:url("")}.bk-root .bk-grid-row,.bk-root .bk-grid-column{display:flex;display:-webkit-flex;flex-wrap:nowrap;-webkit-flex-wrap:nowrap}.bk-root .bk-grid-row>*,.bk-root .bk-grid-column>*{flex-shrink:0;-webkit-flex-shrink:0}.bk-root .bk-grid-row{flex-direction:row;-webkit-flex-direction:row}.bk-root .bk-grid-column{flex-direction:column;-webkit-flex-direction:column}.bk-root .bk-canvas-wrapper{position:relative;font-size:12pt}.bk-root .bk-canvas,.bk-root .bk-canvas-overlays,.bk-root .bk-canvas-events{position:absolute;top:0;left:0;width:100%;height:100%}.bk-root .bk-canvas-map{position:absolute;border:0}.bk-root .bk-logo{margin:5px;position:relative;display:block;background-repeat:no-repeat}.bk-root .bk-logo.bk-grey{filter:url("data:image/svg+xml;utf8,#grayscale");filter:gray;-webkit-filter:grayscale(100%)}.bk-root .bk-logo-notebook{display:inline-block;vertical-align:middle;margin-right:5px}.bk-root .bk-logo-small{width:20px;height:20px;background-image:url()}.bk-root .bk-toolbar,.bk-root .bk-toolbar *{box-sizing:border-box;margin:0;padding:0}.bk-root .bk-toolbar-hidden{visibility:hidden;opacity:0;transition:visibility .3s linear,opacity .3s linear}.bk-root .bk-toolbar,.bk-root .bk-button-bar{display:flex;display:-webkit-flex;flex-wrap:nowrap;-webkit-flex-wrap:nowrap;align-items:center;-webkit-align-items:center;user-select:none;-moz-user-select:none;-webkit-user-select:none;-ms-user-select:none}.bk-root .bk-toolbar .bk-logo{flex-shrink:0;-webkit-flex-shrink:0}.bk-root .bk-toolbar-above,.bk-root .bk-toolbar-below{flex-direction:row;-webkit-flex-direction:row;justify-content:flex-end;-webkit-justify-content:flex-end}.bk-root .bk-toolbar-above .bk-button-bar,.bk-root .bk-toolbar-below .bk-button-bar{display:flex;display:-webkit-flex;flex-direction:row;-webkit-flex-direction:row}.bk-root .bk-toolbar-above .bk-logo,.bk-root .bk-toolbar-below .bk-logo{order:1;-webkit-order:1;margin-left:5px}.bk-root .bk-toolbar-left,.bk-root .bk-toolbar-right{flex-direction:column;-webkit-flex-direction:column;justify-content:flex-start;-webkit-justify-content:flex-start}.bk-root .bk-toolbar-left .bk-button-bar,.bk-root .bk-toolbar-right .bk-button-bar{display:flex;display:-webkit-flex;flex-direction:column;-webkit-flex-direction:column}.bk-root .bk-toolbar-left .bk-logo,.bk-root .bk-toolbar-right .bk-logo{order:0;-webkit-order:0;margin-bottom:5px}.bk-root .bk-toolbar-button{width:30px;height:30px;background-size:60%;background-color:transparent;background-repeat:no-repeat;background-position:center center}.bk-root .bk-toolbar-button:hover{background-color:#f9f9f9}.bk-root .bk-toolbar-button:focus{outline:0}.bk-root .bk-toolbar-button::-moz-focus-inner{border:0}.bk-root .bk-toolbar-above .bk-toolbar-button{border-bottom:2px solid transparent}.bk-root .bk-toolbar-above .bk-toolbar-button.bk-active{border-bottom-color:#26aae1}.bk-root .bk-toolbar-below .bk-toolbar-button{border-top:2px solid transparent}.bk-root .bk-toolbar-below .bk-toolbar-button.bk-active{border-top-color:#26aae1}.bk-root .bk-toolbar-right .bk-toolbar-button{border-left:2px solid transparent}.bk-root .bk-toolbar-right .bk-toolbar-button.bk-active{border-left-color:#26aae1}.bk-root .bk-toolbar-left .bk-toolbar-button{border-right:2px solid transparent}.bk-root .bk-toolbar-left .bk-toolbar-button.bk-active{border-right-color:#26aae1}.bk-root .bk-button-bar+.bk-button-bar:before{content:" ";display:inline-block;background-color:lightgray}.bk-root .bk-toolbar-above .bk-button-bar+.bk-button-bar:before,.bk-root .bk-toolbar-below .bk-button-bar+.bk-button-bar:before{height:10px;width:1px}.bk-root .bk-toolbar-left .bk-button-bar+.bk-button-bar:before,.bk-root .bk-toolbar-right .bk-button-bar+.bk-button-bar:before{height:1px;width:10px}.bk-root .bk-tooltip{font-family:"HelveticaNeue-Light","Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif;font-weight:300;font-size:12px;position:absolute;padding:5px;border:1px solid #e5e5e5;color:#2f2f2f;background-color:white;pointer-events:none;opacity:.95}.bk-root .bk-tooltip>div:not(:first-child){margin-top:5px;border-top:#e5e5e5 1px dashed}.bk-root .bk-tooltip.bk-left.bk-tooltip-arrow::before{position:absolute;margin:-7px 0 0 0;top:50%;width:0;height:0;border-style:solid;border-width:7px 0 7px 0;border-color:transparent;content:" ";display:block;left:-10px;border-right-width:10px;border-right-color:#909599}.bk-root .bk-tooltip.bk-left::before{left:-10px;border-right-width:10px;border-right-color:#909599}.bk-root .bk-tooltip.bk-right.bk-tooltip-arrow::after{position:absolute;margin:-7px 0 0 0;top:50%;width:0;height:0;border-style:solid;border-width:7px 0 7px 0;border-color:transparent;content:" ";display:block;right:-10px;border-left-width:10px;border-left-color:#909599}.bk-root .bk-tooltip.bk-right::after{right:-10px;border-left-width:10px;border-left-color:#909599}.bk-root .bk-tooltip.bk-above::before{position:absolute;margin:0 0 0 -7px;left:50%;width:0;height:0;border-style:solid;border-width:0 7px 0 7px;border-color:transparent;content:" ";display:block;top:-10px;border-bottom-width:10px;border-bottom-color:#909599}.bk-root .bk-tooltip.bk-below::after{position:absolute;margin:0 0 0 -7px;left:50%;width:0;height:0;border-style:solid;border-width:0 7px 0 7px;border-color:transparent;content:" ";display:block;bottom:-10px;border-top-width:10px;border-top-color:#909599}.bk-root .bk-tooltip-row-label{text-align:right;color:#26aae1}.bk-root .bk-tooltip-row-value{color:default}.bk-root .bk-tooltip-color-block{width:12px;height:12px;margin-left:5px;margin-right:5px;outline:#ddd solid 1px;display:inline-block}.rendered_html .bk-root .bk-tooltip table,.rendered_html .bk-root .bk-tooltip tr,.rendered_html .bk-root .bk-tooltip th,.rendered_html .bk-root .bk-tooltip td{border:0;padding:1px} --------------------------------------------------------------------------------