├── .gitignore ├── README.md ├── labelshiftexperiments ├── __init__.py ├── cifarandmnist.py └── maketable.py ├── notebooks ├── CIFAR10.ipynb ├── CIFAR100.ipynb ├── KaggleDiabeticRetinopathy.ipynb ├── MNIST.ipynb ├── README.txt ├── blog_colab.ipynb ├── cifar100_label_shift_adaptation_results.json.gz ├── cifar10_label_shift_adaptation_results.json.gz ├── demo │ ├── blog_colab.ipynb │ ├── demo_shifted_test_labels.txt.gz │ ├── demo_shifted_test_preds.txt.gz │ ├── demo_valid_labels.txt.gz │ └── demo_valid_preds.txt.gz ├── kaggledr_label_shift_adaptation_results.json.gz ├── mnist_label_shift_adaptation_results.json.gz └── obtaining_predictions │ ├── README.txt │ ├── cifar10 │ ├── Download_CIFAR10_models_from_zenodo_and_make_predictions.ipynb │ ├── README.txt │ ├── intro_plot.ipynb │ └── train_cifar100.py │ ├── cifar100 │ ├── README.txt │ ├── getPreds.ipynb │ └── train_cifar100.py │ ├── diabetic_retinopathy │ └── README.txt │ └── mnist │ ├── README.txt │ └── Train_MNIST_and_make_predictions.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Accompanying code for the paper *Maximum Likelihood With Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation* 3 | Accepted to ICML 2020 4 | Authors: Amr Alexandari\*, Anshul Kundaje† and Avanti Shrikumar\*† 5 | *co-first authors, †co-corresponding authors 6 | 7 | See https://colab.research.google.com/github/kundajelab/labelshiftexperiments/blob/master/notebooks/demo/blog_colab.ipynb for a demo notebook illustrating the core functionality 8 | 9 | Core calibration and label shift adaptation code lives in https://github.com/kundajelab/abstention 10 | 11 | See the notebooks/ folder for code to replicate tables in the paper, and post a github issue if you have questions! 12 | -------------------------------------------------------------------------------- /labelshiftexperiments/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | from . import cifarandmnist 3 | from . import maketable 4 | -------------------------------------------------------------------------------- /labelshiftexperiments/cifarandmnist.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | from collections import defaultdict, OrderedDict 4 | import scipy 5 | import sys 6 | import abstention.calibration 7 | 8 | 9 | def read_labels(fh): 10 | to_return = [] 11 | for line in fh: 12 | the_class=int(line.rstrip()) 13 | to_add = np.zeros(10) 14 | to_add[the_class] = 1 15 | to_return.append(to_add) 16 | return np.array(to_return) 17 | 18 | 19 | def read_preds(fh): 20 | return np.array([[float(x) for x in y.rstrip().split("\t")] 21 | for y in fh]) 22 | 23 | 24 | def sample_from_probs_arr(arr_with_probs, rng): 25 | rand_num = rng.uniform() 26 | cdf_so_far = 0 27 | for (idx, prob) in enumerate(arr_with_probs): 28 | cdf_so_far += prob 29 | if (cdf_so_far >= rand_num 30 | or idx == (len(arr_with_probs) - 1)): # need the 31 | # letterIdx==(len(row)-1) clause because of potential floating point errors 32 | # that mean arrWithProbs doesn't sum to 1 33 | return idx 34 | 35 | 36 | def get_func_to_draw_label_proportions(test_labels): 37 | test_class_to_indices = defaultdict(list) 38 | for index,row in enumerate(test_labels): 39 | row_label = np.argmax(row) 40 | test_class_to_indices[row_label].append(index) 41 | def draw_test_indices(total_to_return, label_proportions, rng): 42 | indices_to_use = [] 43 | for class_index, class_proportion in enumerate(label_proportions): 44 | indices_to_use.extend(rng.choice( 45 | test_class_to_indices[class_index], 46 | int(total_to_return*class_proportion), 47 | replace=True)) 48 | for i in range(total_to_return-len(indices_to_use)): 49 | class_index = sample_from_probs_arr(label_proportions, rng) 50 | indices_to_use.append( 51 | rng.choice(test_class_to_indices[class_index])) 52 | return indices_to_use 53 | return draw_test_indices 54 | 55 | 56 | def run_calibmethods(valid_preacts, valid_labels, 57 | test_preacts, test_labels, calibname_to_calibfactory, 58 | samplesize, 59 | samplesizesseen, 60 | metric_to_samplesize_to_calibname_to_unshiftedvals): 61 | calibname_to_calibfunc = {} 62 | calibname_to_calibvalidpreds = {} 63 | for calibname, calibfactory in\ 64 | calibname_to_calibfactory.items(): 65 | calibfunc = calibfactory( 66 | valid_preacts=valid_preacts, 67 | valid_labels=valid_labels) 68 | 69 | unshifted_test_preds = calibfunc(test_preacts) 70 | unshifted_test_nll = -np.mean( 71 | np.sum(np.log(unshifted_test_preds) 72 | *test_labels, axis=-1)) 73 | unshifted_test_ece = abstention.calibration.compute_ece( 74 | softmax_out=unshifted_test_preds, 75 | labels=test_labels, bins=15) 76 | unshifted_test_jsdiv =\ 77 | scipy.spatial.distance.jensenshannon( 78 | p=np.mean(unshifted_test_preds, axis=0), 79 | q=np.mean(test_labels, axis=0)) 80 | 81 | #if statement is there to avoid double-counting 82 | if (samplesize not in samplesizesseen): 83 | metric_to_samplesize_to_calibname_to_unshiftedvals[ 84 | 'ece'][samplesize][calibname].append(unshifted_test_ece) 85 | metric_to_samplesize_to_calibname_to_unshiftedvals[ 86 | 'nll'][samplesize][calibname].append(unshifted_test_nll) 87 | metric_to_samplesize_to_calibname_to_unshiftedvals[ 88 | 'jsdiv'][samplesize][calibname].append(unshifted_test_jsdiv) 89 | calibname_to_calibfunc[calibname] = calibfunc 90 | calibname_to_calibvalidpreds[calibname] = calibfunc(valid_preacts) 91 | 92 | return (calibname_to_calibfunc, calibname_to_calibvalidpreds) 93 | 94 | 95 | def run_experiments(num_trials, seeds, alphas_and_samplesize, 96 | shifttype, 97 | calibname_to_calibfactory, 98 | imbalanceadaptername_to_imbalanceadapter, 99 | adaptncalib_pairs, 100 | validglobprefix, 101 | testglobprefix, 102 | valid_labels, test_labels): 103 | 104 | draw_test_indices = get_func_to_draw_label_proportions(test_labels) 105 | 106 | alpha_to_samplesize_to_adaptername_to_metric_to_vals =( 107 | defaultdict( 108 | lambda: defaultdict( 109 | lambda: defaultdict( 110 | lambda: defaultdict(list))))) 111 | alpha_to_samplesize_to_baselineacc = defaultdict( 112 | lambda: defaultdict(list)) 113 | metric_to_samplesize_to_calibname_to_unshiftedvals = defaultdict( 114 | lambda: defaultdict(lambda: defaultdict(list))) 115 | 116 | samplesizesseen = set() 117 | for (alpha,samplesize) in alphas_and_samplesize: 118 | for seed in seeds: 119 | print("Seed",seed) 120 | for trial_num in range(num_trials): 121 | rng = np.random.RandomState(seed*num_trials + trial_num) 122 | test_preacts = read_preds( 123 | open(glob.glob(testglobprefix+str(seed)+"*.txt")[0])) 124 | valid_preacts = read_preds( 125 | open(glob.glob(validglobprefix+str(seed)+"*.txt")[0])) 126 | #let's also sample different validation sets 127 | # according to the random seed AND the trialnum 128 | sample_valid_indices = rng.choice( 129 | a=np.arange(len(valid_preacts)), 130 | size=samplesize, replace=False) 131 | sample_valid_preacts = valid_preacts[sample_valid_indices] 132 | sample_valid_labels = valid_labels[sample_valid_indices] 133 | 134 | (calibname_to_calibfunc, 135 | calibname_to_calibvalidpreds) = ( 136 | run_calibmethods( 137 | valid_preacts=sample_valid_preacts, 138 | valid_labels=sample_valid_labels, 139 | test_preacts=test_preacts, 140 | test_labels=test_labels, 141 | calibname_to_calibfactory=calibname_to_calibfactory, 142 | samplesize=samplesize, 143 | samplesizesseen=samplesizesseen, 144 | metric_to_samplesize_to_calibname_to_unshiftedvals= 145 | metric_to_samplesize_to_calibname_to_unshiftedvals)) 146 | 147 | #note the calibration method that did the best according to 148 | #each metric, and save it 149 | for metricname in metric_to_samplesize_to_calibname_to_unshiftedvals: 150 | calibname_to_unshiftedvals = metric_to_samplesize_to_calibname_to_unshiftedvals[metricname][samplesize] 151 | best_calibname = min(list(calibname_to_unshiftedvals.keys()), 152 | key=lambda x: calibname_to_unshiftedvals[x][-1]) 153 | calibname_to_calibfunc['best-'+metricname] = calibname_to_calibfunc[best_calibname] 154 | calibname_to_calibvalidpreds['best-'+metricname] = calibname_to_calibvalidpreds[best_calibname] 155 | 156 | if (shifttype=='dirichlet'): 157 | altered_class_priors = rng.dirichlet([ 158 | alpha for x in range(10)]) 159 | elif (shifttype=='tweakone'): 160 | altered_class_priors = np.full((10), (1.0-alpha)/9) 161 | altered_class_priors[3] = alpha 162 | else: 163 | raise RuntimeError("Unsupported shift type",shifttype) 164 | 165 | test_indices = draw_test_indices( 166 | total_to_return=samplesize, 167 | label_proportions=altered_class_priors, 168 | rng=rng) 169 | shifted_test_labels = test_labels[test_indices] 170 | shifted_test_preacts = test_preacts[test_indices] 171 | 172 | calibname_to_calibshiftedtestpreds = {} 173 | for (calibname, calibfunc) in calibname_to_calibfunc.items(): 174 | calibname_to_calibshiftedtestpreds[calibname] =( 175 | calibfunc(shifted_test_preacts)) 176 | 177 | shifted_test_baseline_accuracy = np.mean( 178 | np.argmax(shifted_test_labels,axis=-1)== 179 | np.argmax(abstention.calibration.softmax( 180 | preact=shifted_test_preacts, 181 | temp=1.0, biases=None),axis=-1)) 182 | alpha_to_samplesize_to_baselineacc[alpha][samplesize].append( 183 | shifted_test_baseline_accuracy) 184 | 185 | ideal_shift_weights = (np.mean(shifted_test_labels,axis=0)/ 186 | np.mean(sample_valid_labels,axis=0)) 187 | true_shifted_priors = np.mean(shifted_test_labels, axis=0) 188 | for adapter_name,calib_name in adaptncalib_pairs: 189 | calib_shifted_test_preds =\ 190 | calibname_to_calibshiftedtestpreds[calib_name] 191 | calib_valid_preds = calibname_to_calibvalidpreds[ 192 | calib_name] 193 | imbalance_adapter =\ 194 | imbalanceadaptername_to_imbalanceadapter[adapter_name] 195 | imbalance_adapter_func = imbalance_adapter( 196 | valid_labels=sample_valid_labels, 197 | tofit_initial_posterior_probs=calib_shifted_test_preds, 198 | valid_posterior_probs=calib_valid_preds) 199 | shift_weights = imbalance_adapter_func.multipliers 200 | unnormed_estimshiftedpriors = np.mean( 201 | sample_valid_labels, axis=0)*shift_weights 202 | estim_shifted_priors = (unnormed_estimshiftedpriors/ 203 | np.sum(unnormed_estimshiftedpriors*shift_weights)) 204 | adapted_shifted_test_preds = imbalance_adapter_func( 205 | calib_shifted_test_preds) 206 | adapted_shifted_test_accuracy = np.mean( 207 | np.argmax(shifted_test_labels,axis=-1)== 208 | np.argmax(adapted_shifted_test_preds,axis=-1)) 209 | delta_from_baseline = (adapted_shifted_test_accuracy 210 | -shifted_test_baseline_accuracy) 211 | #expected value of mse weights; weighted by the class 212 | # proportions in the test set 213 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 214 | alpha][samplesize][adapter_name+":"+calib_name][ 215 | 'mseweights_testsetprop'].append( 216 | np.sum(true_shifted_priors*( 217 | np.square(ideal_shift_weights-shift_weights)))) 218 | #mse weights but each class weighted evenly 219 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 220 | alpha][samplesize][adapter_name+":"+calib_name][ 221 | 'mseweights_even'].append( 222 | np.mean(np.square( 223 | ideal_shift_weights-shift_weights))) 224 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 225 | alpha][samplesize][adapter_name+":"+calib_name][ 226 | 'jsdiv'].append( 227 | scipy.spatial.distance.jensenshannon( 228 | p=true_shifted_priors, q=estim_shifted_priors)) 229 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 230 | alpha][samplesize][adapter_name+":"+calib_name][ 231 | 'delta_acc'].append(delta_from_baseline) 232 | 233 | if (samplesize not in samplesizesseen): 234 | print("Calibration stats") 235 | for metric in ['ece', 'nll', 'jsdiv']: 236 | print("Metric",metric) 237 | for calibname in calibname_to_calibfactory: 238 | print(calibname, np.mean( 239 | metric_to_samplesize_to_calibname_to_unshiftedvals[ 240 | metric][samplesize][calibname])) 241 | samplesizesseen.add(samplesize) 242 | 243 | print("On alpha",alpha,"sample size", samplesize) 244 | for metric_name in ['delta_acc', 'jsdiv', 245 | 'mseweights_testsetprop', 246 | 'mseweights_even']: 247 | print("Metric",metric_name) 248 | for adapter_name,calib_name in adaptncalib_pairs: 249 | adaptncalib_name = adapter_name+":"+calib_name 250 | n = len(alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 251 | alpha][samplesize][adaptncalib_name][metric_name]) 252 | 253 | print(adaptncalib_name, np.median( 254 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 255 | alpha][samplesize][adaptncalib_name][metric_name]), "+/-", 256 | (1.0/np.sqrt(n))*np.std( 257 | alpha_to_samplesize_to_adaptername_to_metric_to_vals[ 258 | alpha][samplesize][adaptncalib_name][metric_name], 259 | ddof=1)) 260 | sys.stdout.flush() 261 | 262 | return (alpha_to_samplesize_to_adaptername_to_metric_to_vals, 263 | alpha_to_samplesize_to_baselineacc, 264 | metric_to_samplesize_to_calibname_to_unshiftedvals) 265 | 266 | -------------------------------------------------------------------------------- /labelshiftexperiments/maketable.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from abstention.figure_making_utils import ( 4 | wilcox_srs, get_ustats_mat, 5 | get_top_method_indices) 6 | 7 | def get_methodname_to_ranks(methodname_to_vals, methodnames, sortsign): 8 | methodname_to_ranks = defaultdict(list) 9 | for i in range(len(methodname_to_vals[methodnames[0]])): 10 | methodname_and_val = [ 11 | (x, methodname_to_vals[x][i]) for x in methodnames] 12 | rank_and_methodnameandval = enumerate( 13 | sorted(methodname_and_val, key=lambda x: sortsign*x[1])) 14 | methodname_and_rank = [(x[1][0], x[0]) 15 | for x in rank_and_methodnameandval] 16 | for methodname, rank in methodname_and_rank: 17 | methodname_to_ranks[methodname].append(rank) 18 | return methodname_to_ranks 19 | 20 | 21 | def stderr(vals): 22 | return (1.0/np.sqrt(len(vals)))*np.std(vals, ddof=1) 23 | 24 | 25 | def render_calibration_table( 26 | metric_to_samplesize_to_calibname_to_unshiftedvals, 27 | ustat_threshold, metrics_in_table, 28 | samplesizes_in_table, calibnames_in_table, 29 | metricname_to_nicename, calibname_to_nicename, caption, label, 30 | applyunderline, 31 | decimals=3): 32 | 33 | metric_to_samplesize_to_calibname_to_ranks = defaultdict( 34 | lambda: defaultdict(lambda: defaultdict(list))) 35 | metric_to_samplesize_to_bestmethods = defaultdict(lambda: dict()) 36 | metric_to_samplesize_to_toprankedmethod = defaultdict(lambda: dict()) 37 | 38 | for metricname in metrics_in_table: 39 | for samplesize in samplesizes_in_table: 40 | methodname_to_vals =\ 41 | metric_to_samplesize_to_calibname_to_unshiftedvals[metricname][samplesize] 42 | methodname_and_avgvals = [ 43 | (methodname, np.median(methodname_to_vals[methodname])) 44 | for methodname in calibnames_in_table] 45 | toprankedmethod = ( 46 | min(methodname_and_avgvals, key=lambda x: x[1])[0] 47 | if applyunderline else None) 48 | ustats_mat = get_ustats_mat( 49 | method_to_perfs=methodname_to_vals, 50 | method_names=calibnames_in_table) 51 | tied_top_methods = ( 52 | get_top_method_indices( 53 | sorting_metric_vals=[x[1] for x in methodname_and_avgvals], 54 | ustats_mat=ustats_mat, 55 | threshold=ustat_threshold, 56 | largerisbetter=False)) 57 | metric_to_samplesize_to_bestmethods[metricname][samplesize] = ( 58 | [calibnames_in_table[x] for x in tied_top_methods]) 59 | metric_to_samplesize_to_calibname_to_ranks[ 60 | metricname][samplesize] = ( 61 | get_methodname_to_ranks(methodname_to_vals=methodname_to_vals, 62 | methodnames=calibnames_in_table, 63 | sortsign=1)) 64 | metric_to_samplesize_to_toprankedmethod[ 65 | metricname][samplesize] = toprankedmethod 66 | 67 | toprint = (""" 68 | \\begin{table*} 69 | \\adjustbox{max width=\\textwidth}{ 70 | \\centering 71 | \\begin{tabular}{ c | """+" | ".join([" ".join(["c" for samplesize in samplesizes_in_table]) 72 | for metricname in metrics_in_table])+""" } 73 | \\multirow{2}{*}{\\begin{tabular}{c}\\textbf{Calibration} \\\\ \\textbf{Method} \\end{tabular}} & """ 74 | +(" & ".join(["\\multicolumn{"+str(len(samplesizes_in_table))+"}{| c}{"+metricname_to_nicename[metricname]+"}" 75 | for metricname in metrics_in_table]))+"""\\\\ 76 | \cline{2-"""+str(1+len(metrics_in_table)*len(samplesizes_in_table))+"""} 77 | & """+(" & ".join([" & ".join(["$n$="+str(samplesize) for samplesize in samplesizes_in_table]) 78 | for metricname in metrics_in_table]))+"\\\\\n \hline\n "+ 79 | "\n ".join([ 80 | calibname_to_nicename[calibname]+" & "+(" & ".join([ 81 | ("\\textbf{" if calibname in metric_to_samplesize_to_bestmethods[metricname][samplesize] else "") 82 | +("\\underline{" if calibname==metric_to_samplesize_to_toprankedmethod[metricname][samplesize] else "") 83 | +str(np.round(np.median(metric_to_samplesize_to_calibname_to_unshiftedvals[metricname][samplesize][calibname]), decimals=decimals)) 84 | #+" +/- " 85 | #+str(np.round(stderr(metric_to_samplesize_to_calibname_to_unshiftedvals[metricname][samplesize][calibname]), decimals=decimals)) 86 | +"; " 87 | +str(np.round(np.median(metric_to_samplesize_to_calibname_to_ranks[metricname][samplesize][calibname]), decimals=decimals)) 88 | #+" +/-" 89 | #+str(np.round(stderr(metric_to_samplesize_to_calibname_to_ranks[metricname][samplesize][calibname]), decimals=decimals)) 90 | +("}" if calibname==metric_to_samplesize_to_toprankedmethod[metricname][samplesize] else "") 91 | +("}" if calibname in metric_to_samplesize_to_bestmethods[metricname][samplesize] else "") 92 | for metricname in metrics_in_table for samplesize in samplesizes_in_table 93 | ]))+"\\\\" 94 | for calibname in calibnames_in_table 95 | ]) 96 | +""" 97 | \\end{tabular}} 98 | \\caption{"""+caption+"""} 99 | \\label{tab:"""+label+"""} 100 | \\end{table*} 101 | """) 102 | return toprint 103 | 104 | 105 | def render_adaptation_table( 106 | alpha_to_samplesize_to_adaptncalib_to_metric_to_vals, 107 | ustat_threshold, 108 | valmultiplier, 109 | adaptname_to_nicename, calibname_to_nicename, 110 | methodgroups, metric, largerisbetter, 111 | alphas_in_table, samplesizes_in_table, caption, label, 112 | applyunderline, 113 | symbol='\\alpha', 114 | decimals=3): 115 | 116 | methodgroupname_to_alpha_to_samplesize_to_bestmethods =\ 117 | defaultdict(lambda: defaultdict(lambda: {})) 118 | methodgroupname_to_alpha_to_samplesize_to_toprankedmethod =\ 119 | defaultdict(lambda: defaultdict(lambda: {})) 120 | methodgroupname_to_alpha_to_samplesize_to_methodname_to_ranks =\ 121 | defaultdict(lambda: defaultdict(lambda: {})) 122 | 123 | for methodgroupname in methodgroups: 124 | for alpha in alphas_in_table: 125 | for samplesize in samplesizes_in_table: 126 | methodname_to_vals = dict( 127 | [(methodname, 128 | alpha_to_samplesize_to_adaptncalib_to_metric_to_vals[ 129 | alpha][samplesize][methodname][metric]) 130 | for methodname in methodgroups[methodgroupname]]) 131 | methodname_and_avgvals = [ 132 | (methodname, np.median(methodname_to_vals[methodname])) 133 | for methodname in methodgroups[methodgroupname]] 134 | toprankedmethod = min(methodname_and_avgvals, 135 | key=lambda x: (-1 if largerisbetter else 1)*x[1])[0] 136 | ustats_mat = get_ustats_mat( 137 | method_to_perfs=methodname_to_vals, 138 | method_names=methodgroups[methodgroupname]) 139 | tied_top_methods = ( 140 | get_top_method_indices( 141 | sorting_metric_vals=[x[1] for x in methodname_and_avgvals], 142 | ustats_mat=ustats_mat, 143 | threshold=ustat_threshold, 144 | largerisbetter=largerisbetter)) 145 | 146 | methodgroupname_to_alpha_to_samplesize_to_bestmethods[ 147 | methodgroupname][alpha][samplesize] = ( 148 | [methodgroups[methodgroupname][x] for x in tied_top_methods]) 149 | methodgroupname_to_alpha_to_samplesize_to_toprankedmethod[ 150 | methodgroupname][alpha][samplesize] = ( 151 | toprankedmethod if applyunderline else None) 152 | methodgroupname_to_alpha_to_samplesize_to_methodname_to_ranks[ 153 | methodgroupname][alpha][samplesize] = ( 154 | get_methodname_to_ranks( 155 | methodname_to_vals=methodname_to_vals, 156 | methodnames=methodgroups[methodgroupname], 157 | sortsign=(-1 if largerisbetter else 1))) 158 | 159 | toprint = (""" 160 | \\begin{table*} 161 | \\adjustbox{max width=\\textwidth}{ 162 | \\centering 163 | \\begin{tabular}{ c | c | """+(" | ".join([ " ".join( ["c" for samplesize in samplesizes_in_table ] ) for alpha in alphas_in_table]))+"}\n") 164 | toprint += (" \\multirow{2}{*}{\\begin{tabular}{c}\\textbf{Shift} \\\\ \\textbf{Estimator} \\end{tabular}}" 165 | +" & \\multirow{2}{*}{\\begin{tabular}{c}\\textbf{Calibration} \\\\ \\textbf{Method} \\end{tabular}} & " 166 | +((" & ".join(["\\multicolumn{"+str(len(samplesizes_in_table))+"}{| c}{$"+symbol+"="+str(alpha)+"$}" 167 | for alpha in alphas_in_table]))+"\\\\ \n") 168 | +" \\cline{3-"+str(2+len(alphas_in_table)*len(samplesizes_in_table))+"}\n" 169 | +" & & "+(" & ".join([" & ".join(["$n$="+str(samplesize) for samplesize in samplesizes_in_table]) 170 | for alpha in alphas_in_table]))+"\\\\") 171 | #toprint += " \\hline \\hline" 172 | for methodgroupnum, methodgroupname in enumerate(methodgroups.keys()): 173 | #if (methodgroupnum > 0): 174 | toprint += "\n \\hline\n \\hline" 175 | for adaptncalib in methodgroups[methodgroupname]: 176 | adaptname = adaptncalib.split(":")[0] 177 | calibname = adaptncalib.split(":")[1] 178 | toprint += "\n " 179 | toprint += adaptname_to_nicename[adaptname] 180 | toprint += " & "+calibname_to_nicename[calibname] 181 | toprint += " & " 182 | toprint += " & ".join([ 183 | ("\\textbf{" if adaptncalib in methodgroupname_to_alpha_to_samplesize_to_bestmethods[methodgroupname][alpha][samplesize] else "") 184 | +("\\underline{" if adaptncalib==methodgroupname_to_alpha_to_samplesize_to_toprankedmethod[methodgroupname][alpha][samplesize] else "") 185 | +str(np.round(valmultiplier*np.median(alpha_to_samplesize_to_adaptncalib_to_metric_to_vals[alpha][samplesize][adaptncalib][metric]), decimals=decimals)) 186 | #+" +/- " 187 | #+str(np.round(stderr(alpha_to_samplesize_to_adaptncalib_to_metric_to_vals[alpha][samplesize][adaptncalib][metric]), decimals=decimals)) 188 | +"; " 189 | +str(np.round(np.median(methodgroupname_to_alpha_to_samplesize_to_methodname_to_ranks[methodgroupname][alpha][samplesize][adaptncalib]), decimals=decimals)) 190 | #+" +/-" 191 | #+str(np.round(stderr(methodgroupname_to_alpha_to_samplesize_to_methodname_to_ranks[methodgroupname][alpha][samplesize][adaptncalib]), decimals=decimals)) 192 | +("}" if adaptncalib==methodgroupname_to_alpha_to_samplesize_to_toprankedmethod[methodgroupname][alpha][samplesize] else "") 193 | +("}" if adaptncalib in methodgroupname_to_alpha_to_samplesize_to_bestmethods[methodgroupname][alpha][samplesize] else "") 194 | for alpha in alphas_in_table for samplesize in samplesizes_in_table ]) 195 | toprint += "\\\\" 196 | 197 | toprint += """ 198 | \\end{tabular}} 199 | \\caption{"""+caption+"""} 200 | \\label{tab:"""+label+"""} 201 | \\end{table*} 202 | """ 203 | return toprint 204 | -------------------------------------------------------------------------------- /notebooks/README.txt: -------------------------------------------------------------------------------- 1 | The notebooks in this folder replicate the results starting from saved model predictions that were uploaded to zenodo. See the code under the folder obtaining_predictions for how the models/predictions were obtained. The .json.gz files save the results of the label shift adaptation for ease of regenerating the final tables. 2 | -------------------------------------------------------------------------------- /notebooks/blog_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kundajelab/labelshiftexperiments/blob/master/notebooks/blog_colab.ipynb)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "--2020-11-20 17:23:44-- https://zenodo.org/record/3406662/files/test_labels.txt.gz?download?=1\n", 20 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 21 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 22 | "HTTP request sent, awaiting response... 200 OK\n", 23 | "Length: 6001 (5.9K) [application/octet-stream]\n", 24 | "Saving to: ‘test_labels.txt.gz’\n", 25 | "\n", 26 | "test_labels.txt.gz 100%[===================>] 5.86K --.-KB/s in 0s \n", 27 | "\n", 28 | "2020-11-20 17:23:45 (498 MB/s) - ‘test_labels.txt.gz’ saved [6001/6001]\n", 29 | "\n", 30 | "--2020-11-20 17:23:45-- https://zenodo.org/record/3406662/files/valid_labels.txt.gz?download?=1\n", 31 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 32 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 33 | "HTTP request sent, awaiting response... 200 OK\n", 34 | "Length: 5969 (5.8K) [application/octet-stream]\n", 35 | "Saving to: ‘valid_labels.txt.gz’\n", 36 | "\n", 37 | "valid_labels.txt.gz 100%[===================>] 5.83K --.-KB/s in 0s \n", 38 | "\n", 39 | "2020-11-20 17:23:46 (531 MB/s) - ‘valid_labels.txt.gz’ saved [5969/5969]\n", 40 | "\n", 41 | "--2020-11-20 17:23:46-- https://zenodo.org/record/3406662/files/testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1\n", 42 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 43 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 44 | "HTTP request sent, awaiting response... 200 OK\n", 45 | "Length: 477575 (466K) [application/octet-stream]\n", 46 | "Saving to: ‘testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz’\n", 47 | "\n", 48 | "testpreacts_model_c 100%[===================>] 466.38K 594KB/s in 0.8s \n", 49 | "\n", 50 | "2020-11-20 17:23:48 (594 KB/s) - ‘testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz’ saved [477575/477575]\n", 51 | "\n", 52 | "--2020-11-20 17:23:48-- https://zenodo.org/record/3406662/files/validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1\n", 53 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 54 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 55 | "HTTP request sent, awaiting response... 200 OK\n", 56 | "Length: 477745 (467K) [application/octet-stream]\n", 57 | "Saving to: ‘validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz’\n", 58 | "\n", 59 | "validpreacts_model_ 100%[===================>] 466.55K 592KB/s in 0.8s \n", 60 | "\n", 61 | "2020-11-20 17:23:50 (592 KB/s) - ‘validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz’ saved [477745/477745]\n", 62 | "\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "!wget https://zenodo.org/record/3406662/files/test_labels.txt.gz?download?=1 -O test_labels.txt.gz\n", 68 | "!wget https://zenodo.org/record/3406662/files/valid_labels.txt.gz?download?=1 -O valid_labels.txt.gz\n", 69 | "!wget https://zenodo.org/record/3406662/files/testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\n", 70 | "!wget https://zenodo.org/record/3406662/files/validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "import gzip\n", 80 | "import glob\n", 81 | "import numpy as np\n", 82 | "from collections import defaultdict\n", 83 | "from abstention.calibration import softmax, TempScaling\n", 84 | "from abstention.label_shift import EMImbalanceAdapter" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def read_labels(fh):\n", 94 | " to_return = []\n", 95 | " for line in fh:\n", 96 | " the_class=int(line.rstrip())\n", 97 | " to_add = np.zeros(10)\n", 98 | " to_add[the_class] = 1\n", 99 | " to_return.append(to_add)\n", 100 | " return np.array(to_return)\n", 101 | "\n", 102 | "test_labels = read_labels(gzip.open(glob.glob(\"test_labels.txt.gz\")[0]))\n", 103 | "valid_labels = read_labels(gzip.open(glob.glob(\"valid_labels.txt.gz\")[0]))\n", 104 | "\n", 105 | "def read_preds(fh):\n", 106 | " return np.array([[float(x) for x in y.decode(\"utf-8\").rstrip().split(\"\\t\")]\n", 107 | " for y in fh])\n", 108 | "\n", 109 | "test_preds = softmax(preact=read_preds(gzip.open(glob.glob(\"testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\")[0])),\n", 110 | " temp=1, biases=None)\n", 111 | "valid_preds = softmax(preact=read_preds(gzip.open(glob.glob(\"validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\")[0])),\n", 112 | " temp=1, biases=None)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "def sample_from_probs_arr(arr_with_probs):\n", 122 | " rand_num = np.random.random()\n", 123 | " cdf_so_far = 0\n", 124 | " for (idx, prob) in enumerate(arr_with_probs):\n", 125 | " cdf_so_far += prob\n", 126 | " if (cdf_so_far >= rand_num\n", 127 | " or idx == (len(arr_with_probs) - 1)): # need the\n", 128 | " # letterIdx==(len(row)-1) clause because of potential floating point errors\n", 129 | " # that mean arrWithProbs doesn't sum to 1\n", 130 | " return idx\n", 131 | " \n", 132 | "test_class_to_indices = defaultdict(list)\n", 133 | "for index,row in enumerate(test_labels):\n", 134 | " row_label = np.argmax(row)\n", 135 | " test_class_to_indices[row_label].append(index)\n", 136 | "\n", 137 | "def draw_test_indices(total_to_return, label_proportions):\n", 138 | " indices_to_use = []\n", 139 | " for class_index, class_proportion in enumerate(label_proportions):\n", 140 | " indices_to_use.extend(np.random.choice(\n", 141 | " test_class_to_indices[class_index],\n", 142 | " int(total_to_return*class_proportion),\n", 143 | " replace=True))\n", 144 | " for i in range(total_to_return-len(indices_to_use)):\n", 145 | " class_index = sample_from_probs_arr(label_proportions)\n", 146 | " indices_to_use.append(\n", 147 | " np.random.choice(test_class_to_indices[class_index]))\n", 148 | " return indices_to_use\n", 149 | "\n", 150 | "dirichlet_alpha = 0.1\n", 151 | "samplesize = 1000\n", 152 | "dirichlet_dist = np.random.dirichlet([dirichlet_alpha for x in range(10)])\n", 153 | "test_indices = draw_test_indices(total_to_return=samplesize,\n", 154 | " label_proportions=dirichlet_dist)\n", 155 | "shifted_test_labels = test_labels[test_indices]\n", 156 | "shifted_test_preds = test_preds[test_indices]" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 5, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "imbalance_adapter = EMImbalanceAdapter(calibrator_factory =\n", 166 | " TempScaling(verbose=False, bias_positions='all'))\n", 167 | "\n", 168 | "imbalance_adapter_func = imbalance_adapter(valid_labels=valid_labels,\n", 169 | " tofit_initial_posterior_probs=shifted_test_preds,\n", 170 | " valid_posterior_probs=valid_preds)\n", 171 | "\n", 172 | "adapted_shifted_test_preds = imbalance_adapter_func(shifted_test_preds)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 6, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "0.903 0.961\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(shifted_test_preds,axis=-1))\n", 190 | "adapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(adapted_shifted_test_preds,axis=-1))\n", 191 | "print(test_accuracy, adapted_test_accuracy)" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "Python [conda env:basepair]", 198 | "language": "python", 199 | "name": "conda-env-basepair-py" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.6.8" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 4 216 | } 217 | -------------------------------------------------------------------------------- /notebooks/cifar100_label_shift_adaptation_results.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/cifar100_label_shift_adaptation_results.json.gz -------------------------------------------------------------------------------- /notebooks/cifar10_label_shift_adaptation_results.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/cifar10_label_shift_adaptation_results.json.gz -------------------------------------------------------------------------------- /notebooks/demo/blog_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "colab": { 10 | "name": "blog_colab.ipynb", 11 | "provenance": [], 12 | "toc_visible": true, 13 | "include_colab_link": true 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "MEGggxiZWf1R" 31 | }, 32 | "source": [ 33 | "## Maximum Likelihood + Bias-Corrected Temperature Scaling\n", 34 | "\n", 35 | "This notebook demonstrates how to perform label shift domain adaptation using " 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "id": "WDSQi_3fzlnj" 42 | }, 43 | "source": [ 44 | "### Setup\n", 45 | "\n", 46 | "Download the datasets" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "lFPf85JSWOfc", 53 | "colab": { 54 | "base_uri": "https://localhost:8080/" 55 | }, 56 | "outputId": "bf526096-d0c5-4a62-f3f3-737dd3bdb12e" 57 | }, 58 | "source": [ 59 | "!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz -O demo_valid_labels.txt.gz\n", 60 | "!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_preds.txt.gz -O demo_shifted_test_preds.txt.gz\n", 61 | "!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_preds.txt.gz -O demo_valid_preds.txt.gz" 62 | ], 63 | "execution_count": 1, 64 | "outputs": [ 65 | { 66 | "output_type": "stream", 67 | "text": [ 68 | "--2020-11-22 02:56:31-- https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz\n", 69 | "Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 70 | "Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.\n", 71 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 72 | "Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz [following]\n", 73 | "--2020-11-22 02:56:31-- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz\n", 74 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 75 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", 76 | "HTTP request sent, awaiting response... 200 OK\n", 77 | "Length: 5969 (5.8K) [application/octet-stream]\n", 78 | "Saving to: ‘demo_valid_labels.txt.gz’\n", 79 | "\n", 80 | "demo_valid_labels.t 100%[===================>] 5.83K --.-KB/s in 0s \n", 81 | "\n", 82 | "2020-11-22 02:56:31 (51.0 MB/s) - ‘demo_valid_labels.txt.gz’ saved [5969/5969]\n", 83 | "\n", 84 | "--2020-11-22 02:56:31-- https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_preds.txt.gz\n", 85 | "Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 86 | "Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.\n", 87 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 88 | "Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_preds.txt.gz [following]\n", 89 | "--2020-11-22 02:56:32-- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_preds.txt.gz\n", 90 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 91 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", 92 | "HTTP request sent, awaiting response... 200 OK\n", 93 | "Length: 85108 (83K) [application/octet-stream]\n", 94 | "Saving to: ‘demo_shifted_test_preds.txt.gz’\n", 95 | "\n", 96 | "demo_shifted_test_p 100%[===================>] 83.11K --.-KB/s in 0.02s \n", 97 | "\n", 98 | "2020-11-22 02:56:32 (3.85 MB/s) - ‘demo_shifted_test_preds.txt.gz’ saved [85108/85108]\n", 99 | "\n", 100 | "--2020-11-22 02:56:32-- https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_preds.txt.gz\n", 101 | "Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 102 | "Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.\n", 103 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 104 | "Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_preds.txt.gz [following]\n", 105 | "--2020-11-22 02:56:32-- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_preds.txt.gz\n", 106 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 107 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", 108 | "HTTP request sent, awaiting response... 200 OK\n", 109 | "Length: 959677 (937K) [application/octet-stream]\n", 110 | "Saving to: ‘demo_valid_preds.txt.gz’\n", 111 | "\n", 112 | "demo_valid_preds.tx 100%[===================>] 937.18K --.-KB/s in 0.06s \n", 113 | "\n", 114 | "2020-11-22 02:56:33 (16.4 MB/s) - ‘demo_valid_preds.txt.gz’ saved [959677/959677]\n", 115 | "\n" 116 | ], 117 | "name": "stdout" 118 | } 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "id": "toUqiI6MXvVh" 125 | }, 126 | "source": [ 127 | "Install the necessary package" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "metadata": { 133 | "colab": { 134 | "base_uri": "https://localhost:8080/" 135 | }, 136 | "id": "MTr1k2e-Xwym", 137 | "outputId": "e2d6bc2f-5e88-4d77-d574-0b1875efbe8a" 138 | }, 139 | "source": [ 140 | "!pip install abstention" 141 | ], 142 | "execution_count": 2, 143 | "outputs": [ 144 | { 145 | "output_type": "stream", 146 | "text": [ 147 | "Collecting abstention\n", 148 | " Downloading https://files.pythonhosted.org/packages/c2/cb/b9a4ef4a0efecf1ac74fc12a459f05d17dc76ebba9c9ee1c62b9d651bb18/abstention-0.1.3.1.tar.gz\n", 149 | "Requirement already satisfied: numpy>=1.9 in /usr/local/lib/python3.6/dist-packages (from abstention) (1.18.5)\n", 150 | "Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.6/dist-packages (from abstention) (0.22.2.post1)\n", 151 | "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from abstention) (1.4.1)\n", 152 | "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.20.0->abstention) (0.17.0)\n", 153 | "Building wheels for collected packages: abstention\n", 154 | " Building wheel for abstention (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 155 | " Created wheel for abstention: filename=abstention-0.1.3.1-cp36-none-any.whl size=25470 sha256=f16debecbfdee13d197c22ed52c9a349d2bb8822780917b58c1f7741e5f9de71\n", 156 | " Stored in directory: /root/.cache/pip/wheels/7c/a8/fc/5ddf92c0e5934d70543ea30142078287d911f01e75cffb808c\n", 157 | "Successfully built abstention\n", 158 | "Installing collected packages: abstention\n", 159 | "Successfully installed abstention-0.1.3.1\n" 160 | ], 161 | "name": "stdout" 162 | } 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "id": "Q2QqBu8Uy8D9" 169 | }, 170 | "source": [ 171 | "Import relevant modules and define functions for reading in the data" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "metadata": { 177 | "id": "n3U_XH8bWOfc" 178 | }, 179 | "source": [ 180 | "import gzip\n", 181 | "import numpy as np\n", 182 | "from collections import defaultdict\n", 183 | "from scipy.special import softmax\n", 184 | "\n", 185 | "def read_labels(fh):\n", 186 | " to_return = []\n", 187 | " for line in fh:\n", 188 | " the_class=int(line.rstrip())\n", 189 | " to_add = np.zeros(10)\n", 190 | " to_add[the_class] = 1\n", 191 | " to_return.append(to_add)\n", 192 | " return np.array(to_return)\n", 193 | "\n", 194 | "def read_preds(fh):\n", 195 | " return np.array([[float(x) for x in y.decode(\"utf-8\").rstrip().split(\"\\t\")]\n", 196 | " for y in fh])" 197 | ], 198 | "execution_count": 3, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": { 204 | "id": "OBEvqeu84Ll3" 205 | }, 206 | "source": [ 207 | "Read in the validation set predictions and labels, as well as the predictions on the (label shifted) test set" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "metadata": { 213 | "id": "tGqi1Xub4OS_" 214 | }, 215 | "source": [ 216 | "valid_labels = read_labels(gzip.open(\"demo_valid_labels.txt.gz\", \"rb\"))\n", 217 | "valid_preds = read_preds(gzip.open(\"demo_valid_preds.txt.gz\", \"rb\"))\n", 218 | "shifted_test_preds = read_preds(gzip.open(\"demo_shifted_test_preds.txt.gz\", \"rb\"))" 219 | ], 220 | "execution_count": 4, 221 | "outputs": [] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": { 226 | "id": "WvsUpsRHznpJ" 227 | }, 228 | "source": [ 229 | "### Perform label shift adaptation\n", 230 | "\n", 231 | "Apply Maximum Likelihood + BCTS" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "metadata": { 237 | "id": "QA8EnUvcWOfd" 238 | }, 239 | "source": [ 240 | "from abstention.calibration import TempScaling\n", 241 | "from abstention.label_shift import EMImbalanceAdapter\n", 242 | "\n", 243 | "#Instantiate the BCTS calibrator factory\n", 244 | "bcts_calibrator_factory = TempScaling(verbose=False, bias_positions='all')\n", 245 | "#Specify that we would like to use Maximum Likelihood (EM) for the\n", 246 | "# label shift adaptation, with BCTS for calibration\n", 247 | "imbalance_adapter = EMImbalanceAdapter(calibrator_factory=\n", 248 | " bcts_calibrator_factory)\n", 249 | "#Get the function that will do the label shift adaptation (creating this\n", 250 | "# function requires supplying the validation set labels/predictions as well as\n", 251 | "# the test-set predictions)\n", 252 | "imbalance_adapter_func = imbalance_adapter(valid_labels=valid_labels,\n", 253 | " tofit_initial_posterior_probs=shifted_test_preds,\n", 254 | " valid_posterior_probs=valid_preds)\n", 255 | "#Get the adapted test-set predictions\n", 256 | "adapted_shifted_test_preds = imbalance_adapter_func(shifted_test_preds)" 257 | ], 258 | "execution_count": 5, 259 | "outputs": [] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "xdVWdeSmz1RB" 265 | }, 266 | "source": [ 267 | "### Evaluation\n", 268 | "\n", 269 | "Download and read in the labels for the test set\n" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "metadata": { 275 | "id": "LtrgofVa0mbL", 276 | "colab": { 277 | "base_uri": "https://localhost:8080/" 278 | }, 279 | "outputId": "4f33a5d4-98af-44cb-c2d7-378b974e5d3d" 280 | }, 281 | "source": [ 282 | "!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz -O demo_shifted_test_labels.txt.gz\n", 283 | "\n", 284 | "shifted_test_labels = read_labels(gzip.open(\"demo_shifted_test_labels.txt.gz\", \"rb\"))" 285 | ], 286 | "execution_count": 6, 287 | "outputs": [ 288 | { 289 | "output_type": "stream", 290 | "text": [ 291 | "--2020-11-22 02:56:39-- https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz\n", 292 | "Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 293 | "Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.\n", 294 | "HTTP request sent, awaiting response... 301 Moved Permanently\n", 295 | "Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz [following]\n", 296 | "--2020-11-22 02:56:40-- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz\n", 297 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", 298 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", 299 | "HTTP request sent, awaiting response... 200 OK\n", 300 | "Length: 71 [application/octet-stream]\n", 301 | "Saving to: ‘demo_shifted_test_labels.txt.gz’\n", 302 | "\n", 303 | "demo_shifted_test_l 100%[===================>] 71 --.-KB/s in 0s \n", 304 | "\n", 305 | "2020-11-22 02:56:40 (3.35 MB/s) - ‘demo_shifted_test_labels.txt.gz’ saved [71/71]\n", 306 | "\n" 307 | ], 308 | "name": "stdout" 309 | } 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": { 315 | "id": "1l6uFuwf0pGr" 316 | }, 317 | "source": [ 318 | "Evaluate the improvement in performance due to domain adaptation" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "metadata": { 324 | "colab": { 325 | "base_uri": "https://localhost:8080/" 326 | }, 327 | "id": "4zsm221-WOfd", 328 | "outputId": "e4d9b95d-69f2-47ff-8201-d845c11a81a5" 329 | }, 330 | "source": [ 331 | "#Get the test set accuracy WITHOUT label shift adaptation\n", 332 | "unadapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(shifted_test_preds,axis=-1))\n", 333 | "#Get the test-set accuracy WITH label shift adaptation\n", 334 | "adapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(adapted_shifted_test_preds,axis=-1))\n", 335 | "\n", 336 | "print(\"Accuracy without label shift adaptation:\", unadapted_test_accuracy)\n", 337 | "print(\"Accuracy with label shift adaptation:\", adapted_test_accuracy)" 338 | ], 339 | "execution_count": 7, 340 | "outputs": [ 341 | { 342 | "output_type": "stream", 343 | "text": [ 344 | "Accuracy without label shift adaptation: 0.707\n", 345 | "Accuracy with label shift adaptation: 0.986\n" 346 | ], 347 | "name": "stdout" 348 | } 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "id": "ncoOivjZqMl4" 355 | }, 356 | "source": [ 357 | "## Misc\n", 358 | "\n", 359 | "This is the code that was used to generate the `demo_*` files\n", 360 | "\n", 361 | "```\n", 362 | "import gzip\n", 363 | "import glob\n", 364 | "import numpy as np\n", 365 | "from collections import defaultdict\n", 366 | "from scipy.special import softmax\n", 367 | "\n", 368 | "\n", 369 | "def sample_from_probs_arr(arr_with_probs):\n", 370 | " rand_num = np.random.random()\n", 371 | " cdf_so_far = 0\n", 372 | " for (idx, prob) in enumerate(arr_with_probs):\n", 373 | " cdf_so_far += prob\n", 374 | " if (cdf_so_far >= rand_num\n", 375 | " or idx == (len(arr_with_probs) - 1)): # need the\n", 376 | " # letterIdx==(len(row)-1) clause because of potential floating point errors\n", 377 | " # that mean arrWithProbs doesn't sum to 1\n", 378 | " return idx\n", 379 | "\n", 380 | "\n", 381 | "def draw_test_indices(total_to_return, label_proportions):\n", 382 | " indices_to_use = []\n", 383 | " for class_index, class_proportion in enumerate(label_proportions):\n", 384 | " indices_to_use.extend(np.random.choice(\n", 385 | " TEST_CLASS_TO_INDICES[class_index],\n", 386 | " int(total_to_return*class_proportion),\n", 387 | " replace=True))\n", 388 | " for i in range(total_to_return-len(indices_to_use)):\n", 389 | " class_index = sample_from_probs_arr(label_proportions)\n", 390 | " indices_to_use.append(\n", 391 | " np.random.choice(TEST_CLASS_TO_INDICES[class_index]))\n", 392 | " return indices_to_use\n", 393 | "\n", 394 | "\n", 395 | "def write_preds(preds, filename):\n", 396 | " f = open(filename,'w')\n", 397 | " for pred in preds:\n", 398 | " f.write(\"\\t\".join([str(x) for x in pred])+\"\\n\") \n", 399 | " f.close()\n", 400 | "\n", 401 | "\n", 402 | "def write_labels(labels, filename):\n", 403 | " f = open(filename,'w')\n", 404 | " f.write(\"\\n\".join([str(np.argmax(x, axis=-1)) for x in labels]))\n", 405 | " f.close()\n", 406 | "\n", 407 | "\n", 408 | "def read_labels(fh):\n", 409 | " to_return = []\n", 410 | " for line in fh:\n", 411 | " the_class=int(line.rstrip())\n", 412 | " to_add = np.zeros(10)\n", 413 | " to_add[the_class] = 1\n", 414 | " to_return.append(to_add)\n", 415 | " return np.array(to_return)\n", 416 | "\n", 417 | "\n", 418 | "def read_preds(fh):\n", 419 | " return np.array([[float(x) for x in y.decode(\"utf-8\").rstrip().split(\"\\t\")]\n", 420 | " for y in fh])\n", 421 | "\n", 422 | "\n", 423 | "!wget https://zenodo.org/record/3406662/files/test_labels.txt.gz?download?=1 -O test_labels.txt.gz\n", 424 | "!wget https://zenodo.org/record/3406662/files/testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\n", 425 | "!wget https://zenodo.org/record/3406662/files/validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\n", 426 | "!wget https://zenodo.org/record/3406662/files/valid_labels.txt.gz?download?=1 -O demo_valid_labels.txt.gz\n", 427 | "\n", 428 | "\n", 429 | "test_labels = read_labels(gzip.open(\"test_labels.txt.gz\"))\n", 430 | "test_preds = softmax(read_preds(gzip.open(\n", 431 | " \"testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\")),\n", 432 | " axis=1)\n", 433 | "valid_preds = softmax(read_preds(gzip.open(\n", 434 | " \"validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz\")),\n", 435 | " axis=1)\n", 436 | "\n", 437 | "\n", 438 | "dirichlet_alpha = 0.1\n", 439 | "samplesize = 1000\n", 440 | "dirichlet_dist = np.random.RandomState(123).dirichlet(\n", 441 | " [dirichlet_alpha for x in range(10)])\n", 442 | "\n", 443 | "TEST_CLASS_TO_INDICES = defaultdict(list)\n", 444 | "for index,row in enumerate(test_labels):\n", 445 | " row_label = np.argmax(row)\n", 446 | " TEST_CLASS_TO_INDICES[row_label].append(index)\n", 447 | "\n", 448 | "test_indices = draw_test_indices(total_to_return=samplesize,\n", 449 | " label_proportions=dirichlet_dist)\n", 450 | "shifted_test_labels = test_labels[test_indices]\n", 451 | "shifted_test_preds = test_preds[test_indices]\n", 452 | "\n", 453 | "write_preds(preds=valid_preds, filename=\"demo_valid_preds.txt\")\n", 454 | "write_preds(preds=shifted_test_preds, filename=\"demo_shifted_test_preds.txt\")\n", 455 | "write_labels(labels=shifted_test_labels, filename=\"demo_shifted_test_labels.txt\")\n", 456 | "!gzip -f *.txt\n", 457 | "```\n", 458 | "\n" 459 | ] 460 | } 461 | ] 462 | } -------------------------------------------------------------------------------- /notebooks/demo/demo_shifted_test_labels.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/demo/demo_shifted_test_labels.txt.gz -------------------------------------------------------------------------------- /notebooks/demo/demo_shifted_test_preds.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/demo/demo_shifted_test_preds.txt.gz -------------------------------------------------------------------------------- /notebooks/demo/demo_valid_labels.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/demo/demo_valid_labels.txt.gz -------------------------------------------------------------------------------- /notebooks/demo/demo_valid_preds.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/demo/demo_valid_preds.txt.gz -------------------------------------------------------------------------------- /notebooks/kaggledr_label_shift_adaptation_results.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/kaggledr_label_shift_adaptation_results.json.gz -------------------------------------------------------------------------------- /notebooks/mnist_label_shift_adaptation_results.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/5524cb4674df1f8782fa5a12d1973c573b30a8e1/notebooks/mnist_label_shift_adaptation_results.json.gz -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/README.txt: -------------------------------------------------------------------------------- 1 | This has the code (or links to the code) for training the models and obtaining predictions. But note that the predictions themselves were saved and uploaded to zenodo so that the results of the experiments could be replicated. 2 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar10/Download_CIFAR10_models_from_zenodo_and_make_predictions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 173 10 | }, 11 | "colab_type": "code", 12 | "id": "mI8d9Ba259nS", 13 | "outputId": "c68357ef-f6e7-4b13-9ad5-e718d85f7a69" 14 | }, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "Collecting wget\n", 21 | " Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip\n", 22 | "Building wheels for collected packages: wget\n", 23 | " Building wheel for wget (setup.py) ... \u001b[?25ldone\n", 24 | "\u001b[?25h Stored in directory: /root/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f\n", 25 | "Successfully built wget\n", 26 | "Installing collected packages: wget\n", 27 | "Successfully installed wget-3.2\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "!pip install wget" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "colab": { 40 | "base_uri": "https://localhost:8080/", 41 | "height": 1975 42 | }, 43 | "colab_type": "code", 44 | "id": "Fv31piT55-7K", 45 | "outputId": "7c531c2f-fefe-4ff8-c9c8-b39e4693d2ca" 46 | }, 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "Using TensorFlow backend.\n" 53 | ] 54 | }, 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "keras version: 2.2.4\n", 60 | "tensorflow version: 1.13.1\n", 61 | "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n", 62 | "170500096/170498071 [==============================] - 18s 0us/step\n", 63 | "On model model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 64 | "Downloading model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 65 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 66 | "Instructions for updating:\n", 67 | "Colocations handled automatically by placer.\n", 68 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 69 | "Instructions for updating:\n", 70 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n" 71 | ] 72 | }, 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py:292: UserWarning: No training configuration found in save file: the model was *not* compiled. Compile it manually.\n", 78 | " warnings.warn('No training configuration found in save file: '\n", 79 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:43: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"co..., outputs=Tensor(\"de...)`\n" 80 | ] 81 | }, 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "Making predictions on validation set\n", 87 | "Making predictions on test set\n", 88 | "Test accuracy 0.8967\n", 89 | "Valid accuracy 0.9054\n", 90 | "Saving testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 91 | "a1e7f2dfef7be74264a695d50b47bad7 testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 92 | "Saving validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 93 | "78946084522fda04a0c1207caf8186b9 validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 94 | "On model model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 95 | "Downloading model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 96 | "Making predictions on validation set\n", 97 | "Making predictions on test set\n", 98 | "Test accuracy 0.9029\n", 99 | "Valid accuracy 0.9073\n", 100 | "Saving testpreacts_model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 101 | "55c13ffa492ecb6bd0492b031a338352 testpreacts_model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 102 | "Saving validpreacts_model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 103 | "594aa548712bcef8fbcdc1427a947f3f validpreacts_model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 104 | "On model model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 105 | "Downloading model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 106 | "Making predictions on validation set\n", 107 | "Making predictions on test set\n", 108 | "Test accuracy 0.9025\n", 109 | "Valid accuracy 0.9078\n", 110 | "Saving testpreacts_model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 111 | "2ca636af80d18f04237acc4c848e1ed0 testpreacts_model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 112 | "Saving validpreacts_model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 113 | "607fc7d6242556dae2b036e8fa7277f7 validpreacts_model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 114 | "On model model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 115 | "Downloading model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 116 | "Making predictions on validation set\n", 117 | "Making predictions on test set\n", 118 | "Test accuracy 0.9014\n", 119 | "Valid accuracy 0.9115\n", 120 | "Saving testpreacts_model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 121 | "485e99b4b70ee68290c8be231537c359 testpreacts_model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 122 | "Saving validpreacts_model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 123 | "e3cf0e20fd38ca6581f71094d949eec5 validpreacts_model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 124 | "On model model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 125 | "Downloading model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 126 | "Making predictions on validation set\n", 127 | "Making predictions on test set\n", 128 | "Test accuracy 0.9055\n", 129 | "Valid accuracy 0.9102\n", 130 | "Saving testpreacts_model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 131 | "3516263f2cd1c1ec6b46ca4734cd3317 testpreacts_model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 132 | "Saving validpreacts_model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 133 | "5c8bef1ea1170444859de1dec12f94a8 validpreacts_model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 134 | "On model model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 135 | "Downloading model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 136 | "Making predictions on validation set\n", 137 | "Making predictions on test set\n", 138 | "Test accuracy 0.9039\n", 139 | "Valid accuracy 0.9115\n", 140 | "Saving testpreacts_model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 141 | "7aeed6c0690d35f209eadd0ecd7bab93 testpreacts_model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 142 | "Saving validpreacts_model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 143 | "e2b98916f42df71bad93f0954e060944 validpreacts_model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 144 | "On model model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 145 | "Downloading model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 146 | "Making predictions on validation set\n", 147 | "Making predictions on test set\n", 148 | "Test accuracy 0.9013\n", 149 | "Valid accuracy 0.9096\n", 150 | "Saving testpreacts_model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 151 | "123364a7f734b6a37c75f38300f042db testpreacts_model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 152 | "Saving validpreacts_model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 153 | "e28edce2aa583ae93db363f88724900f validpreacts_model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 154 | "On model model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 155 | "Downloading model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 156 | "Making predictions on validation set\n", 157 | "Making predictions on test set\n", 158 | "Test accuracy 0.9011\n", 159 | "Valid accuracy 0.9061\n", 160 | "Saving testpreacts_model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 161 | "33f8d2f3ea8bed29dfa70f0175cb51f1 testpreacts_model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 162 | "Saving validpreacts_model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 163 | "941c58d0572f170be3a044e10e38a679 validpreacts_model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 164 | "On model model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.h5\n", 165 | "Downloading model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.h5\n", 166 | "Making predictions on validation set\n", 167 | "Making predictions on test set\n", 168 | "Test accuracy 0.9003\n", 169 | "Valid accuracy 0.903\n", 170 | "Saving testpreacts_model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.txt\n", 171 | "825080050787b804bfc53dcceb892f67 testpreacts_model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.txt\n", 172 | "Saving validpreacts_model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.txt\n", 173 | "be9481578bf4f53df6a5fc07c8951485 validpreacts_model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.txt\n", 174 | "On model model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 175 | "Downloading model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\n", 176 | "Making predictions on validation set\n", 177 | "Making predictions on test set\n", 178 | "Test accuracy 0.9068\n", 179 | "Valid accuracy 0.9119\n", 180 | "Saving testpreacts_model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 181 | "a9472bf70d0e78a53fb1969e8f0422a8 testpreacts_model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 182 | "Saving validpreacts_model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n", 183 | "5484e8bf745771966f8b5e01225de701 validpreacts_model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.txt\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "import keras\n", 189 | "from keras.models import load_model\n", 190 | "from keras.models import Model\n", 191 | "print(\"keras version:\", keras.__version__)\n", 192 | "import tensorflow as tf\n", 193 | "print(\"tensorflow version:\", tf.__version__)\n", 194 | "import wget\n", 195 | "import os\n", 196 | "import sys\n", 197 | "from keras.datasets import cifar10\n", 198 | "import numpy as np\n", 199 | "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", 200 | "\n", 201 | "mean = np.mean(x_train,axis=(0,1,2,3))\n", 202 | "std = np.std(x_train, axis=(0, 1, 2, 3))\n", 203 | "x_train = (x_train-mean)/(std+1e-7)\n", 204 | "x_test = (x_test-mean)/(std+1e-7)\n", 205 | "x_valid = x_train[:10000]\n", 206 | "y_valid = y_train[:10000]\n", 207 | "\n", 208 | "model_files = [\n", 209 | " \"model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 210 | " \"model_cifar10_balanced_seed-10_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 211 | " \"model_cifar10_balanced_seed-20_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 212 | " \"model_cifar10_balanced_seed-30_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 213 | " \"model_cifar10_balanced_seed-40_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 214 | " \"model_cifar10_balanced_seed-50_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 215 | " \"model_cifar10_balanced_seed-60_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 216 | " \"model_cifar10_balanced_seed-70_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\",\n", 217 | " \"model_cifar10_balanced_seed-80_bestbefore-100_currentepoch-100_valacc-90_vgg.h5\",\n", 218 | " \"model_cifar10_balanced_seed-90_bestbefore-100_currentepoch-100_valacc-91_vgg.h5\"\n", 219 | "]\n", 220 | "\n", 221 | "for model_file in model_files:\n", 222 | " print(\"On model\", model_file)\n", 223 | " if (os.path.isfile(model_file)==False):\n", 224 | " print(\"Downloading\", model_file)\n", 225 | " wget.download(\"https://zenodo.org/record/2648107/files/\"\n", 226 | " +model_file+\"?download=1\", out=model_file)\n", 227 | " model = load_model(model_file)\n", 228 | " \n", 229 | " pre_softmax_model = Model(input=model.input,\n", 230 | " output=model.layers[-2].output)\n", 231 | " print(\"Making predictions on validation set\")\n", 232 | " valid_preacts = pre_softmax_model.predict(x_valid)\n", 233 | " print(\"Making predictions on test set\")\n", 234 | " test_preacts = pre_softmax_model.predict(x_test)\n", 235 | " \n", 236 | " print(\"Test accuracy\",np.mean(np.argmax(test_preacts,axis=1)\n", 237 | " == np.squeeze(y_test)))\n", 238 | " print(\"Valid accuracy\",np.mean(np.argmax(valid_preacts,axis=1)\n", 239 | " == np.squeeze(y_valid)))\n", 240 | " sys.stdout.flush()\n", 241 | " test_predictions_file = (\"testpreacts_\"+model_file.split(\".\")[0])+\".txt\"\n", 242 | " print(\"Saving\", test_predictions_file)\n", 243 | " f = open(test_predictions_file,'w')\n", 244 | " for test_preact in test_preacts:\n", 245 | " f.write(\"\\t\".join([str(x) for x in test_preact])+\"\\n\") \n", 246 | " f.close()\n", 247 | " !md5sum $test_predictions_file\n", 248 | " !gzip $test_predictions_file\n", 249 | "\n", 250 | " valid_predictions_file = (\"validpreacts_\"+model_file.split(\".\")[0])+\".txt\"\n", 251 | " print(\"Saving\", valid_predictions_file)\n", 252 | " f = open(valid_predictions_file,'w')\n", 253 | " for valid_preact in valid_preacts:\n", 254 | " f.write(\"\\t\".join([str(x) for x in valid_preact])+\"\\n\") \n", 255 | " f.close()\n", 256 | " !md5sum $valid_predictions_file\n", 257 | " !gzip $valid_predictions_file\n", 258 | " " 259 | ] 260 | } 261 | ], 262 | "metadata": { 263 | "accelerator": "GPU", 264 | "colab": { 265 | "collapsed_sections": [], 266 | "include_colab_link": true, 267 | "name": "gist - Download CIFAR10 models from zenodo and make predictions.ipynb", 268 | "provenance": [], 269 | "version": "0.3.2" 270 | }, 271 | "kernelspec": { 272 | "display_name": "Python [default]", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.6.5" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 2 291 | } 292 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar10/README.txt: -------------------------------------------------------------------------------- 1 | Code used to train models was based on https://github.com/geifmany/selective_deep_learning 2 | 3 | Kundaje lab internal notes: 4 | Model training code at https://github.com/kundajelab/uncertainty_experiments/blob/master/cifar10/Train%20model.ipynb. Permalink: https://github.com/kundajelab/uncertainty_experiments/blob/bb16abb6cbb04877fa8b1a821a80ce8f5723b0d9/cifar10/Train%20model.ipynb 5 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar10/train_cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | from keras.models import load_model 4 | from keras.models import Model 5 | print("keras version:", keras.__version__) 6 | import tensorflow as tf 7 | print("tensorflow version:", tf.__version__) 8 | import os 9 | import sys 10 | import numpy as np 11 | from keras.datasets import cifar100 12 | from keras.preprocessing.image import ImageDataGenerator 13 | from keras.models import Sequential 14 | from keras.layers import Dense, Dropout, Activation, Flatten 15 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization 16 | from keras import optimizers 17 | from keras.layers.core import Lambda 18 | from keras import backend as K 19 | from keras import regularizers 20 | import random 21 | 22 | class cifar100vgg: 23 | def __init__(self,train=True): 24 | self.num_classes = 100 25 | self.weight_decay = 0.0005 26 | self.x_shape = [32,32,3] 27 | 28 | self.model = self.build_model() 29 | if train: 30 | self.model = self.train(self.model) 31 | else: 32 | self.model.load_weights('cifar100vgg.h5') 33 | 34 | 35 | def build_model(self): 36 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 37 | 38 | model = Sequential() 39 | weight_decay = self.weight_decay 40 | 41 | model.add(Conv2D(64, (3, 3), padding='same', 42 | input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay))) 43 | model.add(Activation('relu')) 44 | model.add(BatchNormalization()) 45 | model.add(Dropout(0.3)) 46 | 47 | model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 48 | model.add(Activation('relu')) 49 | model.add(BatchNormalization()) 50 | 51 | model.add(MaxPooling2D(pool_size=(2, 2))) 52 | 53 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 54 | model.add(Activation('relu')) 55 | model.add(BatchNormalization()) 56 | model.add(Dropout(0.4)) 57 | 58 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 59 | model.add(Activation('relu')) 60 | model.add(BatchNormalization()) 61 | 62 | model.add(MaxPooling2D(pool_size=(2, 2))) 63 | 64 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 65 | model.add(Activation('relu')) 66 | model.add(BatchNormalization()) 67 | model.add(Dropout(0.4)) 68 | 69 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 70 | model.add(Activation('relu')) 71 | model.add(BatchNormalization()) 72 | model.add(Dropout(0.4)) 73 | 74 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 75 | model.add(Activation('relu')) 76 | model.add(BatchNormalization()) 77 | 78 | model.add(MaxPooling2D(pool_size=(2, 2))) 79 | 80 | 81 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 82 | model.add(Activation('relu')) 83 | model.add(BatchNormalization()) 84 | model.add(Dropout(0.4)) 85 | 86 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 87 | model.add(Activation('relu')) 88 | model.add(BatchNormalization()) 89 | model.add(Dropout(0.4)) 90 | 91 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 92 | model.add(Activation('relu')) 93 | model.add(BatchNormalization()) 94 | 95 | model.add(MaxPooling2D(pool_size=(2, 2))) 96 | 97 | 98 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 99 | model.add(Activation('relu')) 100 | model.add(BatchNormalization()) 101 | model.add(Dropout(0.4)) 102 | 103 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 104 | model.add(Activation('relu')) 105 | model.add(BatchNormalization()) 106 | model.add(Dropout(0.4)) 107 | 108 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 109 | model.add(Activation('relu')) 110 | model.add(BatchNormalization()) 111 | 112 | model.add(MaxPooling2D(pool_size=(2, 2))) 113 | model.add(Dropout(0.5)) 114 | 115 | model.add(Flatten()) 116 | model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) 117 | model.add(Activation('relu')) 118 | model.add(BatchNormalization()) 119 | 120 | model.add(Dropout(0.5)) 121 | model.add(Dense(self.num_classes)) 122 | model.add(Activation('softmax')) 123 | return model 124 | 125 | 126 | def normalize(self,X_train,X_test): 127 | #this function normalize inputs for zero mean and unit variance 128 | # it is used when training a model. 129 | # Input: training set and test set 130 | # Output: normalized training set and test set according to the trianing set statistics. 131 | mean = np.mean(X_train,axis=(0,1,2,3)) 132 | std = np.std(X_train, axis=(0, 1, 2, 3)) 133 | print(mean) 134 | print(std) 135 | X_train = (X_train-mean)/(std+1e-7) 136 | X_test = (X_test-mean)/(std+1e-7) 137 | return X_train, X_test 138 | 139 | def normalize_production(self,x): 140 | #this function is used to normalize instances in production according to saved training set statistics 141 | # Input: X - a training set 142 | # Output X - a normalized training set according to normalization constants. 143 | 144 | #these values produced during first training and are general for the standard cifar10 training set normalization 145 | mean = 121.936 146 | std = 68.389 147 | return (x-mean)/(std+1e-7) 148 | 149 | def predict(self,x,normalize=True,batch_size=50): 150 | if normalize: 151 | x = self.normalize_production(x) 152 | return self.model.predict(x,batch_size) 153 | 154 | def train(self,model): 155 | 156 | #training parameters 157 | batch_size = 128 158 | maxepoches = 250 159 | learning_rate = 0.1 160 | lr_decay = 1e-6 161 | lr_drop = 20 162 | 163 | # The data, shuffled and split between train and test sets: 164 | (x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data() 165 | x_full_train = x_full_train.astype('float32') 166 | x_test = x_test.astype('float32') 167 | x_full_train, x_test = self.normalize(x_full_train, x_test) 168 | 169 | y_full_train = keras.utils.to_categorical(y_full_train, self.num_classes) 170 | y_test = keras.utils.to_categorical(y_test, self.num_classes) 171 | 172 | x_train = x_full_train[:-10000] 173 | y_train = y_full_train[:-10000] 174 | 175 | def lr_scheduler(epoch): 176 | return learning_rate * (0.5 ** (epoch // lr_drop)) 177 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 178 | 179 | 180 | #data augmentation 181 | datagen = ImageDataGenerator( 182 | featurewise_center=False, # set input mean to 0 over the dataset 183 | samplewise_center=False, # set each sample mean to 0 184 | featurewise_std_normalization=False, # divide inputs by std of the dataset 185 | samplewise_std_normalization=False, # divide each input by its std 186 | zca_whitening=False, # apply ZCA whitening 187 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 188 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 189 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 190 | horizontal_flip=True, # randomly flip images 191 | vertical_flip=False) # randomly flip images 192 | # (std, mean, and principal components if ZCA whitening is applied). 193 | datagen.fit(x_train) 194 | 195 | 196 | 197 | #optimization details 198 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 199 | model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy']) 200 | 201 | 202 | # training process in a for loop with learning rate drop every 25 epoches. 203 | 204 | historytemp = model.fit_generator(datagen.flow(x_train, y_train, 205 | batch_size=batch_size), 206 | steps_per_epoch=x_train.shape[0] // batch_size, 207 | epochs=maxepoches, 208 | validation_data=(x_test, y_test),callbacks=[reduce_lr],verbose=2) 209 | model.save_weights('cifar100vgg.h5') 210 | return model 211 | 212 | def save_model(self, name): 213 | self.model.save_weights('cifar100vgg_'+name+'.h5') 214 | 215 | def load_model(self, weights): 216 | self.model.load_weights(weights) 217 | 218 | def getModel(self): 219 | return Model(input=self.model.input, output=self.model.layers[-2].output) 220 | 221 | 222 | # The data, shuffled and split between train and test sets: 223 | (x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data() 224 | x_full_train = x_full_train.astype('float32') 225 | x_test = x_test.astype('float32') 226 | dummy_model = cifar100vgg(train=False) 227 | x_full_train, x_test = dummy_model.normalize(x_full_train, x_test) 228 | 229 | y_full_train = keras.utils.to_categorical(y_full_train, dummy_model.num_classes) 230 | y_test = keras.utils.to_categorical(y_test, dummy_model.num_classes) 231 | 232 | x_train = x_full_train[:-10000] 233 | y_train = y_full_train[:-10000] 234 | x_valid = x_full_train[-10000:] 235 | y_valid = y_full_train[-10000:] 236 | 237 | 238 | for seed in np.arange(20, 100,10): 239 | np.random.seed(seed) 240 | random.seed(seed) 241 | model = cifar100vgg() 242 | model.save_model("seed"+str(seed)) 243 | pre_softmax_model = model.getModel() 244 | 245 | valid_preacts = pre_softmax_model.predict(x_valid) 246 | test_preacts = pre_softmax_model.predict(x_test) 247 | sys.stdout.flush() 248 | 249 | test_predictions_file = "testpreacts_seed"+str(seed)+".txt" 250 | print("Saving", test_predictions_file) 251 | f = open(test_predictions_file,'w') 252 | for test_preact in test_preacts: 253 | f.write("\t".join([str(x) for x in test_preact])+"\n") 254 | f.close() 255 | 256 | valid_predictions_file = "validpreacts_seed"+str(seed)+".txt" 257 | print("Saving", valid_predictions_file) 258 | f = open(valid_predictions_file,'w') 259 | for valid_preact in valid_preacts: 260 | f.write("\t".join([str(x) for x in valid_preact])+"\n") 261 | f.close() -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar100/README.txt: -------------------------------------------------------------------------------- 1 | Code used to train models was based on https://github.com/geifmany/selective_deep_learning 2 | 3 | Kundaje lab internal notes: 4 | - Amr Alexandar trained these models 5 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar100/getPreds.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | }, 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "keras version: 2.2.4\n", 20 | "tensorflow version: 1.14.0\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import keras\n", 26 | "from keras.models import load_model\n", 27 | "from keras.models import Model\n", 28 | "print(\"keras version:\", keras.__version__)\n", 29 | "import tensorflow as tf\n", 30 | "print(\"tensorflow version:\", tf.__version__)\n", 31 | "import os\n", 32 | "import sys\n", 33 | "from keras.datasets import cifar10\n", 34 | "import numpy as np\n", 35 | "from __future__ import print_function\n", 36 | "import keras\n", 37 | "from keras.datasets import cifar100\n", 38 | "from keras.preprocessing.image import ImageDataGenerator\n", 39 | "from keras.models import Sequential\n", 40 | "from keras.layers import Dense, Dropout, Activation, Flatten\n", 41 | "from keras.layers import Conv2D, MaxPooling2D, BatchNormalization\n", 42 | "from keras import optimizers\n", 43 | "import numpy as np\n", 44 | "from keras.layers.core import Lambda\n", 45 | "from keras import backend as K\n", 46 | "from keras import regularizers\n", 47 | "import random" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "class cifar100vgg:\n", 66 | " def __init__(self,train=True):\n", 67 | " self.num_classes = 100\n", 68 | " self.weight_decay = 0.0005\n", 69 | " self.x_shape = [32,32,3]\n", 70 | "\n", 71 | " self.model = self.build_model()\n", 72 | " if train:\n", 73 | " self.model = self.train(self.model)\n", 74 | " else:\n", 75 | " self.model.load_weights('cifar100vgg.h5')\n", 76 | "\n", 77 | "\n", 78 | " def build_model(self):\n", 79 | " # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper.\n", 80 | "\n", 81 | " model = Sequential()\n", 82 | " weight_decay = self.weight_decay\n", 83 | "\n", 84 | " model.add(Conv2D(64, (3, 3), padding='same',\n", 85 | " input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay)))\n", 86 | " model.add(Activation('relu'))\n", 87 | " model.add(BatchNormalization())\n", 88 | " model.add(Dropout(0.3))\n", 89 | "\n", 90 | " model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 91 | " model.add(Activation('relu'))\n", 92 | " model.add(BatchNormalization())\n", 93 | "\n", 94 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 95 | "\n", 96 | " model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 97 | " model.add(Activation('relu'))\n", 98 | " model.add(BatchNormalization())\n", 99 | " model.add(Dropout(0.4))\n", 100 | "\n", 101 | " model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 102 | " model.add(Activation('relu'))\n", 103 | " model.add(BatchNormalization())\n", 104 | "\n", 105 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 106 | "\n", 107 | " model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 108 | " model.add(Activation('relu'))\n", 109 | " model.add(BatchNormalization())\n", 110 | " model.add(Dropout(0.4))\n", 111 | "\n", 112 | " model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 113 | " model.add(Activation('relu'))\n", 114 | " model.add(BatchNormalization())\n", 115 | " model.add(Dropout(0.4))\n", 116 | "\n", 117 | " model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 118 | " model.add(Activation('relu'))\n", 119 | " model.add(BatchNormalization())\n", 120 | "\n", 121 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 122 | "\n", 123 | "\n", 124 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 125 | " model.add(Activation('relu'))\n", 126 | " model.add(BatchNormalization())\n", 127 | " model.add(Dropout(0.4))\n", 128 | "\n", 129 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 130 | " model.add(Activation('relu'))\n", 131 | " model.add(BatchNormalization())\n", 132 | " model.add(Dropout(0.4))\n", 133 | "\n", 134 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 135 | " model.add(Activation('relu'))\n", 136 | " model.add(BatchNormalization())\n", 137 | "\n", 138 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 139 | "\n", 140 | "\n", 141 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 142 | " model.add(Activation('relu'))\n", 143 | " model.add(BatchNormalization())\n", 144 | " model.add(Dropout(0.4))\n", 145 | "\n", 146 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 147 | " model.add(Activation('relu'))\n", 148 | " model.add(BatchNormalization())\n", 149 | " model.add(Dropout(0.4))\n", 150 | "\n", 151 | " model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))\n", 152 | " model.add(Activation('relu'))\n", 153 | " model.add(BatchNormalization())\n", 154 | "\n", 155 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 156 | " model.add(Dropout(0.5))\n", 157 | "\n", 158 | " model.add(Flatten())\n", 159 | " model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay)))\n", 160 | " model.add(Activation('relu'))\n", 161 | " model.add(BatchNormalization())\n", 162 | "\n", 163 | " model.add(Dropout(0.5))\n", 164 | " model.add(Dense(self.num_classes))\n", 165 | " model.add(Activation('softmax'))\n", 166 | " return model\n", 167 | "\n", 168 | "\n", 169 | " def normalize(self,X_train,X_test):\n", 170 | " #this function normalize inputs for zero mean and unit variance\n", 171 | " # it is used when training a model.\n", 172 | " # Input: training set and test set\n", 173 | " # Output: normalized training set and test set according to the trianing set statistics.\n", 174 | " mean = np.mean(X_train,axis=(0,1,2,3))\n", 175 | " std = np.std(X_train, axis=(0, 1, 2, 3))\n", 176 | " print(mean)\n", 177 | " print(std)\n", 178 | " X_train = (X_train-mean)/(std+1e-7)\n", 179 | " X_test = (X_test-mean)/(std+1e-7)\n", 180 | " return X_train, X_test\n", 181 | "\n", 182 | " def normalize_production(self,x):\n", 183 | " #this function is used to normalize instances in production according to saved training set statistics\n", 184 | " # Input: X - a training set\n", 185 | " # Output X - a normalized training set according to normalization constants.\n", 186 | "\n", 187 | " #these values produced during first training and are general for the standard cifar10 training set normalization\n", 188 | " mean = 121.936\n", 189 | " std = 68.389\n", 190 | " return (x-mean)/(std+1e-7)\n", 191 | "\n", 192 | " def predict(self,x,normalize=True,batch_size=50):\n", 193 | " if normalize:\n", 194 | " x = self.normalize_production(x)\n", 195 | " return self.model.predict(x,batch_size)\n", 196 | "\n", 197 | " def train(self,model):\n", 198 | "\n", 199 | " #training parameters\n", 200 | " batch_size = 128\n", 201 | " maxepoches = 250\n", 202 | " learning_rate = 0.1\n", 203 | " lr_decay = 1e-6\n", 204 | " lr_drop = 20\n", 205 | "\n", 206 | " # The data, shuffled and split between train and test sets:\n", 207 | " (x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data()\n", 208 | " x_full_train = x_full_train.astype('float32')\n", 209 | " x_test = x_test.astype('float32')\n", 210 | " x_full_train, x_test = self.normalize(x_full_train, x_test)\n", 211 | "\n", 212 | " y_full_train = keras.utils.to_categorical(y_full_train, self.num_classes)\n", 213 | " y_test = keras.utils.to_categorical(y_test, self.num_classes)\n", 214 | " \n", 215 | " x_train = x_full_train[:-10000]\n", 216 | " y_train = y_full_train[:-10000]\n", 217 | "\n", 218 | " def lr_scheduler(epoch):\n", 219 | " return learning_rate * (0.5 ** (epoch // lr_drop))\n", 220 | " reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler)\n", 221 | "\n", 222 | "\n", 223 | " #data augmentation\n", 224 | " datagen = ImageDataGenerator(\n", 225 | " featurewise_center=False, # set input mean to 0 over the dataset\n", 226 | " samplewise_center=False, # set each sample mean to 0\n", 227 | " featurewise_std_normalization=False, # divide inputs by std of the dataset\n", 228 | " samplewise_std_normalization=False, # divide each input by its std\n", 229 | " zca_whitening=False, # apply ZCA whitening\n", 230 | " rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180)\n", 231 | " width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)\n", 232 | " height_shift_range=0.1, # randomly shift images vertically (fraction of total height)\n", 233 | " horizontal_flip=True, # randomly flip images\n", 234 | " vertical_flip=False) # randomly flip images\n", 235 | " # (std, mean, and principal components if ZCA whitening is applied).\n", 236 | " datagen.fit(x_train)\n", 237 | "\n", 238 | "\n", 239 | "\n", 240 | " #optimization details\n", 241 | " sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True)\n", 242 | " model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy'])\n", 243 | "\n", 244 | "\n", 245 | " # training process in a for loop with learning rate drop every 25 epoches.\n", 246 | "\n", 247 | " historytemp = model.fit_generator(datagen.flow(x_train, y_train,\n", 248 | " batch_size=batch_size),\n", 249 | " steps_per_epoch=x_train.shape[0] // batch_size,\n", 250 | " epochs=maxepoches,\n", 251 | " validation_data=(x_test, y_test),callbacks=[reduce_lr],verbose=2)\n", 252 | " model.save_weights('cifar100vgg.h5')\n", 253 | " return model\n", 254 | " \n", 255 | " def save_model(self, name):\n", 256 | " self.model.save_weights('cifar100vgg_'+name+'.h5')\n", 257 | " \n", 258 | " def load_model(self, weights):\n", 259 | " self.model.load_weights(weights)\n", 260 | " \n", 261 | " def getModel(self):\n", 262 | " return Model(input=self.model.input, output=self.model.layers[-2].output)\n", 263 | " \n", 264 | " def evaluate(self, x, y):\n", 265 | " self.model.evaluate(x=x, y=y, batch_size=200, verbose=1)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 4, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 278 | "\n", 279 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", 280 | "\n", 281 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n", 282 | "\n", 283 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", 284 | "\n", 285 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n", 286 | "\n", 287 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.\n", 288 | "\n", 289 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 290 | "Instructions for updating:\n", 291 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 292 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n", 293 | "\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "seed = 0\n", 299 | "np.random.seed(seed)\n", 300 | "random.seed(seed)\n", 301 | "model = cifar100vgg(train=False)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 5, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "model.load_model(\"cifar100vgg_seed0.h5\")" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 6, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stderr", 320 | "output_type": "stream", 321 | "text": [ 322 | "/users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/ipykernel/__main__.py:198: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"co..., outputs=Tensor(\"de...)`\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "pre_softmax_model = model.getModel()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 7, 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "121.93584\n", 340 | "68.38902\n", 341 | "Saving testpreacts_seed0.txt\n", 342 | "Saving validpreacts_seed0.txt\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "# The data, shuffled and split between train and test sets:\n", 348 | "(x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data()\n", 349 | "x_full_train = x_full_train.astype('float32')\n", 350 | "x_test = x_test.astype('float32')\n", 351 | "x_full_train, x_test = model.normalize(x_full_train, x_test)\n", 352 | "\n", 353 | "y_full_train = keras.utils.to_categorical(y_full_train, model.num_classes)\n", 354 | "y_test = keras.utils.to_categorical(y_test, model.num_classes)\n", 355 | "\n", 356 | "x_train = x_full_train[:-10000]\n", 357 | "y_train = y_full_train[:-10000]\n", 358 | "x_valid = x_full_train[-10000:]\n", 359 | "y_valid = y_full_train[-10000:]\n", 360 | "\n", 361 | "\n", 362 | "valid_preacts = pre_softmax_model.predict(x_valid)\n", 363 | "test_preacts = pre_softmax_model.predict(x_test)\n", 364 | "sys.stdout.flush()\n", 365 | "\n", 366 | "test_predictions_file = \"testpreacts_seed\"+str(seed)+\".txt\"\n", 367 | "print(\"Saving\", test_predictions_file)\n", 368 | "f = open(test_predictions_file,'w')\n", 369 | "for test_preact in test_preacts:\n", 370 | " f.write(\"\\t\".join([str(x) for x in test_preact])+\"\\n\") \n", 371 | "f.close()\n", 372 | "\n", 373 | "valid_predictions_file = \"validpreacts_seed\"+str(seed)+\".txt\"\n", 374 | "print(\"Saving\", valid_predictions_file)\n", 375 | "f = open(valid_predictions_file,'w')\n", 376 | "for valid_preact in valid_preacts:\n", 377 | " f.write(\"\\t\".join([str(x) for x in valid_preact])+\"\\n\") \n", 378 | "f.close()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 8, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "0.6649\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "print(np.mean(np.argmax(y_valid, axis=-1)==np.argmax(valid_preacts, axis=-1)))" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 9, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "done\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "print(\"done\")" 413 | ] 414 | } 415 | ], 416 | "metadata": { 417 | "kernelspec": { 418 | "display_name": "Python [conda env:basepair]", 419 | "language": "python", 420 | "name": "conda-env-basepair-py" 421 | }, 422 | "language_info": { 423 | "codemirror_mode": { 424 | "name": "ipython", 425 | "version": 3 426 | }, 427 | "file_extension": ".py", 428 | "mimetype": "text/x-python", 429 | "name": "python", 430 | "nbconvert_exporter": "python", 431 | "pygments_lexer": "ipython3", 432 | "version": "3.6.8" 433 | } 434 | }, 435 | "nbformat": 4, 436 | "nbformat_minor": 2 437 | } 438 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/cifar100/train_cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | from keras.models import load_model 4 | from keras.models import Model 5 | print("keras version:", keras.__version__) 6 | import tensorflow as tf 7 | print("tensorflow version:", tf.__version__) 8 | import os 9 | import sys 10 | import numpy as np 11 | from keras.datasets import cifar100 12 | from keras.preprocessing.image import ImageDataGenerator 13 | from keras.models import Sequential 14 | from keras.layers import Dense, Dropout, Activation, Flatten 15 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization 16 | from keras import optimizers 17 | from keras.layers.core import Lambda 18 | from keras import backend as K 19 | from keras import regularizers 20 | import random 21 | 22 | class cifar100vgg: 23 | def __init__(self,train=True): 24 | self.num_classes = 100 25 | self.weight_decay = 0.0005 26 | self.x_shape = [32,32,3] 27 | 28 | self.model = self.build_model() 29 | if train: 30 | self.model = self.train(self.model) 31 | else: 32 | self.model.load_weights('cifar100vgg.h5') 33 | 34 | 35 | def build_model(self): 36 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 37 | 38 | model = Sequential() 39 | weight_decay = self.weight_decay 40 | 41 | model.add(Conv2D(64, (3, 3), padding='same', 42 | input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay))) 43 | model.add(Activation('relu')) 44 | model.add(BatchNormalization()) 45 | model.add(Dropout(0.3)) 46 | 47 | model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 48 | model.add(Activation('relu')) 49 | model.add(BatchNormalization()) 50 | 51 | model.add(MaxPooling2D(pool_size=(2, 2))) 52 | 53 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 54 | model.add(Activation('relu')) 55 | model.add(BatchNormalization()) 56 | model.add(Dropout(0.4)) 57 | 58 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 59 | model.add(Activation('relu')) 60 | model.add(BatchNormalization()) 61 | 62 | model.add(MaxPooling2D(pool_size=(2, 2))) 63 | 64 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 65 | model.add(Activation('relu')) 66 | model.add(BatchNormalization()) 67 | model.add(Dropout(0.4)) 68 | 69 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 70 | model.add(Activation('relu')) 71 | model.add(BatchNormalization()) 72 | model.add(Dropout(0.4)) 73 | 74 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 75 | model.add(Activation('relu')) 76 | model.add(BatchNormalization()) 77 | 78 | model.add(MaxPooling2D(pool_size=(2, 2))) 79 | 80 | 81 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 82 | model.add(Activation('relu')) 83 | model.add(BatchNormalization()) 84 | model.add(Dropout(0.4)) 85 | 86 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 87 | model.add(Activation('relu')) 88 | model.add(BatchNormalization()) 89 | model.add(Dropout(0.4)) 90 | 91 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 92 | model.add(Activation('relu')) 93 | model.add(BatchNormalization()) 94 | 95 | model.add(MaxPooling2D(pool_size=(2, 2))) 96 | 97 | 98 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 99 | model.add(Activation('relu')) 100 | model.add(BatchNormalization()) 101 | model.add(Dropout(0.4)) 102 | 103 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 104 | model.add(Activation('relu')) 105 | model.add(BatchNormalization()) 106 | model.add(Dropout(0.4)) 107 | 108 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 109 | model.add(Activation('relu')) 110 | model.add(BatchNormalization()) 111 | 112 | model.add(MaxPooling2D(pool_size=(2, 2))) 113 | model.add(Dropout(0.5)) 114 | 115 | model.add(Flatten()) 116 | model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) 117 | model.add(Activation('relu')) 118 | model.add(BatchNormalization()) 119 | 120 | model.add(Dropout(0.5)) 121 | model.add(Dense(self.num_classes)) 122 | model.add(Activation('softmax')) 123 | return model 124 | 125 | 126 | def normalize(self,X_train,X_test): 127 | #this function normalize inputs for zero mean and unit variance 128 | # it is used when training a model. 129 | # Input: training set and test set 130 | # Output: normalized training set and test set according to the trianing set statistics. 131 | mean = np.mean(X_train,axis=(0,1,2,3)) 132 | std = np.std(X_train, axis=(0, 1, 2, 3)) 133 | print(mean) 134 | print(std) 135 | X_train = (X_train-mean)/(std+1e-7) 136 | X_test = (X_test-mean)/(std+1e-7) 137 | return X_train, X_test 138 | 139 | def normalize_production(self,x): 140 | #this function is used to normalize instances in production according to saved training set statistics 141 | # Input: X - a training set 142 | # Output X - a normalized training set according to normalization constants. 143 | 144 | #these values produced during first training and are general for the standard cifar10 training set normalization 145 | mean = 121.936 146 | std = 68.389 147 | return (x-mean)/(std+1e-7) 148 | 149 | def predict(self,x,normalize=True,batch_size=50): 150 | if normalize: 151 | x = self.normalize_production(x) 152 | return self.model.predict(x,batch_size) 153 | 154 | def train(self,model): 155 | 156 | #training parameters 157 | batch_size = 128 158 | maxepoches = 250 159 | learning_rate = 0.1 160 | lr_decay = 1e-6 161 | lr_drop = 20 162 | 163 | # The data, shuffled and split between train and test sets: 164 | (x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data() 165 | x_full_train = x_full_train.astype('float32') 166 | x_test = x_test.astype('float32') 167 | x_full_train, x_test = self.normalize(x_full_train, x_test) 168 | 169 | y_full_train = keras.utils.to_categorical(y_full_train, self.num_classes) 170 | y_test = keras.utils.to_categorical(y_test, self.num_classes) 171 | 172 | x_train = x_full_train[:-10000] 173 | y_train = y_full_train[:-10000] 174 | 175 | def lr_scheduler(epoch): 176 | return learning_rate * (0.5 ** (epoch // lr_drop)) 177 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 178 | 179 | 180 | #data augmentation 181 | datagen = ImageDataGenerator( 182 | featurewise_center=False, # set input mean to 0 over the dataset 183 | samplewise_center=False, # set each sample mean to 0 184 | featurewise_std_normalization=False, # divide inputs by std of the dataset 185 | samplewise_std_normalization=False, # divide each input by its std 186 | zca_whitening=False, # apply ZCA whitening 187 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 188 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 189 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 190 | horizontal_flip=True, # randomly flip images 191 | vertical_flip=False) # randomly flip images 192 | # (std, mean, and principal components if ZCA whitening is applied). 193 | datagen.fit(x_train) 194 | 195 | 196 | 197 | #optimization details 198 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 199 | model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy']) 200 | 201 | 202 | # training process in a for loop with learning rate drop every 25 epoches. 203 | 204 | historytemp = model.fit_generator(datagen.flow(x_train, y_train, 205 | batch_size=batch_size), 206 | steps_per_epoch=x_train.shape[0] // batch_size, 207 | epochs=maxepoches, 208 | validation_data=(x_test, y_test),callbacks=[reduce_lr],verbose=2) 209 | model.save_weights('cifar100vgg.h5') 210 | return model 211 | 212 | def save_model(self, name): 213 | self.model.save_weights('cifar100vgg_'+name+'.h5') 214 | 215 | def load_model(self, weights): 216 | self.model.load_weights(weights) 217 | 218 | def getModel(self): 219 | return Model(input=self.model.input, output=self.model.layers[-2].output) 220 | 221 | 222 | # The data, shuffled and split between train and test sets: 223 | (x_full_train, y_full_train), (x_test, y_test) = cifar100.load_data() 224 | x_full_train = x_full_train.astype('float32') 225 | x_test = x_test.astype('float32') 226 | dummy_model = cifar100vgg(train=False) 227 | x_full_train, x_test = dummy_model.normalize(x_full_train, x_test) 228 | 229 | y_full_train = keras.utils.to_categorical(y_full_train, dummy_model.num_classes) 230 | y_test = keras.utils.to_categorical(y_test, dummy_model.num_classes) 231 | 232 | x_train = x_full_train[:-10000] 233 | y_train = y_full_train[:-10000] 234 | x_valid = x_full_train[-10000:] 235 | y_valid = y_full_train[-10000:] 236 | 237 | 238 | for seed in np.arange(20, 100,10): 239 | np.random.seed(seed) 240 | random.seed(seed) 241 | model = cifar100vgg() 242 | model.save_model("seed"+str(seed)) 243 | pre_softmax_model = model.getModel() 244 | 245 | valid_preacts = pre_softmax_model.predict(x_valid) 246 | test_preacts = pre_softmax_model.predict(x_test) 247 | sys.stdout.flush() 248 | 249 | test_predictions_file = "testpreacts_seed"+str(seed)+".txt" 250 | print("Saving", test_predictions_file) 251 | f = open(test_predictions_file,'w') 252 | for test_preact in test_preacts: 253 | f.write("\t".join([str(x) for x in test_preact])+"\n") 254 | f.close() 255 | 256 | valid_predictions_file = "validpreacts_seed"+str(seed)+".txt" 257 | print("Saving", valid_predictions_file) 258 | f = open(valid_predictions_file,'w') 259 | for valid_preact in valid_preacts: 260 | f.write("\t".join([str(x) for x in valid_preact])+"\n") 261 | f.close() -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/diabetic_retinopathy/README.txt: -------------------------------------------------------------------------------- 1 | We used the original model made publicly available at https://github.com/JeffreyDF/kaggle_diabetic_retinopathy (won 5th place at the Kaggle Diabetic Retinopathy detection challenge) 2 | 3 | Our code for making the predictions using one eye at a time is at https://github.com/kundajelab/kaggle_diabetic_retinopathy/blob/26ca72b09393d9c4360d635f7caa90aaf4d6744a/notebooks/OneEyeAtATimeValidationSetPredictions.ipynb 4 | 5 | For each eye, the predictions were averaged over different rotations and flips in order to get the final predictions, which were prepared for upload to Zenodo with this script: https://github.com/kundajelab/kaggle_diabetic_retinopathy/blob/26ca72b09393d9c4360d635f7caa90aaf4d6744a/notebooks/prepare_data_for_zenodo_upload/prepare_data.py 6 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/mnist/README.txt: -------------------------------------------------------------------------------- 1 | Kundajelab internal notes: 2 | - Amr Alexandari trained these models. 3 | -------------------------------------------------------------------------------- /notebooks/obtaining_predictions/mnist/Train_MNIST_and_make_predictions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | }, 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "keras version: 2.2.4\n", 20 | "tensorflow version: 1.14.0\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "from __future__ import print_function\n", 26 | "import keras\n", 27 | "from keras.models import load_model\n", 28 | "from keras.models import Sequential, Model\n", 29 | "print(\"keras version:\", keras.__version__)\n", 30 | "import tensorflow as tf\n", 31 | "print(\"tensorflow version:\", tf.__version__)\n", 32 | "import random\n", 33 | "import os\n", 34 | "import sys\n", 35 | "import numpy as np\n", 36 | "from keras.datasets import mnist\n", 37 | "from keras.layers import Dense, Dropout, Flatten\n", 38 | "from keras.layers import Conv2D, MaxPooling2D, Activation\n", 39 | "from keras import backend as K\n", 40 | "from keras.callbacks import EarlyStopping" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "x_train shape: (60000, 28, 28, 1)\n", 53 | "60000 train samples\n", 54 | "10000 valid samples\n", 55 | "10000 test samples\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "batch_size = 128\n", 61 | "num_classes = 10\n", 62 | "epochs = 10\n", 63 | "\n", 64 | "# input image dimensions\n", 65 | "img_rows, img_cols = 28, 28\n", 66 | "\n", 67 | "# the data, split between train and test sets\n", 68 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 69 | "\n", 70 | "if K.image_data_format() == 'channels_first':\n", 71 | " x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n", 72 | " x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n", 73 | " input_shape = (1, img_rows, img_cols)\n", 74 | "else:\n", 75 | " x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", 76 | " x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", 77 | " input_shape = (img_rows, img_cols, 1)\n", 78 | "\n", 79 | "full_x_train = x_train.astype('float32')\n", 80 | "x_test = x_test.astype('float32')\n", 81 | "full_x_train /= 255\n", 82 | "x_test /= 255\n", 83 | "x_valid = full_x_train[-10000:]\n", 84 | "print('x_train shape:', full_x_train.shape)\n", 85 | "print(full_x_train.shape[0], 'train samples')\n", 86 | "print(x_valid.shape[0], 'valid samples')\n", 87 | "print(x_test.shape[0], 'test samples')\n", 88 | "\n", 89 | "# convert class vectors to binary class matrices\n", 90 | "full_y_train = keras.utils.to_categorical(y_train, num_classes)\n", 91 | "y_valid = full_y_train[-10000:]\n", 92 | "y_test = keras.utils.to_categorical(y_test, num_classes)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "0" 104 | ] 105 | }, 106 | "execution_count": 3, 107 | "metadata": {}, 108 | "output_type": "execute_result" 109 | } 110 | ], 111 | "source": [ 112 | "output_file = \"test_labels.txt\"\n", 113 | "f = open(output_file, 'w')\n", 114 | "f.write(\"\\n\".join([\"\\t\".join([str(x) for x in y]) for y in y_test]))\n", 115 | "f.close()\n", 116 | "os.system(\"gzip -f \"+output_file)\n", 117 | "\n", 118 | "output_file = \"valid_labels.txt\"\n", 119 | "f = open(output_file, 'w')\n", 120 | "f.write(\"\\n\".join([\"\\t\".join([str(x) for x in y]) for y in y_valid]))\n", 121 | "f.close()\n", 122 | "os.system(\"gzip -f \"+output_file)\n", 123 | "\n", 124 | "output_file = \"train_labels.txt\"\n", 125 | "f = open(output_file, 'w')\n", 126 | "f.write(\"\\n\".join([\"\\t\".join([str(x) for x in y]) for y in full_y_train]))\n", 127 | "f.close()\n", 128 | "os.system(\"gzip -f \"+output_file)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "On train set size 30000\n", 141 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 142 | "\n", 143 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", 144 | "\n", 145 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n", 146 | "\n", 147 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 148 | "\n", 149 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.\n", 150 | "\n", 151 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 152 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 153 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 154 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 155 | "Instructions for updating:\n", 156 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 157 | "WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n", 158 | "\n", 159 | "Train on 30000 samples, validate on 10000 samples\n", 160 | "Epoch 1/10\n", 161 | "30000/30000 [==============================] - 4s 141us/step - loss: 1.1742 - acc: 0.7175 - val_loss: 0.6226 - val_acc: 0.8657\n", 162 | "Epoch 2/10\n", 163 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.5563 - acc: 0.8654 - val_loss: 0.4470 - val_acc: 0.8933\n", 164 | "Epoch 3/10\n", 165 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4487 - acc: 0.8852 - val_loss: 0.3881 - val_acc: 0.9009\n", 166 | "Epoch 4/10\n", 167 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4018 - acc: 0.8932 - val_loss: 0.3557 - val_acc: 0.9094\n", 168 | "Epoch 5/10\n", 169 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3741 - acc: 0.8990 - val_loss: 0.3367 - val_acc: 0.9108\n", 170 | "Epoch 6/10\n", 171 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3552 - acc: 0.9028 - val_loss: 0.3242 - val_acc: 0.9130\n", 172 | "Epoch 7/10\n", 173 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3414 - acc: 0.9062 - val_loss: 0.3136 - val_acc: 0.9156\n", 174 | "Epoch 8/10\n", 175 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3306 - acc: 0.9092 - val_loss: 0.3051 - val_acc: 0.9173\n", 176 | "Epoch 9/10\n", 177 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3214 - acc: 0.9115 - val_loss: 0.2984 - val_acc: 0.9169\n", 178 | "Epoch 10/10\n", 179 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3138 - acc: 0.9126 - val_loss: 0.2926 - val_acc: 0.9194\n", 180 | "Making predictions on validation set\n" 181 | ] 182 | }, 183 | { 184 | "name": "stderr", 185 | "output_type": "stream", 186 | "text": [ 187 | "/users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/ipykernel_launcher.py:36: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"fl..., outputs=Tensor(\"de...)`\n" 188 | ] 189 | }, 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "Making predictions on test set\n", 195 | "Test accuracy: 0.9178\n", 196 | "Valid accuracy: 0.9194\n", 197 | "Saving testpreacts_model_mnist_set-30000_seed-0.txt\n", 198 | "f0f52eca81d56c315628f9effb3dfbb0 testpreacts_model_mnist_set-30000_seed-0.txt\n", 199 | "Saving validpreacts_model_mnist_set-30000_seed-0.txt\n", 200 | "f8b8c16e39c2dc3dc98f917d9828de7f validpreacts_model_mnist_set-30000_seed-0.txt\n", 201 | "On train set size 30000\n", 202 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 203 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 204 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 205 | "Train on 30000 samples, validate on 10000 samples\n", 206 | "Epoch 1/10\n", 207 | "30000/30000 [==============================] - 1s 30us/step - loss: 1.1656 - acc: 0.7336 - val_loss: 0.6216 - val_acc: 0.8713\n", 208 | "Epoch 2/10\n", 209 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5498 - acc: 0.8703 - val_loss: 0.4456 - val_acc: 0.8940\n", 210 | "Epoch 3/10\n", 211 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4429 - acc: 0.8881 - val_loss: 0.3855 - val_acc: 0.9018\n", 212 | "Epoch 4/10\n", 213 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3964 - acc: 0.8962 - val_loss: 0.3539 - val_acc: 0.9070\n", 214 | "Epoch 5/10\n", 215 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3690 - acc: 0.9017 - val_loss: 0.3346 - val_acc: 0.9114\n", 216 | "Epoch 6/10\n", 217 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3502 - acc: 0.9057 - val_loss: 0.3205 - val_acc: 0.9140\n", 218 | "Epoch 7/10\n", 219 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3361 - acc: 0.9087 - val_loss: 0.3094 - val_acc: 0.9158\n", 220 | "Epoch 8/10\n", 221 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3248 - acc: 0.9109 - val_loss: 0.3017 - val_acc: 0.9181\n", 222 | "Epoch 9/10\n", 223 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3155 - acc: 0.9139 - val_loss: 0.2946 - val_acc: 0.9201\n", 224 | "Epoch 10/10\n", 225 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3077 - acc: 0.9159 - val_loss: 0.2886 - val_acc: 0.9210\n", 226 | "Making predictions on validation set\n", 227 | "Making predictions on test set\n", 228 | "Test accuracy: 0.9183\n", 229 | "Valid accuracy: 0.921\n", 230 | "Saving testpreacts_model_mnist_set-30000_seed-10.txt\n", 231 | "260be080ee3540ded634d773be845134 testpreacts_model_mnist_set-30000_seed-10.txt\n", 232 | "Saving validpreacts_model_mnist_set-30000_seed-10.txt\n", 233 | "9c1619390d5a00b9b940eafbe7695356 validpreacts_model_mnist_set-30000_seed-10.txt\n", 234 | "On train set size 30000\n", 235 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 236 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 237 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 238 | "Train on 30000 samples, validate on 10000 samples\n", 239 | "Epoch 1/10\n", 240 | "30000/30000 [==============================] - 1s 34us/step - loss: 1.1828 - acc: 0.7143 - val_loss: 0.6320 - val_acc: 0.8631\n", 241 | "Epoch 2/10\n", 242 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5625 - acc: 0.8633 - val_loss: 0.4522 - val_acc: 0.8898\n", 243 | "Epoch 3/10\n", 244 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4527 - acc: 0.8831 - val_loss: 0.3905 - val_acc: 0.9013\n", 245 | "Epoch 4/10\n", 246 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4046 - acc: 0.8932 - val_loss: 0.3582 - val_acc: 0.9059\n", 247 | "Epoch 5/10\n", 248 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3758 - acc: 0.8988 - val_loss: 0.3396 - val_acc: 0.9097\n", 249 | "Epoch 6/10\n", 250 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3564 - acc: 0.9036 - val_loss: 0.3244 - val_acc: 0.9122\n", 251 | "Epoch 7/10\n", 252 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3419 - acc: 0.9070 - val_loss: 0.3133 - val_acc: 0.9155\n", 253 | "Epoch 8/10\n", 254 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3305 - acc: 0.9099 - val_loss: 0.3050 - val_acc: 0.9171\n", 255 | "Epoch 9/10\n", 256 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3210 - acc: 0.9118 - val_loss: 0.2979 - val_acc: 0.9183\n", 257 | "Epoch 10/10\n", 258 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3130 - acc: 0.9140 - val_loss: 0.2918 - val_acc: 0.9192\n", 259 | "Making predictions on validation set\n", 260 | "Making predictions on test set\n", 261 | "Test accuracy: 0.918\n", 262 | "Valid accuracy: 0.9192\n", 263 | "Saving testpreacts_model_mnist_set-30000_seed-20.txt\n", 264 | "bf586c2636ed2c4727cc3505b85d12e5 testpreacts_model_mnist_set-30000_seed-20.txt\n", 265 | "Saving validpreacts_model_mnist_set-30000_seed-20.txt\n", 266 | "2718805d70ab7ece95b5e1cfb987893c validpreacts_model_mnist_set-30000_seed-20.txt\n", 267 | "On train set size 30000\n", 268 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 269 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 270 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 271 | "Train on 30000 samples, validate on 10000 samples\n", 272 | "Epoch 1/10\n", 273 | "30000/30000 [==============================] - 1s 33us/step - loss: 1.2099 - acc: 0.7147 - val_loss: 0.6378 - val_acc: 0.8682\n", 274 | "Epoch 2/10\n", 275 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.5614 - acc: 0.8677 - val_loss: 0.4506 - val_acc: 0.8913\n", 276 | "Epoch 3/10\n", 277 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4488 - acc: 0.8865 - val_loss: 0.3878 - val_acc: 0.8999\n", 278 | "Epoch 4/10\n", 279 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4001 - acc: 0.8949 - val_loss: 0.3568 - val_acc: 0.9061\n", 280 | "Epoch 5/10\n", 281 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3715 - acc: 0.9010 - val_loss: 0.3363 - val_acc: 0.9094\n", 282 | "Epoch 6/10\n", 283 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3521 - acc: 0.9044 - val_loss: 0.3222 - val_acc: 0.9131\n", 284 | "Epoch 7/10\n", 285 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3377 - acc: 0.9081 - val_loss: 0.3115 - val_acc: 0.9158\n", 286 | "Epoch 8/10\n", 287 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3262 - acc: 0.9104 - val_loss: 0.3034 - val_acc: 0.9170\n", 288 | "Epoch 9/10\n", 289 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3169 - acc: 0.9131 - val_loss: 0.2972 - val_acc: 0.9186\n", 290 | "Epoch 10/10\n", 291 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3090 - acc: 0.9145 - val_loss: 0.2906 - val_acc: 0.9209\n", 292 | "Making predictions on validation set\n", 293 | "Making predictions on test set\n", 294 | "Test accuracy: 0.9181\n", 295 | "Valid accuracy: 0.9209\n", 296 | "Saving testpreacts_model_mnist_set-30000_seed-30.txt\n", 297 | "1432e16698c34f231e53aca497524da4 testpreacts_model_mnist_set-30000_seed-30.txt\n", 298 | "Saving validpreacts_model_mnist_set-30000_seed-30.txt\n", 299 | "81f9001f7d0e90ecc692b42f3eabfa86 validpreacts_model_mnist_set-30000_seed-30.txt\n", 300 | "On train set size 30000\n", 301 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 302 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 303 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 304 | "Train on 30000 samples, validate on 10000 samples\n", 305 | "Epoch 1/10\n", 306 | "30000/30000 [==============================] - 1s 34us/step - loss: 1.1571 - acc: 0.7264 - val_loss: 0.6238 - val_acc: 0.8636\n", 307 | "Epoch 2/10\n", 308 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5570 - acc: 0.8623 - val_loss: 0.4497 - val_acc: 0.8880\n", 309 | "Epoch 3/10\n", 310 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4503 - acc: 0.8827 - val_loss: 0.3905 - val_acc: 0.8994\n", 311 | "Epoch 4/10\n", 312 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4038 - acc: 0.8917 - val_loss: 0.3599 - val_acc: 0.9057\n", 313 | "Epoch 5/10\n", 314 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3760 - acc: 0.8984 - val_loss: 0.3404 - val_acc: 0.9086\n", 315 | "Epoch 6/10\n", 316 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3574 - acc: 0.9026 - val_loss: 0.3264 - val_acc: 0.9120\n", 317 | "Epoch 7/10\n", 318 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3432 - acc: 0.9059 - val_loss: 0.3178 - val_acc: 0.9143\n", 319 | "Epoch 8/10\n", 320 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3323 - acc: 0.9090 - val_loss: 0.3077 - val_acc: 0.9167\n", 321 | "Epoch 9/10\n", 322 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3234 - acc: 0.9119 - val_loss: 0.3006 - val_acc: 0.9176\n", 323 | "Epoch 10/10\n", 324 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3158 - acc: 0.9132 - val_loss: 0.2953 - val_acc: 0.9203\n", 325 | "Making predictions on validation set\n", 326 | "Making predictions on test set\n", 327 | "Test accuracy: 0.9167\n", 328 | "Valid accuracy: 0.9203\n", 329 | "Saving testpreacts_model_mnist_set-30000_seed-40.txt\n", 330 | "232078a3be3dd38a43d967672fb7dd2b testpreacts_model_mnist_set-30000_seed-40.txt\n", 331 | "Saving validpreacts_model_mnist_set-30000_seed-40.txt\n", 332 | "07f6f7ced1dfddb9055f6e5883d5cf36 validpreacts_model_mnist_set-30000_seed-40.txt\n", 333 | "On train set size 30000\n", 334 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 335 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 336 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 337 | "Train on 30000 samples, validate on 10000 samples\n", 338 | "Epoch 1/10\n", 339 | "30000/30000 [==============================] - 1s 36us/step - loss: 1.1192 - acc: 0.7434 - val_loss: 0.6134 - val_acc: 0.8694\n", 340 | "Epoch 2/10\n", 341 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5502 - acc: 0.8659 - val_loss: 0.4460 - val_acc: 0.8920\n", 342 | "Epoch 3/10\n", 343 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4464 - acc: 0.8840 - val_loss: 0.3875 - val_acc: 0.9012\n", 344 | "Epoch 4/10\n", 345 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3992 - acc: 0.8937 - val_loss: 0.3572 - val_acc: 0.9065\n", 346 | "Epoch 5/10\n", 347 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3713 - acc: 0.9004 - val_loss: 0.3361 - val_acc: 0.9118\n", 348 | "Epoch 6/10\n", 349 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3521 - acc: 0.9038 - val_loss: 0.3231 - val_acc: 0.9137\n", 350 | "Epoch 7/10\n", 351 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3375 - acc: 0.9079 - val_loss: 0.3117 - val_acc: 0.9167\n", 352 | "Epoch 8/10\n", 353 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3259 - acc: 0.9107 - val_loss: 0.3030 - val_acc: 0.9171\n", 354 | "Epoch 9/10\n", 355 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3165 - acc: 0.9126 - val_loss: 0.2963 - val_acc: 0.9187\n", 356 | "Epoch 10/10\n", 357 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3087 - acc: 0.9151 - val_loss: 0.2900 - val_acc: 0.9200\n", 358 | "Making predictions on validation set\n", 359 | "Making predictions on test set\n", 360 | "Test accuracy: 0.9167\n", 361 | "Valid accuracy: 0.92\n", 362 | "Saving testpreacts_model_mnist_set-30000_seed-50.txt\n", 363 | "071a68e9a11eb7388f54fedabb8af47a testpreacts_model_mnist_set-30000_seed-50.txt\n", 364 | "Saving validpreacts_model_mnist_set-30000_seed-50.txt\n", 365 | "f2fb73c24a2c43aac354be2161968422 validpreacts_model_mnist_set-30000_seed-50.txt\n", 366 | "On train set size 30000\n", 367 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 368 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 369 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 370 | "Train on 30000 samples, validate on 10000 samples\n", 371 | "Epoch 1/10\n", 372 | "30000/30000 [==============================] - 1s 38us/step - loss: 1.1516 - acc: 0.7345 - val_loss: 0.6133 - val_acc: 0.8682\n", 373 | "Epoch 2/10\n", 374 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5463 - acc: 0.8688 - val_loss: 0.4432 - val_acc: 0.8909\n", 375 | "Epoch 3/10\n", 376 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4418 - acc: 0.8865 - val_loss: 0.3852 - val_acc: 0.8994\n", 377 | "Epoch 4/10\n", 378 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3955 - acc: 0.8956 - val_loss: 0.3542 - val_acc: 0.9061\n", 379 | "Epoch 5/10\n", 380 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3682 - acc: 0.9017 - val_loss: 0.3350 - val_acc: 0.9100\n", 381 | "Epoch 6/10\n", 382 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3493 - acc: 0.9060 - val_loss: 0.3216 - val_acc: 0.9133\n", 383 | "Epoch 7/10\n", 384 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3351 - acc: 0.9094 - val_loss: 0.3103 - val_acc: 0.9160\n", 385 | "Epoch 8/10\n", 386 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3242 - acc: 0.9110 - val_loss: 0.3028 - val_acc: 0.9173\n", 387 | "Epoch 9/10\n", 388 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3151 - acc: 0.9135 - val_loss: 0.2960 - val_acc: 0.9189\n", 389 | "Epoch 10/10\n", 390 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3075 - acc: 0.9155 - val_loss: 0.2899 - val_acc: 0.9195\n", 391 | "Making predictions on validation set\n", 392 | "Making predictions on test set\n", 393 | "Test accuracy: 0.9182\n", 394 | "Valid accuracy: 0.9195\n", 395 | "Saving testpreacts_model_mnist_set-30000_seed-60.txt\n", 396 | "e88f6b8b1e711ce76b303ef202a7ee94 testpreacts_model_mnist_set-30000_seed-60.txt\n", 397 | "Saving validpreacts_model_mnist_set-30000_seed-60.txt\n", 398 | "b278c9c9da926a8b54c3a8a033ddad45 validpreacts_model_mnist_set-30000_seed-60.txt\n", 399 | "On train set size 30000\n", 400 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 401 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 402 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 403 | "Train on 30000 samples, validate on 10000 samples\n", 404 | "Epoch 1/10\n", 405 | "30000/30000 [==============================] - 1s 38us/step - loss: 1.1595 - acc: 0.7385 - val_loss: 0.6254 - val_acc: 0.8720\n", 406 | "Epoch 2/10\n", 407 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5535 - acc: 0.8679 - val_loss: 0.4480 - val_acc: 0.8927\n", 408 | "Epoch 3/10\n", 409 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4465 - acc: 0.8846 - val_loss: 0.3892 - val_acc: 0.8996\n", 410 | "Epoch 4/10\n", 411 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3998 - acc: 0.8941 - val_loss: 0.3581 - val_acc: 0.9051\n", 412 | "Epoch 5/10\n", 413 | "30000/30000 [==============================] - 1s 24us/step - loss: 0.3721 - acc: 0.8997 - val_loss: 0.3383 - val_acc: 0.9090\n", 414 | "Epoch 6/10\n", 415 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3532 - acc: 0.9042 - val_loss: 0.3245 - val_acc: 0.9125\n", 416 | "Epoch 7/10\n", 417 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3389 - acc: 0.9073 - val_loss: 0.3150 - val_acc: 0.9141\n", 418 | "Epoch 8/10\n", 419 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3280 - acc: 0.9099 - val_loss: 0.3062 - val_acc: 0.9146\n", 420 | "Epoch 9/10\n", 421 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3186 - acc: 0.9117 - val_loss: 0.2991 - val_acc: 0.9163\n", 422 | "Epoch 10/10\n", 423 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3111 - acc: 0.9144 - val_loss: 0.2941 - val_acc: 0.9182\n", 424 | "Making predictions on validation set\n", 425 | "Making predictions on test set\n", 426 | "Test accuracy: 0.9168\n", 427 | "Valid accuracy: 0.9182\n", 428 | "Saving testpreacts_model_mnist_set-30000_seed-70.txt\n", 429 | "701dd72ed9a99076e3428e5bd73cd29d testpreacts_model_mnist_set-30000_seed-70.txt\n", 430 | "Saving validpreacts_model_mnist_set-30000_seed-70.txt\n", 431 | "c163e6153430ca9e24813308f53b45fa validpreacts_model_mnist_set-30000_seed-70.txt\n", 432 | "On train set size 30000\n", 433 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 434 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 435 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 436 | "Train on 30000 samples, validate on 10000 samples\n", 437 | "Epoch 1/10\n", 438 | "30000/30000 [==============================] - 1s 38us/step - loss: 1.1987 - acc: 0.7223 - val_loss: 0.6414 - val_acc: 0.8644\n", 439 | "Epoch 2/10\n", 440 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.5650 - acc: 0.8657 - val_loss: 0.4546 - val_acc: 0.8884\n", 441 | "Epoch 3/10\n", 442 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4525 - acc: 0.8838 - val_loss: 0.3911 - val_acc: 0.9004\n", 443 | "Epoch 4/10\n", 444 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4033 - acc: 0.8923 - val_loss: 0.3584 - val_acc: 0.9055\n", 445 | "Epoch 5/10\n", 446 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3741 - acc: 0.8994 - val_loss: 0.3377 - val_acc: 0.9111\n", 447 | "Epoch 6/10\n", 448 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3542 - acc: 0.9036 - val_loss: 0.3234 - val_acc: 0.9130\n", 449 | "Epoch 7/10\n", 450 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3395 - acc: 0.9078 - val_loss: 0.3115 - val_acc: 0.9154\n", 451 | "Epoch 8/10\n", 452 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3278 - acc: 0.9099 - val_loss: 0.3023 - val_acc: 0.9174\n", 453 | "Epoch 9/10\n", 454 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3181 - acc: 0.9130 - val_loss: 0.2962 - val_acc: 0.9185\n", 455 | "Epoch 10/10\n", 456 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3102 - acc: 0.9153 - val_loss: 0.2891 - val_acc: 0.9201\n", 457 | "Making predictions on validation set\n", 458 | "Making predictions on test set\n", 459 | "Test accuracy: 0.918\n", 460 | "Valid accuracy: 0.9201\n", 461 | "Saving testpreacts_model_mnist_set-30000_seed-80.txt\n", 462 | "b4fbf67adcce57b92f32cfeb8fbd2197 testpreacts_model_mnist_set-30000_seed-80.txt\n", 463 | "Saving validpreacts_model_mnist_set-30000_seed-80.txt\n", 464 | "e95d92082bb3ef6afadca58759bd36bd validpreacts_model_mnist_set-30000_seed-80.txt\n", 465 | "On train set size 30000\n", 466 | "Mean y train: [0.0987 0.1141 0.09826667 0.10243333 0.09753333 0.0903\n", 467 | " 0.09916667 0.10356667 0.09583333 0.1001 ]\n", 468 | "Mean y valid: [0.0991 0.1064 0.099 0.103 0.0983 0.0915 0.0967 0.109 0.1009 0.0961]\n", 469 | "Train on 30000 samples, validate on 10000 samples\n", 470 | "Epoch 1/10\n", 471 | "30000/30000 [==============================] - 1s 42us/step - loss: 1.1638 - acc: 0.7276 - val_loss: 0.6319 - val_acc: 0.8621\n", 472 | "Epoch 2/10\n", 473 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.5602 - acc: 0.8641 - val_loss: 0.4535 - val_acc: 0.8906\n", 474 | "Epoch 3/10\n", 475 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.4516 - acc: 0.8836 - val_loss: 0.3929 - val_acc: 0.9016\n", 476 | "Epoch 4/10\n", 477 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.4036 - acc: 0.8925 - val_loss: 0.3611 - val_acc: 0.9066\n", 478 | "Epoch 5/10\n", 479 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3754 - acc: 0.8990 - val_loss: 0.3421 - val_acc: 0.9109\n", 480 | "Epoch 6/10\n", 481 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3562 - acc: 0.9029 - val_loss: 0.3278 - val_acc: 0.9126\n", 482 | "Epoch 7/10\n", 483 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3420 - acc: 0.9057 - val_loss: 0.3173 - val_acc: 0.9142\n", 484 | "Epoch 8/10\n", 485 | "30000/30000 [==============================] - 1s 22us/step - loss: 0.3309 - acc: 0.9083 - val_loss: 0.3084 - val_acc: 0.9167\n", 486 | "Epoch 9/10\n", 487 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3217 - acc: 0.9110 - val_loss: 0.3017 - val_acc: 0.9174\n", 488 | "Epoch 10/10\n", 489 | "30000/30000 [==============================] - 1s 23us/step - loss: 0.3140 - acc: 0.9134 - val_loss: 0.2969 - val_acc: 0.9192\n", 490 | "Making predictions on validation set\n", 491 | "Making predictions on test set\n", 492 | "Test accuracy: 0.9159\n", 493 | "Valid accuracy: 0.9192\n", 494 | "Saving testpreacts_model_mnist_set-30000_seed-90.txt\n", 495 | "47fea0b166d23f8dcb30e9fa363960e6 testpreacts_model_mnist_set-30000_seed-90.txt\n", 496 | "Saving validpreacts_model_mnist_set-30000_seed-90.txt\n", 497 | "a3856050266c8bbc06b4dbc2d74b5c20 validpreacts_model_mnist_set-30000_seed-90.txt\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "from keras import optimizers\n", 503 | "model_files = []\n", 504 | "for seed in range(0,100,10):\n", 505 | " np.random.seed(seed)\n", 506 | " random.seed(seed)\n", 507 | " for model_idx,train_set_size in enumerate([30000]):\n", 508 | " model_file = \"model_mnist_set-\"+str(train_set_size)+\"_seed-\"+str(seed)+\".h5\"\n", 509 | " model_files.append(model_file)\n", 510 | " print(\"On train set size\",train_set_size)\n", 511 | "\n", 512 | " model = Sequential()\n", 513 | " model.add(Flatten(input_shape=input_shape))\n", 514 | " model.add(Dense(256, activation='relu'))\n", 515 | " model.add(Dense(num_classes))\n", 516 | " model.add(Activation(\"softmax\"))\n", 517 | "\n", 518 | " optimizer = optimizers.SGD(lr=0.01, momentum=0.5, decay=5e-4)\n", 519 | " model.compile(loss=keras.losses.categorical_crossentropy,\n", 520 | " optimizer=optimizer,\n", 521 | " metrics=['accuracy'])\n", 522 | " x_train = full_x_train[:train_set_size] \n", 523 | " y_train = full_y_train[:train_set_size]\n", 524 | " print(\"Mean y train:\",np.mean(y_train, axis=0))\n", 525 | " print(\"Mean y valid:\",np.mean(y_valid, axis=0))\n", 526 | " model.fit(x_train, y_train,\n", 527 | " batch_size=batch_size,\n", 528 | " epochs=epochs,\n", 529 | " verbose=1,\n", 530 | " validation_data=(x_valid, y_valid),\n", 531 | " callbacks=[EarlyStopping(\n", 532 | " monitor='val_loss', patience=10,\n", 533 | " restore_best_weights=True)])\n", 534 | " model.save(model_file)\n", 535 | "\n", 536 | " pre_softmax_model = Model(input=model.input,\n", 537 | " output=model.layers[-2].output)\n", 538 | " print(\"Making predictions on validation set\")\n", 539 | " valid_preacts = pre_softmax_model.predict(x_valid)\n", 540 | " print(\"Making predictions on test set\")\n", 541 | " test_preacts = pre_softmax_model.predict(x_test)\n", 542 | " print('Test accuracy:', np.mean(np.argmax(test_preacts,axis=-1)\n", 543 | " ==np.argmax(y_test,axis=-1)))\n", 544 | " print('Valid accuracy:', np.mean(np.argmax(valid_preacts,axis=-1)\n", 545 | " ==np.argmax(y_valid,axis=-1)))\n", 546 | " sys.stdout.flush()\n", 547 | " test_predictions_file = (\"testpreacts_\"+model_file.split(\".\")[0])+\".txt\"\n", 548 | " print(\"Saving\", test_predictions_file)\n", 549 | " f = open(test_predictions_file,'w')\n", 550 | " for test_preact in test_preacts:\n", 551 | " f.write(\"\\t\".join([str(x) for x in test_preact])+\"\\n\") \n", 552 | " f.close()\n", 553 | " !md5sum $test_predictions_file\n", 554 | " !gzip $test_predictions_file\n", 555 | "\n", 556 | " valid_predictions_file = (\"validpreacts_\"+model_file.split(\".\")[0])+\".txt\"\n", 557 | " print(\"Saving\", valid_predictions_file)\n", 558 | " f = open(valid_predictions_file,'w')\n", 559 | " for valid_preact in valid_preacts:\n", 560 | " f.write(\"\\t\".join([str(x) for x in valid_preact])+\"\\n\") \n", 561 | " f.close()\n", 562 | " !md5sum $valid_predictions_file\n", 563 | " !gzip $valid_predictions_file" 564 | ] 565 | } 566 | ], 567 | "metadata": { 568 | "accelerator": "GPU", 569 | "colab": { 570 | "collapsed_sections": [], 571 | "include_colab_link": true, 572 | "name": "gist - Download CIFAR10 models from zenodo and make predictions.ipynb", 573 | "provenance": [], 574 | "version": "0.3.2" 575 | }, 576 | "kernelspec": { 577 | "display_name": "Python [conda env:basepair]", 578 | "language": "python", 579 | "name": "conda-env-basepair-py" 580 | }, 581 | "language_info": { 582 | "codemirror_mode": { 583 | "name": "ipython", 584 | "version": 3 585 | }, 586 | "file_extension": ".py", 587 | "mimetype": "text/x-python", 588 | "name": "python", 589 | "nbconvert_exporter": "python", 590 | "pygments_lexer": "ipython3", 591 | "version": "3.6.8" 592 | } 593 | }, 594 | "nbformat": 4, 595 | "nbformat_minor": 4 596 | } 597 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | if __name__== '__main__': 4 | setup(include_package_data=True, 5 | description='Experiments on label shift domain adaptation', 6 | url='NA', 7 | version='0.1.0.0', 8 | packages=['labelshiftexperiments'], 9 | setup_requires=[], 10 | install_requires=['numpy>=1.9', 11 | 'scikit-learn>=0.20.0', 12 | 'scipy>=1.1.0'], 13 | scripts=[], 14 | name='labelshiftexperiments') 15 | --------------------------------------------------------------------------------