├── .gitignore
├── plots
├── html
│ ├── .DS_Store
│ ├── PCA_analysis.html
│ ├── roc.html
│ ├── roc_pr.html
│ ├── roc_pr_curve.html
│ ├── roc_pr_plot.html
│ └── roc_cvfold
│ │ └── bokeh-1.0.4.min.css
└── png
│ ├── Normalized_corr_matrix.png
│ ├── Z_normalized_corr_matrix.png
│ └── Z_normalized_corr_matrix_Abs.png
├── pip_requirements.txt
├── code
├── __pycache__
│ └── accstats.cpython-36.pyc
├── accstats.py
├── pca_feature_correlation.py
└── roc_pr_curve.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | __pycache__/*
3 | .ipynb_checkpoints
4 | .ipynb_checkpoints/*
5 | .DS_Store
6 |
--------------------------------------------------------------------------------
/plots/html/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/html/.DS_Store
--------------------------------------------------------------------------------
/pip_requirements.txt:
--------------------------------------------------------------------------------
1 | seaborn>=0.9.0
2 | bokeh>=0.12.0
3 | numpy>=1.14.0
4 | pandas>=0.22.0
5 | matplotlib>=2.1.0
6 | scikit-learn>=0.19.0
7 |
--------------------------------------------------------------------------------
/plots/png/Normalized_corr_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Normalized_corr_matrix.png
--------------------------------------------------------------------------------
/plots/png/Z_normalized_corr_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Z_normalized_corr_matrix.png
--------------------------------------------------------------------------------
/code/__pycache__/accstats.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/code/__pycache__/accstats.cpython-36.pyc
--------------------------------------------------------------------------------
/plots/png/Z_normalized_corr_matrix_Abs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gaurav-kaushik/Data-Visualizations-Medium/HEAD/plots/png/Z_normalized_corr_matrix_Abs.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Understanding Data and Machine Learning Models with Visualizations
2 |
3 | ### Part 1 - PCA and Feature Correlation
4 |
5 | * code/Interactive PCA and Feature Correlation.ipynb
6 | * code/pca\_feature\_correlation.py
7 | * Post on [Cascade.Bio blog](https://medium.com/cascade-bio-blog/creating-visualizations-to-better-understand-your-data-and-models-part-1-a51e7e5af9c0)
8 |
9 | ### Part 2 - Machine Learning Decision Boundary Visualization
10 |
11 | * code/Interactive\_Model\_Predictions\_and\_Decision\_Boundaries.ipynb
12 | * Post on [Cascade.Bio blog](https://medium.com/cascade-bio-blog/creating-visualizations-to-better-understand-your-data-and-models-part-2-28d5c46e956)
13 |
14 | ### Part 3 - ROC Curves
15 |
16 | * code/Interactive\_ROC\_analysis.ipynb
17 | * code/accstats.py
18 | * code/roc\_pr\_curve.py
19 | * Post on [HackerNoon](https://medium.com/hackernoon/making-sense-of-real-world-data-roc-curves-and-when-to-use-them-90a17e6d1db)
20 |
21 |
22 | ---
23 |
24 | Note: code was developed in Python 3.6 and likely not backwards compatible because of liberal use of f-strings (sorry not sorry)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019 Gaurav Kaushik
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/code/accstats.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import numpy as np
5 |
6 |
7 | def confusion_matrix(acc:float=0.995, subpop:float=1e-4, population:float=8.5e6) -> dict:
8 | """
9 | Generates confusion matrix and derived variables
10 | based on accuracy of detecting a fraction (subpop) of a population
11 |
12 | Args:
13 | acc (float): accuracy of detecting subpop (0-1)
14 | subpop (float): fraction of population in subpopulation (0-1)
15 | population (float): Total population size (absolute number)
16 |
17 | Returns:
18 | dict: derived variables as a dictionary
19 | """
20 |
21 | # Inputs
22 | population = int(population)
23 | print(f"\nInputs\n------------")
24 | print(f"Accuracy (%): {100*acc}")
25 | print(f"Subpopulation (%): {100*subpop}")
26 | print(f"Population size: {population}")
27 | print(f"Predicted subpopulation size: {int(subpop*population)}")
28 |
29 | # Check variables
30 | if acc > 1 or acc < 0:
31 | print("\nERROR: give valid accuracy (0 to 1).")
32 | return
33 | if subpop > 1 or subpop < 0:
34 | print("\nERROR: give valid subpop percent (0 to 1).")
35 | return
36 | if population < 1:
37 | print("\nERROR: cannot have zero or negative populations.")
38 | return
39 |
40 | # confusion matrix
41 | tp = np.rint(population*subpop*acc).astype(int)
42 | fp = np.rint(population*(1-acc)).astype(int)
43 | tn = np.rint(population*(1-subpop)*acc).astype(int)
44 | fn = np.rint(population*subpop).astype(int) - tp
45 | print(f"\nResults\n------------")
46 | print(f"True Positives (Power): {tp}")
47 | print(f"False Positives (Type I): {fp}")
48 | print(f"True Negatives: {tn}")
49 | print(f"False Negatives (Type II): {fn}")
50 |
51 | # derivations
52 | round_var = 4 # round vars to this place
53 | tpr = np.round((tp)/(tp+fn), round_var)
54 | fpr = np.round((fp)/(fp+tn), round_var)
55 | precision = np.round((tp)/(tp+fp), round_var)
56 | specificity = np.round((tn)/(tn+fp), round_var)
57 | fdr = np.round((fp)/(fp+tp), round_var)
58 | fscore = np.round(2*tp/(2*tp+fp+fn), round_var)
59 | print(f"\nDerivations\n------------")
60 | print(f"True Positive Rate (Recall): {tpr}")
61 | print(f"False Positive Rate: {fpr}")
62 | print(f"Precision: {precision}")
63 | print(f"Specificity: {specificity}")
64 | print(f"False Discovery Rate: {fdr}")
65 | print(f"F-Score: {fscore}")
66 |
67 | # output a dictionary of derived variables
68 | output = {
69 | 'True_Positives':tp,
70 | 'False_Positives':fp,
71 | 'True_Negatives':tn,
72 | 'False_Negatives':fn,
73 | 'TPR':tpr,
74 | 'FPR':fpr,
75 | 'Precision':precision,
76 | 'Specificity':specificity,
77 | 'FDR':fdr,
78 | 'FScore':fscore
79 | }
80 |
81 | return output
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument("-a", "--acc", type=float, default=0.995, help="Accuracy (0 -> 1)")
87 | parser.add_argument("-s", "--sub", type=float, default=1e-4, help="Subpop fraction (0 -> 1)")
88 | parser.add_argument("-p", "--pop", type=float, default=8.5e6, help="Pop size")
89 | args = parser.parse_args()
90 | accuracy = args.acc
91 | subpopulation = args.sub
92 | population = args.pop
93 |
94 | # main
95 | confusion_matrix(accuracy, subpopulation, population)
96 |
--------------------------------------------------------------------------------
/code/pca_feature_correlation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import os
5 | import pandas as pd
6 | import numpy as np
7 | import seaborn as sns
8 | from sklearn import datasets
9 | from sklearn.decomposition import PCA
10 | from sklearn.preprocessing import StandardScaler
11 | import matplotlib.pyplot as plt
12 | from matplotlib.colors import cnames
13 | from itertools import cycle
14 | from bokeh.plotting import output_file, figure, show, ColumnDataSource
15 | from bokeh.models import HoverTool
16 | import warnings
17 | warnings.filterwarnings(action='ignore')
18 |
19 |
20 | def get_float_list(range_max:int, div:int=100) -> list:
21 | """ To get 0 -> 1, range_max must be same order of mag as div """
22 | return [float(x)/div for x in range(int(range_max))]
23 |
24 |
25 | def get_colorcycle(colordict:dict):
26 | """ Subset cnames with a string match and get a color cycle for plotting """
27 | return cycle(list(colordict.keys()))
28 |
29 |
30 | def get_colordict(filter_:str='dark') -> dict:
31 | """ return dictionary of colornames by filter """
32 | return dict((k, v) for k, v in cnames.items() if filter_ in k)
33 |
34 |
35 | def pca_report_interactive(X, scale_X:bool=True, save_plot:bool=False):
36 | """
37 | X: input data matrix
38 | scale_X: determine whether to rescale X (StandardScaler) [default: True, X is not prescaled
39 | save_plot: save plot to file (html) and not show
40 | """
41 |
42 | # calculate mean and var
43 | X_mean, X_var = X.mean(), X.var()
44 | print('\n*--- PCA Report ---*\n')
45 | print(f'X mean:\t\t{X_mean:.3f}\nX variance:\t{X_var:.3f}')
46 |
47 | if scale_X:
48 | # rescale and run PCA
49 | print("\n...Rescaling data...\n")
50 | scaler = StandardScaler()
51 | X_scaled = scaler.fit_transform(X)
52 | X_s_mean, X_s_var = X_scaled.mean(), X_scaled.var()
53 | print(f'X_scaled mean:\t\t{np.round(X_s_mean):.3f}')
54 | print(f'X_scaled variance:\t{np.round(X_s_var):.3f}\n')
55 | pca_ = PCA().fit(X_scaled)
56 | X_pca = PCA().fit_transform(X)
57 | else:
58 | # run PCA directly
59 | print("...Assuming data is properly scaled...")
60 | pca_ = PCA().fit(X)
61 | X_pca = PCA().fit_transform(X)
62 |
63 | # Get cumulative explained variance for each dimension
64 | pca_evr = pca_.explained_variance_ratio_
65 | cumsum_ = np.cumsum(pca_evr)
66 |
67 | # Get dimensions where var >= 95% and values for variance at 2D, 3D
68 | dim_95 = np.argmax(cumsum_ >= 0.95) + 1
69 | twoD = np.round(cumsum_[1], decimals=3)*100
70 | threeD = np.round(cumsum_[2], decimals=3)*100
71 | instances_, dims_ = X.shape
72 |
73 | # check shape of X
74 | if dims_ > instances_:
75 | print("WARNING: number of features greater than number of instances.")
76 | dimensions = list(range(1, instances_+1))
77 | else:
78 | dimensions = list(range(1, dims_+1))
79 |
80 | # Print report
81 | print("\n -- Summary --")
82 | print(f"You can reduce from {dims_} to {dim_95} dimensions while retaining 95% of variance.")
83 | print(f"2 principal components explain {twoD:.2f}% of variance.")
84 | print(f"3 principal components explain {threeD:.2f}% of variance.")
85 |
86 | """ - Plotting - """
87 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there
88 | hover_ = HoverTool(names=['PCA'], tooltips=[("dimensions", "@x_dim"),
89 | ("cumulative variance", "@y_cumvar"),
90 | ("explained variance", "@y_var")])
91 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
92 |
93 | # insert 0 at beginning for cleaner plotting
94 | cumsum_plot = np.insert(cumsum_, 0, 0)
95 | pca_evr_plot = np.insert(pca_evr, 0, 0)
96 | dimensions_plot = np.insert(dimensions, 0, 0)
97 |
98 | """
99 | ColumnDataSource
100 | - a special type in Bokeh that allows you to store data for plotting
101 | - store data as dict (key:list)
102 | - to plot two keys against one another, make sure they're the same length!
103 | - below:
104 | x_dim # of dimensions (length = # of dimensions)
105 | y_cumvar # cumulative variance (length = # of dimensions)
106 | var_95 # y = 0.95 (length = # of dimensions)
107 | zero_one # list of 0 to 1
108 | twoD # x = 2
109 | threeD # x = 3
110 | """
111 |
112 | # get sources
113 | source_PCA = ColumnDataSource(data=dict(x_dim = dimensions_plot,y_cumvar = cumsum_plot, y_var = pca_evr_plot))
114 | source_var95 = ColumnDataSource(data=dict(var95_x = [dim_95]*96, var95_y = get_float_list(96)))
115 | source_twoD = ColumnDataSource(data=dict(twoD_x = [2]*(int(twoD)+1), twoD_y = get_float_list(twoD+1)))
116 | source_threeD = ColumnDataSource(data=dict(threeD_x = [3]*(int(threeD)+1), threeD_y = get_float_list(threeD+1)))
117 |
118 | """ PLOT """
119 | # set up figure and add axis labels
120 | p = figure(title='PCA Analysis', tools=p_tools)
121 | p.xaxis.axis_label = f'N of {dims_} Principal Components'
122 | p.yaxis.axis_label = 'Variance Explained (per PC & Cumulative)'
123 |
124 | # add reference lines: y=0.95, x=2, x=3
125 | p.line('twoD_x', 'twoD_y', line_width=0.5, line_dash='dotted', color='#435363', source=source_twoD) # x=2
126 | p.line('threeD_x', 'threeD_y', line_width=0.5, line_dash='dotted', color='#435363', source=source_threeD) # x=3
127 | p.line('var95_x', 'var95_y', line_width=2, line_dash='dotted', color='#435363', source=source_var95) # var = 0.95
128 |
129 | # add bar plot for variance per dimension
130 | p.vbar(x='x_dim', top='y_var', width=.5, bottom=0, color='#D9F2EF', source=source_PCA, name='PCA')
131 |
132 | # add cumulative variance (scatter + line)
133 | p.line('x_dim', 'y_cumvar', line_width=1, color='#F79737', source=source_PCA)
134 | p.circle('x_dim', 'y_cumvar', size=7, color='#FF4C00', source=source_PCA, name='PCA')
135 |
136 | # change gridlines
137 | p.ygrid.grid_line_alpha = 0.25
138 | p.xgrid.grid_line_alpha = 0.25
139 |
140 | # change axis bounds and grid
141 | p.xaxis.bounds = (0, dims_)
142 | p.yaxis.bounds = (0, 1)
143 | p.grid.bounds = (0, dims_)
144 |
145 | # save and show p
146 | if save_plot:
147 | output_file('PCA_analysis.html')
148 | show(p)
149 |
150 | # output PCA info as a dataframe
151 | df_PCA = pd.DataFrame({'dimension': dimensions, 'variance_cumulative': cumsum_, 'variance': pca_evr}).set_index(['dimension'])
152 |
153 | return df_PCA, X_pca, pca_evr
154 |
155 |
156 | def pca_feature_correlation(X, X_pca, explained_var, features:list=None, fig_dpi:int=150, save_plot:bool=False):
157 | """
158 | 1. Get dot product of X and X_pca
159 | 2. Run normalizations of X*X_pca
160 | 3. Retrieve df/matrices
161 |
162 | X: data (numpy matrix)
163 | X_pca: PCA
164 | explained_var: explained variance matrix
165 | features: list of feature names
166 | fig_dpi: dpi to use for heatmaps
167 | save_plot: save plot to file (html) and not show
168 | """
169 |
170 | # Add zeroes for data where features > instances
171 | outer_diff = X.T.shape[0] - X_pca.shape[1]
172 | if outer_diff > 0: # outer dims must match to get sq matrix
173 | Z = np.zeros([X_pca.shape[0], outer_diff])
174 | X_pca = np.c_[X_pca, Z]
175 | explained_var = np.append(explained_var, np.zeros(outer_diff))
176 |
177 | # Get correlation between original features (X) and PCs (X_pca)
178 | dot_matrix = np.dot(X.T, X_pca)
179 | print(f"X*X_pca: {X.T.shape} * {X_pca.shape} = {dot_matrix.shape}")
180 |
181 | # Correlation matrix -> df
182 | df_dotproduct = pd.DataFrame(dot_matrix)
183 | df_dotproduct.columns = [''.join(['PC', f'{i+1}']) for i in range(dot_matrix.shape[0])]
184 | if any(features): df_dotproduct.index = features
185 |
186 | # Normalize & Sort
187 | df_n, df_na, df_nabv = normalize_dataframe(df_dotproduct, explained_var, plot_opt=True, save_plot=save_plot)
188 |
189 | return df_dotproduct, df_n, df_na, df_nabv
190 |
191 |
192 | def normalize_dataframe(df, explained_var=None, fig_dpi:int=150, plot_opt:bool=True, save_plot:bool=False):
193 | """
194 | 1. Get z-normalized df (normalized to µ=0, σ=1)
195 | 2. Get absolute value of z-normalized df
196 | 3. If explained_variance matrix provided, dot it w/ (2)
197 | """
198 | # Normalize, Reindex, & Sort
199 | df_norm = (df.copy()-df.mean())/df.std()
200 | df_norm = df_norm.sort_values(list(df_norm.columns), ascending=False)
201 |
202 | # Absolute value of normalized (& sort)
203 | df_abs = df_norm.copy().abs().set_index(df_norm.index)
204 | df_abs = df_abs.sort_values(by=list(df_abs.columns), ascending=False)
205 |
206 | # Plot
207 | if plot_opt:
208 | # Z-normalized corr matrix
209 | plt.figure(dpi=fig_dpi)
210 | ax_normal = sns.heatmap(df_norm, cmap="RdBu")
211 | ax_normal.set_title("Z-Normalized Data")
212 | if save_plot:
213 | plt.savefig('Z_normalized_corr_matrix.png')
214 | else:
215 | plt.show()
216 |
217 | # |Z-normalized corr matrix|
218 | plt.figure(dpi=fig_dpi)
219 | ax_abs = sns.heatmap(df_abs, cmap="Purples")
220 | ax_abs.set_title("|Z-Normalized|")
221 | if save_plot:
222 | plt.savefig('Z_normalized_corr_matrix_Abs.png')
223 | else:
224 | plt.show()
225 |
226 | # Re-normalize by explained var (& sort)
227 | if explained_var.any():
228 | df_byvar = df_abs.copy()*explained_var
229 | df_byvar = df_byvar.sort_values(by=list(df_norm.columns), ascending=False)
230 | if plot_opt:
231 | plt.figure(dpi=fig_dpi)
232 | ax_relative = sns.heatmap(df_byvar, cmap="Purples")
233 | ax_relative.set_title("|Z-Normalized|*Explained_Variance")
234 | if save_plot:
235 | plt.savefig('Normalized_corr_matrix.png')
236 | else:
237 | plt.show()
238 | else:
239 | df_byvar = None
240 | return df_norm, df_abs, df_byvar
241 |
242 |
243 | def pca_rank_features(df_nabv, verbose:bool=True):
244 | """
245 | Given a dataframe df_nabv with dimensions [f, p], where:
246 | f = features (sorted)
247 | p = principal components
248 | df_nabv.values are |Z-normalized X|*pca_.explained_variance_ratio_
249 |
250 | 1. Create column of sum of each row, sort by it 'score_'
251 | 3. Set index as 'rank'
252 | """
253 | df_rank = df_nabv.copy().assign(score_ = df_nabv.sum(axis=1)).sort_values('score_', ascending=False)
254 | df_rank['feature_'] = df_rank.index
255 | df_rank.index = range(1, len(df_rank)+1)
256 | df_rank.drop(df_nabv.columns, axis=1, inplace=True)
257 | df_rank.index.rename('rank', inplace=True)
258 | if verbose: print(df_rank)
259 | return df_rank
260 |
261 |
262 | def pca_full_report(X, features_:list=None, fig_dpi:int=150, save_plot:bool=False):
263 | """
264 | Run complete PCA workflow:
265 | 1. pca_report_interactive()
266 | 2. pca_feature_correlation()
267 | 3. pca_rank_features()
268 |
269 | X: data (numpy array)
270 | features_: list of feature names
271 | fig_dpi: image resolution
272 |
273 | """
274 | # Retrieve the interactive report
275 | df_pca, X_pca, pca_evr = pca_report_interactive(X, save_plot=save_plot)
276 | # Get feature-PC correlation matrices
277 | df_corr, df_n, df_na, df_nabv = pca_feature_correlation(X, X_pca, pca_evr, features_, fig_dpi, save_plot)
278 | # Get rank for each feature
279 | df_rank = pca_rank_features(df_nabv)
280 | return (df_pca, X_pca, pca_evr, df_corr, df_n, df_na, df_nabv, df_rank)
281 |
282 |
283 | if __name__ == '__main__':
284 | """ IRIS """
285 | data = datasets.load_iris()
286 | outputs = pca_full_report(X=data.data, features_=data.feature_names, save_plot=True)
287 |
288 |
--------------------------------------------------------------------------------
/plots/html/PCA_analysis.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Bokeh Plot
7 |
8 |
9 |
10 |
11 |
14 |
15 |
16 |
17 |
20 |
21 |
24 |
59 |
60 |
--------------------------------------------------------------------------------
/plots/html/roc.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Bokeh Plot
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
42 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/plots/html/roc_pr.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Bokeh Plot
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
47 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/plots/html/roc_pr_curve.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Bokeh Plot
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
47 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/plots/html/roc_pr_plot.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Bokeh Plot
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
47 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/code/roc_pr_curve.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import pandas as pd
4 | import numpy as np
5 | from itertools import cycle
6 | from sklearn import datasets
7 | from sklearn.datasets import make_classification
8 | from sklearn.decomposition import PCA
9 | from sklearn.preprocessing import StandardScaler
10 | from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, accuracy_score
11 | from sklearn.model_selection import train_test_split, StratifiedKFold
12 | from sklearn.ensemble import RandomForestClassifier
13 | from sklearn.linear_model import LogisticRegression
14 | from sklearn.naive_bayes import GaussianNB
15 | from bokeh.plotting import output_file, figure, show, ColumnDataSource
16 | from bokeh.models import HoverTool
17 | from bokeh.layouts import row
18 | import matplotlib.pyplot as plt
19 | from matplotlib.colors import cnames
20 | cnames = dict((k, v) for k, v in cnames.items() if 'dark' in k)
21 |
22 |
23 | def PCA_2D_labeled(X, y, cnames:list, target_names:list):
24 | """
25 | Get a quick 2D rescaled PCA of a labeled dataset
26 |
27 | Args:
28 | X (numpy.ndarray): data
29 | y (numpy.ndarray): labels
30 | cnames: a list of color names (str)
31 | target_names: a list of target names (str)
32 |
33 | Returns:
34 | matplotlib plot object
35 | """
36 |
37 | # rescaled, 2D PCA
38 | X_2D = PCA(2).fit_transform(StandardScaler().fit_transform(X))
39 |
40 | # plot
41 | plt.figure(dpi=150)
42 | for c, i, t in zip(['red', 'green'], set(y), target_names):
43 | # plot each column with a color pertaining to the labels
44 | plt.scatter(X_2D[y==i, 0], X_2D[y==i, 1], color=c, alpha=.2, lw=1, label=t)
45 | plt.legend(loc='best')
46 | plt.xticks([])
47 | plt.yticks([])
48 | plt.xlabel('PC1')
49 | plt.ylabel('PC2')
50 | plt.tight_layout()
51 | plt.show()
52 | return plt
53 |
54 |
55 | """ Function to return all plots for a classifer """
56 |
57 | def classifier_plots(clf_trained, X_test, y_test, target_names:list, minority_idx:int=0, ylog:bool=False):
58 | """
59 | Get summary plots for a trained classifier
60 |
61 | Args:
62 | clf_trained: trained sklearn clf
63 | X_test (np.ndarray): withheld test data
64 | y_test (np.ndarray): withheld test data labels
65 | target_names (list): list of target labels/names
66 | minority_idx: (int): index for the minority class (e.g. 0, 1)
67 | ylog (str): toggle log-scaling on yaxis
68 |
69 | Returns:
70 | None
71 | """
72 |
73 | """ Probabilty Dist """
74 | # get the probability distribution
75 | probas = clf_trained.predict_proba(X_test)
76 |
77 | # PLOT - count
78 | plt.figure(dpi=150)
79 | plt.hist(probas, bins=20)
80 | plt.title('Classification Probabilities')
81 | plt.xlabel('Probability')
82 | plt.ylabel('# of Instances')
83 | plt.xlim([0.5, 1.0])
84 | if ylog: plt.yscale('log')
85 | plt.legend(target_names)
86 | plt.show()
87 |
88 |
89 | # PLOT - density
90 | plt.figure(dpi=150)
91 | plt.hist(probas[:, minority_idx], bins=20, density=True)
92 | plt.title('Classification Density (Minority)')
93 | plt.xlabel('Probability')
94 | plt.ylabel('% of Total')
95 | if ylog: plt.yscale('log')
96 | plt.xlim([0, 1.0])
97 | plt.legend(target_names)
98 | plt.show()
99 |
100 | """ ROC curve """
101 |
102 | # get false and true positive rates
103 | fpr, tpr, _ = roc_curve(y_test, probas[:,0], pos_label=0)
104 |
105 | # get area under the curve
106 | clf_auc = auc(fpr, tpr)
107 |
108 | # PLOT ROC curve
109 | plt.figure(dpi=150)
110 | plt.plot(fpr, tpr, lw=1, color='green', label=f'AUC = {clf_auc:.3f}')
111 | plt.plot([0,1], [0,1], '--k', lw=0.5, label='Random')
112 | plt.title('ROC')
113 | plt.xlabel('False Positive Rate')
114 | plt.ylabel('True Positive Rate (Recall)')
115 | plt.xlim([-0.05, 1.05])
116 | plt.ylim([-0.05, 1.05])
117 | plt.legend()
118 | plt.show()
119 |
120 | """ Precision Recall Curve """
121 |
122 | # get precision and recall values
123 | precision, recall, _ = precision_recall_curve(y_test, probas[:,0], pos_label=0)
124 |
125 | # average precision score
126 | avg_precision = average_precision_score(y_test, probas[:,0])
127 |
128 | # precision auc
129 | pr_auc = auc(recall, precision)
130 |
131 | # plot
132 | plt.figure(dpi=150)
133 | plt.plot(recall, precision, lw=1, color='blue', label=f'AP={avg_precision:.3f}; AUC={pr_auc:.3f}')
134 | plt.fill_between(recall, precision, -1, facecolor='lightblue', alpha=0.5)
135 |
136 | plt.title('PR Curve')
137 | plt.xlabel('Recall (TPR)')
138 | plt.ylabel('Precision')
139 | plt.xlim([-0.05, 1.05])
140 | plt.ylim([-0.05, 1.05])
141 | plt.legend()
142 | plt.show()
143 |
144 |
145 | """
146 | Interactive Plots and Utility Functions
147 | """
148 |
149 | def get_clf_name(clf):
150 | """
151 | get_clf_name takes a classifer (trained or untrained) and returns its name as a string
152 | clf.__str__() will return a string of the classifiers name and params
153 | e.g. "LogisticRegresion(...)" or "(("Logistic Regression(...)"))
154 | We then split on "(", use filter to drop empty strings, convert to list, and return first item
155 | clf: sklearn classifier (e.g. rf = RandomForestClassifier())
156 | """
157 | return list(filter(None, clf.__str__().split("(")))[0]
158 |
159 |
160 | def get_ROC_data(data, clf, pos_label_=None, verbose=False):
161 |
162 | """
163 | source_ROC, df_ROC, clf = get_ROC_data(data, clf, verbose=False)
164 |
165 | get_ROC_data will return ColumnDataSource and dataframes with TPR and FPR
166 | for a particular dataset and an untrained classifier. The CSD can be used
167 | to plot a Bokeh plot while the dataframe can be used for additional
168 | exploration and plotting with other libs. Note that the dataframes
169 | are returned with metadata (e.g. AUC and the clf used).
170 |
171 | data: tuple of our data (X_train, X_test, y_train, y_test)
172 | where each item in the tuple is a numpy ndarray
173 | clf: an untrained classifier (e.g. rf = RandomForestClassifier())
174 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
175 | verbose: print warnings [default: False]
176 | """
177 |
178 | # split data into training, testing
179 | (X_train, X_test, y_train, y_test) = data
180 |
181 | # train and retrieve probabilities of class per feature for the test data
182 | probas_ = clf.fit(X_train, y_train).predict_proba(X_test)
183 |
184 | # get false and true positive rates for positive labels
185 | # (and thresholds, which is not used but shown here for fyi)
186 | if not pos_label_:
187 | pos_label_ = np.max(y_train)
188 | if verbose:
189 | print(f"Warning: Maximum target value of '{pos_label_}' used as positive.")
190 | print("You can use 'pos_label_' to indicate your own.")
191 |
192 | # get values for roc curve
193 | fpr, tpr, thresholds = roc_curve(y_test, probas_[:,1], pos_label=pos_label_)
194 | thresholds[0] = np.nan
195 |
196 | # get area under the curve (AUC)
197 | roc_auc = auc(fpr, tpr)
198 |
199 | # create legend variables - we'll create an array with len(tpr)
200 | auc_ = [f"AUC: {roc_auc:.3f}"]*len(tpr)
201 | clf_name = get_clf_name(clf)
202 | clf_ = [f"{clf_name}, AUC: {roc_auc:.3f}"]*len(tpr)
203 |
204 | # create bokeh column source for plotting new ROC
205 | source_ROC = ColumnDataSource(data=dict(x_fpr=fpr,
206 | y_tpr=tpr,
207 | thresh=thresholds,
208 | auc_legend=auc_,
209 | clf_legend=clf_))
210 |
211 | # create output dataframe with TPR and FPR, and metadata
212 | df_ROC = pd.DataFrame({'TPR':tpr, 'FPR':fpr, 'Thresholds':thresholds})
213 | df_ROC.auc = roc_auc
214 | df_ROC.clf = get_clf_name(clf)
215 | df_ROC.score = clf.score(X_test, y_test)
216 |
217 | return source_ROC, df_ROC, clf
218 |
219 |
220 | def interpolate_mean_tpr(FPRs=None, TPRs=None, df_list=None):
221 | """
222 | mean_fpr, mean_tpr = interpolate_mean_tpr(FPRs=None, TPRs=None, df_list=None)
223 |
224 | FPRs: False positive rates (list of n arrays)
225 | TPRs: True positive rates (list of n arrays)
226 | df_list: DataFrames with TPR, FPR columns (list of n DataFrames)
227 | """
228 |
229 | # seed empty linspace
230 | mean_tpr, mean_fpr = 0, np.linspace(0, 1, 101)
231 |
232 | if TPRs and FPRs:
233 | for idx, PRs in enumerate(zip(FPRs, TPRs)):
234 | mean_tpr += np.interp(mean_fpr, PRs[0], PRs[1])
235 |
236 | elif df_list:
237 | for idx, df_ in enumerate(df_list):
238 | mean_tpr += np.interp(mean_fpr, df_.FPR, df_.TPR)
239 |
240 | else:
241 | print("Please give valid inputs.")
242 | return None, None
243 |
244 | # normalize by length of inputs (# indices looped over)
245 | mean_tpr /= (idx+1)
246 |
247 | # add origin point
248 | mean_fpr = np.insert(mean_fpr, 0, 0)
249 | mean_tpr = np.insert(mean_tpr, 0, 0)
250 |
251 | return mean_fpr, mean_tpr
252 |
253 |
254 | def plot_ROC(clf, X, y, test_size_:float=0.5, pos_label_:str=None, filename:str=None, verbose:bool=False):
255 | """
256 | clf, classifiers, df_ROCs = plot_ROC(clf, X, y, pos_label_=None, verbose=False)
257 |
258 | Plot an interactive ROC curve for a binary classifier.
259 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv.
260 |
261 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer())
262 | X: training + testing data
263 | y: targets (numeric/integers)
264 | test_size: fraction of data to be reserved for testing [default: 0.5]
265 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
266 | filename: if provided, save to html [default: None]
267 | verbose: print warnings [default: False]
268 | """
269 |
270 | """ Split and get ROC curve data """
271 | data = train_test_split(X, y, test_size=test_size_)
272 | source_ROC, df_ROC, clf = get_ROC_data(data, clf, pos_label_, verbose)
273 |
274 | """ Set up initial PLOT """
275 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there
276 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Thresh", "@thresh")])
277 |
278 | # Create your toolbox
279 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
280 |
281 | # Create figure and labels
282 | clf_name = get_clf_name(clf)
283 | p = figure(title=f'{clf_name} ROC curve', tools=p_tools)
284 | p.xaxis.axis_label = 'False Positive Rate'
285 | p.yaxis.axis_label = 'True Positive Rate'
286 |
287 | """ PLOT ROC """
288 | p.line('x_fpr', 'y_tpr', line_width=1, color="blue", source=source_ROC)
289 | p.circle('x_fpr', 'y_tpr', size=3, color="orange", legend='auc_legend', source=source_ROC, name='ROC')
290 |
291 | """ Plot Threshold==0.5 """
292 | # get value closest to threshold == 0.5
293 | df_half = df_ROC.dropna().iloc[(df_ROC['Thresholds'].dropna()-0.5).abs().argsort()[:2]]
294 | df_half['Legend'] = 'Thresh~0.5'
295 | source_half = ColumnDataSource(data=dict(x_fpr=df_half.FPR,
296 | y_tpr=df_half.TPR,
297 | thresh=df_half.Thresholds,
298 | legend_=df_half.Legend))
299 | p.circle('x_fpr', 'y_tpr', size=5, color="blue", source=source_half, legend="legend_", name='ROC')
300 |
301 | """ PLOT chance line """
302 | # Plot chance (tpr = fpr)
303 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance')
304 |
305 | # Finishing touches
306 | p.legend.location = "bottom_right"
307 |
308 | """ save and show """
309 | if filename:
310 | output_file(filename)
311 | show(p)
312 |
313 | return clf, df_ROC
314 |
315 |
316 | def plot_ROC_CV(clf, X, y, cv_fold=3, pos_label_=None, verbose=False):
317 |
318 | """
319 | clf, classifiers, df_ROCs = plot_ROC_CV(clf, X, y, cv_fold=3, pos_label_=None, verbose=False)
320 |
321 | Plot an interactive ROC curve for a binary classifier with n='cv-fold' cross-validations.
322 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv,
323 | and precision_info, which is a tuple of (precision, recall, avg_precision) of types (array, array, float).
324 |
325 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer())
326 | X: training + testing data
327 | y: targets (numeric/integers)
328 | cv_fold: cross-validations to run [default: 3]
329 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
330 | verbose: print warnings [default: False]
331 | """
332 |
333 | """ Check cross-validations to run and get stratification """
334 | # Check cross-validation > 1 and get stratified data
335 | if cv_fold > 1:
336 | skf = StratifiedKFold(cv_fold)
337 | else:
338 | print(f"cv_fold must be greater than 1. You have input {cv_fold}")
339 | return clf
340 |
341 | """ Get source data for each ROC curve """
342 | # Loop over each split in the data and get source data, df, and clf
343 | source_ROCs, df_ROCs, classifiers = [], [], []
344 | for idx, val in enumerate(skf.split(X, y)):
345 | (train, test) = val
346 | data = (X[train], X[test], y[train], y[test]) # not that skf returns indices, not values
347 | source_, df_, clf_ = get_ROC_data(data, clf, pos_label_, verbose)
348 | source_ROCs.append(source_)
349 | df_ROCs.append(df_)
350 | classifiers.append(clf)
351 |
352 | """ Set up initial PLOT """
353 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there
354 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Threshold", "@thresh")])
355 |
356 | # Create your toolbox
357 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
358 |
359 | # Create figure and labels
360 | clf_name = get_clf_name(clf)
361 | p = figure(title=f'{clf_name} ROC curve with {cv_fold}-fold cross-validation', tools=p_tools)
362 | p.xaxis.axis_label = 'False Positive Rate'
363 | p.yaxis.axis_label = 'True Positive Rate'
364 |
365 | """ Get ROC CURVE for each iteration """
366 | # Set the matplotlib colorwheel as a cycle
367 | colors_ = cycle(list(cnames.keys()))
368 |
369 | # plot each ROC curve - loop over source_ROCs, colors_
370 | for _, val in enumerate(zip(source_ROCs, colors_)):
371 | (ROC, color_) = val
372 | p.line('x_fpr', 'y_tpr', line_width=1, color=color_, source=ROC)
373 | p.circle('x_fpr', 'y_tpr', size=10, color=color_, legend='auc_legend', source=ROC, name='ROC')
374 |
375 | """ Mean ROC and AUC for all curves and plot """
376 | # process inputs
377 | mean_fpr, mean_tpr = interpolate_mean_tpr(df_list=df_ROCs)
378 | mean_auc = auc(mean_fpr, mean_tpr)
379 | mean_legend = [f'Mean, AUC: {mean_auc:.3f}']*len(mean_tpr)
380 |
381 | # Create ColumnDataSource
382 | source_ROC_mean = ColumnDataSource(data=dict(x_fpr=mean_fpr,
383 | y_tpr=mean_tpr,
384 | auc_legend=mean_legend))
385 |
386 | # Plot mean ROC
387 | p.line('x_fpr', 'y_tpr', legend='auc_legend', color='black',
388 | line_width=3.33, line_alpha=0.33, line_dash='dashdot', source=source_ROC_mean, name='ROC')
389 |
390 | # Plot chance (tpr = fpr)
391 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance')
392 |
393 | # Finishing touches
394 | p.legend.location = "bottom_right"
395 | show(p)
396 |
397 | return clf, classifiers, df_ROCs
398 |
399 |
400 | def plot_ROC_clfs(classifiers, X, y, test_size=0.33, pos_label_=None, verbose=False):
401 |
402 | """
403 | clf, classifiers, df_ROCs, precision_info = plot_ROC_clfs(clf, X, y,pos_label_=None, verbose=False)
404 |
405 | Plot an interactive ROC curve for a binary classifier with n='cv-fold' cross-validations.
406 | It returns the original clf, a classifier for each cv, and a list of dataframes for each cv.
407 | precision_info is a tuple of (precision, recall, avg_precision) of types (array, array, float).
408 |
409 | classifiers: list of untrained classifiers
410 | X: training + testing data
411 | y: targets (numeric/integers)
412 | test_size: test size for train_test_split (0 < x < 1)
413 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
414 | verbose: print warnings [default: False]
415 | """
416 |
417 | """ Get source data for each ROC curve """
418 |
419 | # Get training and test data
420 | (data_) = train_test_split(X, y, test_size=test_size)
421 |
422 | # Loop over each CLASSIFIER now -- note that we don't redefine our classifiers
423 | source_ROCs, df_ROCs = [], []
424 | for _, clf_ in enumerate(classifiers):
425 | source_, df_, clf_ = get_ROC_data(data_, clf_, pos_label_, verbose)
426 | source_ROCs.append(source_)
427 | df_ROCs.append(df_)
428 |
429 | """ Set up initial PLOT """
430 |
431 | # Create custom HoverTool -- we'll name each ROC curve 'ROC' so we only see info on hover there
432 | hover_ = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"), ("FPR", "@x_fpr"), ("Threshold", "@thresh")])
433 |
434 | # Create your toolbox
435 | p_tools = [hover_, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
436 |
437 | # Create figure and labels
438 | p = figure(title=f'Benchmarking {len(classifiers)} classifiers', tools=p_tools)
439 | p.xaxis.axis_label = 'False Positive Rate'
440 | p.yaxis.axis_label = 'True Positive Rate'
441 |
442 | """ Get ROC CURVE for each iteration """
443 |
444 | # Set the matplotlib colorwheel as a cycle
445 | colors_ = cycle(list(cnames.keys()))
446 |
447 | # loop over source, color and plot each ROC curve
448 | for _, val in enumerate(zip(source_ROCs, colors_)):
449 | (ROC, color_) = val
450 | p.line('x_fpr', 'y_tpr', line_width=1, color=color_, source=ROC)
451 | p.circle('x_fpr', 'y_tpr', size=5, color=color_, legend='clf_legend', source=ROC, name='ROC')
452 |
453 | """ Mean ROC and AUC for all curves and plot """
454 |
455 | # process mean values, legend, ColumnDataSource
456 | mean_fpr, mean_tpr = interpolate_mean_tpr(df_list=df_ROCs)
457 | mean_auc = auc(mean_fpr, mean_tpr)
458 | mean_legend = [f'Mean, AUC: {mean_auc:.3f}']*len(mean_tpr)
459 | source_ROC_mean = ColumnDataSource(data=dict(x_fpr = mean_fpr, y_tpr = mean_tpr, roc_legend=mean_legend))
460 |
461 | # PLOT mean ROC
462 | p.line('x_fpr', 'y_tpr', legend='roc_legend', color='black',
463 | line_width=5, line_alpha=0.3, line_dash='dashed', source=source_ROC_mean, name='ROC')
464 |
465 | # PLOT chance (tpr = fpr)
466 | p.line([0, 1], [0, 1], line_dash='dashed', line_width=0.2, color='black', name='Chance')
467 |
468 | # Finishing touches
469 | p.legend.location = "bottom_right"
470 | show(p)
471 |
472 | # Print scores
473 | print("Scores:")
474 | # Get scores for each classifier:
475 | for i, df_ in enumerate(df_ROCs):
476 | print(df_.clf, np.round(df_.score, decimals=3))
477 |
478 | return classifiers, df_ROCs
479 |
480 |
481 | def get_ROC_PR_data(data, clf, pos_label_=None, verbose=False):
482 |
483 | """
484 | source, df, clf = get_ROC_PR_data(data, clf, verbose=False)
485 |
486 | get_ROC_data will return ColumnDataSource and dataframes with TPR and FPR
487 | for a particular dataset and an untrained classifier. The CSD can be used
488 | to plot a Bokeh plot while the dataframe can be used for additional
489 | exploration and plotting with other libs. Note that the dataframes
490 | are returned with metadata (e.g. AUC and the clf used).
491 |
492 | data: tuple of our data (X_train, X_test, y_train, y_test)
493 | where each item in the tuple is a numpy ndarray
494 | clf: an untrained classifier (e.g. rf = RandomForestClassifier())
495 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
496 | verbose: print warnings [default: False]
497 | """
498 |
499 | # split data into training, testing
500 | (X_train, X_test, y_train, y_test) = data
501 |
502 | # train and retrieve probabilities of class per feature for the test data
503 | probas = clf.fit(X_train, y_train).predict_proba(X_test)
504 |
505 | # get false and true positive rates for positive labels
506 | # (and thresholds, which is not used but shown here for fyi)
507 | if not pos_label_:
508 | pos_label_ = np.max(y_train)
509 | if verbose:
510 | print(f"Warning: Maximum target value of '{pos_label_}' used as positive.")
511 | print("You can use 'pos_label_' to indicate your own.")
512 |
513 | """ ROC """
514 | fpr, tpr, roc_thresholds = roc_curve(y_test, probas[:,1], pos_label=pos_label_)
515 | roc_thresholds[0] = np.nan
516 |
517 | # get area under the curve (AUC)
518 | roc_auc = auc(fpr, tpr)
519 |
520 | """ PR """
521 | # get precision and recall values
522 | precision, recall, pr_thresholds = precision_recall_curve(y_test, probas[:,1], pos_label=pos_label_)
523 | pr_thresholds = np.insert(pr_thresholds, 0, 0) # do this to correct lengths
524 |
525 | # average precision score
526 | avg_precision = average_precision_score(y_test, probas[:,1])
527 |
528 | # precision auc
529 | pr_auc = auc(recall, precision)
530 |
531 |
532 | """ Create Sources """
533 | # create legend variables - we'll create an array with len(tpr)
534 | roc_auc_ = [f"AUC: {roc_auc:.3f}"]*len(tpr)
535 | pr_auc_ = [f"AUC: {pr_auc:.3f}"]*len(precision)
536 | clf_name = get_clf_name(clf)
537 | clf_roc = [f"{clf_name}, AUC: {roc_auc:.3f}"]*len(tpr)
538 | clf_pr = [f"{clf_name}, AUC: {pr_auc:.3f}"]*len(precision)
539 |
540 | # create bokeh column source for plotting new ROC
541 | source_ROC = ColumnDataSource(data=dict(x_fpr=fpr,
542 | y_tpr=tpr,
543 | thresh_roc=roc_thresholds,
544 | auc_legend=roc_auc_,
545 | clf_legend=clf_roc))
546 |
547 | source_PR = ColumnDataSource(data=dict(x_rec=recall,
548 | y_prec=precision,
549 | thresh_pr=pr_thresholds,
550 | auc_legend=pr_auc_,
551 | clf_legend=clf_pr))
552 |
553 | """ Dataframes """
554 | # create output dataframe with TPR and FPR, and metadata
555 | df_ROC = pd.DataFrame({'TPR':tpr, 'FPR':fpr, 'Thresholds':roc_thresholds})
556 | df_ROC.auc = roc_auc
557 | df_ROC.clf = get_clf_name(clf)
558 | df_ROC.score = clf.score(X_test, y_test)
559 |
560 | # create output dataframe with TPR and FPR, and metadata
561 | df_PR = pd.DataFrame({'Recall':recall, 'Precision':precision, 'Thresholds':pr_thresholds})
562 | df_PR.auc = pr_auc
563 | df_PR.clf = get_clf_name(clf)
564 | df_PR.score = clf.score(X_test, y_test)
565 |
566 | return source_ROC, source_PR, df_ROC, df_PR, clf
567 |
568 |
569 | def plot_ROC_PR(clf, X, y, test_size_:float=0.5, pos_label_:str=None, filename:str=None, verbose:bool=False):
570 | """
571 | clf, classifiers, df_ROCs = plot_ROC(clf, X, y, pos_label_=None, verbose=False)
572 |
573 | Plot an interactive ROC curve for a binary classifier.
574 | It returns the original clf, a classifier for each cv, a list of dataframes for each cv.
575 |
576 | clf: untrained classifier object (e.g. rf_clf = RandomForestClassifer())
577 | X: training + testing data
578 | y: targets (numeric/integers)
579 | test_size: fraction of data to be reserved for testing [default: 0.5]
580 | pos_label_: if targets are not binary (0, 1) then indicate integer for "positive" [default: None]
581 | filename: if provided, save to html [default: None]
582 | verbose: print warnings [default: False]
583 | """
584 |
585 | """ Split and get ROC curve data """
586 | data = train_test_split(X, y, test_size=test_size_)
587 | source_ROC, source_PR, df_ROC, df_PR, clf = get_ROC_PR_data(data, clf, pos_label_, verbose)
588 |
589 |
590 | """ PLOT ROC """
591 |
592 | # Create custom HoverTool -- we'll make one for each curve
593 | hover_ROC = HoverTool(names=['ROC'], tooltips=[("TPR", "@y_tpr"),
594 | ("FPR", "@x_fpr"),
595 | ("Thresh", "@thresh_roc"),
596 | ])
597 |
598 | # Create your toolbox
599 | p_tools_ROC = [hover_ROC, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
600 |
601 | clf_name = get_clf_name(clf)
602 | p1 = figure(title=f'{clf_name} ROC curve', tools=p_tools_ROC)
603 | p1.xaxis.axis_label = 'False Positive Rate'
604 | p1.yaxis.axis_label = 'True Positive Rate'
605 |
606 | # plot curve and datapts
607 | p1.line('x_fpr', 'y_tpr', line_width=1, color="blue", source=source_ROC)
608 | p1.circle('x_fpr', 'y_tpr', size=3, color="orange", legend='auc_legend', source=source_ROC, name='ROC')
609 |
610 | # highlight values closest to threshold == 0.5
611 | df_half = df_ROC.dropna().iloc[(df_ROC['Thresholds'].dropna()-0.5).abs().argsort()[:2]]
612 | df_half['Legend'] = 'Thresh~0.5'
613 | source_half = ColumnDataSource(data=dict(x_fpr=df_half.FPR,
614 | y_tpr=df_half.TPR,
615 | thresh_roc=df_half.Thresholds,
616 | legend_=df_half.Legend))
617 | p1.circle('x_fpr', 'y_tpr', size=5, color="blue", source=source_half, legend="legend_", name='ROC')
618 |
619 | # Plot chance (tpr = fpr)
620 | p1.line([0, 1], [0, 1], line_dash='dashed', line_width=0.5, color='black', name='Chance')
621 |
622 | # Finishing touches
623 | p1.legend.location = "bottom_right"
624 |
625 | """ PLOT PR """
626 |
627 | # Create custom HoverTool -- we'll make one for each curve
628 | hover_PR = HoverTool(names=['PR'], tooltips=[("Precision", "@y_prec"),
629 | ("Recall", "@x_rec"),
630 | ("Thresh", "@thresh_pr")
631 | ])
632 |
633 | # Create your toolbox
634 | p_tools_PR = [hover_PR, 'crosshair', 'zoom_in', 'zoom_out', 'save', 'reset', 'tap', 'box_zoom']
635 |
636 | p2 = figure(title=f'{clf_name} PR curve', tools=p_tools_PR)
637 | p2.xaxis.axis_label = 'Recall'
638 | p2.yaxis.axis_label = 'Precision'
639 |
640 | p2.line('x_rec', 'y_prec', line_width=1, color="blue", source=source_PR)
641 | p2.circle('x_rec', 'y_prec', size=3, color="orange", legend='auc_legend', source=source_PR, name='PR')
642 |
643 | # highlight values closest to threshold == 0.5
644 | df_half = df_PR.dropna().iloc[(df_PR['Thresholds'].dropna()-0.5).abs().argsort()[:2]]
645 | df_half['Legend'] = 'Thresh~0.5'
646 | source_half = ColumnDataSource(data=dict(x_rec=df_half.Recall,
647 | y_prec=df_half.Precision,
648 | thresh_pr=df_half.Thresholds,
649 | legend_=df_half.Legend))
650 | p2.circle('x_rec', 'y_prec', size=5, color="blue", source=source_half, legend="legend_", name='PR')
651 |
652 | # Plot chance (prec = rec)
653 | p2.line([0, 1], [1, 0], line_dash='dashed', line_width=0.5, color='black', name='Chance')
654 |
655 | # Finishing touches
656 | p2.legend.location = "bottom_left"
657 |
658 | """ save and show """
659 | if filename:
660 | output_file(filename)
661 | show(row(p1, p2))
662 |
663 | return clf, df_ROC
664 |
665 |
666 | if __name__ == '__main__':
667 |
668 | # get UCI Breast Cancer Data
669 | data = datasets.load_breast_cancer()
670 | X = data.data
671 | y = data.target
672 | target_names = list(data.target_names)
673 |
674 | # Classifier plots
675 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
676 | rf_clf = RandomForestClassifier(n_estimators=100)
677 | rf_clf.fit(X_train, y_train)
678 | classifier_plots(rf_clf, X_test, y_test, target_names, ylog=False)
679 |
680 | # ROC curve for synthetic data
681 | rf_clf = RandomForestClassifier(n_estimators=100)
682 | rf_clf, df_rf_ROC = plot_ROC(rf_clf, X, y, test_size_=0.5)
683 |
684 | # Interactive ROC curve with cross-validation
685 | rf_clf = RandomForestClassifier(n_estimators=100)
686 | rf_clf, rf_cv_clfs, df_rf_ROCs = plot_ROC_CV(rf_clf, X, y)
687 |
688 | # Benchmark classifiers
689 | rf_bench = RandomForestClassifier(random_state=42, n_estimators=100)
690 | lr_bench = LogisticRegression(random_state=42, solver='saga')
691 | gnb_bench = GaussianNB(priors=None, var_smoothing=1e-06)
692 | clfs_benchmark, dfs_bench = plot_ROC_clfs(classifiers=[rf_bench, lr_bench, gnb_bench], X=X, y=y)
693 |
694 | # Interactive ROC and PR curves
695 | rf_clf = RandomForestClassifier(n_estimators=100)
696 | rf_clf, df_rf_ROC_PR = plot_ROC_PR(rf_clf, X, y, test_size_=0.33)
697 |
--------------------------------------------------------------------------------
/plots/html/roc_cvfold/bokeh-1.0.4.min.css:
--------------------------------------------------------------------------------
1 | .bk-root{font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:10pt;position:relative;width:auto;height:auto}.bk-root .bk-shading{position:absolute;display:block;border:1px dashed green}.bk-root .bk-tile-attribution a{color:black}.bk-root .bk-tool-icon-box-select{background-image:url("")}.bk-root .bk-tool-icon-box-zoom{background-image:url("")}.bk-root .bk-tool-icon-zoom-in{background-image:url("")}.bk-root .bk-tool-icon-zoom-out{background-image:url("")}.bk-root .bk-tool-icon-help{background-image:url("")}.bk-root .bk-tool-icon-hover{background-image:url("")}.bk-root .bk-tool-icon-crosshair{background-image:url("")}.bk-root .bk-tool-icon-lasso-select{background-image:url("")}.bk-root .bk-tool-icon-pan{background-image:url("")}.bk-root .bk-tool-icon-xpan{background-image:url("")}.bk-root .bk-tool-icon-ypan{background-image:url("")}.bk-root .bk-tool-icon-range{background-image:url("")}.bk-root .bk-tool-icon-polygon-select{background-image:url("")}.bk-root .bk-tool-icon-redo{background-image:url("")}.bk-root .bk-tool-icon-reset{background-image:url("")}.bk-root .bk-tool-icon-save{background-image:url("")}.bk-root .bk-tool-icon-tap-select{background-image:url("")}.bk-root .bk-tool-icon-undo{background-image:url("")}.bk-root .bk-tool-icon-wheel-pan{background-image:url("")}.bk-root .bk-tool-icon-wheel-zoom{background-image:url("")}.bk-root .bk-tool-icon-box-edit{background-image:url("")}.bk-root .bk-tool-icon-freehand-draw{background-image:url("")}.bk-root .bk-tool-icon-poly-draw{background-image:url("")}.bk-root .bk-tool-icon-point-draw{background-image:url("")}.bk-root .bk-tool-icon-poly-edit{background-image:url("")}.bk-root .bk-grid-row,.bk-root .bk-grid-column{display:flex;display:-webkit-flex;flex-wrap:nowrap;-webkit-flex-wrap:nowrap}.bk-root .bk-grid-row>*,.bk-root .bk-grid-column>*{flex-shrink:0;-webkit-flex-shrink:0}.bk-root .bk-grid-row{flex-direction:row;-webkit-flex-direction:row}.bk-root .bk-grid-column{flex-direction:column;-webkit-flex-direction:column}.bk-root .bk-canvas-wrapper{position:relative;font-size:12pt}.bk-root .bk-canvas,.bk-root .bk-canvas-overlays,.bk-root .bk-canvas-events{position:absolute;top:0;left:0;width:100%;height:100%}.bk-root .bk-canvas-map{position:absolute;border:0}.bk-root .bk-logo{margin:5px;position:relative;display:block;background-repeat:no-repeat}.bk-root .bk-logo.bk-grey{filter:url("data:image/svg+xml;utf8,#grayscale");filter:gray;-webkit-filter:grayscale(100%)}.bk-root .bk-logo-notebook{display:inline-block;vertical-align:middle;margin-right:5px}.bk-root .bk-logo-small{width:20px;height:20px;background-image:url()}.bk-root .bk-toolbar,.bk-root .bk-toolbar *{box-sizing:border-box;margin:0;padding:0}.bk-root .bk-toolbar-hidden{visibility:hidden;opacity:0;transition:visibility .3s linear,opacity .3s linear}.bk-root .bk-toolbar,.bk-root .bk-button-bar{display:flex;display:-webkit-flex;flex-wrap:nowrap;-webkit-flex-wrap:nowrap;align-items:center;-webkit-align-items:center;user-select:none;-moz-user-select:none;-webkit-user-select:none;-ms-user-select:none}.bk-root .bk-toolbar .bk-logo{flex-shrink:0;-webkit-flex-shrink:0}.bk-root .bk-toolbar-above,.bk-root .bk-toolbar-below{flex-direction:row;-webkit-flex-direction:row;justify-content:flex-end;-webkit-justify-content:flex-end}.bk-root .bk-toolbar-above .bk-button-bar,.bk-root .bk-toolbar-below .bk-button-bar{display:flex;display:-webkit-flex;flex-direction:row;-webkit-flex-direction:row}.bk-root .bk-toolbar-above .bk-logo,.bk-root .bk-toolbar-below .bk-logo{order:1;-webkit-order:1;margin-left:5px}.bk-root .bk-toolbar-left,.bk-root .bk-toolbar-right{flex-direction:column;-webkit-flex-direction:column;justify-content:flex-start;-webkit-justify-content:flex-start}.bk-root .bk-toolbar-left .bk-button-bar,.bk-root .bk-toolbar-right .bk-button-bar{display:flex;display:-webkit-flex;flex-direction:column;-webkit-flex-direction:column}.bk-root .bk-toolbar-left .bk-logo,.bk-root .bk-toolbar-right .bk-logo{order:0;-webkit-order:0;margin-bottom:5px}.bk-root .bk-toolbar-button{width:30px;height:30px;background-size:60%;background-color:transparent;background-repeat:no-repeat;background-position:center center}.bk-root .bk-toolbar-button:hover{background-color:#f9f9f9}.bk-root .bk-toolbar-button:focus{outline:0}.bk-root .bk-toolbar-button::-moz-focus-inner{border:0}.bk-root .bk-toolbar-above .bk-toolbar-button{border-bottom:2px solid transparent}.bk-root .bk-toolbar-above .bk-toolbar-button.bk-active{border-bottom-color:#26aae1}.bk-root .bk-toolbar-below .bk-toolbar-button{border-top:2px solid transparent}.bk-root .bk-toolbar-below .bk-toolbar-button.bk-active{border-top-color:#26aae1}.bk-root .bk-toolbar-right .bk-toolbar-button{border-left:2px solid transparent}.bk-root .bk-toolbar-right .bk-toolbar-button.bk-active{border-left-color:#26aae1}.bk-root .bk-toolbar-left .bk-toolbar-button{border-right:2px solid transparent}.bk-root .bk-toolbar-left .bk-toolbar-button.bk-active{border-right-color:#26aae1}.bk-root .bk-button-bar+.bk-button-bar:before{content:" ";display:inline-block;background-color:lightgray}.bk-root .bk-toolbar-above .bk-button-bar+.bk-button-bar:before,.bk-root .bk-toolbar-below .bk-button-bar+.bk-button-bar:before{height:10px;width:1px}.bk-root .bk-toolbar-left .bk-button-bar+.bk-button-bar:before,.bk-root .bk-toolbar-right .bk-button-bar+.bk-button-bar:before{height:1px;width:10px}.bk-root .bk-tooltip{font-family:"HelveticaNeue-Light","Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif;font-weight:300;font-size:12px;position:absolute;padding:5px;border:1px solid #e5e5e5;color:#2f2f2f;background-color:white;pointer-events:none;opacity:.95}.bk-root .bk-tooltip>div:not(:first-child){margin-top:5px;border-top:#e5e5e5 1px dashed}.bk-root .bk-tooltip.bk-left.bk-tooltip-arrow::before{position:absolute;margin:-7px 0 0 0;top:50%;width:0;height:0;border-style:solid;border-width:7px 0 7px 0;border-color:transparent;content:" ";display:block;left:-10px;border-right-width:10px;border-right-color:#909599}.bk-root .bk-tooltip.bk-left::before{left:-10px;border-right-width:10px;border-right-color:#909599}.bk-root .bk-tooltip.bk-right.bk-tooltip-arrow::after{position:absolute;margin:-7px 0 0 0;top:50%;width:0;height:0;border-style:solid;border-width:7px 0 7px 0;border-color:transparent;content:" ";display:block;right:-10px;border-left-width:10px;border-left-color:#909599}.bk-root .bk-tooltip.bk-right::after{right:-10px;border-left-width:10px;border-left-color:#909599}.bk-root .bk-tooltip.bk-above::before{position:absolute;margin:0 0 0 -7px;left:50%;width:0;height:0;border-style:solid;border-width:0 7px 0 7px;border-color:transparent;content:" ";display:block;top:-10px;border-bottom-width:10px;border-bottom-color:#909599}.bk-root .bk-tooltip.bk-below::after{position:absolute;margin:0 0 0 -7px;left:50%;width:0;height:0;border-style:solid;border-width:0 7px 0 7px;border-color:transparent;content:" ";display:block;bottom:-10px;border-top-width:10px;border-top-color:#909599}.bk-root .bk-tooltip-row-label{text-align:right;color:#26aae1}.bk-root .bk-tooltip-row-value{color:default}.bk-root .bk-tooltip-color-block{width:12px;height:12px;margin-left:5px;margin-right:5px;outline:#ddd solid 1px;display:inline-block}.rendered_html .bk-root .bk-tooltip table,.rendered_html .bk-root .bk-tooltip tr,.rendered_html .bk-root .bk-tooltip th,.rendered_html .bk-root .bk-tooltip td{border:0;padding:1px}
--------------------------------------------------------------------------------