├── README.md ├── caipi-draw-weights.py ├── caipi-draw.py ├── caipi.py ├── caipi ├── __init__.py ├── image.py ├── learners.py ├── problem.py ├── tabular.py ├── text.py └── utils.py ├── colors-draw-all.sh ├── data ├── fer2013.py └── toy_colors.npz ├── form ├── caipi.csv └── questionnaire.pdf ├── matplotlibrc ├── prepare-newsgroups.py ├── prepare-reviews.py ├── results ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle ├── colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle ├── colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle ├── colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle ├── colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle ├── newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle ├── newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle ├── newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle └── newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle ├── run-caipi-color.sh ├── run-caipi-toy.sh ├── run-caipi-ttt.sh └── versus-rrr.py /README.md: -------------------------------------------------------------------------------- 1 | Caipi 2 | ===== 3 | 4 | An implementation of the CAIPI framework for interactive explanatory learning. 5 | 6 | 7 | Required Packages 8 | ----------------- 9 | 10 | Caipi is written in Python 3.5. Make sure that you have the following 11 | packages: 12 | 13 | - [numpy](https://www.numpy.org) 14 | - [sklearn](https://scikit-learn.org) 15 | - [lime](https://github.com/marcotcr/lime) 16 | - [blessings](https://pypi.python.org/pypi/blessings) 17 | - [nltk](http://www.nltk.org/) for the 20 newsgroups task 18 | - [skimage](http://scikit-image.org/) for the image classification task 19 | 20 | 21 | Usage 22 | ----- 23 | 24 | You can run CAIPI as follows: 25 | ```bash 26 | python3 caipi.py $problem $learner $example-selection-strategy 27 | ``` 28 | For the complete list of options, type: 29 | ```bash 30 | python3 caipi.py --help 31 | ``` 32 | -------------------------------------------------------------------------------- /caipi-draw-weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from caipi import load 7 | 8 | 9 | def nrm(w): 10 | return (w - np.min(w)) / (np.max(w) - np.min(w) + 1e-13) 11 | 12 | 13 | def nstd(x): 14 | return np.std(x, axis=0) / np.sqrt(x.shape[0]) 15 | 16 | 17 | plt.style.use('ggplot') 18 | 19 | data = load(sys.argv[2]) 20 | data = np.array(data) 21 | if data.ndim == 3: 22 | data = data.reshape((10, 101, 1, 300)) 23 | n_folds, n_iters, n_classes, n_features = data.shape 24 | 25 | if n_features == 300: 26 | 27 | RULE0_COORDS = {(0, 4), (0, 20), (0, 24), (4, 20), (4, 24), (20, 24)} 28 | RULE0_BASIS = np.array([1.0 if (i, j) in RULE0_COORDS else 0.0 29 | for i in range(5*5) 30 | for j in range(i+1, 5*5)]) 31 | 32 | RULE1_COORDS = {(1, 2), (1, 3), (2, 3)} 33 | RULE1_BASIS = np.array([-1.0 if (i, j) in RULE1_COORDS else 0.0 34 | for i in range(5*5) 35 | for j in range(i+1, 5*5)]) 36 | 37 | DICTIONARY = np.array([RULE0_BASIS, RULE1_BASIS]).T 38 | 39 | results = [] 40 | for k in range(n_folds): 41 | weights = data[k, :, 0, :] 42 | # (n_iters, n_features) 43 | alpha, residuals, _, _ = np.linalg.lstsq(DICTIONARY, weights.T, 44 | rcond=None) 45 | # (n_rules, n_iters), (n_iters,) 46 | result = np.vstack((alpha, residuals.T)) 47 | # (n_rules + 1, n_iters) 48 | results.append(result) 49 | 50 | results = np.array(results) 51 | # (n_folds, n_rules + 1, n_iters) 52 | 53 | def plot_both(ax, results, what, label, linestyle='-'): 54 | temp = results[:, what, :].reshape((n_folds, -1)) 55 | x = np.arange(temp.shape[-1]) 56 | y, yerr = np.mean(temp, axis=0), nstd(temp) 57 | ax.plot(x, y, linewidth=2, label=label, linestyle=linestyle) 58 | ax.fill_between(x, y - yerr, y + yerr, alpha=0.35, linewidth=0) 59 | 60 | fig, ax = plt.subplots(1, 1) 61 | ax.set_title('Rule Coefficients and Residual', fontsize=16) 62 | ax.set_xlabel('Iteration', fontsize=16) 63 | ax.tick_params(axis='both', which='major', labelsize=16) 64 | plot_both(ax, results, 0, 'Coeff. Rule 0') 65 | plot_both(ax, results, 1, 'Coeff. Rule 1') 66 | plot_both(ax, results, 2, 'Residual', linestyle=':') 67 | 68 | legend = ax.legend(loc='upper right', 69 | fontsize=16, 70 | shadow=False) 71 | 72 | fig.savefig(sys.argv[1] + '__coeff', bbox_inches='tight', pad_inches=0) 73 | 74 | fig = plt.figure(figsize=(30, 100)) 75 | ax = fig.add_subplot(111) 76 | ax.set_axis_off() 77 | 78 | matrix = data.reshape((n_folds, n_iters, n_classes * n_features)) 79 | matrix = matrix.mean(axis=0) 80 | 81 | ax.matshow(nrm(matrix), cmap=plt.get_cmap('gray')) 82 | 83 | fig.savefig(sys.argv[1] + '__weights', bbox_inches='tight', pad_inches=0) 84 | -------------------------------------------------------------------------------- /caipi-draw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from caipi import load 6 | 7 | 8 | class Tango: 9 | # From light to dark 10 | YELLOW = ("#fce94f", "#edd400", "#c4a000") 11 | ORANGE = ("#fcaf3e", "#f57900", "#ce5c00") 12 | BROWN = ("#e9b96e", "#c17d11", "#8f5902") 13 | GREEN = ("#8ae234", "#73d216", "#4e9a06") 14 | BLUE = ("#729fcf", "#3465a4", "#204a87") 15 | VIOLET = ("#ad7fa8", "#75507b", "#5c3566") 16 | RED = ("#ef2929", "#cc0000", "#a40000") 17 | WHITE = ("#eeeeec", "#d3d7cf", "#babdb6") 18 | BLACK = ("#888a85", "#555753", "#2e3436") 19 | 20 | 21 | def get_style(args): 22 | 23 | label = { 24 | 'svm': 'SVM', 25 | 'l1svm': 'L1 SVM', 26 | 'lr': 'LR', 27 | }[args.learner] 28 | 29 | if args.start_expl_at >= 0: 30 | label += ' + Corr.' 31 | 32 | base_color = { 33 | 'svm': Tango.RED, 34 | 'l1svm': Tango.VIOLET, 35 | 'lr': Tango.GREEN, 36 | }[args.learner] 37 | 38 | shade = 0 if args.start_expl_at >= 0 else 2 39 | color = base_color[shade] 40 | 41 | style, marker = { 42 | True: ('-', 's'), 43 | False: (':', '*'), 44 | }[args.start_expl_at >= 0] 45 | 46 | return label, color, style, marker 47 | 48 | 49 | def draw(args): 50 | plt.style.use('ggplot') 51 | np.set_printoptions(precision=2, linewidth=80, threshold=np.inf) 52 | 53 | pickle_data, instant_data, pickle_args = [], [], [] 54 | for path in args.pickles: 55 | data = load(path) 56 | pickle_data.append(np.array(data['perfs'])) 57 | try: 58 | instant_data.append(np.array(data['instant_perfs'])) 59 | except KeyError: 60 | pass 61 | pickle_args.append(data['args']) 62 | 63 | min_folds = min(list(len(datum) for datum in pickle_data)) 64 | print('# folds =', min_folds) 65 | min_measures = min(datum.shape[-1] for datum in pickle_data) 66 | perfs = np.array([datum[:min_folds,:,:min_measures] for datum in pickle_data]) 67 | instant_perfs = np.array([datum[:min_folds] for datum in instant_data]) 68 | max_iters = pickle_args[0].max_iters 69 | 70 | # perfs have shape: [n_pickles, n_folds, n_iters, n_measures] 71 | if perfs.shape[-1] == 3: 72 | to_title = [ 73 | 'Predictive F1', 74 | 'Confictive Rec.', 75 | '# Corrections', 76 | ] 77 | else: 78 | to_title = [ 79 | 'Predictive Pr', 'Predictive Rc', 'Predictive F1', 80 | 'Explanatory Pr', 'Explanatory Rc', 'Explanatory F1', 81 | '# Corrections', 82 | ] 83 | 84 | to_title_inst = [ 85 | 'Inst. Predictive Pr', 'Inst. Predictive Rc', 'Inst. Predictive F1', 86 | 'Inst. Explanatory Pr', 'Inst. Explanatory Rc', 'Inst. Explanatory F1', 87 | ] 88 | 89 | for i_measure in range(perfs.shape[-1]): 90 | 91 | #print(to_title[i_measure]) 92 | #print(perfs[:, :, :, i_measure]) 93 | 94 | fig, ax = plt.subplots(1, 1) 95 | ax.set_title(to_title[i_measure], fontsize=16) 96 | ax.set_xlabel('Iterations', fontsize=16) 97 | ax.tick_params(axis='both', which='major', labelsize=16) 98 | if to_title[i_measure].startswith('Predictive'): 99 | ax.set_ylim(args.min_pred_val, args.max_pred_val) 100 | else: 101 | n_ticks = len(ax.get_xticklabels()) 102 | eval_iters = max_iters // n_ticks 103 | labels = list(range(0, max_iters, eval_iters)) 104 | ax.set_xticklabels(['dunno'] + [str(l) for l in labels]) 105 | 106 | for i_pickle in range(perfs.shape[0]): 107 | perf = perfs[i_pickle, :, :, i_measure] 108 | 109 | y = np.mean(perf, axis=0) 110 | yerr = np.std(perf, axis=0) / np.sqrt(perf.shape[0]) 111 | if -1 in y: 112 | yerr = yerr[y != -1] 113 | y = y[y != -1] 114 | x = np.arange(y.shape[0]) 115 | 116 | label, color, style, marker = get_style(pickle_args[i_pickle]) 117 | ax.plot(x, y, label=label, color=color, 118 | linestyle=style, linewidth=2) 119 | ax.fill_between(x, y - yerr, y + yerr, color=color, 120 | alpha=0.35, linewidth=0) 121 | 122 | legend = ax.legend(loc='lower right', 123 | fontsize=16, 124 | shadow=False) 125 | 126 | fig.savefig(args.basename + '_{}.png'.format(i_measure), 127 | bbox_inches='tight', pad_inches=0) 128 | 129 | if not len(instant_perfs): 130 | print('Your pickle file does not have instant perfs, skipped.') 131 | return 132 | 133 | for i_measure in range(instant_perfs.shape[-1]): 134 | 135 | #print(to_title[i_measure]) 136 | #print(perfs[:, :, :, i_measure]) 137 | 138 | fig, ax = plt.subplots(1, 1) 139 | ax.set_title(to_title_inst[i_measure], fontsize=16) 140 | ax.set_xlabel('Iterations', fontsize=16) 141 | ax.tick_params(axis='both', which='major', labelsize=16) 142 | if to_title[i_measure].startswith('Predictive'): 143 | ax.set_ylim(args.min_inst_pred_val, args.max_inst_pred_val) 144 | elif eval_iters > 0: 145 | pass 146 | 147 | for i_pickle in range(instant_perfs.shape[0]): 148 | perf = instant_perfs[i_pickle, :, :, i_measure] 149 | 150 | y = np.cumsum(perf, axis=1) 151 | y = np.mean(y, axis=0) 152 | y *= 1 / np.arange(1, len(y) + 1) 153 | yerr = np.std(y, axis=0) / np.sqrt(y.shape[0]) 154 | if -1 in y: 155 | yerr = yerr[y != -1] 156 | y = y[y != -1] 157 | x = np.arange(y.shape[0]) 158 | 159 | label, color, style, marker = get_style(pickle_args[i_pickle]) 160 | ax.plot(x, y, label=label, color=color, 161 | linestyle=style, linewidth=2) 162 | ax.fill_between(x, y - yerr, y + yerr, color=color, 163 | alpha=0.35, linewidth=0) 164 | 165 | legend = ax.legend(loc='lower right', 166 | fontsize=16, 167 | shadow=False) 168 | 169 | fig.savefig(args.basename + '_instant_{}.png'.format(i_measure), 170 | bbox_inches='tight', pad_inches=0) 171 | 172 | 173 | if __name__ == '__main__': 174 | import argparse 175 | 176 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 177 | parser.add_argument('basename', type=str, 178 | help='basename of the loss/time PNG plots') 179 | parser.add_argument('pickles', type=str, nargs='+', 180 | help='comma-separated list of pickled results') 181 | parser.add_argument('--min-pred-val', type=float, default=0, 182 | help='minimum pred. score') 183 | parser.add_argument('--max-pred-val', type=float, default=1.05, 184 | help='minimum pred. score') 185 | parser.add_argument('--min-inst-pred-val', type=float, default=0, 186 | help='minimum instantaneous pred. score') 187 | parser.add_argument('--max-inst-pred-val', type=float, default=1.05, 188 | help='minimum instantaneous pred. score') 189 | parser.add_argument('--legend', action='store_true', 190 | help='whether to draw the legend') 191 | args = parser.parse_args() 192 | 193 | draw(args) 194 | -------------------------------------------------------------------------------- /caipi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | from sklearn.utils import check_random_state 5 | from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit 6 | from os.path import join 7 | 8 | from caipi import * 9 | 10 | 11 | PROBLEMS = { 12 | 'toy-fst': lambda *args, **kwargs: \ 13 | ToyProblem(*args, rule='fst', **kwargs), 14 | 'toy-lst': lambda *args, **kwargs: \ 15 | ToyProblem(*args, rule='lst', **kwargs), 16 | 'colors-rule0': lambda *args, **kwargs: \ 17 | ColorsProblem(*args, rule=0, **kwargs), 18 | 'colors-rule1': lambda *args, **kwargs: \ 19 | ColorsProblem(*args, rule=1, **kwargs), 20 | 'reviews': ReviewsProblem, 21 | 'newsgroups': NewsgroupsProblem, 22 | } 23 | 24 | 25 | LEARNERS = { 26 | 'lr': lambda *args, **kwargs: \ 27 | LinearLearner(*args, model='lr', **kwargs), 28 | 'svm': lambda *args, **kwargs: \ 29 | LinearLearner(*args, model='svm', **kwargs), 30 | 'l1svm': lambda *args, **kwargs: \ 31 | LinearLearner(*args, model='l1svm', **kwargs), 32 | 'elastic': lambda *args, **kwargs: \ 33 | LinearLearner(*args, model='elastic', **kwargs), 34 | } 35 | 36 | 37 | def _get_basename(args): 38 | basename = '__'.join([args.problem, args.learner, args.strategy]) 39 | fields = [ 40 | ('k', args.n_folds), 41 | ('n', args.n_examples), 42 | ('p', args.prop_known), 43 | ('P', args.prop_eval), 44 | ('T', args.max_iters), 45 | ('e', args.eval_iters), 46 | ('E', args.start_expl_at), 47 | ('C', args.corr_type), 48 | ('F', args.n_features), 49 | ('S', args.n_samples), 50 | ('K', args.kernel_width), 51 | ('R', args.lime_repeats), 52 | ('V', args.vectorizer), 53 | ('s', args.seed), 54 | ] 55 | basename += '__' + '__'.join([name + '=' + str(value) 56 | for name, value in fields]) 57 | return join('results', basename) 58 | 59 | 60 | def _subsample(problem, examples, prop, rng=None): 61 | rng = check_random_state(rng) 62 | 63 | classes = sorted(set(problem.y)) 64 | if 0 <= prop <= 1: 65 | n_sampled = int(round(len(examples) * prop)) 66 | n_sampled_per_class = max(n_sampled // len(classes), 3) 67 | else: 68 | n_sampled_per_class = max(int(prop), 3) 69 | 70 | sample = [] 71 | for y in classes: 72 | examples_y = np.array([i for i in examples if problem.y[i] == y]) 73 | pi = rng.permutation(len(examples_y)) 74 | sample.extend(examples_y[pi[:n_sampled_per_class]]) 75 | 76 | return list(sample) 77 | 78 | 79 | def eval_passive(problem, args, rng=None): 80 | """Useful for checking the based performance of the learner and whether 81 | the explanations are stable.""" 82 | 83 | rng = check_random_state(rng) 84 | basename = _get_basename(args) 85 | 86 | folds = StratifiedShuffleSplit(n_splits=args.n_folds, random_state=0) \ 87 | .split(problem.y, problem.y) 88 | train_examples, test_examples = list(folds)[0] 89 | eval_examples = _subsample(problem, test_examples, 90 | args.prop_eval, rng=0) 91 | print('#train={} #test={} #eval={}'.format( 92 | len(train_examples), len(test_examples), len(eval_examples))) 93 | 94 | print(' #explainable in train', len(set(train_examples) & problem.explainable)) 95 | print(' #explainable in eval', len(set(eval_examples) & problem.explainable)) 96 | 97 | learner = LEARNERS[args.learner](problem, strategy=args.strategy, rng=0) 98 | 99 | learner.fit(problem.X[train_examples], 100 | problem.y[train_examples]) 101 | train_params = learner.get_params() 102 | 103 | print('Computing full-train performance...') 104 | perf = problem.eval(learner, train_examples, 105 | test_examples, eval_examples, 106 | t='train', basename=basename) 107 | print('perf on full training set =', perf) 108 | 109 | print('Checking LIME stability...') 110 | perf = problem.eval(learner, train_examples, 111 | test_examples, eval_examples, 112 | t='train2', basename=basename) 113 | print('perf on full training set =', perf) 114 | 115 | print('Computing corrections for {} examples...'.format(len(train_examples))) 116 | X_test_tuples = {tuple(densify(problem.X[i]).ravel()) 117 | for i in test_examples} 118 | 119 | all_corrections = set() 120 | for j, i in enumerate(train_examples): 121 | print(' correcting {:3d} / {:3d}'.format(j + 1, len(train_examples))) 122 | x = densify(problem.X[i]) 123 | pred_y = learner.predict(x)[0] 124 | pred_expl = problem.explain(learner, train_examples, i, pred_y) 125 | corrections = problem.query_corrections(i, pred_y, pred_expl, 126 | X_test_tuples) 127 | all_corrections.update(corrections) 128 | 129 | print('all_corrections =', all_corrections) 130 | 131 | print('Computing corrected train performance...') 132 | train_corr_examples = list(sorted(set(train_examples) | all_corrections)) 133 | learner.fit(problem.X[train_corr_examples], 134 | problem.y[train_corr_examples]) 135 | train_corr_params = learner.get_params() 136 | perf = problem.eval(learner, train_examples, 137 | test_examples, eval_examples, 138 | t='train+corr', basename=basename) 139 | print('perf on corrected set =', perf) 140 | 141 | print('w_train :\n', train_params) 142 | print('w_{train+corr} :\n', train_corr_params) 143 | 144 | dump(basename + '_passive_models.pickle', { 145 | 'w_train': train_params, 146 | 'w_both': train_corr_params 147 | }) 148 | 149 | 150 | def caipi(problem, 151 | learner, 152 | train_examples, 153 | known_examples, 154 | test_examples, 155 | eval_examples, 156 | max_iters=100, 157 | start_expl_at=-1, 158 | eval_iters=10, 159 | basename=None, 160 | rng=None): 161 | rng = check_random_state(rng) 162 | 163 | print('CAIPI T={} #train={} #known={} #test={} #eval={}'.format( 164 | max_iters, 165 | len(train_examples), len(known_examples), 166 | len(test_examples), len(eval_examples))) 167 | print(' #explainable in train', len(set(train_examples) & problem.explainable)) 168 | print(' #explainable in eval', len(set(eval_examples) & problem.explainable)) 169 | 170 | X_test_tuples = {tuple(densify(problem.X[i]).ravel()) 171 | for i in test_examples} 172 | 173 | #learner.select_model(problem.X[train_examples], 174 | # problem.y[train_examples]) 175 | #learner.fit(problem.X[train_examples], 176 | # problem.y[train_examples]) 177 | #perf = problem.eval(learner, 178 | # train_examples, 179 | # test_examples, 180 | # eval_examples, 181 | # t='train', 182 | # basename=basename) 183 | #params = np.round(learner.get_params(), decimals=1) 184 | #print('train model = {params}, perfs = {perf}'.format(**locals())) 185 | 186 | #learner.select_model(problem.X[known_examples], 187 | # problem.y[known_examples]) 188 | learner.fit(problem.X[known_examples], 189 | problem.y[known_examples]) 190 | 191 | corrections = set() 192 | perfs, instant_perfs, params = [], [], [] 193 | for t in range(max_iters): 194 | 195 | if len(known_examples) >= len(train_examples): 196 | break 197 | 198 | unknown_examples = set(train_examples) - set(known_examples) 199 | i = learner.select_query(problem, unknown_examples & problem.explainable) 200 | assert i in train_examples and i not in known_examples 201 | x = densify(problem.X[i]) 202 | 203 | explain = 0 <= start_expl_at <= t 204 | 205 | pred_y = learner.predict(x)[0] 206 | pred_expl = problem.explain(learner, known_examples, i, pred_y) \ 207 | if explain else None 208 | 209 | print('evaluating on query...') 210 | instant_perf = problem.eval(learner, 211 | known_examples, 212 | [i], 213 | [i], 214 | t=t, 215 | basename=basename + '_instant') 216 | instant_perfs.append(instant_perf) 217 | 218 | true_y = problem.query_label(i) 219 | known_examples.append(i) 220 | 221 | if explain: 222 | new_corrections = problem.query_corrections(i, pred_y, pred_expl, 223 | X_test_tuples) 224 | corrections.update(new_corrections) 225 | 226 | learner.fit(problem.X[known_examples + list(corrections)], 227 | problem.y[known_examples + list(corrections)]) 228 | params.append(learner.get_params()) 229 | 230 | do_eval = eval_iters > 0 and t % eval_iters == 0 231 | 232 | print('evaluating on test|eval...') 233 | perf = problem.eval(learner, 234 | train_examples, 235 | test_examples, 236 | eval_examples if do_eval else None, 237 | t=t, basename=basename) 238 | perf = tuple(list(perf) + list([len(corrections)])) 239 | perfs.append(perf) 240 | 241 | params_for_print = np.round(learner.get_params(), decimals=1) 242 | print('{t:3d} : model = {params_for_print}, perfs on query = {instant_perf}, perfs on test = {perf}'.format(**locals())) 243 | 244 | return perfs, instant_perfs, params 245 | 246 | 247 | def eval_interactive(problem, args, rng=None): 248 | """The main evaluation loop.""" 249 | 250 | rng = check_random_state(args.seed) 251 | basename = _get_basename(args) 252 | 253 | folds = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=0) \ 254 | .split(problem.y, problem.y) 255 | 256 | perfs, instant_perfs, params = [], [], [] 257 | for k, (train_examples, test_examples) in enumerate(folds): 258 | print() 259 | print(80 * '=') 260 | print('Running fold {}/{}'.format(k + 1, args.n_folds)) 261 | print(80 * '=') 262 | 263 | train_examples = list(train_examples) 264 | known_examples = _subsample(problem, train_examples, 265 | args.prop_known, rng=0) 266 | test_examples = list(test_examples) 267 | eval_examples = _subsample(problem, test_examples, 268 | args.prop_eval, rng=0) 269 | 270 | learner = LEARNERS[args.learner](problem, strategy=args.strategy, rng=0) 271 | 272 | perf, instant_perf, param = \ 273 | caipi(problem, 274 | learner, 275 | train_examples, 276 | known_examples, 277 | test_examples, 278 | eval_examples, 279 | max_iters=args.max_iters, 280 | start_expl_at=args.start_expl_at, 281 | eval_iters=args.eval_iters, 282 | basename=basename + '_fold={}'.format(k), 283 | rng=rng) 284 | perfs.append(perf) 285 | instant_perfs.append(instant_perf) 286 | params.append(param) 287 | 288 | dump(basename + '.pickle', 289 | {'args': args, 'perfs': perfs, 'instant_perfs': instant_perfs}) 290 | dump(basename + '-params.pickle', params) 291 | 292 | 293 | def main(): 294 | import argparse 295 | 296 | fmt_class = argparse.ArgumentDefaultsHelpFormatter 297 | parser = argparse.ArgumentParser(formatter_class=fmt_class) 298 | parser.add_argument('problem', choices=sorted(PROBLEMS.keys()), 299 | help='name of the problem') 300 | parser.add_argument('learner', choices=sorted(LEARNERS.keys()), 301 | default='svm', help='Active learner to use') 302 | parser.add_argument('strategy', type=str, default='random', 303 | help='Query selection strategy to use') 304 | parser.add_argument('-s', '--seed', type=int, default=0, 305 | help='RNG seed') 306 | 307 | group = parser.add_argument_group('Evaluation') 308 | group.add_argument('-k', '--n-folds', type=int, default=10, 309 | help='Number of cross-validation folds') 310 | group.add_argument('-n', '--n-examples', type=int, default=None, 311 | help='Restrict dataset to this many examples') 312 | group.add_argument('-p', '--prop-known', type=float, default=0.1, 313 | help='Proportion of initial labelled examples') 314 | group.add_argument('-P', '--prop-eval', type=float, default=0.1, 315 | help='Proportion of the test set to evaluate the ' 316 | 'explanations on') 317 | group.add_argument('-T', '--max-iters', type=int, default=100, 318 | help='Maximum number of learning iterations') 319 | group.add_argument('-e', '--eval-iters', type=int, default=10, 320 | help='Interval for evaluating performance on the ' 321 | 'evaluation set') 322 | group.add_argument('--passive', action='store_true', 323 | help='DEBUG: eval perfs using passive learning') 324 | 325 | group = parser.add_argument_group('Interaction') 326 | group.add_argument('-E', '--start-expl-at', type=int, default=-1, 327 | help='Iteration at which corrections kick in') 328 | group.add_argument('-C', '--corr-type', type=str, default=None, 329 | help='Type of correction feedback to use') 330 | group.add_argument('-F', '--n-features', type=int, default=10, 331 | help='Number of LIME features to present the user') 332 | group.add_argument('-S', '--n-samples', type=int, default=5000, 333 | help='Size of the LIME sampled dataset') 334 | group.add_argument('-K', '--kernel-width', type=float, default=0.75, 335 | help='LIME kernel width') 336 | group.add_argument('-R', '--lime-repeats', type=int, default=1, 337 | help='Number of times to re-run LIME') 338 | 339 | group = parser.add_argument_group('Text') 340 | group.add_argument('--vectorizer', type=str, default=None, 341 | help='Text vectorizer to use') 342 | args = parser.parse_args() 343 | 344 | np.seterr(all='raise') 345 | np.set_printoptions(precision=3, linewidth=80, threshold=np.inf) 346 | np.random.seed(args.seed) 347 | 348 | rng = np.random.RandomState(args.seed) 349 | 350 | print('Creating problem...') 351 | problem = PROBLEMS[args.problem](n_examples=args.n_examples, 352 | corr_type=args.corr_type, 353 | n_samples=args.n_samples, 354 | n_features=args.n_features, 355 | kernel_width=args.kernel_width, 356 | lime_repeats=args.lime_repeats, 357 | vect_type=args.vectorizer, 358 | rng=rng) 359 | 360 | if args.passive: 361 | print('Evaluating passive learning...') 362 | eval_passive(problem, args, rng=rng) 363 | else: 364 | print('Evaluating interactive learning...') 365 | eval_interactive(problem, args, rng=rng) 366 | 367 | if __name__ == '__main__': 368 | main() 369 | -------------------------------------------------------------------------------- /caipi/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from .utils import * 3 | 4 | from . import learners 5 | from .learners import * 6 | 7 | from . import problem 8 | from .problem import * 9 | 10 | from . import tabular 11 | from .tabular import * 12 | 13 | from . import text 14 | from .text import * 15 | 16 | from . import image 17 | from .image import * 18 | -------------------------------------------------------------------------------- /caipi/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import gzip 4 | from os.path import join 5 | from matplotlib.cm import get_cmap 6 | from itertools import product 7 | from skimage.color import gray2rgb, rgb2gray 8 | from sklearn.linear_model import Ridge 9 | from sklearn.utils import check_random_state 10 | from lime.lime_image import LimeImageExplainer 11 | from lime.wrappers.scikit_image import SegmentationAlgorithm 12 | 13 | from . import Problem, PipeStep, densify, vstack, hstack 14 | 15 | 16 | class ImageProblem(Problem): 17 | def __init__(self, **kwargs): 18 | labels = kwargs.pop('labels') 19 | images = kwargs.pop('images') 20 | self.class_names = kwargs.pop('class_names') 21 | n_examples = kwargs.pop('n_examples', None) 22 | self.lime_repeats = kwargs.pop('lime_repeats', 1) 23 | 24 | if n_examples is not None: 25 | rng = check_random_state(kwargs.get('rng', None)) 26 | perm = rng.permutation(len(labels))[:n_examples] 27 | images, labels = images[perm], labels[perm] 28 | 29 | self.y = labels 30 | self.images = self._add_confounders(images) 31 | self.X = np.stack([gray2rgb(image) for image in self.images], 0) 32 | 33 | self.explainable = set(range(len(self.y))) 34 | 35 | super().__init__(**kwargs) 36 | 37 | def _add_confounders(self, images): 38 | noisy_images = [] 39 | for image, label in zip(images, self.y): 40 | confounder = self._y_to_confounder(image, label) 41 | noisy_images.append(np.maximum(image, confounder)) 42 | return np.array(noisy_images, dtype=np.uint8) 43 | 44 | def _y_to_confounder(self, image, label): 45 | dd = image.shape[-1] // len(self.class_names) 46 | ys, xs = range(label * dd, (label + 1) * dd), range(dd) 47 | mask = np.zeros_like(image) 48 | mask[np.ix_(ys, xs)] = 255 49 | return mask 50 | 51 | def preproc(self, images): 52 | return np.array([rgb2gray(image).ravel() for image in images]) 53 | 54 | def explain(self, learner, known_examples, i, pred_y, return_segments=False): 55 | explainer = LimeImageExplainer(verbose=False) 56 | 57 | local_model = Ridge(alpha=1, fit_intercept=True, random_state=0) 58 | # NOTE we *oversegment* the image on purpose! 59 | segmenter = SegmentationAlgorithm('quickshift', 60 | kernel_size=1, 61 | max_dist=4, 62 | ratio=0.1, 63 | sigma=0, 64 | random_seed=0) 65 | expl = explainer.explain_instance(self.X[i], 66 | top_labels=len(self.class_names), 67 | classifier_fn=learner.predict_proba, 68 | segmentation_fn=segmenter, 69 | model_regressor=local_model, 70 | num_samples=self.n_samples, 71 | num_features=self.n_features, 72 | batch_size=1, 73 | hide_color=False) 74 | #print(expl.top_labels) 75 | _, mask = expl.get_image_and_mask(pred_y, 76 | positive_only=False, 77 | num_features=self.n_features, 78 | min_weight=0.01, 79 | hide_rest=False) 80 | if return_segments: 81 | return mask, expl.segments 82 | return mask 83 | 84 | def query_label(self, i): 85 | return self.y[i] 86 | 87 | @staticmethod 88 | def _extract_coords(image, mask): 89 | return {(r, c) 90 | for r in range(image.shape[0]) 91 | for c in range(image.shape[1]) 92 | if mask[r, c] != 0} 93 | 94 | def query_corrections(self, i, pred_y, pred_mask, X_test): 95 | true_y = self.y[i] 96 | if pred_mask is None: 97 | return set() 98 | if pred_y != true_y: 99 | return set() 100 | if i not in self.explainable: 101 | return set() 102 | 103 | image = self.images[i] 104 | conf_mask = self._y_to_confounder(image, self.y[i]) 105 | conf_mask[conf_mask == 255] = 2 106 | 107 | conf_coords = self._extract_coords(image, conf_mask) 108 | pred_coords = self._extract_coords(image, pred_mask) 109 | fp_coords = conf_coords & pred_coords 110 | 111 | X_corrections = [] 112 | for value in [-10, 0, 11]: 113 | corr_image = np.array(image, copy=True) 114 | for r, c in fp_coords: 115 | print('correcting pixel {},{} for label {}'.format( 116 | r, c, true_y)) 117 | corr_image[r, c] = value 118 | X_corrections.append(gray2rgb(corr_image)) 119 | n_corrections = len(X_corrections) 120 | 121 | if not n_corrections: 122 | return set() 123 | 124 | X_corrections = np.array(X_corrections) 125 | y_corrections = np.array([pred_y] * n_corrections, dtype=np.int8) 126 | extra_examples = set(range(self.X.shape[0], 127 | self.X.shape[0] + n_corrections)) 128 | 129 | self.X = vstack([self.X, X_corrections]) 130 | self.y = hstack([self.y, y_corrections]) 131 | 132 | return extra_examples 133 | 134 | def _eval_expl(self, learner, known_examples, eval_examples, 135 | t=None, basename=None): 136 | if eval_examples is None: 137 | return -1, 138 | 139 | perfs = [] 140 | for i in set(eval_examples) & self.explainable: 141 | true_y = self.y[i] 142 | pred_y = learner.predict(densify(self.X[i]))[0] 143 | 144 | image = self.images[i] 145 | conf_mask = self._y_to_confounder(image, true_y) 146 | conf_mask[conf_mask == 255] = 2 147 | 148 | pred_mask, segments = \ 149 | self.explain(learner, known_examples, i, pred_y, 150 | return_segments=True) 151 | 152 | # Compute confounder recall 153 | conf_coords = self._extract_coords(image, conf_mask) 154 | pred_coords = self._extract_coords(image, pred_mask) 155 | perfs.append(len(conf_coords & pred_coords) / len(conf_coords)) 156 | 157 | if basename is None: 158 | continue 159 | 160 | self.save_expl(basename + '_{}_true.png'.format(i), 161 | i, true_y, mask=conf_mask) 162 | self.save_expl(basename + '_{}_{}_expl.png'.format(i, t), 163 | i, pred_y, mask=pred_mask) 164 | 165 | return np.mean(perfs, axis=0), 166 | 167 | def eval(self, learner, known_examples, test_examples, eval_examples, 168 | t=None, basename=None): 169 | pred_perfs = learner.score(self.X[test_examples], 170 | self.y[test_examples]), 171 | expl_perfs = self._eval_expl(learner, 172 | known_examples, 173 | eval_examples, 174 | t=t, basename=basename) 175 | return tuple(pred_perfs) + tuple(expl_perfs) 176 | 177 | def save_expl(self, path, i, y, mask=None, segments=None): 178 | fig, ax = plt.subplots(1, 1) 179 | ax.set_aspect('equal') 180 | ax.text(0.5, 1.05, 181 | 'true = {} | this = {}'.format(self.y[i], y), 182 | horizontalalignment='center', 183 | transform=ax.transAxes) 184 | 185 | cmap = get_cmap('tab20') 186 | 187 | r, c = self.images[i].shape 188 | if mask is not None: 189 | image = np.zeros((r, c, 3)) 190 | for r, c in product(range(r), range(c)): 191 | image[r, c] = cmap((mask[r, c] & 3) / 3)[:3] 192 | elif segments is not None: 193 | image = np.zeros((r, c, 3)) 194 | for r, c in product(range(r), range(c)): 195 | image[r, c] = cmap((segments[r, c] & 15) / 15)[:3] 196 | else: 197 | image = self.X[i] 198 | ax.imshow(image) 199 | 200 | fig.savefig(path, bbox_inches=0, pad_inches=0) 201 | plt.close(fig) 202 | 203 | 204 | def _load_mnist(path, kind='train'): 205 | """Load MNIST data from `path`""" 206 | labels_path = join(path, '{}-labels-idx1-ubyte.gz'.format(kind)) 207 | with gzip.open(labels_path, 'rb') as fp: 208 | labels = np.frombuffer(fp.read(), dtype=np.uint8, offset=8) 209 | 210 | images_path = join(path, '{}-images-idx3-ubyte.gz'.format(kind)) 211 | with gzip.open(images_path, 'rb') as fp: 212 | images = np.frombuffer(fp.read(), dtype=np.uint8, offset=16) 213 | 214 | return images.reshape(len(labels), 28, 28), labels 215 | 216 | 217 | class MNISTProblem(ImageProblem): 218 | def __init__(self, n_examples=None, **kwargs): 219 | path = join('data', 'mnist') 220 | tr_images, tr_labels = _load_mnist(path, kind='train') 221 | ts_images, ts_labels = _load_mnist(path, kind='t10k') 222 | images = np.vstack((tr_images, ts_images)) 223 | labels = np.hstack((tr_labels, ts_labels)) 224 | 225 | CLASS_NAMES = list(map(str, range(10))) 226 | 227 | super().__init__(images=images, 228 | labels=labels, 229 | class_names=CLASS_NAMES, 230 | n_examples=n_examples, 231 | **kwargs) 232 | 233 | 234 | class FashionProblem(ImageProblem): 235 | def __init__(self, n_examples=None, **kwargs): 236 | path = join('data', 'fashion') 237 | tr_images, tr_labels = _load_mnist(path, kind='train') 238 | ts_images, ts_labels = _load_mnist(path, kind='t10k') 239 | images = np.vstack((tr_images, ts_images)) 240 | labels = np.hstack((tr_labels, ts_labels)) 241 | 242 | CLASS_NAMES = [ 243 | 'T-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 244 | 'shirt', 'sneaker', 'bag', 'ankle_boots' 245 | ] 246 | 247 | super().__init__(images=images, 248 | labels=labels, 249 | class_names=CLASS_NAMES, 250 | n_examples=n_examples, 251 | **kwargs) 252 | -------------------------------------------------------------------------------- /caipi/learners.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.svm import LinearSVC 3 | from sklearn.linear_model import LogisticRegression, SGDClassifier 4 | from sklearn.calibration import CalibratedClassifierCV 5 | from sklearn.model_selection import StratifiedKFold 6 | from sklearn.utils import check_random_state 7 | 8 | from .utils import densify, vstack, hstack 9 | 10 | 11 | class ActiveLearner: 12 | def __init__(self, problem, rng=None): 13 | self.problem = problem 14 | self.rng = check_random_state(rng) 15 | 16 | def _check_preprocessed(self, X): 17 | if X.ndim != 2: 18 | return self.problem.preproc(X) 19 | return X 20 | 21 | def fit(self, X, y): 22 | X = self._check_preprocessed(X) 23 | self._decision_model.fit(X, y) 24 | if self._decision_model is not self._prob_model: 25 | self._prob_model.fit(X, y) 26 | 27 | def get_params(self): 28 | try: 29 | return np.array(self._decision_model.coef_, copy=True) 30 | except AttributeError: 31 | return None 32 | 33 | def decision_function(self, X): 34 | X = self._check_preprocessed(X) 35 | return self._decision_model.decision_function(X) 36 | 37 | def score(self, X, y): 38 | X = self._check_preprocessed(X) 39 | return self._decision_model.score(X, y) 40 | 41 | def predict(self, X): 42 | X = self._check_preprocessed(X) 43 | return self._decision_model.predict(X) 44 | 45 | def predict_proba(self, X): 46 | X = self._check_preprocessed(X) 47 | return self._prob_model.predict_proba(X) 48 | 49 | 50 | class LinearLearner(ActiveLearner): 51 | def __init__(self, *args, strategy='random', model='svm', C=None, 52 | sparse=False, **kwargs): 53 | 54 | super().__init__(*args, **kwargs) 55 | self.model = model 56 | 57 | pm = None 58 | if model == 'lr': 59 | # logistic regression 60 | dm = pm = LogisticRegression(C=C or 1000, 61 | penalty='l2', 62 | multi_class='ovr', 63 | fit_intercept=False, 64 | random_state=0) 65 | 66 | elif model == 'svm': 67 | # linear SVM (dense) 68 | dm = LinearSVC(C=C or 1000, 69 | penalty='l2', 70 | loss='hinge', 71 | multi_class='ovr', 72 | random_state=0) 73 | 74 | elif model == 'l1svm': 75 | # linear SVM (sparse) 76 | dm = LinearSVC(C=C or 1, 77 | penalty='l1', 78 | loss='squared_hinge', 79 | dual=False, 80 | multi_class='ovr', 81 | random_state=0) 82 | 83 | elif model == 'elastic': 84 | # elastic net (kinda sparse) 85 | dm = SGDClassifier(penalty='elasticnet', 86 | loss='hinge', 87 | l1_ratio=0.15, 88 | random_state=0) 89 | 90 | if pm is None: 91 | cv = StratifiedKFold(shuffle=True, random_state=0) 92 | pm = CalibratedClassifierCV(dm, method='sigmoid', cv=cv) 93 | 94 | self._decision_model = dm 95 | self._prob_model = pm 96 | 97 | self.select_query = { 98 | 'random': self._select_at_random, 99 | 'least-confident': self._select_least_confident, 100 | 'least-margin': self._select_least_margin, 101 | }[strategy] 102 | 103 | def _select_at_random(self, problem, examples): 104 | return self.rng.choice(sorted(examples)) 105 | 106 | def _select_least_confident(self, problem, examples): 107 | examples = sorted(examples) 108 | margins = np.abs(self.decision_function(problem.X[examples])) 109 | # NOTE margins has shape (n_examples,) or (n_examples, n_classes) 110 | if margins.ndim == 2: 111 | margins = margins.min(axis=1) 112 | return examples[np.argmin(margins)] 113 | 114 | def _select_least_margin(self, problem, examples): 115 | examples = sorted(examples) 116 | probs = self.predict_proba(problem.X[examples]) 117 | # NOTE probs has shape (n_examples, n_classes) 118 | diffs = np.zeros(probs.shape[0]) 119 | for i, prob in enumerate(probs): 120 | sorted_indices = np.argsort(prob) 121 | diffs[i] = prob[sorted_indices[-1]] - prob[sorted_indices[-2]] 122 | return examples[np.argmin(diffs)] 123 | -------------------------------------------------------------------------------- /caipi/problem.py: -------------------------------------------------------------------------------- 1 | from sklearn.utils import check_random_state 2 | 3 | 4 | class Problem: 5 | def __init__(self, 6 | n_samples=5000, 7 | n_features=10, 8 | kernel_width=1.0, 9 | metric='euclidean', 10 | rng=None, 11 | **kwargs): 12 | self.rng = check_random_state(rng) 13 | 14 | self.n_samples = n_samples 15 | self.n_features = n_features 16 | self.kernel_width = kernel_width 17 | self.metric = metric 18 | 19 | def explain(self, learner, known_examples, i, y_pred): 20 | """Computes the learner's explanation of a prediction.""" 21 | raise NotImplementedError() 22 | 23 | def query_label(self, i): 24 | """Queries the oracle for a label.""" 25 | raise NotImplementedError() 26 | 27 | def query_corrections(self, X_corr, y_corr, i, pred_y, pred_expl): 28 | """Queries the oracle for an improved explanation.""" 29 | raise NotImplementedError() 30 | 31 | def save_expl(self, path, i, pred_y, expl): 32 | """Saves an explanation to file.""" 33 | raise NotImplementedError() 34 | 35 | def eval(self, learner, known_examples, test_examples, eval_examples, 36 | t=None, basename=None): 37 | """Evaluates the learner.""" 38 | raise NotImplementedError() 39 | -------------------------------------------------------------------------------- /caipi/tabular.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from os.path import join 5 | from itertools import product 6 | from collections import defaultdict 7 | from sklearn.pipeline import make_pipeline 8 | from sklearn.linear_model import Ridge 9 | from sklearn.metrics import precision_recall_fscore_support as prfs 10 | from lime.lime_tabular import LimeTabularExplainer 11 | from matplotlib.patches import Circle, RegularPolygon 12 | from time import time 13 | 14 | from . import Problem, PipeStep, densify, vstack, hstack, setprfs 15 | 16 | 17 | _FEAT_NAME_REGEX = re.compile('[0-4],[0-4]') 18 | 19 | 20 | class TabularProblem(Problem): 21 | def __init__(self, **kwargs): 22 | self.y = kwargs.pop('y') 23 | self.Z = kwargs.pop('Z') 24 | self.class_names = kwargs.pop('class_names') 25 | self.z_names = kwargs.pop('z_names') 26 | self.categorical_features = kwargs.pop('categorical_features', 27 | list(range(self.Z.shape[1]))) 28 | self.discretize_features = kwargs.pop('discretize_features', False) 29 | self.lime_repeats = kwargs.pop('lime_repeats', 1) 30 | 31 | self.X = np.array([self.z_to_x(z) for z in self.Z], dtype=np.float64) 32 | self.explainable = set(range(len(self.y))) 33 | 34 | super().__init__(**kwargs) 35 | 36 | def z_to_x(self, z): 37 | """Converts an interpretable instance to an instance.""" 38 | raise NotImplementedError() 39 | 40 | def z_to_y(self, z): 41 | """Computes the true label of an interpretable instance.""" 42 | raise NotImplementedError() 43 | 44 | def z_to_expl(self, z): 45 | """Computes the true explanation of an interpretable instance.""" 46 | raise NotImplementedError() 47 | 48 | def query_label(self, i): 49 | return self.y[i] 50 | 51 | @staticmethod 52 | def _to_feat_name(disc_feat): 53 | return _FEAT_NAME_REGEX.findall(disc_feat)[0] 54 | 55 | @staticmethod 56 | def _feat_to_bounds(feat): 57 | EPS = 1e-13 58 | if ' > ' in feat: # (lb, +infty) 59 | name, lb = feat.split(' > ') 60 | lb, ub = float(lb), np.inf 61 | elif ' < ' in feat: # (lb, ub] 62 | name, ub = feat.split(' <= ') 63 | lb, name = name.split(' < ') 64 | lb, ub = float(lb), float(ub) 65 | elif ' <= ' in feat: # (-infty, ub] 66 | name, ub = feat.split(' <= ') 67 | lb, ub = -np.inf, float(ub) 68 | elif '=' in feat: # [lb, ub] ~ (lb - EPS, ub] 69 | name, value = feat.split('=') 70 | lb, ub = float(value) - EPS, float(value) 71 | else: # no discretization 72 | name, lb, ub = feat, -np.inf, np.inf 73 | return name, lb, ub 74 | 75 | def explain(self, learner, known_examples, i, y_pred): 76 | lime = LimeTabularExplainer(self.Z[known_examples], 77 | class_names=self.class_names, 78 | feature_names=self.z_names, 79 | kernel_width=self.kernel_width, 80 | categorical_features=self.categorical_features, 81 | discretize_continuous=self.discretize_features, 82 | feature_selection='forward_selection', 83 | verbose=False) 84 | 85 | step = PipeStep(lambda Z: np.array([self.z_to_x(z) for z in Z], 86 | dtype=np.float64)) 87 | pipeline = make_pipeline(step, learner) 88 | 89 | 90 | try: 91 | counts = defaultdict(int) 92 | for r in range(self.lime_repeats): 93 | 94 | t = time() 95 | local_model = Ridge(alpha=1, fit_intercept=True, random_state=0) 96 | expl = lime.explain_instance(self.Z[i], 97 | pipeline.predict_proba, 98 | model_regressor=local_model, 99 | num_samples=self.n_samples, 100 | num_features=self.n_features, 101 | distance_metric=self.metric) 102 | print(' LIME {}/{} took {}s'.format(r + 1, self.lime_repeats, 103 | time() - t)) 104 | 105 | for feat, coeff in expl.as_list(): 106 | coeff = int(np.sign(coeff)) 107 | counts[(feat, coeff)] += 1 108 | 109 | sorted_counts = sorted(counts.items(), key=lambda _: _[-1]) 110 | sorted_counts = list(sorted_counts)[-self.n_features:] 111 | return [fs for fs, _ in sorted_counts] 112 | 113 | except FloatingPointError: 114 | # XXX sometimes the calibrator classifier CV throws this 115 | print('Warning: LIME failed, returning no explanation') 116 | return None 117 | 118 | def _eval_expl(self, learner, known_examples, eval_examples, 119 | t=None, basename=None): 120 | 121 | if eval_examples is None: 122 | return -1, -1, -1 123 | 124 | perfs = [] 125 | for i in eval_examples: 126 | true_y = self.y[i] 127 | true_expl = self.z_to_expl(self.Z[i]) 128 | 129 | pred_y = learner.predict(densify(self.X[i]))[0] 130 | pred_expl = self.explain(learner, known_examples, i, pred_y) 131 | if pred_expl is None: 132 | print('Warning: skipping eval example') 133 | return -1, -1, -1 134 | 135 | true_feats = {feat.split('=')[0] for feat, _ in true_expl} 136 | pred_feats = {self._feat_to_bounds(feat)[0] for feat, _ 137 | in pred_expl} 138 | perfs.append(setprfs(true_feats, pred_feats)) 139 | 140 | if basename is None: 141 | continue 142 | 143 | self.save_expl(basename + '_{}_true.png'.format(i), 144 | i, true_y, true_expl) 145 | self.save_expl(basename + '_{}_{}.png'.format(i, t), 146 | i, pred_y, pred_expl) 147 | 148 | return np.mean(perfs, axis=0) 149 | 150 | def eval(self, learner, known_examples, test_examples, eval_examples, 151 | t=None, basename=None): 152 | pred_perfs = prfs(self.y[test_examples], 153 | learner.predict(self.X[test_examples]), 154 | average='binary')[:3] 155 | expl_perfs = self._eval_expl(learner, 156 | known_examples, 157 | eval_examples, 158 | t=t, basename=basename) 159 | return tuple(pred_perfs) + tuple(expl_perfs) 160 | 161 | 162 | _COORDS_FST = [[0, 0], [0, 2]] 163 | _COORDS_LST = [[2, 0], [2, 2]] 164 | 165 | 166 | class ToyProblem(TabularProblem): 167 | """A toy problem about classifying 3x3 black and white images. 168 | 169 | NOTE: some folds can be very unluky, for instance when the train corner 170 | patterns look like this: 171 | 172 | {(0.0, 0.0, 0.0, 1.0) 173 | (1.0, 1.0, 1.0, 1.0), 174 | (0.0, 1.0, 0.0, 0.0), 175 | (0.0, 0.0, 1.0, 0.0), 176 | (0.0, 0.0, 0.0, 0.0), 177 | (0.0, 1.0, 1.0, 0.0), 178 | (0.0, 1.0, 0.0, 1.0)} 179 | 180 | while the test corner patterns look like this: 181 | 182 | {(1.0, 0.0, 1.0, 0.0), 183 | (1.0, 1.0, 1.0, 1.0), 184 | (1.0, 0.0, 0.0, 0.0), 185 | (1.0, 0.0, 0.0, 1.0)} 186 | 187 | In this case, the train can be classified correctly by using the first 188 | pixel alone, while the test is predicted as all positive (recall 1, but 189 | very low precision). This can actually happen. 190 | """ 191 | 192 | def __init__(self, rule='fst', **kwargs): 193 | if not rule in ('fst', 'lst'): 194 | raise ValueError('invalid rule "{}"'.format(rule)) 195 | 196 | Z = np.array(list(product([0, 1], repeat=9))) 197 | y = np.array(list(map(self.z_to_y, Z))) 198 | 199 | notxor = lambda a, b: (a and b) or (not a and not b) 200 | valid_examples = [i for i in range(len(y)) 201 | if notxor(self._rule_fst(Z[i]), self._rule_lst(Z[i]))] 202 | 203 | z_names = ['{},{}'.format(r, c) 204 | for r, c in product(range(3), repeat=2)] 205 | 206 | self.rule = rule 207 | super().__init__(y=y[valid_examples], 208 | Z=Z[valid_examples], 209 | class_names=['negative', 'positive'], 210 | z_names=z_names, 211 | metric='hamming', 212 | **kwargs) 213 | 214 | def z_to_x(self, z): 215 | return z 216 | 217 | @staticmethod 218 | def _rule_fst(z): 219 | return all([z[3*r+c] for r, c in _COORDS_FST]) 220 | 221 | @staticmethod 222 | def _rule_lst(z): 223 | return all([z[3*r+c] for r, c in _COORDS_LST]) 224 | 225 | def z_to_y(self, z): 226 | return 1 if self._rule_fst(z) or self._rule_lst(z) else 0 227 | 228 | def z_to_expl(self, z): 229 | z = z.reshape((3, -1)) 230 | feat_coeff = [] 231 | for r, c in (_COORDS_FST if self.rule == 'fst' else _COORDS_LST): 232 | value = z[r, c] 233 | feat_coeff.append(('{},{}={}'.format(r, c, value), 2*value-1)) 234 | return feat_coeff 235 | 236 | def _parse_feat(self, feat): 237 | r = int(feat.split(',')[0]) 238 | c = int(feat.split(',')[-1].split('=')[0]) 239 | value = int(feat.split(',')[-1].split('=')[-1]) 240 | return r, c, value 241 | 242 | def query_corrections(self, X_corr, y_corr, i, pred_y, pred_expl, X_test): 243 | true_y = self.y[i] 244 | if pred_expl is None or pred_y != true_y: 245 | return X_corr, y_corr 246 | 247 | z = self.Z[i] 248 | true_feats = {feat.split('=')[0] for (feat, _) in self.z_to_expl(z)} 249 | pred_feats = {feat.split('=')[0] for (feat, _) in pred_expl} 250 | 251 | Z_new_corr = [] 252 | for feat in pred_feats - true_feats: 253 | r, c, _ = self._parse_feat(feat) 254 | z_corr = np.array(z, copy=True) 255 | z_corr[3*r+c] = 1 - z_corr[3*r+c] 256 | if self.z_to_y(z_corr) != true_y or tuple(z_corr) in X_test: 257 | continue 258 | Z_new_corr.append(z_corr) 259 | 260 | if not len(Z_new_corr): 261 | return X_corr, y_corr 262 | 263 | X_new_corr = np.array(Z_new_corr, dtype=np.float64) 264 | y_new_corr = np.array([true_y for _ in Z_new_corr], dtype=np.int8) 265 | 266 | X_corr = vstack([X_corr, X_new_corr]) 267 | y_corr = hstack([y_corr, y_new_corr]) 268 | return X_corr, y_corr 269 | 270 | def save_expl(self, path, i, y, expl): 271 | z = self.Z[i].reshape((3, -1)) 272 | 273 | fig, ax = plt.subplots(1, 1) 274 | ax.set_aspect('equal') 275 | ax.grid(True) 276 | ax.set_xticks([-0.5, 0.5, 1.5]) 277 | ax.set_yticks([-0.5, 0.5, 1.5]) 278 | ax.xaxis.set_ticklabels([]) 279 | ax.yaxis.set_ticklabels([]) 280 | 281 | ax.imshow(1 - z, interpolation='nearest', cmap=plt.get_cmap('binary')) 282 | for feat, coeff in expl: 283 | r, c, value = self._parse_feat(feat) 284 | if z[r, c] == value: 285 | color = '#00FF00' if coeff > 0 else '#FF0000' 286 | ax.add_patch(Circle((c, r), 0.4, color=color)) 287 | 288 | title = 'The system says "{}". The user says "{}"'.format( 289 | self.class_names[y], self.class_names[self.y[i]]) 290 | ax.text(0.5, 1.05, title, 291 | horizontalalignment='center', 292 | transform=ax.transAxes, 293 | fontsize=12) 294 | 295 | fig.savefig(path, bbox_inches=0, pad_inches=0, 296 | facecolor='#BFBFBF', 297 | edgecolor='none') 298 | plt.close(fig) 299 | 300 | 301 | 302 | _COLORS = [ 303 | (255, 0, 0), # r 304 | (0, 255, 0), # g 305 | (0, 128, 255), # b 306 | (128, 0, 255), # v 307 | ] 308 | 309 | class ColorsProblem(TabularProblem): 310 | """Colors problem from the "Right for the Right Reasons" paper.""" 311 | 312 | def __init__(self, rule=0, n_examples=1000, **kwargs): 313 | if not rule in (0, 1): 314 | raise ValueError('invalid rule "{}"'.format(rule)) 315 | self.rule = rule 316 | 317 | data = np.load(join('data', 'toy_colors.npz')) 318 | images = np.vstack([data['arr_0'], data['arr_1']]) 319 | images = np.array([image.reshape((5, 5, 3)) for image in images]) 320 | labels = 1 - np.hstack([data['arr_2'], data['arr_3']]) 321 | 322 | if n_examples: 323 | rng = np.random.RandomState(0) 324 | pi = rng.permutation(len(images))[:n_examples] 325 | else: 326 | pi = list(range(len(images))) 327 | 328 | z_names = ['{},{}'.format(r, c) 329 | for r, c in product(range(5), repeat=2)] 330 | 331 | super().__init__(y=np.array([labels[i] for i in pi]), 332 | Z=np.array([self._image_to_z(images[i]) for i in pi]), 333 | class_names=['negative', 'positive'], 334 | z_names=z_names, 335 | metric='hamming', 336 | discretize_features=True, 337 | categorical_features=[], 338 | **kwargs) 339 | 340 | @staticmethod 341 | def _image_to_z(image): 342 | return np.array([_COLORS.index(tuple(image[r, c])) 343 | for r, c in product(range(5), repeat=2)], 344 | dtype=np.float64) 345 | 346 | def z_to_x(self, z): 347 | x = [1 if z[i] == z[j] else 0 348 | for i in range(5*5) 349 | for j in range(i+1, 5*5)] 350 | return np.array(x, dtype=np.float64) 351 | 352 | @staticmethod 353 | def _rule0(z): 354 | return z[0,0] == z[0,4] and z[0,0] == z[4,0] and z[0,0] == z[4,4] 355 | 356 | @staticmethod 357 | def _rule1(z): 358 | return z[0,1] != z[0,2] and z[0,1] != z[0,3] and z[0,2] != z[0,3] 359 | 360 | def z_to_y(self, z): 361 | z = z.reshape((5, 5)) 362 | return self._rule0(z) if self.rule == 0 else self._rule1(z) 363 | 364 | def z_to_expl(self, z): 365 | z = z.reshape((5, 5)) 366 | 367 | if self.rule == 0: 368 | COORDS = [[0, 0], [0, 4], [4, 0], [4, 4]] 369 | else: 370 | COORDS = [[0, 1], [0, 2], [0, 3]] 371 | 372 | counts = np.bincount([z[r,c] for r, c in COORDS]) 373 | max_count, max_value = np.max(counts), np.argmax(counts) 374 | 375 | feat_to_coeff = defaultdict(int) 376 | if self.rule == 0: 377 | for r, c in COORDS: 378 | weight = 1 if max_count != 1 and z[r, c] == max_value else -1 379 | feat_to_coeff['{},{}={}'.format(r, c, int(z[r, c]))] += weight 380 | else: 381 | for r, c in COORDS: 382 | weight = 1 if max_count == 1 or z[r, c] != max_value else -1 383 | feat_to_coeff['{},{}={}'.format(r, c, int(z[r, c]))] += weight 384 | 385 | return list(feat_to_coeff.items()) 386 | 387 | def query_corrections(self, i, pred_y, pred_expl, X_test): 388 | if pred_expl is None: 389 | return set() 390 | if pred_y != self.y[i]: 391 | return set() 392 | 393 | z = self.Z[i] 394 | true_feats = {feat for (feat, _) in self.z_to_expl(z)} 395 | pred_feats = {feat for (feat, _) in pred_expl} 396 | 397 | ALL_VALUES = set(range(4)) 398 | 399 | Z_corr = [] 400 | for feat in pred_feats - true_feats: 401 | feat, lb, ub = self._feat_to_bounds(feat) 402 | r, c = feat.split(',') 403 | r, c = int(r), int(c) 404 | other_values = {value for value in ALL_VALUES if not (lb < value <= ub)} 405 | for value in other_values: 406 | z_corr = np.array(z, copy=True) 407 | z_corr[5*r+c] = value 408 | if self.z_to_y(z_corr) != pred_y: 409 | continue 410 | if tuple(z_corr) in X_test: 411 | continue 412 | Z_corr.append(z_corr) 413 | 414 | if not len(Z_corr): 415 | return set() 416 | 417 | X_corr = np.array([self.z_to_x(z_corr) for z_corr in Z_corr], 418 | dtype=np.float64) 419 | y_corr = np.array([pred_y for _ in Z_corr], dtype=np.int8) 420 | 421 | n_examples = len(self.y) 422 | self.X = vstack([self.X, X_corr]) 423 | self.y = hstack([self.y, y_corr]) 424 | 425 | return set(range(n_examples, len(self.y))) 426 | 427 | def save_expl(self, path, i, pred_y, expl): 428 | z = self.Z[i] 429 | image = np.array([_COLORS[int(value)] for value in z]).reshape((5, 5, 3)) 430 | 431 | fig, ax = plt.subplots(1, 1) 432 | ax.set_aspect('equal') 433 | 434 | z = z.reshape((5, 5)) 435 | ax.imshow(image, interpolation='nearest') 436 | for feat, coeff in expl: 437 | feat, lb, ub = self._feat_to_bounds(feat) 438 | r, c = feat.split(',') 439 | r, c = int(r), int(c) 440 | if lb < z[r, c] <= ub: 441 | color = '#FFFFFF' if coeff > 0 else '#000000' 442 | ax.add_patch(Circle((c, r), 0.35, color=color)) 443 | 444 | ax.text(0.5, 1.05, 'true = {} | pred = {}'.format(self.y[i], pred_y), 445 | horizontalalignment='center', 446 | transform=ax.transAxes) 447 | 448 | fig.savefig(path, bbox_inches=0, pad_inches=0) 449 | plt.close(fig) 450 | 451 | 452 | 453 | _TRIPLETS = [ 454 | [[0, 0], [0, 1], [0, 2]], 455 | [[1, 0], [1, 1], [1, 2]], 456 | [[2, 0], [2, 1], [2, 2]], 457 | [[0, 0], [1, 0], [2, 0]], 458 | [[0, 1], [1, 1], [2, 1]], 459 | [[0, 2], [1, 2], [2, 2]], 460 | [[0, 0], [1, 1], [2, 2]], 461 | [[0, 2], [1, 1], [2, 0]], 462 | ] 463 | 464 | _SALIENT_CONFIGS = [ 465 | # Win configs 466 | (( 1, 1, 1), 1), 467 | # Almost-win configs 468 | (( 1, 1, 0), -1), 469 | (( 1, 0, 1), -1), 470 | (( 0, 1, 1), -1), 471 | (( 1, 1, -1), -1), 472 | (( 1, -1, 1), -1), 473 | ((-1, 1, 1), -1), 474 | ] 475 | 476 | class TTTProblem(TabularProblem): 477 | """Tic-tac-toe endgames.""" 478 | 479 | def __init__(self, **kwargs): 480 | Z, y = [], [] 481 | with open(join('data', 'tic-tac-toe.data'), 'rt') as fp: 482 | for line in map(str.strip, fp): 483 | chars = line.split(',') 484 | Z.append([{'x': 1, 'b': 0, 'o': -1}[c] for c in chars[:-1]]) 485 | y.append({'positive': 1, 'negative': 0}[chars[-1]]) 486 | Z = np.array(Z, dtype=np.float64) 487 | y = np.array(y, dtype=np.int8) 488 | 489 | class_names = ['no-win', 'win'] 490 | z_names = [] 491 | for r, c in product(range(3), repeat=2): 492 | z_names.append('{r},{c}'.format(**locals())) 493 | 494 | super().__init__(Z, y, class_names, z_names, **kwargs) 495 | 496 | @staticmethod 497 | def get_config(z, triplet): 498 | return tuple(int(z[3*r+c]) for r, c in triplet) 499 | 500 | def z_to_y(self, z): 501 | for triplet in _TRIPLETS: 502 | if self.get_config(z, triplet) == (1, 1, 1): 503 | return True 504 | return False 505 | 506 | def z_to_expl(self, z): 507 | feat_coeff = set() 508 | for triplet in _TRIPLETS: 509 | config = self.get_config(z, triplet) 510 | for salient_config, coeff in _SALIENT_CONFIGS: 511 | if config == tuple(salient_config): 512 | for r, c in triplet: 513 | value = int(z[3*r+c]) 514 | if (value == 1 if coeff else value != 1): 515 | feat = '{r},{c}={value}'.format(**locals()) 516 | feat_coeff.add((feat, coeff)) 517 | print(self.z_to_y(z)) 518 | print(z.reshape((3, 3))) 519 | print(feat_coeff) 520 | quit() 521 | return feat_coeff 522 | 523 | @staticmethod 524 | def z_to_x(z): 525 | CONFIGS = list(product([-1, 0, 1], repeat=3)) 526 | 527 | def is_piece_at(z, i, j, piece): 528 | return 1 if z[i*3 + j] == piece else 0 529 | 530 | x = [] 531 | for i in range(3): 532 | x.extend([is_piece_at(z, i, 0, config[0]) and 533 | is_piece_at(z, i, 1, config[1]) and 534 | is_piece_at(z, i, 2, config[2]) 535 | for config in CONFIGS]) 536 | for j in range(3): 537 | x.extend([is_piece_at(z, 0, j, config[0]) and 538 | is_piece_at(z, 1, j, config[1]) and 539 | is_piece_at(z, 2, j, config[2]) 540 | for config in CONFIGS]) 541 | x.extend([is_piece_at(z, 0, 0, config[0]) and 542 | is_piece_at(z, 1, 1, config[1]) and 543 | is_piece_at(z, 2, 2, config[2]) 544 | for config in CONFIGS]) 545 | x.extend([is_piece_at(z, 0, 2, config[0]) and 546 | is_piece_at(z, 1, 1, config[1]) and 547 | is_piece_at(z, 2, 0, config[2]) 548 | for config in CONFIGS]) 549 | 550 | assert(sum(x) == (3 + 3 + 1 + 1)) 551 | return np.array(x, dtype=np.float64) 552 | 553 | def query_label(self, i): 554 | return self.y[i] 555 | 556 | def query_improved_expl(self, i, pred_y, pred_z): 557 | true_y = self.y[i] 558 | if pred_y != true_y: 559 | return None, None 560 | 561 | raise NotImplementedError() 562 | 563 | board = self.boards[i] 564 | true_feats = [feat for (feat, coeff) in 565 | self._board_to_expl(self.boards[i])] 566 | pred_feats = [feat for (feat, coeff) in pred_z] 567 | 568 | alt_boards = [] 569 | for feat in set(pred_feats) - set(true_feats): 570 | indices = feat.split('[')[-1].split(']')[0].split(',') 571 | i, j = int(indices[0]), int(indices[1]) 572 | for alt_piece in set(['o', 'b', 'x']) - set([str(board[i, j])]): 573 | alt_board = np.array(board) 574 | alt_board[i,j] = alt_piece 575 | # Do not add board with a wrong label 576 | if true_y == self._board_to_y(alt_board): 577 | alt_boards.append(alt_board) 578 | if not len(alt_boards): 579 | return None, None 580 | 581 | X_extra = [self._z_to_x(self._board_to_z(alt_board)) 582 | for alt_board in alt_boards] 583 | y_extra = [pred_y for alt_board in alt_boards] 584 | 585 | return (np.array(X_extra, dtype=np.float64), 586 | np.array(y_extra, dtype=np.int8)) 587 | 588 | def _score_features(self, board, expl): 589 | scores = np.zeros((3, 3)) 590 | for feat, coeff in expl: 591 | indices = feat.split('[')[-1].split(']')[0].split(',') 592 | value = int(feat.split('=')[-1]) 593 | i, j = int(indices[0]), int(indices[1]) 594 | if self._PIECE_TO_INT[board[i,j]] == value: 595 | scores[i, j] += coeff 596 | return scores 597 | 598 | def save_expl(self, path, i, y, z): 599 | board = self.boards[i] 600 | scores = self._score_features(board, z) 601 | 602 | fig = plt.figure(figsize=[3, 3]) 603 | ax = fig.add_subplot(111) 604 | 605 | for i in range(4): 606 | ax.plot([i, i], [0, 3], 'k') 607 | ax.plot([0, 3], [i, i], 'k') 608 | 609 | ax.set_position([0, 0, 1, 1]) 610 | ax.set_axis_off() 611 | ax.set_xlim(-1, 4) 612 | ax.set_ylim(-1, 4) 613 | 614 | for i, j in product(range(3), range(3)): 615 | if board[i, j] != 'b': 616 | ax.plot(j + 0.5, 617 | 3 - (i + 0.5), 618 | board[i, j], 619 | markersize=25, 620 | markerfacecolor=(0, 0, 0), 621 | markeredgecolor=(0, 0, 0), 622 | markeredgewidth=2) 623 | if np.abs(scores[i][j]) > 0: 624 | color = (0, 1, 0, 0.3) if scores[i,j] > 0 else (1, 0, 0, 0.3) 625 | ax.plot(j + 0.5, 626 | 3 - (i + 0.5), 627 | 's', 628 | markersize=35, 629 | markerfacecolor=color, 630 | markeredgewidth=0) 631 | 632 | ax.text(0.5, 0.825, 'y = {}'.format(y), 633 | horizontalalignment='center', 634 | transform=ax.transAxes) 635 | 636 | fig.savefig(path, bbox_inches=0, pad_inches=0) 637 | plt.close(fig) 638 | 639 | 640 | class SudokuProblem(TabularProblem): 641 | """Classification of sudoku instances in correct and incorrect.""" 642 | def __init__(self, **kwargs): 643 | 644 | Z = np.array([]) 645 | y = np.array([], dtype=np.int8) 646 | 647 | z_names = ['{},{}'.format(r, c) 648 | for r, c in product(range(3), repeat=2)] 649 | 650 | super().__init__(y=y, 651 | Z=Z, 652 | class_names=['infeasible', 'feasible'], 653 | z_names=z_names, 654 | metric='hamming', 655 | **kwargs) 656 | 657 | def z_to_x(self, z): 658 | raise NotImplementedError() 659 | 660 | def z_to_y(self, z): 661 | raise NotImplementedError() 662 | 663 | def z_to_expl(self, z): 664 | raise NotImplementedError() 665 | 666 | def query_corrections(self, X_corr, y_corr, i, pred_y, pred_expl, X_test): 667 | true_y = self.y[i] 668 | if pred_expl is None or pred_y != true_y: 669 | return X_corr, y_corr 670 | raise NotImplementedError() 671 | 672 | def save_expl(self, path, i, y, expl): 673 | raise NotImplementedError() 674 | -------------------------------------------------------------------------------- /caipi/text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re, blessings 3 | from time import time 4 | from os.path import join 5 | from collections import defaultdict 6 | from scipy.sparse import csr_matrix 7 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 8 | from sklearn.pipeline import make_pipeline 9 | from sklearn.linear_model import Ridge 10 | from sklearn.metrics import precision_recall_fscore_support as prfs 11 | from sklearn.utils import check_random_state 12 | from lime.lime_text import LimeTextExplainer 13 | from gensim.models.keyedvectors import KeyedVectors 14 | 15 | from . import Problem, load, densify, vstack, hstack 16 | 17 | 18 | _TERM = blessings.Terminal() 19 | 20 | 21 | class Normalizer: 22 | def fit(self, X, y=None): 23 | return self 24 | 25 | def transform(self, X, norms=None, return_norms=False, append_value=1): 26 | new_X, ret_norms = [], [] 27 | try: 28 | iterator = enumerate(X.todense()) 29 | except: 30 | iterator = enumerate(X) 31 | for i, x in iterator: 32 | x = np.array(x).ravel() 33 | norm = np.linalg.norm(x) if norms is None else norms[i] 34 | norm = max(1e-16, norm) 35 | ret_norms.append(norm) 36 | new_X.append(x.ravel() / norm) 37 | new_X = np.array(new_X) 38 | if append_value is not None: 39 | new_X = np.hstack([new_X, 40 | append_value * np.ones((new_X.shape[0], 1))]) 41 | new_X = csr_matrix(np.array(new_X)) 42 | if return_norms: 43 | return new_X, ret_norms 44 | return new_X 45 | 46 | 47 | class Word2VecVectorizer: 48 | def __init__(self, path): 49 | self.word2vec = KeyedVectors.load_word2vec_format(path, binary=True) 50 | self.n_features = self.word2vec.wv['the'].shape[0] 51 | self.no_embedding = set() 52 | 53 | def _embed_document(self, text): 54 | word_vectors = [] 55 | for word in text.split(): 56 | word = word.lower() 57 | try: 58 | word_vectors.append(self.word2vec.wv[word]) 59 | except KeyError: 60 | if word not in self.no_embedding: 61 | print('Warning: could not embed "{}"'.format(word)) 62 | self.no_embedding.add(word) 63 | return (np.mean(word_vectors, axis=0) if len(word_vectors) else 64 | np.zeros(self.n_features)) 65 | 66 | def fit(self, docs, y=None): 67 | return self 68 | 69 | def transform(self, docs): 70 | return np.array([self._embed_document(doc) for doc in docs]) 71 | 72 | 73 | class TextProblem(Problem): 74 | def __init__(self, **kwargs): 75 | self.class_names = kwargs.pop('class_names') 76 | self.y = kwargs.pop('y') 77 | self.docs = kwargs.pop('docs') 78 | self.processed_docs = kwargs.pop('processed_docs') 79 | self.explanations = kwargs.pop('explanations') 80 | self.lime_repeats = kwargs.pop('lime_repeats', 1) 81 | n_examples = kwargs.pop('n_examples', None) 82 | self.corr_type = kwargs.pop('corr_type', 'replace-expl') or 'replace-expl' 83 | self.vect_type = kwargs.pop('vect_type', 'glove') 84 | super().__init__(**kwargs) 85 | 86 | if n_examples is not None: 87 | rng = check_random_state(kwargs.get('rng', None)) 88 | perm = rng.permutation(len(self.y))[:n_examples] 89 | self.y = self.y[perm] 90 | self.docs = [self.docs[i] for i in perm] 91 | self.processed_docs = [self.processed_docs[i] for i in perm] 92 | self.explanations = [self.explanations[i] for i in perm] 93 | 94 | self.normalizer = Normalizer() 95 | if self.vect_type == 'binary': 96 | self.vectorizer = CountVectorizer(lowercase=False, binary=True) \ 97 | .fit(self.processed_docs) 98 | elif self.vect_type == 'tfidf': 99 | self.vectorizer = TfidfVectorizer(lowercase=False) \ 100 | .fit(self.processed_docs) 101 | elif self.vect_type == 'glove': 102 | path = join('data', 'word2vec_glove.6B.300d.bin') 103 | self.vectorizer = Word2VecVectorizer(path) 104 | elif self.vect_type == 'google-news': 105 | path = join('data', 'GoogleNews-vectors-negative300.bin') 106 | self.vectorizer = Word2VecVectorizer(path) 107 | else: 108 | raise ValueError('unknown vect_type "{}"'.format(self.vect_type)) 109 | 110 | self.X = self.vectorizer.transform(self.processed_docs) 111 | self.X, self.norms = self.normalizer.transform(self.X, 112 | return_norms=True, 113 | append_value=1) 114 | 115 | self.explainable = {i for i in range(len(self.y)) 116 | if self.explanations[i] is not None and len(self.explanations[i])} 117 | 118 | def _masks_to_expl(self, i): 119 | """Turns a list of word masks (which might highlight repeated words) 120 | into a set of words.""" 121 | words = self.processed_docs[i].split() 122 | true_masks = self.explanations[i] 123 | true_mask = np.sum(true_masks, axis=0) 124 | selected_words = {words[i] for i in np.where(true_mask)[0]} 125 | return [(word, self.y[i]) for word in selected_words] 126 | 127 | def explain(self, learner, known_examples, i, y_pred): 128 | explainer = LimeTextExplainer(class_names=self.class_names) 129 | 130 | # XXX hack 131 | if self.explanations[i] is None: 132 | n_features = 1 133 | else: 134 | true_expl = self._masks_to_expl(i) 135 | n_features = max(len(true_expl), 1) # XXX FIXME XXX 136 | 137 | pipeline = make_pipeline(self.vectorizer, self.normalizer, learner) 138 | 139 | counts = defaultdict(int) 140 | for r in range(self.lime_repeats): 141 | 142 | t = time() 143 | local_model = Ridge(alpha=1, fit_intercept=True, random_state=0) 144 | expl = explainer.explain_instance(self.processed_docs[i], 145 | pipeline.predict_proba, 146 | model_regressor=local_model, 147 | num_features=n_features, 148 | num_samples=self.n_samples) 149 | print(' LIME {}/{} took {:3.2f}s'.format(r + 1, self.lime_repeats, 150 | time() - t)) 151 | 152 | for word, coeff in expl.as_list(): 153 | counts[(word, int(np.sign(coeff)))] += 1 154 | 155 | sorted_counts = sorted(counts.items(), key=lambda _: _[-1]) 156 | sorted_counts = list(sorted_counts)[-n_features:] 157 | return [ws for ws, _ in sorted_counts] 158 | 159 | def query_label(self, i): 160 | return self.y[i] 161 | 162 | def query_corrections(self, i, pred_y, pred_expl, X_test): 163 | if pred_expl is None: 164 | return set() 165 | if pred_y != self.y[i]: 166 | return set() 167 | if i not in self.explainable: 168 | # XXX makes perf drop for some reason 169 | return set() 170 | 171 | all_words = set(self.processed_docs[i].split()) 172 | true_words = {word for word, _ in self._masks_to_expl(i)} 173 | pred_words = {word for word, _ in pred_expl} 174 | fp_words = pred_words - true_words 175 | 176 | print('original =', ' '.join(all_words)) 177 | if self.corr_type == 'replace-expl': 178 | correction = ' '.join(true_words) 179 | print('correction =', correction) 180 | print() 181 | self.X[i] = self.normalizer.transform( 182 | self.vectorizer.transform([correction]))[0] 183 | self.processed_docs[i] = correction 184 | extra_examples = set() 185 | 186 | elif self.corr_type == 'replace-no-fp': 187 | correction = ' '.join(all_words - fp_words) 188 | print('correction =', correction) 189 | print() 190 | self.X[i] = self.normalizer.transform( 191 | self.vectorizer.transform([correction]))[0] 192 | self.processed_docs[i] = correction 193 | extra_examples = set() 194 | 195 | elif self.corr_type == 'add-contrast': 196 | # NOTE make sure to set fit_intercept=False! 197 | words = np.array(self.processed_docs[i].split()) 198 | 199 | corrections = [] 200 | for mask in self.explanations[i]: 201 | masked_indices = np.where(mask)[0] 202 | rationale = ' '.join(words[masked_indices]) 203 | corrections.append(rationale) 204 | 205 | X_corrections = self.vectorizer.transform(corrections) 206 | X_corrections = self.normalizer.transform(X_corrections, 207 | norms=[self.norms[i] for _ in corrections], 208 | append_value=0) 209 | 210 | extra_examples = set(range(self.X.shape[0], 211 | self.X.shape[0] + len(corrections))) 212 | 213 | elif self.corr_type == 'add-contrast-fp': 214 | # NOTE make sure to set fit_intercept=False! 215 | words = np.array(self.processed_docs[i].split()) 216 | 217 | correction = ' '.join(word for word in words 218 | if word not in fp_words) 219 | corrections = [correction] 220 | 221 | X_corrections = self.vectorizer.transform(corrections) 222 | X_corrections = self.normalizer.transform(X_corrections, 223 | norms=[self.norms[i] for _ in corrections], 224 | append_value=0) 225 | 226 | extra_examples = set(range(self.X.shape[0], 227 | self.X.shape[0] + len(corrections))) 228 | 229 | else: 230 | raise ValueError('unknown correction type "{}"'.format( 231 | self.corr_type)) 232 | 233 | if len(extra_examples): 234 | y_corrections = np.array([pred_y] * len(corrections)) 235 | self.X = vstack([self.X, X_corrections]) 236 | self.y = hstack([self.y, y_corrections]) 237 | 238 | return extra_examples 239 | 240 | def _eval_expl(self, learner, known_examples, eval_examples, 241 | t=None, basename=None): 242 | if eval_examples is None: 243 | return -1, -1, -1 244 | 245 | perfs = [] 246 | for i in set(eval_examples) & self.explainable: 247 | true_y = self.y[i] 248 | true_expl = self._masks_to_expl(i) 249 | 250 | pred_y = learner.predict(densify(self.X[i]))[0] 251 | pred_expl = self.explain(learner, known_examples, i, pred_y) 252 | 253 | # NOTE here we don't care if the coefficients are wrong, since 254 | # those depend on whether the prediction is wrong 255 | 256 | true_words = {(word, np.sign(coeff)) for word, coeff in true_expl} 257 | pred_words = {(word, np.sign(coeff) if true_y == pred_y else -np.sign(coeff)) 258 | for word, coeff in pred_expl} 259 | 260 | matches = true_words & pred_words 261 | pr = len(matches) / len(pred_words) if len(pred_words) else 0.0 262 | rc = len(matches) / len(true_words) if len(true_words) else 0.0 263 | f1 = 0.0 if pr + rc <= 0 else 2 * pr * rc / (pr + rc) 264 | perfs.append((pr, rc, f1)) 265 | 266 | if basename is None: 267 | continue 268 | 269 | self.save_expl(basename + '_{}_{}.txt'.format(i, t), 270 | i, pred_y, pred_expl) 271 | self.save_expl(basename + '_{}_true.txt'.format(i), 272 | i, true_y, true_expl) 273 | 274 | return np.mean(perfs, axis=0) 275 | 276 | def eval(self, learner, known_examples, test_examples, eval_examples, 277 | t=None, basename=None): 278 | pred_perfs = prfs(self.y[test_examples], 279 | learner.predict(self.X[test_examples]), 280 | average='weighted')[:3] 281 | expl_perfs = self._eval_expl(learner, 282 | known_examples, 283 | eval_examples, 284 | t=t, basename=basename) 285 | return np.array(tuple(pred_perfs) + tuple(expl_perfs)) 286 | 287 | @staticmethod 288 | def _highlight_words(text, expl): 289 | for word, sign in expl: 290 | color = _TERM.green if sign >= 0 else _TERM.red 291 | colored_word = color + word + _TERM.normal 292 | matches = list(re.compile(word).finditer(text)) 293 | matches.reverse() 294 | for match in matches: 295 | start = match.start() 296 | text = text[:start] + colored_word + text[start+len(word):] 297 | return text 298 | 299 | def save_expl(self, path, i, pred_y, expl): 300 | with open(path, 'wt') as fp: 301 | fp.write('true y: ' + self.class_names[self.y[i]] + '\n') 302 | fp.write('pred y: ' + self.class_names[pred_y] + '\n') 303 | fp.write(80 * '-' + '\n') 304 | #fp.write(self._highlight_words(self.docs[i], expl)) 305 | fp.write('\n' + 80 * '-' + '\n') 306 | fp.write('explanation:\n') 307 | for word, sign in sorted(expl): 308 | fp.write('{:32s} : {:3.1f}\n'.format(word, sign)) 309 | 310 | 311 | class ReviewsProblem(TextProblem): 312 | def __init__(self, **kwargs): 313 | 314 | path = join('data', 'review_polarity_rationales.pickle') 315 | try: 316 | dataset = load(path) 317 | except: 318 | raise RuntimeError('Run the data preparation script first!') 319 | 320 | super().__init__(class_names=['neg', 'pos'], 321 | y=dataset['y'], 322 | docs=dataset['docs'], 323 | processed_docs=dataset['docs'], 324 | explanations=dataset['explanations'], 325 | **kwargs) 326 | 327 | 328 | class NewsgroupsProblem(TextProblem): 329 | def __init__(self, **kwargs): 330 | 331 | path = join('data', 'newsgroups.pickle') 332 | try: 333 | dataset = load(path) 334 | except: 335 | raise RuntimeError('Run the data preparation script first!') 336 | 337 | super().__init__(class_names=['neg', 'pos'], 338 | y=dataset['y'], 339 | docs=dataset['docs'], 340 | processed_docs=dataset['docs'], 341 | explanations=dataset['explanations'], 342 | **kwargs) 343 | -------------------------------------------------------------------------------- /caipi/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import scipy as sp 4 | 5 | 6 | def load(path, **kwargs): 7 | with open(path, 'rb') as fp: 8 | return pickle.load(fp, **kwargs) 9 | 10 | 11 | def dump(path, what, **kwargs): 12 | with open(path, 'wb') as fp: 13 | pickle.dump(what, fp, **kwargs) 14 | 15 | 16 | def densify(x): 17 | try: 18 | x = x.toarray() 19 | except AttributeError: 20 | pass 21 | if x.shape[0] != 1: 22 | # if X[i] is already dense, densify(X[i]) is a no-op, so we get an x 23 | # of shape (n_features,) and we turn it into (1, n_features); 24 | # if X[i] is sparse, densify(X[i]) gives an x of shape (1, n_features). 25 | x = x[np.newaxis, ...] 26 | return x 27 | 28 | 29 | def _stack(arrays, d_stack, s_stack): 30 | arrays = [a for a in arrays if a is not None] 31 | if len(arrays) == 0: 32 | return None 33 | if len(arrays) == 1: 34 | return arrays[0] 35 | if isinstance(arrays[0], sp.sparse.csr_matrix): 36 | return s_stack(arrays) 37 | else: 38 | return d_stack(arrays) 39 | 40 | vstack = lambda arrays: _stack(arrays, np.vstack, sp.sparse.vstack) 41 | hstack = lambda arrays: _stack(arrays, np.hstack, sp.sparse.hstack) 42 | 43 | 44 | def setprfs(true, pred): 45 | matches = true & pred 46 | pr = len(matches) / len(pred) if len(pred) else 0.0 47 | rc = len(matches) / len(true) if len(true) else 0.0 48 | f1 = 0.0 if pr + rc <= 0 else 2*pr*rc / (pr + rc) 49 | return pr, rc, f1 50 | 51 | 52 | class PipeStep: 53 | def __init__(self, func): 54 | self.func = func 55 | 56 | def fit(self, *args, **kwargs): 57 | return self 58 | 59 | def transform(self, X): 60 | return self.func(X) 61 | -------------------------------------------------------------------------------- /colors-draw-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./caipi-draw.py colors-rule0-10folds \ 4 | results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 5 | results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 6 | results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 7 | results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 8 | --min-pred-f1 0.6 9 | 10 | ./caipi-draw.py colors-rule1-10folds \ 11 | results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 12 | results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 13 | results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 14 | results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 15 | --min-pred-f1 0.6 16 | 17 | ./caipi-draw-weights.py colors-ruleX__l1svm__least-confident__noei results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 18 | ./caipi-draw-weights.py colors-rule0__l1svm__least-confident__ei results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 19 | ./caipi-draw-weights.py colors-rule1__l1svm__least-confident__ei results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 20 | 21 | ./caipi-draw-weights.py colors-ruleX__svm__least-confident__noei results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 22 | ./caipi-draw-weights.py colors-rule0__svm__least-confident__ei results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 23 | ./caipi-draw-weights.py colors-rule1__svm__least-confident__ei results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 24 | 25 | exit 26 | 27 | ./caipi-draw.py colors-rule0 \ 28 | results/colors-rule0__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 29 | results/colors-rule0__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 30 | results/colors-rule0__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 31 | results/colors-rule0__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle \ 32 | --min-pred-f1 0.6 33 | 34 | ./caipi-draw.py colors-rule1 \ 35 | results/colors-rule1__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 36 | results/colors-rule1__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 37 | results/colors-rule1__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 38 | results/colors-rule1__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle \ 39 | --min-pred-f1 0.6 40 | 41 | ./caipi-draw-weights.py colors-ruleX__l1svm__least-confident__noei results/colors-rule0__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 42 | ./caipi-draw-weights.py colors-rule0__l1svm__least-confident__ei results/colors-rule0__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 43 | ./caipi-draw-weights.py colors-rule1__l1svm__least-confident__ei results/colors-rule1__l1svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 44 | 45 | ./caipi-draw-weights.py colors-ruleX__lr__least-confident__noei results/colors-rule1__lr__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 46 | ./caipi-draw-weights.py colors-rule1__lr__least-confident__ei results/colors-rule1__lr__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 47 | 48 | ./caipi-draw-weights.py colors-ruleX__svm__least-confident__noei results/colors-rule0__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 49 | ./caipi-draw-weights.py colors-rule0__svm__least-confident__ei results/colors-rule0__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle 50 | ./caipi-draw-weights.py colors-rule1__svm__least-confident__ei results/colors-rule1__svm__least-confident__k=3__n=None__p=0.0__P=5.0__T=101__e=20__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle 51 | -------------------------------------------------------------------------------- /data/fer2013.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """The facial emotion recognition dataset can be downloaded at: 3 | 4 | https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge 5 | """ 6 | 7 | import numpy as np 8 | import pickle 9 | import matplotlib.pyplot as plt 10 | 11 | print('loading') 12 | y, images = [], [] 13 | with open('fer2013.csv', 'rt') as fp: 14 | for line in fp.readlines()[1:]: 15 | label, pixels, _ = line.split(',') 16 | y.append(label) 17 | image = np.array(pixels.split()).astype(np.uint8).reshape(48, 48) 18 | images.append(image) 19 | 20 | print('reshaping') 21 | X = np.array(images) 22 | y = np.array(y).astype(np.uint8) 23 | 24 | print('saving') 25 | with open('fer2013.pickle', 'wb') as fp: 26 | pickle.dump({ 27 | 'data': X, 28 | 'target': y, 29 | 'class_names': ('anger', 'disgust', 'fear', 'happiness', 'sadness', 'surprise', 'neutral'), 30 | }, fp, protocol=pickle.HIGHEST_PROTOCOL) 31 | 32 | # print('drawing') 33 | # for i, x in enumerate(X): 34 | # fig, ax = plt.subplots(1, 1) 35 | # fig.set_size_inches((1, 1)) 36 | # ax.imshow(x, cmap='gist_gray', aspect='equal') 37 | # ax.get_xaxis().set_visible(False) 38 | # ax.get_yaxis().set_visible(False) 39 | # fig.savefig('example-{}.png'.format(i), bbox_inches='tight', pad_inches=0) 40 | -------------------------------------------------------------------------------- /data/toy_colors.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/data/toy_colors.npz -------------------------------------------------------------------------------- /form/caipi.csv: -------------------------------------------------------------------------------- 1 | "Timestamp","Do you believe that the AI system eventually learned to classify images correctly?","Do you believe that the AI system eventually learned the correct classification rule?","Would you like to further assess the AI system by checking whether it classifies 10 random images correctly?","Do you believe that the AI system eventually learned to classify images correctly?","Do you believe that the AI system eventually learned the correct classification rule?","Would you like to further assess the AI system by checking whether it classifies 10 random images correctly?","Do you believe that the AI system eventually learned to classify images correctly?","Do you believe that the AI system eventually learned the correct classification rule?","Would you like to further assess the AI system by checking whether it classifies 10 random images correctly?" 2 | "2018/05/08 12:35:07 pm EET","Yes","Yes","No","No","No","No","Yes","No","No" 3 | "2018/05/08 12:39:10 pm EET","No","No","Yes","Yes","Yes","Yes","Yes","Yes","Yes" 4 | "2018/05/08 12:46:22 pm EET","Yes","Yes","Yes","Yes","Yes","Yes","No","No","Yes" 5 | "2018/05/08 1:03:57 pm EET","Yes","Yes","Yes","Yes","Yes","No","No","No","No" 6 | "2018/05/08 1:11:45 pm EET","Yes","No","Yes","Yes","Yes","Yes","No","No","No" 7 | "2018/05/08 1:33:59 pm EET","Yes","No","Yes","Yes","Yes","Yes","No","No","Yes" 8 | "2018/05/08 1:36:40 pm EET","No","No","No","No","No","Yes","No","No","No" 9 | "2018/05/08 1:57:34 pm EET","Yes","No","Yes","No","No","Yes","No","No","No" 10 | "2018/05/08 5:36:05 pm EET","No","No","Yes","Yes","Yes","Yes","No","No","Yes" 11 | "2018/05/08 7:23:15 pm EET","Yes","Yes","Yes","Yes","Yes","No","No","No","Yes" 12 | "2018/05/08 8:43:10 pm EET","No","No","Yes","No","No","Yes","No","No","No" 13 | "2018/05/08 9:09:23 pm EET","Yes","No","No","Yes","No","Yes","No","No","No" 14 | "2018/05/08 9:37:49 pm EET","Yes","No","Yes","Yes","Yes","No","Yes","No","Yes" 15 | "2018/05/09 1:17:27 am EET","No","No","Yes","Yes","Yes","Yes","No","No","Yes" 16 | "2018/05/09 8:17:34 am EET","Yes","Yes","Yes","Yes","Yes","Yes","No","No","No" 17 | "2018/05/09 9:41:22 am EET","Yes","Yes","Yes","Yes","Yes","Yes","Yes","Yes","No" 18 | "2018/05/10 4:12:20 am EET","No","No","Yes","Yes","No","No","Yes","No","No" -------------------------------------------------------------------------------- /form/questionnaire.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/form/questionnaire.pdf -------------------------------------------------------------------------------- /matplotlibrc: -------------------------------------------------------------------------------- 1 | backend: agg 2 | -------------------------------------------------------------------------------- /prepare-newsgroups.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import spacy 5 | from os.path import join 6 | from sklearn.datasets import fetch_20newsgroups 7 | from sklearn.linear_model import SGDClassifier 8 | from sklearn.feature_extraction.text import TfidfVectorizer 9 | from caipi import load, dump 10 | 11 | 12 | SPACY = spacy.load('en_core_web_sm', disable=['parser', 'ner']) 13 | POS_TAGS = {'ADJ', 'ADV', 'NOUN', 'VERB'} 14 | 15 | 16 | def simplify(line): 17 | tokens = SPACY(line) 18 | valid_lemmas = [] 19 | for i, token in enumerate(tokens): 20 | if (token.pos_ in POS_TAGS and 21 | token.lemma_ != '-PRON-'): 22 | valid_lemmas.append(token.lemma_) 23 | return ' '.join(valid_lemmas) 24 | 25 | 26 | categories = ['alt.atheism', 'soc.religion.christian'] 27 | twenty = fetch_20newsgroups(subset='all', 28 | categories=categories, 29 | remove=['headers', 'footers'], 30 | random_state=0) 31 | docs = [simplify(doc) for doc in twenty.data] 32 | 33 | 34 | vectorizer = TfidfVectorizer(lowercase=False) 35 | X = vectorizer.fit_transform(docs).toarray() 36 | y = twenty.target 37 | 38 | vocabulary = np.array(vectorizer.get_feature_names()) 39 | feature_selector = SGDClassifier(penalty='l1', random_state=0) 40 | feature_selector.fit(X, y) 41 | print('feature_selector acc =', feature_selector.score(X, y)) 42 | coef = np.abs(feature_selector.coef_.ravel()) 43 | 44 | selected_indices = [i for i in coef.argsort()[::-1] if coef[i] >= 1] 45 | selected_words = vocabulary[selected_indices] 46 | 47 | print('# words =', len(vocabulary)) 48 | print('# selected words =', len(selected_words)) 49 | 50 | docs2 = [] 51 | rats = [] 52 | keep = [] 53 | for i, doc in enumerate(docs): 54 | words = np.array(doc.split()) 55 | if len(words) == 0: 56 | continue 57 | indices = [i for i in range(len(words)) 58 | if words[i] in selected_words] 59 | print('%% relevant =', len(indices) / len(words)) 60 | mask = np.zeros((1, len(words))) 61 | mask[0, indices] = 1 62 | print(mask) 63 | rats.append(mask) 64 | docs2.append(doc) 65 | keep.append(i) 66 | y2 = y[keep] 67 | 68 | vectorizer = TfidfVectorizer(lowercase=False, vocabulary=selected_words) 69 | X2 = vectorizer.fit_transform(docs2).toarray() 70 | feature_selector = SGDClassifier(penalty='l1', random_state=0) 71 | feature_selector.fit(X2, y2) 72 | print('feature_selector acc =', feature_selector.score(X2, y2)) 73 | 74 | dataset = { 75 | 'y': y2, 76 | 'docs': docs2, 77 | 'explanations': rats, 78 | } 79 | 80 | dump(join('data', 'newsgroups.pickle'), dataset) 81 | -------------------------------------------------------------------------------- /prepare-reviews.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import re 4 | import pickle 5 | import numpy as np 6 | import spacy 7 | from os import listdir 8 | from os.path import join 9 | from sklearn.feature_extraction.text import TfidfVectorizer 10 | from sklearn.linear_model import SGDClassifier 11 | from caipi import load, dump 12 | 13 | 14 | N_DOCUMENTS_PER_CLASS = np.nan 15 | METHOD = 'global' 16 | 17 | 18 | # Make sure to download the dataset from: 19 | # 20 | # http://cs.jhu.edu/~ozaidan/rationales 21 | # 22 | # and uncompress it in data/review_polarity_rationales/ 23 | 24 | 25 | SPACY = spacy.load('en_core_web_sm', disable=['parser', 'ner']) 26 | 27 | 28 | POS_TAGS = {'ADJ', 'ADV', 'NOUN', 'VERB'} 29 | RAT_TAGS = {'POS', '/POS', 'NEG', '/NEG'} 30 | RAT_TAGS2 = {'<' + tag + '>' for tag in RAT_TAGS} 31 | REGEX = re.compile('<(POS|NEG)> (?P[^<>]*) ') 32 | 33 | 34 | # NOTE 'oscar winner' 35 | # XXX what about negation? it would need POS tagging maybe 36 | 37 | 38 | def simplify(line): 39 | tokens = SPACY(line) 40 | valid_lemmas = [] 41 | for i, token in enumerate(tokens): 42 | if (token.pos_ in POS_TAGS and 43 | token.lemma_ != '-PRON-'): 44 | valid_lemmas.append(token.lemma_) 45 | if (token.text in RAT_TAGS and 46 | tokens[i-1].text == '<' and 47 | tokens[i+1].text == '>'): 48 | valid_lemmas.append('<' + token.text + '>') 49 | return ' '.join(valid_lemmas) 50 | 51 | 52 | def process_rats(line): 53 | matches = list(REGEX.finditer(line)) 54 | if len(matches) == 0: 55 | return line, None 56 | 57 | ranges = [] 58 | for match in matches: 59 | ranges.extend([(match.start(), True), (match.end(), False)]) 60 | if not ranges[0] == (0, True): 61 | ranges = [(0, False)] + ranges 62 | if not ranges[-1] == (len(line), False): 63 | ranges = ranges + [(len(line), False)] 64 | 65 | words = line.split() 66 | masks = np.zeros((len(matches), len(words))) 67 | 68 | j = 0 69 | valid_words = [] 70 | for i in range(len(ranges) - 1): 71 | s, is_rationale = ranges[i] 72 | e, _ = ranges[i + 1] 73 | segment_words = [word for word in line[s:e].strip().split() 74 | if not word in RAT_TAGS2] 75 | 76 | if is_rationale: 77 | masks[j, len(valid_words):len(valid_words)+len(segment_words)] = 1 78 | j += 1 79 | valid_words.extend(segment_words) 80 | 81 | return ' '.join(valid_words), masks 82 | 83 | 84 | def read_docs(base_path, label): 85 | docs, rats = [], [] 86 | rel_paths = sorted(listdir(base_path)) 87 | for k, rel_path in enumerate(rel_paths): 88 | if k >= N_DOCUMENTS_PER_CLASS: 89 | break 90 | print('processing {}/{} {}'.format(k + 1, len(rel_paths), rel_path)) 91 | n = rel_path.split('_')[-1].split('.')[0] 92 | with open(join(base_path, rel_path), encoding='latin-1') as fp: 93 | doc = simplify(fp.read().strip()) 94 | doc, masks = process_rats(doc) 95 | docs.append(doc) 96 | rats.append(masks) 97 | return docs, rats 98 | 99 | 100 | np.set_printoptions(threshold=np.nan) 101 | 102 | try: 103 | print('Loading...') 104 | y, docs, rats = load('reviews.pickle') 105 | 106 | except: 107 | print('Reading documents...') 108 | pos_docs, pos_rats = read_docs(join('data', 'review_polarity_rationales', 'withRats_pos'), +1) 109 | neg_docs, neg_rats = read_docs(join('data', 'review_polarity_rationales', 'withRats_neg'), -1) 110 | 111 | print('Saving...') 112 | y = np.array([+1] * len(pos_docs) + [-1] * len(neg_docs)) 113 | docs = pos_docs + neg_docs 114 | rats = pos_rats + neg_rats 115 | dump('reviews.pickle', (y, docs, rats)) 116 | 117 | vectorizer = TfidfVectorizer(lowercase=False) 118 | X = vectorizer.fit_transform(docs).toarray() 119 | vocabulary = np.array(vectorizer.get_feature_names()) 120 | 121 | model = SGDClassifier(penalty='l1', random_state=0).fit(X, y) 122 | coef = np.abs(model.coef_.ravel()) 123 | selected = coef.argsort()[-len(vocabulary) // 5:] 124 | relevant_words = set(vocabulary[selected]) 125 | 126 | print('feature selector acc =', model.score(X, y)) 127 | print('# words =', len(vocabulary)) 128 | print('# relevant words =', len(relevant_words)) 129 | 130 | rats = [] 131 | for doc in docs: 132 | 133 | words = doc.split() 134 | relevant_indices = [i for i in range(len(words)) 135 | if words[i] in relevant_words] 136 | print('# relevant in doc =', len(relevant_indices)) 137 | mask = np.zeros((1, len(words))) 138 | mask[0, relevant_indices] = 1 139 | rats.append(mask) 140 | 141 | dataset = { 142 | 'y': y, 143 | 'docs': docs, 144 | 'explanations': rats, 145 | } 146 | 147 | with open(join('data', 'review_polarity_rationales.pickle'), 'wb') as fp: 148 | pickle.dump(dataset, fp) 149 | -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=-1__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=4__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule0__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=4__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__l1svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__lr__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__C=None__F=3__S=2000__K=0.75__R=100__V=None__s=0.pickle -------------------------------------------------------------------------------- /results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0-params.pickle -------------------------------------------------------------------------------- /results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/colors-rule1__svm__least-confident__k=10__n=None__p=0.0__P=5.0__T=101__e=-1__E=0__F=3__S=2000__K=0.75__R=100__s=0.pickle -------------------------------------------------------------------------------- /results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle -------------------------------------------------------------------------------- /results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=101__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=10__V=tfidf__s=0.pickle -------------------------------------------------------------------------------- /results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=-1__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle -------------------------------------------------------------------------------- /results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefanoteso/caipi/b867958bc13460d018f3cbe57ce9ea39fddaca31/results/newsgroups__lr__least-confident__k=10__n=None__p=0.0__P=1.0__T=301__e=20__E=0__C=add-contrast-fp__F=10__S=200__K=3.0__R=25__V=tfidf__s=0.pickle -------------------------------------------------------------------------------- /run-caipi-color.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for L in l1svm svm lr; do 4 | 5 | # No corrections 6 | for S in random least-confident; do 7 | ./caipi.py colors-rule0 $L $S -k 10 -p 0 -P 5 -T 101 -e 20 -E -1 -F 4 -S 2000 -R 100 8 | ./caipi.py colors-rule1 $L $S -k 10 -p 0 -P 5 -T 101 -e 20 -E -1 -F 3 -S 2000 -R 100 9 | done 10 | 11 | # With corrections 12 | for S in random least-confident; do 13 | ./caipi.py colors-rule0 $L $S -k 10 -p 0 -P 5 -T 101 -e 20 -E 0 -F 4 -S 2000 -R 100 14 | ./caipi.py colors-rule1 $L $S -k 10 -p 0 -P 5 -T 101 -e 20 -E 0 -F 3 -S 2000 -R 100 15 | done 16 | 17 | done 18 | -------------------------------------------------------------------------------- /run-caipi-toy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for PROBLEM in toy-fst toy-lst; do 4 | for L in lr l1svm svm; do 5 | for S in random least-confident; do 6 | ./caipi.py $PROBLEM $L $S -p 0 -P 5 -k 3 -T 1001 -e 5 -S 1000 -F 2 -E 0 2>/dev/null 7 | #./caipi.py $PROBLEM $L $S -p 0 -P 1 -k 3 -T 101 -e 5 -S 1000 -F 2 -E 0 -I 2>/dev/null 8 | done 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /run-caipi-ttt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for S in random least-confident; do 4 | ./caipi.py ttt svm $S -p 0.01 -S 10000 -F 3 -S 15000 5 | done 6 | 7 | for S in random least-confident; do 8 | ./caipi.py ttt svm $S -p 0.01 -S 10000 -F 3 -S 15000 -E 10 -I 9 | done 10 | -------------------------------------------------------------------------------- /versus-rrr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | sys.path.append('../rrr') 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from multilayer_perceptron import MultilayerPerceptron 10 | from figure_grid import * 11 | from local_linear_explanation import explanation_grid 12 | import decoy_mnist 13 | 14 | 15 | np.random.seed(0) 16 | 17 | NUM_EPOCHS = 64 # this is the default in rrr 18 | 19 | 20 | def correct_one(x, e, label, n_counterexamples): 21 | x = x.reshape((28, 28)) # 28 x 28 x [0 ... 255] 22 | e = e.reshape((28, 28)) # 28 x 28 x {False, True} 23 | 24 | X_counterexamples = [] 25 | for _ in range(n_counterexamples): 26 | x_counterexample = np.array(x, copy=True) 27 | x_counterexample[e] = np.random.randint(0, 256, size=x[e].shape) 28 | X_counterexamples.append(x_counterexample.ravel()) 29 | return X_counterexamples 30 | 31 | 32 | def get_corrections(X, E, y, n_counterexamples=2): 33 | X_counterexamples, y_counterexamples = [], [] 34 | for x, e, label in zip(X, E, y): 35 | temp = correct_one(x, e, label, n_counterexamples) 36 | X_counterexamples.extend(temp) 37 | y_counterexamples.extend([label] * len(temp)) 38 | return np.array(X_counterexamples), np.array(y_counterexamples) 39 | 40 | 41 | (X_tr_orig, X_tr_decoy, y_tr, E_tr, 42 | X_ts_orig, X_ts_decoy, y_ts, E_ts) = \ 43 | decoy_mnist.generate_dataset(cachefile='../data/fashion/decoy-fashion.npz') 44 | 45 | for n_counterexamples in range(1, 6): 46 | print('Fitting MLP corrected on {} examples'.format(n_counterexamples)) 47 | 48 | X_tr_corrections, y_tr_corrections = get_corrections(X_tr_decoy, E_tr, y_tr, 49 | n_counterexamples=n_counterexamples) 50 | X_tr_corrected = np.vstack([X_tr_decoy, X_tr_corrections]) 51 | y_tr_corrected = np.hstack([y_tr, y_tr_corrections]) 52 | 53 | print('# examples =', len(y_tr_corrected)) 54 | 55 | mlp_corrected = MultilayerPerceptron() 56 | mlp_corrected.fit(X_tr_corrected, y_tr_corrected, num_epochs=NUM_EPOCHS, 57 | verbose=100) 58 | print('avg. acc. on train (decoy) ', mlp_corrected.score(X_tr_decoy, y_tr)) 59 | print('avg. acc. on test (decoy) ', mlp_corrected.score(X_ts_decoy, y_ts)) 60 | print('avg. acc. on test (nodecoy)', mlp_corrected.score(X_ts_orig, y_ts)) 61 | 62 | quit() 63 | 64 | print('Fitting MLP annotated') 65 | mlp_annotated = MultilayerPerceptron(l2_grads=1000) 66 | mlp_annotated.fit(X_tr_decoy, y_tr, E_tr, num_epochs=NUM_EPOCHS) 67 | print('avg. acc. on train (decoy) ', mlp_annotated.score(X_tr_decoy, y_tr)) 68 | print('avg. acc. on test (decoy) ', mlp_annotated.score(X_ts_decoy, y_ts)) 69 | print('avg. acc. on test (nodecoy)', mlp_annotated.score(X_ts_orig, y_ts)) 70 | 71 | print('Fitting MLP normal') 72 | mlp_normal = MultilayerPerceptron() 73 | mlp_normal.fit(X_tr_decoy, y_tr, num_epochs=NUM_EPOCHS) 74 | print('avg. acc. on train (decoy) ', mlp_normal.score(X_tr_decoy, y_tr)) 75 | print('avg. acc. on test (decoy) ', mlp_normal.score(X_ts_decoy, y_ts)) 76 | print('avg. acc. on test (nodecoy)', mlp_normal.score(X_ts_orig, y_ts)) 77 | 78 | print('Fitting MLP nodecoy') 79 | mlp_nodecoy = MultilayerPerceptron() 80 | mlp_nodecoy.fit(X_tr_orig, y_tr, num_epochs=NUM_EPOCHS) 81 | print('avg. acc. on train (decoy) ', mlp_nodecoy.score(X_tr_decoy, y_tr)) 82 | print('avg. acc. on test (decoy) ', mlp_nodecoy.score(X_ts_decoy, y_ts)) 83 | print('avg. acc. on test (nodecoy)', mlp_nodecoy.score(X_ts_orig, y_ts)) 84 | --------------------------------------------------------------------------------