├── __init__.py ├── ReadMe.md ├── classConfusion.py └── Documentation.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | from .classConfusion import * 2 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # ClassConfusion 2 | 3 | Class Confusion was was designed to help extrapolate your models decisions through visuals such as graphs or confusion matrices that go more in-depth than the standard plot_confusion_matrix. Class Confusion can be used with both Tabular and Image classification models 4 | 5 | To utilize this function, input in the ClassificationInterpretation object as well as a list of classes you want to examine: 6 | ```python3 7 | from ClassLosses import * 8 | ClassLosses(interp, classList) 9 | ``` 10 | 11 | You can also pass in direct class combinations you want to see as well, just make them a list of tuples as such: 12 | 13 | ```python3 14 | comboList = [('<50k', '>=50k')] 15 | ClassLosses(interp, comboList, is_ordered=True) 16 | ``` 17 | 18 | Please read the Documentation for a guide to how to utilize this function. 19 | 20 | Some example outputs: 21 | ![](https://camo.githubusercontent.com/dc2f4b6e86db5e41274b60e605de25dd3a29ee27/68747470733a2f2f692e696d6775722e636f6d2f6a41453642566d2e706e67) 22 | 23 | ![](https://camo.githubusercontent.com/cefb9ee9dd7ed469afff8b899040a8330ca043df/68747470733a2f2f692e696d6775722e636f6d2f695555537032412e706e67) 24 | -------------------------------------------------------------------------------- /classConfusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import re 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | from itertools import permutations 8 | from fastai.train import ClassificationInterpretation 9 | from google.colab import widgets 10 | 11 | class ClassConfusion(): 12 | "Plot the most confused datapoints and statistics for the models misses." 13 | def __init__(self, interp:ClassificationInterpretation, classlist:list, 14 | is_ordered:bool=False, cut_off:int=100, varlist:list=None, 15 | figsize:tuple=(8,8)): 16 | self.interp = interp 17 | self._is_tab = (str(type(interp.learn.data)) == "") 18 | if self._is_tab: 19 | if interp.learn.data.train_ds.x.cont_names != []: 20 | for x in range(len(interp.learn.data.procs)): 21 | if "Normalize" in str(interp.learn.data.procs[x]): 22 | self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means 23 | self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds 24 | self.is_ordered = is_ordered 25 | self.cut_off = cut_off 26 | self.figsize = figsize 27 | self.classl = classlist 28 | self.varlist = varlist 29 | self._show_losses(classlist) 30 | 31 | def _show_losses(self, classl:list, **kwargs): 32 | "Checks if the model is for Tabular or Images and gathers top losses" 33 | _, self.tl_idx = self.interp.top_losses(len(self.interp.losses)) 34 | self._tab_losses() if self._is_tab else self._create_tabs() 35 | 36 | def _create_tabs(self): 37 | "Creates a tab for each variable" 38 | self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2)) 39 | if self._is_tab: 40 | self._boxes = len(self.df_list) 41 | self._cols = math.ceil(math.sqrt(self._boxes)) 42 | self._rows = math.ceil(self._boxes/self._cols) 43 | self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist 44 | else: 45 | vals = self.interp.most_confused() 46 | self._ranges = [] 47 | self.tbnames = [] 48 | self._boxes = int(input('Please enter a value for `k`, or the top images you will see: ')) 49 | for x in iter(vals): 50 | for y in range(len(self.lis)): 51 | if x[0:2] == self.lis[y]: 52 | self._ranges.append(x[2]) 53 | self.tbnames.append(str(x[0] + ' | ' + x[1])) 54 | 55 | self.tb = widgets.TabBar(self.tbnames) 56 | self._populate_tabs() 57 | 58 | def _populate_tabs(self): 59 | "Adds relevant graphs to each tab" 60 | with tqdm(total=len(self.tbnames)) as pbar: 61 | for i, tab in enumerate(self.tbnames): 62 | with self.tb.output_to(i): 63 | self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i) 64 | pbar.update(1) 65 | 66 | def _plot_imgs(self, tab:str, i:int ,**kwargs): 67 | "Plots the most confused images" 68 | classes_gnd = self.interp.data.classes 69 | x = 0 70 | if self._ranges[i] < self._boxes: 71 | cols = math.ceil(math.sqrt(self._ranges[i])) 72 | rows = math.ceil(self._ranges[i]/cols) 73 | 74 | if self._ranges[i] < 4 or self._boxes < 4: 75 | cols = 2 76 | rows = 2 77 | else: 78 | cols = math.ceil(math.sqrt(self._boxes)) 79 | rows = math.ceil(self._boxes/cols) 80 | fig, ax = plt.subplots(rows, cols, figsize=self.figsize) 81 | 82 | [axi.set_axis_off() for axi in ax.ravel()] 83 | for j, idx in enumerate(self.tl_idx): 84 | if self._boxes < x+1 or x > self._ranges[i]: 85 | break 86 | da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] 87 | row = (int)(x / cols) 88 | col = x % cols 89 | 90 | ix = int(cl) 91 | if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]: 92 | img, lbl = self.interp.data.valid_ds[idx] 93 | fn = self.interp.data.valid_ds.x.items[idx] 94 | fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0) 95 | img.show(ax=ax[row, col]) 96 | ax[row,col].set_title(fn) 97 | x += 1 98 | plt.show(fig) 99 | plt.tight_layout() 100 | 101 | def _plot_tab(self, tab:str): 102 | "Generates graphs" 103 | if self._boxes is not None: 104 | fig, ax = plt.subplots(self._boxes, figsize=self.figsize) 105 | else: 106 | fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize) 107 | fig.subplots_adjust(hspace=.5) 108 | for j, x in enumerate(self.df_list): 109 | title = f'{"".join(x.columns[-1])} {tab} distribution' 110 | if self._boxes is None: 111 | row = int(j / self._cols) 112 | col = j % row 113 | if tab in self.cat_names: 114 | vals = pd.value_counts(x[tab].values) 115 | if self._boxes is not None: 116 | if vals.nunique() < 10: 117 | fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75) 118 | elif vals.nunique() > self.cut_off: 119 | print(f'Number of values is above {self.cut_off}') 120 | else: 121 | fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75) 122 | else: 123 | fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75) 124 | else: 125 | vals = x[tab] 126 | if self._boxes is not None: 127 | axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency') 128 | else: 129 | axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency') 130 | axs.set_ylabel('Frequency') 131 | if len(set(vals)) > 1: 132 | vals.plot(kind='kde', ax=axs, title=title, secondary_y=True) 133 | else: 134 | print('Less than two unique values, cannot graph the KDE') 135 | plt.show(fig) 136 | plt.tight_layout() 137 | 138 | def _tab_losses(self, **kwargs): 139 | "Gathers dataframes of the combinations data" 140 | classes = self.interp.data.classes 141 | cat_names = self.interp.data.x.cat_names 142 | cont_names = self.interp.data.x.cont_names 143 | comb = self.classl if self.is_ordered else list(permutations(self.classl,2)) 144 | 145 | self.df_list = [] 146 | arr = [] 147 | for i, idx in enumerate(self.tl_idx): 148 | da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx] 149 | res = '' 150 | for c, n in zip(da.cats, da.names[:len(da.cats)]): 151 | string = f'{da.classes[n][c]}' 152 | if string == 'True' or string == 'False': 153 | string += ';' 154 | res += string 155 | 156 | else: 157 | string = string[1:] 158 | res += string + ';' 159 | for c, n in zip(da.conts, da.names[len(da.cats):]): 160 | res += f'{c:.4f};' 161 | arr.append(res) 162 | f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) 163 | for i, var in enumerate(self.interp.data.cont_names): 164 | f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) 165 | f['Original'] = 'Original' 166 | self.df_list.append(f) 167 | 168 | for j, x in enumerate(comb): 169 | arr = [] 170 | for i, idx in enumerate(self.tl_idx): 171 | da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] 172 | cl = int(cl) 173 | 174 | if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]: 175 | res = '' 176 | for c, n in zip(da.cats, da.names[:len(da.cats)]): 177 | string = f'{da.classes[n][c]}' 178 | if string == 'True' or string == 'False': 179 | string += ';' 180 | res += string 181 | else: 182 | string = string[1:] 183 | res += string + ';' 184 | for c, n in zip(da.conts, da.names[len(da.cats):]): 185 | res += f'{c:.4f};' 186 | arr.append(res) 187 | f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) 188 | for i, var in enumerate(self.interp.data.cont_names): 189 | f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) 190 | f[str(x)] = str(x) 191 | self.df_list.append(f) 192 | self.cat_names = cat_names 193 | self._create_tabs() 194 | -------------------------------------------------------------------------------- /Documentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Class Confusion Widget\n", 8 | "This widget was designed to help extrapolate your models decisions through visuals such as graphs or confusion matrices that go more in-depth than the standard `plot_confusion_matrix`. Class Confusion can be used with **both** Tabular and Image classification models. (Note: Due to widgets not exporting well, there will be images instead showing the output. The code will still be there though for you to run!)\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "---\n", 16 | "\n", 17 | "# Images\n", 18 | "Before you can use the widget, we need to finish training our model and generate a `ClassificationInterpretation` object" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!git clone https://github.com/muellerzr/ClassConfusion.git\n", 28 | "from fastai.vision import *\n", 29 | "from ClassConfusion import *" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 5, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "path = untar_data(URLs.PETS)\n", 39 | "path_img = path/'images'\n", 40 | "fnames = get_image_files(path_img)\n", 41 | "pat = r'/([^/]+)_\\d+.jpg$'" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 7, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), \n", 51 | " size=224, bs=64).normalize(imagenet_stats)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 8, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "learn = cnn_learner(data, models.resnet34, metrics=error_rate)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 9, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/html": [ 71 | "\n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | "
epochtrain_lossvalid_losserror_ratetime
02.1838460.5084920.14468100:12
11.1357530.4189190.10638300:11
" 98 | ], 99 | "text/plain": [ 100 | "" 101 | ] 102 | }, 103 | "metadata": {}, 104 | "output_type": "display_data" 105 | } 106 | ], 107 | "source": [ 108 | "learn.fit_one_cycle(2)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "Class Confusion's constructor differs depending on our use-case. For both cases, we are interested in the `classlist`, `is_ordered`, and `figsize` variables.\n", 116 | "\n", 117 | "* `interp`: Either a Tabular or Image ClassificationInterpretation object\n", 118 | "\n", 119 | "\n", 120 | "* `classlist`: Here you pass in the list of classes you are interested in looking at. Depending on if you have specific combinations or not you want to try will determine how you pass them in. If we just want to look at all combinations between a few classes, we can pass their class names as a normal array, `['Abyssinian', 'Bengal', 'Birman']`. If we want to pass in a specific combination or three, we pass them in as a list of arrays or tuples, `[('Abyssian', 'Bengal'), ('Bengal', 'Birman')]`. Here we have what our **actual** class was first, and the **prediction** second.\n", 121 | "\n", 122 | "\n", 123 | "* `is_ordered`: This will determine whether to generate all the combinations from the set of names you passed in. If you have a specific listed set of combinations, we want `is_ordered` to be True.\n", 124 | "\n", 125 | "\n", 126 | "* `figsize`: This is a tuple for the size you want your photos to return as. Defaults to (8,8)\n", 127 | "\n", 128 | "Also when you call the function, it will ask for a `k` value. `k` is the same as `k` from `plot_top_losses`, which is the number of images you want to look at." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 12, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "interp = ClassificationInterpretation.from_learner(learn)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "Let's look at an example set for the 'Ragdoll', 'Birman', and 'Maine_Coon' classes" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "classlist = ['Ragdoll', 'Birman', 'Maine_Coon']\n", 154 | "ClassConfusion(interp, classlist, is_ordered=False, figsize=(8,8))" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "![](https://i.imgur.com/jAE6BVm.png)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "The output is now our confused images as well as their filenames, in case we want to go find those particular instances.\n", 169 | "\n", 170 | "Next, let's look at a set of classes in a particular order." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "classlist = [('Ragdoll', 'Birman'), ('British_Shorthair', 'Russian_Blue')]\n", 180 | "ClassConfusion(interp, classlist, is_ordered=True)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "![](https://i.imgur.com/EFLUEnQ.png)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "Now we are looking at exact cells from our Confusion Matrix!" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "---\n", 202 | "\n", 203 | "## Tabular\n", 204 | "Tabular has a bit more bells and whistles than the Images does. We'll look at the `ADULT_SAMPLE` dataset for an example. \n", 205 | "\n", 206 | "\n", 207 | "Along with the standard constructor items above, there are two more, `cut_off` and `varlist`:\n", 208 | "\n", 209 | "* `cut_off`: This is the cut-off number, an integer, for plotting categorical variables. It sets a maximum to 100 bars on the graph at a given moment, else it will defaultly show a `Number of values is above 100` messege, and move onto the next set.\n", 210 | "\n", 211 | "\n", 212 | "* `varlist`: This is a list of variables that you specifically want to look at. Defaulty ClassConfusion will use every variable that was used in the model, including `_na`'s." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 20, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "from fastai.tabular import *" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 22, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "path = untar_data(URLs.ADULT_SAMPLE)\n", 231 | "df = pd.read_csv(path/'adult.csv')" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 24, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "dep_var = 'salary'\n", 241 | "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", 242 | "cont_names = ['age', 'fnlwgt', 'education-num']\n", 243 | "procs = [FillMissing, Categorify, Normalize]" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 25, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 26, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n", 262 | " .split_by_idx(list(range(800,1000)))\n", 263 | " .label_from_df(cols=dep_var)\n", 264 | " .add_test(test)\n", 265 | " .databunch())" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 27, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 28, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/html": [ 285 | "\n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | "
epochtrain_lossvalid_lossaccuracytime
00.3666780.3975670.82500000:04
" 305 | ], 306 | "text/plain": [ 307 | "" 308 | ] 309 | }, 310 | "metadata": {}, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "learn.fit(1, 1e-2)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 29, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "interp = ClassificationInterpretation.from_learner(learn)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 30, 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGEVJREFUeJzt3Xm8VWW9x/HPV0FDxQFlksmrMoimIIhDV8UhlUGxMhWVnHLqem/mNUUrc6gks1veKOe8iBFq5QROXW9aEpiAiJqklgPCYQYVlEL73T/WOrDFM3E8z16cfb7v12u/ztrPevZav80+fM/zrLX23ooIzMxS2qToAsys8jlozCw5B42ZJeegMbPkHDRmlpyDxsySc9DYBpHURtKDkt6WdM8n2M7Jkh5rytqKIulASX8puo6NmXwdTWWSdBJwIdAHeBeYBXw3Ip76hNsdBfw7cEBEfPCJC93ISQqgZ0S8WnQtzZlHNBVI0oXAj4HvAR2B7sDPgBFNsPkewMstIWQaQlKromtoFiLCtwq6AdsAK4Ev1tFnc7Igmp/ffgxsnq8bDLwF/CewCKgCTs/XXQn8A1iT7+NM4ArgzpJt7wQE0Cq/fxrwN7JR1WvAySXtT5U87gDgGeDt/OcBJeueAK4GpuTbeQzYoZbnVl3/xSX1HwsMBV4GlgGXlfQfBEwFVuR9xwKb5et+nz+XVfnzPaFk+5cAC4Dx1W35Y3bJ97F3fn9HYAkwuOjfjUJ/L4suwLcmfkHhKOCD6v/otfS5CpgGdADaA38Ers7XDc4ffxXQOv8P+h6wXb5+/WCpNWiALYF3gN75us7A7vny2qAB2gHLgVH540bm97fP1z8B/BXoBbTJ74+p5blV1395Xv9ZwGJgAtAW2B1YDeyc9x8A7JfvdyfgJeCCku0FsGsN2/8+WWC3KQ2avM9Z+Xa2AB4Friv696Lom6dOlWd7YEnUPbU5GbgqIhZFxGKykcqokvVr8vVrIuIhsr/mvRtZzz+BPSS1iYiqiHixhj7DgFciYnxEfBARvwTmAEeX9Lk9Il6OiPeBu4F+dexzDdnxqDXARGAH4PqIeDff/4vAngARMSMipuX7fR24CTi4Ac/p2xHx97yej4iIW4BXgKfJwvUb9Wyv4jloKs9SYId6jh3sCLxRcv+NvG3tNtYLqveArTa0kIhYRTbdOBeokjRZUp8G1FNdU5eS+ws2oJ6lEfFhvlwdBAtL1r9f/XhJvSRNkrRA0jtkx7V2qGPbAIsjYnU9fW4B9gB+EhF/r6dvxXPQVJ6pZFODY+voM5/soG617nlbY6wimyJU61S6MiIejYjPkv1ln0P2H7C+eqprmtfImjbEDWR19YyIrYHLANXzmDpP1Uraiuy4123AFZLaNUWhzZmDpsJExNtkxyd+KulYSVtIai1piKRr826/BL4pqb2kHfL+dzZyl7OAgyR1l7QNcGn1CkkdJR0jaUvg72RTsA9r2MZDQC9JJ0lqJekEoC8wqZE1bYi2ZMeRVuajrfPWW78Q2HkDt3k9MCMivgxMBm78xFU2cw6aChQR/0V2Dc03yQ6EzgXOB+7Lu3wHmA7MBp4HZuZtjdnXb4G78m3N4KPhsAnZ2av5ZGdiDga+UsM2lgLD875Lyc4YDY+IJY2paQNdBJxEdjbrFrLnUuoKYJykFZKOr29jkkaQHZA/N2+6ENhb0slNVnEz5Av2zCw5j2jMLDkHjZkl56Axs+QcNGaWXIt+Q5hatQlt1rboMqwW/XfrXnQJVo+ZM2csiYj29fVr2UGzWVs2713vGUsryJSnxxZdgtWjTWutf0V3jTx1MrPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlompEbv30ybzx+DdPvuexj6y4YdRjvPzuW7bfdEoADB/Rkwe9/wLSJo5k2cTSXnn1Uuctt0c758hl037EDA/rtsbZt2bJlDDvqs+yxW0+GHfVZli9fXmCF5bVRB42k/5H0mqRZ+a1f3i5J/y3pVUmzJe2dtw+WNKnYqtMZ/+A0RvzbTz/W3rXjthy6Xx/erFr2kfYpz/6V/U4cw34njuGamx8pV5kGjDr1NO6f9NF/8+uuHcPgQw/jhZdeYfChh3HdtWMKqq78NrqgkbSZpC1Lmr4eEf3y26y8bQjQM7+dDdxQ7jqLMGXmX1n29nsfa7/2oi/wjevvIyIKqMpq8q8HHkS7du0+0jbpwfs5ZdSpAJwy6lQefOC+IkorxEYTNJJ2k/RD4C9Ar3q6jwDuiMw0YFtJndfb3j6SnpW0c6KSNwrDDv408xet4PmX531s3b57/gtP3zWa+8aex247dyqgOiu1aOFCOnfOfk07d+7M4kWLCq6ofAoNGklbSjpd0lPArcBLwJ4R8WxJt+/m06MfSdo8b+sCzC3p81beVr3dA4AbgRER8be0z6I4bT7VmkvOPJKrbpj8sXWz5syl99Bvse8JY7hh4pPc/aOzC6jQLFP0iKYKOBP4ckR8JiJujYh3S9ZfCvQB9gHaAZfk7aphW9Xzht2Am4GjI+LN9TtJOlvSdEnT44P3m+p5FGLnru3p0WV7/nTXpcyZfCVdOmzL1AmX0HH7try7ajWr3v8HAI8+9Wdat9p07YFiK0aHjh2pqqoCoKqqivYdOhRcUfkUHTTHAfOAeyVdLqlH6cqIqMqnR38HbgcG5aveArqVdO0KzM+Xq4DVQP+adhgRN0fEwIgYqFZtmvCplN+Lr86nx2GX0mfYt+kz7NvMW7SC/U/6PguXvkvH7duu7Tdw9x5sIrF0xaoCq7Vhw4/hzvHjALhz/DiGHz2i4IrKp1WRO4+Ix4DHJG0PnALcL2kJ2QjndUmdI6JKkoBjgRfyhz4AnC9pIrAv8HberzewgmyU9JikVRHxRLmfVyrjrjmNAwf0ZIdtt+LVR67m6hsfYtx9U2vs+7nD+3PWFw/kgw8/ZPXqNXzp0tvLXG3L9qVTRvKHJ59gyZIl7LJTV751+ZVcdPFoThl5PONuv41u3brzi4n3FF1m2WhjO1MhaRBQFRFzJf0f0J5sqjQLODciVubBMxY4CngPOD0ipksaDFwUEcMldQceBs6IiKdr2tcmW3SIzXsfX4ZnZY2x/JmxRZdg9WjTWjMiYmB9/Qod0dQkIv5UsnxoLX0C+Lca2p8AnsiX3wR2T1KkmW2Qoo/RmFkL4KAxs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkWtW2QtKDQNS2PiKOSVKRmVWcWoMGuK5sVZhZRas1aCLiyXIWYmaVq64RDQCSegLXAH2BT1W3R8TOCesyswrSkIPBtwM3AB8AhwB3AONTFmVmlaUhQdMmIh4HFBFvRMQVwKFpyzKzSlLv1AlYLWkT4BVJ5wPzgA5pyzKzStKQEc0FwBbAfwADgFHAqSmLMrPKUu+IJiKeyRdXAqenLcfMKlFDzjr9jhou3IsIH6cxswZpyDGai0qWPwV8gewMlJlZgzRk6jRjvaYpknwxn5k1WEOmTu1K7m5CdkC4U7KKymiPXt2Y9L9+p8XG6u331hRdgjWRhkydZpAdoxHZlOk14MyURZlZZWlI0OwWEatLGyRtnqgeM6tADbmO5o81tE1t6kLMrHLV9Xk0nYAuQBtJ/cmmTgBbk13AZ2bWIHVNnY4ETgO6Aj9kXdC8A1yWtiwzqyR1fR7NOGCcpC9ExK/LWJOZVZiGHKMZIGnb6juStpP0nYQ1mVmFaUjQDImIFdV3ImI5MDRdSWZWaRoSNJuWns6W1Abw6W0za7CGXEdzJ/C4pNvz+6cD49KVZGaVpiHvdbpW0mzgcLIzT48APVIXZmaVo6FfILcA+CfZO7cPA15KVpGZVZy6LtjrBZwIjASWAneRfW7wIWWqzcwqRF1TpznAH4CjI+JVAElfK0tVZlZR6po6fYFsyvQ7SbdIOox1VwebmTVYrUETEfdGxAlAH+AJ4GtAR0k3SDqiTPWZWQWo92BwRKyKiF9ExHCy9z3NAkYnr8zMKkZDzzoBEBHLIuImfzC5mW2IDQoaM7PGcNCYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoGmm5s+bywkjjuTQ/ftx+Gf25uc3jQXgR9//DoP22Jkhg/dlyOB9+b/fPlJwpS3TvLfm8vnhn+XAfT7NQfvuxS03/ASA5cuWcfyIIezfvy/HjxjCiuXLC660PBQRRddQI0mDgfuB1/Km30TEVfm6o4DrgU2BWyNiTN7+OjAwIpY0ZB979hsQkx6f0sSVl8fCBVUsWriAT+/Vn5Xvvsvwww7g5vF3M/m+X7PFlltyzvlfK7rET2zz1psWXUKjLVxQxcIFC9izX/b6HHHwvtw+4Vfc9Ys72G67dvz7hRfzk/+6lhUrlvOtq64putxG67TNZjMiYmB9/ZKNaCRt1wSb+UNE9Mtv1SGzKfBTYAjQFxgpqW8T7KtZ6dipM5/eqz8AW7Vty669+rCwan7BVVm1jp06s2e/da9Pz959WDB/Po8+9CDHnzQKgONPGsUjkx8ossyySTl1mi5pgqRDJakJtzsIeDUi/hYR/wAmAiNKO0hqI+kRSWc14X43WnPffIMXn59FvwH7AHDHbTdy5EH7cNF/nMPbK1rG0Hxj9uYbr/PC7OfYe+AgFi9eRMdOnYEsjJYsXlxwdeWRMmh6AROA84E/S7pM0o7VKyX9SNKsGm6jS7axv6TnJD0safe8rQswt6TPW3lbta2AB4EJEXHL+kVJOlvSdEnTly1t/i/yqpUrOfe0kVz+3R/Qtu3WnHL6Wfx++p95+Imn6dCxE1dfPrr+jVgyq1au5MujTuCqa66j7dZbF11OYVql2nBEfAhMAiZJag9cA7wp6YCI+FNE1HcQYSbQIyJWShoK3Af0BGoaHZUeaLofuDYiflFLXTcDN0N2jGaDntRGZs2aNZx7+kiOPe4Ehgw/FoD2HTquXT9y1BmccdLniyqvxVuzZg1njjqBzx8/kmHHfA6A9u07sHBBFR07dWbhgip2aN++4CrLI+lZJ0nbSDobeIBshHMmMDtfV+eIJiLeiYiV+fJDQGtJO5CNYLqV7KYrUHpwYgowpImnaxudiODir57Lrr16c9ZXvrq2feGCqrXLj06+n959Wtzhq41CRPC188+mZ+8+nHv+BWvbjxhyNHdPGA/A3RPGc+TQo4sqsaySnXWSdCewP3APcFtEvLKBj+8ELIyIkDQI+BXQg+xM08vAYcA84BngpIh4sfqsE/AtYLOIOK+ufTTns07PTJvCccMPp0/fPdhkk+zvxde/cSUP/OZu/vzCbCTRtVsPvvfDn6w9JtDcNOezTk9PncKIow5ht93XvT6XXn41ew8cxNmnnsS8t+bSpWs3bhn3S7Zr167gahuvoWedUgbNMcBDEfFBIx9/PnAe8AHwPnBhRPwxXzcU+DFZ6Pw8Ir6bt79OFjRLgZ8DiyPi4tr20ZyDpiVozkHTUhQeNM2Bg2bj5qDZ+BV+HY2ZWTUHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl56Axs+QcNGaWnIPGzJJz0JhZcg4aM0vOQWNmyTlozCw5B42ZJeegMbPkHDRmlpyDxsySc9CYWXIOGjNLzkFjZsk5aMwsOQeNmSXnoDGz5Bw0Zpacg8bMknPQmFlyioiiayiMpMXAG0XX0YR2AJYUXYTVqdJeox4R0b6+Ti06aCqNpOkRMbDoOqx2LfU18tTJzJJz0JhZcg6aynJz0QVYvVrka+RjNGaWnEc0Zpacg8bMknPQmFlyDppmTtJONbTtU/5KzGrnoGn+fiOpS/UdSQcDPy+wHluPpDNraBtTRC1FcdA0f+cA90nqJGkocD0wtOCa7KOOk3Ry9R1JPwPqvWy/kvj0dgWQtD9wE7AaGBYRiwsuyUpIagM8QDbSHAIsi4gLiq2qvBw0zZSkB4HSF68vUAUsB4iIY4qoy9aR1K7kblvgPmAKcDlARCwroq4iOGiaqfxYTK0i4sly1WI1k/Qa2R8D1bA6ImLnMpdUGAdNMyepI9CF7Bd6fkQsLLgks49x0DRTkvoBNwLbAPPy5q7ACuC8iHi2qNpsHUl9gBGU/DEA7o+IOYUWVmYOmmZK0izgnIh4er32/YCbImKvYiqzapIuAUYCE4G38uauwInAxIhoMae4HTTNlKRXIqJnLetejYhdy12TfZSkl4HdI2LNeu2bAS/W9vpVolZFF2CN9rCkycAdwNy8rRvwJeCRwqqyUv8EduTjHxfbOV/XYnhE04xJGsK6+b/IhucPRMRDhRZmAEg6ChgLvMK6PwbdgV2B8yOixfxBcNCYJSRpE2AQH/1j8ExEfFhoYWXmoGmmJO0ZEbPz5dbAJWS/0C8A34mI94qsz2omqV1LulCvmt/r1Hz9T8nyGLLh+A+BNmSnva1gkr5Zstw3Pzg8Q9LrkvYtsLSy84immZL0bET0z5dnAftExBpJAp6LiD2LrdAkzYyIvfPlycDYiHhY0iDgxxFxQLEVlo/POjVf20j6HNmodPPqU6gREZL812Pjs2NEPAwQEX/K32jZYjhomq8ngeo3Tk6T1DEiFkrqRGV9E2JztrOkB8gOAneVtEXJsbPWBdZVdp46mSVSwxtfZ0TEyvz9acdFxE+LqKsIDppmTNIWQM+IeK6krTvwYUTMq/2RZuXls07N2xqyj/LcsqTtVrIrT20jIuni0p8tjYOmGcsPAN8LnABrRzPtI2J6oYVZTU5c72eL4qBp/m4FTs+XvwTcXmAtVr+aPgSr4vmsUzMXEXMkIakX2UcS/GvRNZmtzyOaynAb2chmdkQsL7oYs/U5aCrD3cBeZIFjttHx1KkC5BeBbVN0HVanJ/KfvyuyiKL4OhozS85TJ7MykTQw/xjPFsdBY1YGkjoDfwSOL7qWInjqZFYGkkYDu5C9ZWRwweWUnUc0ZuUxCrgU2EzSLkUXU24OGrPEJB0CzImIJWRXbp9ZcEll56AxS+9M1l3jdBfwxfxDy1uMFvVkzcpN0rbAfkD1p+u9A0wDhhZZV7n5YLCZJecRjZkl56Axs+QcNPaJSfpQ0ixJL0i6J/+I0cZua7CkSfnyMfn1J7X13VbSVxqxjyskXdTYGm3DOWisKbwfEf0iYg/gH8C5pSuV2eDftYh4ICLG1NFlW2CDg8bKz0FjTe0PwK6SdpL0kqSfATOBbpKOkDRV0sx85LMVgKSjJM2R9BTw+eoNSTpN0th8uaOkeyU9l98OIPuGzl3y0dQP8n5fl/SMpNmSrizZ1jck/UXS/wK9y/avYYCDxpqQpFbAEOD5vKk3cEf+jZqrgG8Ch+ff3jgduFDSp4BbgKOBA4FOtWz+v4EnI2IvYG/gRWA08Nd8NPV1SUcAPcm+g7wfMEDSQZIGkH1Wb3+yINuniZ+61cOfR2NNoU3+tbyQjWhuA3YE3oiIaXn7fkBfYEr2rb1sBkwF+gCvRcQrAJLuBM6uYR+Hkn0mMhHxIfC2pO3W63NEfns2v78VWfC0Be6t/vK2/EvdrIwcNNYU3o+IfqUNeZisKm0CfhsRI9fr1w9oqou5BFwTETett48LmnAf1gieOlm5TAM+I2lXyL78Lv9A9TnAv5S80XBkLY9/HDgvf+ymkrYG3iUbrVR7FDij5NhPF0kdgN8Dn5PURlJbsmmalZGDxsoiIhYDpwG/lDSbLHj6RMRqsqnS5Pxg8Bu1bOKrwCGSngdmALtHxFKyqdgLkn4QEY8BE4Cpeb9fAW0jYibZe4xmAb8mm95ZGfktCGaWnEc0Zpacg8bMknPQmFlyDhozS85BY2bJOWjMLDkHjZkl9/8GhLfeows3HwAAAABJRU5ErkJggg==\n", 335 | "text/plain": [ 336 | "
" 337 | ] 338 | }, 339 | "metadata": { 340 | "needs_background": "light" 341 | }, 342 | "output_type": "display_data" 343 | } 344 | ], 345 | "source": [ 346 | "interp.plot_confusion_matrix()" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "With tabular problems, looking at each *individual* row will probably not help us much. Instead what **Class Confusion** will do is plot every variable at whatever combination we dictate, and we can see how the distribution of those variables in our misses changed in relative to our overall dataset distribution. For example, let's explore `>=50k` and `<50k`" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "ClassConfusion(interp, ['>=50k', '<50k'], figsize=(12,12))" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": {}, 368 | "source": [ 369 | "![](https://i.imgur.com/iUUSp2A.png)" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": {}, 375 | "source": [ 376 | "Now we can see the distributions for each of those two missed boxes in our confusion matrix, and look at what is really going on there. If we look at education, we can see that for many times where we thought people were making above for below 50k, they were often graduates of high school and persuing some college degree. \n", 377 | "\n" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "We can also look at the distribution for continuous variables as well. Shown below is `age`:" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "ClassConfusion(interp, ['>=50k', '<50k'], figsize=(12,12))" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "![](https://i.imgur.com/jMiTb3y.png)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "If we want to look at specific variables, we pass them into `varlist`. Below is `age`, `education`, and `relationship`:" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "ClassConfusion(interp, ['>=50k', '<50k'], varlist=['age', 'education', 'relationship'],\n", 417 | " figsize=(12,12))" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "![](https://i.imgur.com/ZIqwljr.png)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "metadata": {}, 430 | "source": [ 431 | "We can plot the distribution for our true positives as well, if we want to compare those by using histograms:" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "ClassConfusion(interp, [['>=50k', '>=50k'], ['>=50k', '<50k']], varlist=['age', 'education', 'relationship'],\n", 441 | " is_ordered=True, figsize=(12,12))" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "![](https://i.imgur.com/xNUUPz0.png)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [] 457 | } 458 | ], 459 | "metadata": { 460 | "kernelspec": { 461 | "display_name": "Python 3", 462 | "language": "python", 463 | "name": "python3" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.6.7" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 2 480 | } 481 | --------------------------------------------------------------------------------