├── .gitignore ├── README.md ├── cifar100vgg.py ├── cifar10vgg.py └── risk_control.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Selective Classification for Deep Neural Networks 2 | By Yonatan Geifman and Ran El-Yaniv 3 | 4 | ### Introduction: 5 | This repository contains the implementation of VGG-16 models for CIFAR-10 and CIFAR-100 and the risk bound proposed in the paper ["Selective Classification for Deep Neural Networks"](https://arxiv.org/abs/1705.08500). 6 | 7 | ### Citation 8 | 9 | If you use these models in your research, please cite: 10 | 11 | @ARTICLE{2017arXiv170508500G, 12 | author = {{Geifman}, Y. and {El-Yaniv}, R.}, 13 | title = "{Selective Classification for Deep Neural Networks}", 14 | journal = {ArXiv e-prints}, 15 | archivePrefix = "arXiv", 16 | eprint = {1705.08500}, 17 | year = 2017 18 | } 19 | 20 | ### Weights files: 21 | 22 | [cifar-100 weights](https://drive.google.com/open?id=0B4odNGNGJ56qTEdnT1RjTU44Zms) 23 | 24 | [cifar-10 weights](https://drive.google.com/open?id=0B4odNGNGJ56qVW9JdkthbzBsX28) 25 | 26 | 27 | -------------------------------------------------------------------------------- /cifar100vgg.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import keras 4 | from keras.datasets import cifar100 5 | from keras.preprocessing.image import ImageDataGenerator 6 | from keras.models import Sequential 7 | from keras.layers import Dense, Dropout, Activation, Flatten 8 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization 9 | from keras import optimizers 10 | import numpy as np 11 | import pickle 12 | from keras.layers.core import Lambda 13 | from keras import backend as K 14 | from keras import regularizers 15 | from risk_control import risk_control 16 | 17 | class cifar100vgg: 18 | def __init__(self,train=True): 19 | self.num_classes = 100 20 | self.weight_decay = 0.0005 21 | self.x_shape = [32,32,3] 22 | 23 | self.model = self.build_model() 24 | if train: 25 | self.model = self.train(self.model) 26 | else: 27 | self.model.load_weights('cifar100vgg.h5') 28 | 29 | 30 | def build_model(self): 31 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 32 | 33 | model = Sequential() 34 | weight_decay = self.weight_decay 35 | 36 | model.add(Conv2D(64, (3, 3), padding='same', 37 | input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay))) 38 | model.add(Activation('relu')) 39 | model.add(BatchNormalization()) 40 | model.add(Dropout(0.3)) 41 | 42 | model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 43 | model.add(Activation('relu')) 44 | model.add(BatchNormalization()) 45 | 46 | model.add(MaxPooling2D(pool_size=(2, 2))) 47 | 48 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 49 | model.add(Activation('relu')) 50 | model.add(BatchNormalization()) 51 | model.add(Dropout(0.4)) 52 | 53 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 54 | model.add(Activation('relu')) 55 | model.add(BatchNormalization()) 56 | 57 | model.add(MaxPooling2D(pool_size=(2, 2))) 58 | 59 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 60 | model.add(Activation('relu')) 61 | model.add(BatchNormalization()) 62 | model.add(Dropout(0.4)) 63 | 64 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 65 | model.add(Activation('relu')) 66 | model.add(BatchNormalization()) 67 | model.add(Dropout(0.4)) 68 | 69 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 70 | model.add(Activation('relu')) 71 | model.add(BatchNormalization()) 72 | 73 | model.add(MaxPooling2D(pool_size=(2, 2))) 74 | 75 | 76 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 77 | model.add(Activation('relu')) 78 | model.add(BatchNormalization()) 79 | model.add(Dropout(0.4)) 80 | 81 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 82 | model.add(Activation('relu')) 83 | model.add(BatchNormalization()) 84 | model.add(Dropout(0.4)) 85 | 86 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 87 | model.add(Activation('relu')) 88 | model.add(BatchNormalization()) 89 | 90 | model.add(MaxPooling2D(pool_size=(2, 2))) 91 | 92 | 93 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 94 | model.add(Activation('relu')) 95 | model.add(BatchNormalization()) 96 | model.add(Dropout(0.4)) 97 | 98 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 99 | model.add(Activation('relu')) 100 | model.add(BatchNormalization()) 101 | model.add(Dropout(0.4)) 102 | 103 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 104 | model.add(Activation('relu')) 105 | model.add(BatchNormalization()) 106 | 107 | model.add(MaxPooling2D(pool_size=(2, 2))) 108 | model.add(Dropout(0.5)) 109 | 110 | model.add(Flatten()) 111 | model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) 112 | model.add(Activation('relu')) 113 | model.add(BatchNormalization()) 114 | 115 | model.add(Dropout(0.5)) 116 | model.add(Dense(self.num_classes)) 117 | model.add(Activation('softmax')) 118 | return model 119 | 120 | 121 | def normalize(self,X_train,X_test): 122 | #this function normalize inputs for zero mean and unit variance 123 | # it is used when training a model. 124 | # Input: training set and test set 125 | # Output: normalized training set and test set according to the trianing set statistics. 126 | mean = np.mean(X_train,axis=(0,1,2,3)) 127 | std = np.std(X_train, axis=(0, 1, 2, 3)) 128 | print(mean) 129 | print(std) 130 | X_train = (X_train-mean)/(std+1e-7) 131 | X_test = (X_test-mean)/(std+1e-7) 132 | return X_train, X_test 133 | 134 | def normalize_production(self,x): 135 | #this function is used to normalize instances in production according to saved training set statistics 136 | # Input: X - a training set 137 | # Output X - a normalized training set according to normalization constants. 138 | 139 | #these values produced during first training and are general for the standard cifar10 training set normalization 140 | mean = 121.936 141 | std = 68.389 142 | return (x-mean)/(std+1e-7) 143 | 144 | def predict(self,x,normalize=True,batch_size=50): 145 | if normalize: 146 | x = self.normalize_production(x) 147 | return self.model.predict(x,batch_size) 148 | 149 | def train(self,model): 150 | 151 | #training parameters 152 | batch_size = 128 153 | maxepoches = 250 154 | learning_rate = 0.1 155 | lr_decay = 1e-6 156 | 157 | # The data, shuffled and split between train and test sets: 158 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 159 | x_train = x_train.astype('float32') 160 | x_test = x_test.astype('float32') 161 | x_train, x_test = self.normalize(x_train, x_test) 162 | 163 | y_train = keras.utils.to_categorical(y_train, self.num_classes) 164 | y_test = keras.utils.to_categorical(y_test, self.num_classes) 165 | 166 | lrf = learning_rate 167 | 168 | 169 | #data augmentation 170 | datagen = ImageDataGenerator( 171 | featurewise_center=False, # set input mean to 0 over the dataset 172 | samplewise_center=False, # set each sample mean to 0 173 | featurewise_std_normalization=False, # divide inputs by std of the dataset 174 | samplewise_std_normalization=False, # divide each input by its std 175 | zca_whitening=False, # apply ZCA whitening 176 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 177 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 178 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 179 | horizontal_flip=True, # randomly flip images 180 | vertical_flip=False) # randomly flip images 181 | # (std, mean, and principal components if ZCA whitening is applied). 182 | datagen.fit(x_train) 183 | 184 | 185 | 186 | #optimization details 187 | sgd = optimizers.SGD(lr=lrf, decay=lr_decay, momentum=0.9, nesterov=True) 188 | model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy']) 189 | 190 | 191 | # training process in a for loop with learning rate drop every 25 epoches. 192 | 193 | for epoch in range(1,maxepoches): 194 | 195 | if epoch%25==0 and epoch>0: 196 | lrf/=2 197 | sgd = optimizers.SGD(lr=lrf, decay=lr_decay, momentum=0.9, nesterov=True) 198 | model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) 199 | 200 | historytemp = model.fit_generator(datagen.flow(x_train, y_train, 201 | batch_size=batch_size), 202 | steps_per_epoch=x_train.shape[0] // batch_size, 203 | epochs=epoch, 204 | validation_data=(x_test, y_test),initial_epoch=epoch-1) 205 | model.save_weights('cifar100vgg.h5') 206 | return model 207 | 208 | if __name__ == '__main__': 209 | 210 | delta = 0.001 211 | 212 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 213 | x_train = x_train.astype('float32') 214 | x_test = x_test.astype('float32') 215 | 216 | y_train = keras.utils.to_categorical(y_train, 100) 217 | y_test = keras.utils.to_categorical(y_test, 100) 218 | 219 | model = cifar100vgg(train=False) 220 | 221 | predicted_x = model.predict(x_test) 222 | kappa = np.max(predicted_x,1) 223 | residuals = (np.argmax(predicted_x,1)!=np.argmax(y_test,1)) 224 | bound_cal = risk_control() 225 | [theta, b_star] = bound_cal.bound(0.15,delta,kappa,residuals) 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /cifar10vgg.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import keras 4 | from keras.datasets import cifar10 5 | from keras.preprocessing.image import ImageDataGenerator 6 | from keras.models import Sequential 7 | from keras.layers import Dense, Dropout, Activation, Flatten 8 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization 9 | from keras import optimizers 10 | import numpy as np 11 | import pickle 12 | from keras.layers.core import Lambda 13 | from keras import backend as K 14 | from keras import regularizers 15 | from risk_control import risk_control 16 | 17 | class cifar10vgg: 18 | def __init__(self,train=False): 19 | self.num_classes = 10 20 | self.weight_decay = 0.0005 21 | self.x_shape = [32,32,3] 22 | 23 | self.model = self.build_model() 24 | if train: 25 | self.model = self.train(self.model) 26 | else: 27 | self.model.load_weights('cifar10vgg.h5') 28 | 29 | 30 | def build_model(self): 31 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 32 | 33 | model = Sequential() 34 | weight_decay = self.weight_decay 35 | 36 | model.add(Conv2D(64, (3, 3), padding='same', 37 | input_shape=self.x_shape,kernel_regularizer=regularizers.l2(weight_decay))) 38 | model.add(Activation('relu')) 39 | model.add(BatchNormalization()) 40 | model.add(Dropout(0.3)) 41 | 42 | model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 43 | model.add(Activation('relu')) 44 | model.add(BatchNormalization()) 45 | 46 | model.add(MaxPooling2D(pool_size=(2, 2))) 47 | 48 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 49 | model.add(Activation('relu')) 50 | model.add(BatchNormalization()) 51 | model.add(Dropout(0.4)) 52 | 53 | model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 54 | model.add(Activation('relu')) 55 | model.add(BatchNormalization()) 56 | 57 | model.add(MaxPooling2D(pool_size=(2, 2))) 58 | 59 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 60 | model.add(Activation('relu')) 61 | model.add(BatchNormalization()) 62 | model.add(Dropout(0.4)) 63 | 64 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 65 | model.add(Activation('relu')) 66 | model.add(BatchNormalization()) 67 | model.add(Dropout(0.4)) 68 | 69 | model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 70 | model.add(Activation('relu')) 71 | model.add(BatchNormalization()) 72 | 73 | model.add(MaxPooling2D(pool_size=(2, 2))) 74 | 75 | 76 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 77 | model.add(Activation('relu')) 78 | model.add(BatchNormalization()) 79 | model.add(Dropout(0.4)) 80 | 81 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 82 | model.add(Activation('relu')) 83 | model.add(BatchNormalization()) 84 | model.add(Dropout(0.4)) 85 | 86 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 87 | model.add(Activation('relu')) 88 | model.add(BatchNormalization()) 89 | 90 | model.add(MaxPooling2D(pool_size=(2, 2))) 91 | 92 | 93 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 94 | model.add(Activation('relu')) 95 | model.add(BatchNormalization()) 96 | model.add(Dropout(0.4)) 97 | 98 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 99 | model.add(Activation('relu')) 100 | model.add(BatchNormalization()) 101 | model.add(Dropout(0.4)) 102 | 103 | model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) 104 | model.add(Activation('relu')) 105 | model.add(BatchNormalization()) 106 | 107 | model.add(MaxPooling2D(pool_size=(2, 2))) 108 | model.add(Dropout(0.5)) 109 | 110 | model.add(Flatten()) 111 | model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) 112 | model.add(Activation('relu')) 113 | model.add(BatchNormalization()) 114 | 115 | model.add(Dropout(0.5)) 116 | model.add(Dense(self.num_classes)) 117 | model.add(Activation('softmax')) 118 | return model 119 | 120 | 121 | def normalize(self,X_train,X_test): 122 | #this function normalize inputs for zero mean and unit variance 123 | # it is used when training a model. 124 | # Input: training set and test set 125 | # Output: normalized training set and test set according to the trianing set statistics. 126 | mean = np.mean(X_train,axis=(0,1,2,3)) 127 | std = np.std(X_train, axis=(0, 1, 2, 3)) 128 | X_train = (X_train-mean)/(std+1e-7) 129 | X_test = (X_test-mean)/(std+1e-7) 130 | return X_train, X_test 131 | 132 | def normalize_production(self,x): 133 | #this function is used to normalize instances in production according to saved training set statistics 134 | # Input: X - a training set 135 | # Output X - a normalized training set according to normalization constants. 136 | 137 | #these values produced during first training and are general for the standard cifar10 training set normalization 138 | mean = 120.707 139 | std = 64.15 140 | return (x-mean)/(std+1e-7) 141 | 142 | def predict(self,x,normalize=True,batch_size=50): 143 | if normalize: 144 | x = self.normalize_production(x) 145 | return self.model.predict(x,batch_size) 146 | 147 | def train(self,model): 148 | 149 | #training parameters 150 | batch_size = 128 151 | maxepoches = 250 152 | learning_rate = 0.1 153 | lr_decay = 1e-6 154 | 155 | # The data, shuffled and split between train and test sets: 156 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 157 | x_train = x_train.astype('float32') 158 | x_test = x_test.astype('float32') 159 | x_train, x_test = self.normalize(x_train, x_test) 160 | 161 | y_train = keras.utils.to_categorical(y_train, self.num_classes) 162 | y_test = keras.utils.to_categorical(y_test, self.num_classes) 163 | 164 | lrf = learning_rate 165 | 166 | 167 | #data augmentation 168 | datagen = ImageDataGenerator( 169 | featurewise_center=False, # set input mean to 0 over the dataset 170 | samplewise_center=False, # set each sample mean to 0 171 | featurewise_std_normalization=False, # divide inputs by std of the dataset 172 | samplewise_std_normalization=False, # divide each input by its std 173 | zca_whitening=False, # apply ZCA whitening 174 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 175 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 176 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 177 | horizontal_flip=True, # randomly flip images 178 | vertical_flip=False) # randomly flip images 179 | # (std, mean, and principal components if ZCA whitening is applied). 180 | datagen.fit(x_train) 181 | 182 | 183 | 184 | #optimization details 185 | sgd = optimizers.SGD(lr=lrf, decay=lr_decay, momentum=0.9, nesterov=True) 186 | model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy']) 187 | 188 | 189 | # training process in a for loop with learning rate drop every 25 epoches. 190 | 191 | for epoch in range(1,maxepoches): 192 | 193 | if epoch%25==0 and epoch>0: 194 | lrf/=2 195 | sgd = optimizers.SGD(lr=lrf, decay=lr_decay, momentum=0.9, nesterov=True) 196 | model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) 197 | 198 | historytemp = model.fit_generator(datagen.flow(x_train, y_train, 199 | batch_size=batch_size), 200 | steps_per_epoch=x_train.shape[0] // batch_size, 201 | epochs=epoch, 202 | validation_data=(x_test, y_test),initial_epoch=epoch-1) 203 | model.save_weights('cifar10vgg.h5') 204 | return model 205 | 206 | if __name__ == '__main__': 207 | 208 | delta = 0.001 209 | 210 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 211 | x_train = x_train.astype('float32') 212 | x_test = x_test.astype('float32') 213 | 214 | y_train = keras.utils.to_categorical(y_train, 10) 215 | y_test = keras.utils.to_categorical(y_test, 10) 216 | 217 | model = cifar10vgg(train=True) 218 | 219 | predicted_x = model.predict(x_test) 220 | kappa = np.max(predicted_x,1) 221 | residuals = (np.argmax(predicted_x,1)!=np.argmax(y_test,1)) 222 | bound_cal = risk_control() 223 | [theta, b_star] =bound_cal.bound(0.02,delta,kappa,residuals) 224 | print(theta) 225 | [theta, b_star] =bound_cal.bound(0.04,delta,kappa,residuals) 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /risk_control.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from scipy.stats import binom 4 | import scipy 5 | import math 6 | from scipy.optimize import fsolve 7 | import random 8 | 9 | class risk_control: 10 | 11 | def calculate_bound(self,delta,m,erm): 12 | #This function is a solver for the inverse of binomial CDF based on binary search. 13 | precision = 1e-7 14 | def func(b): 15 | return (-1*delta) + scipy.stats.binom.cdf(int(m*erm),m,b) 16 | a=erm #start binary search from the empirical risk 17 | c=1 # the upper bound is 1 18 | b = (a+c)/2 #mid point 19 | funcval =func(b) 20 | while abs(funcval)>precision: 21 | if a == 1.0 and c == 1.0: 22 | b = 1.0 23 | break 24 | elif funcval>0: 25 | a=b 26 | else: 27 | c=b 28 | b = (a + c) / 2 29 | funcval = func(b) 30 | return b 31 | 32 | def bound(self,rstar,delta,kappa,residuals,split=True): 33 | # A function to calculate the risk bound proposed in the paper, the algorithm is based on algorithm 1 from the paper. 34 | #Input: rstar - the requested risk bound 35 | # delta - the desired delta 36 | # kappa - rating function over the points (higher values is more confident prediction) 37 | # residuals - a vector of the residuals of the samples 0 is correct prediction and 1 corresponding to an error 38 | # split - is a boolean controls whether to split train and test 39 | #Output - [theta, bound] (also prints latex text for the tables in the paper) 40 | 41 | # when spliting to train and test this represents the fraction of the validation size 42 | valsize = 0.5 43 | 44 | probs = kappa 45 | FY = residuals 46 | 47 | 48 | if split: 49 | idx = list(range(len(FY))) 50 | random.shuffle(idx) 51 | slice = round(len(FY)*(1-valsize)) 52 | FY_val = FY[idx[slice:]] 53 | probs_val = probs[idx[slice:]] 54 | FY = FY[idx[:slice]] 55 | probs = probs[idx[:slice]] 56 | m = len(FY) 57 | 58 | probs_idx_sorted = np.argsort(probs) 59 | 60 | a=0 61 | b = m-1 62 | deltahat = delta/math.ceil(math.log2(m)) 63 | 64 | for q in range(math.ceil(math.log2(m))+1): 65 | # the for runs log(m)+1 iterations but actually the bound calculated on only log(m) different candidate thetas 66 | mid = math.ceil((a+b)/2) 67 | 68 | mi = len(FY[probs_idx_sorted[mid:]]) 69 | theta = probs[probs_idx_sorted[mid]] 70 | risk = sum(FY[probs_idx_sorted[mid:]])/mi 71 | if split: 72 | testrisk = sum(FY_val[probs_val>=theta])/len(FY_val[probs_val>=theta]) 73 | testcov = len(FY_val[probs_val>=theta])/len(FY_val) 74 | bound = self.calculate_bound(deltahat,mi,risk) 75 | coverage = mi/m 76 | if bound>rstar: 77 | a=mid 78 | else: 79 | b=mid 80 | 81 | if split: 82 | print("%.2f & %.4f & %.4f & %.4f & %.4f & %.4f \\\\" % (rstar,risk,coverage,testrisk,testcov,bound)) 83 | else: 84 | print("%.2f & %.4f & %.4f & %.4f \\\\" % (rstar,risk,coverage,bound)) 85 | return [theta,bound] 86 | 87 | 88 | 89 | --------------------------------------------------------------------------------