├── README.md ├── checkpoints └── readme.md ├── models ├── __init__.py ├── catdog_vgg_selectivenet.py ├── cifar10_vgg_selectivenet.py └── svhn_vgg_selectivenet.py ├── results └── readme.md ├── selectivnet_utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # SelectiveNet 2 | Code for the paper "SelectiveNet: A Deep Neural Network with an Integrated Reject Option" 3 | ֿ 4 | To run an experiment run train.py --model_name [name] --baseline [name] --dataset [cifar_10/catsdogs/SVHN] 5 | 6 | The calibration could be performed using the function post_calibration in selectivenet_utils.py 7 | 8 | -------------------------------------------------------------------------------- /checkpoints/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/catdog_vgg_selectivenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import keras 4 | import numpy as np 5 | import os 6 | import pickle 7 | from keras import backend as K 8 | from keras import backend as K 9 | from keras import optimizers 10 | from keras import regularizers 11 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Concatenate 12 | from keras.layers import Dense, Dropout, Activation, Flatten, Input 13 | from keras.layers.core import Lambda 14 | from keras.models import Model 15 | from keras.preprocessing.image import ImageDataGenerator 16 | 17 | from selectivnet_utils import * 18 | 19 | 20 | class CatsvsDogVgg: 21 | def __init__(self, train=True, filename="weightsvgg.h5", coverage=0.8, alpha=0.5, baseline=False): 22 | self.lamda = coverage 23 | self.alpha = alpha 24 | self.mc_dropout_rate = K.variable(value=0) 25 | self.num_classes = 2 26 | self.weight_decay = 0.0005 27 | self._load_data() 28 | 29 | self.x_shape = self.x_train.shape[1:] 30 | self.filename = filename 31 | 32 | self.model = self.build_model() 33 | if baseline: 34 | self.alpha = 0 35 | 36 | if train: 37 | self.model = self.train(self.model) 38 | else: 39 | self.model.load_weights("checkpoints/{}".format(self.filename)) 40 | 41 | def build_model(self): 42 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 43 | weight_decay = self.weight_decay 44 | basic_dropout_rate = 0.3 45 | input = Input(shape=self.x_shape) 46 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(input) 47 | curr = Activation('relu')(curr) 48 | curr = BatchNormalization()(curr) 49 | curr = Dropout(basic_dropout_rate)(curr) 50 | 51 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 52 | 53 | curr = Activation('relu')(curr) 54 | curr = BatchNormalization()(curr) 55 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 56 | 57 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 58 | 59 | curr = Activation('relu')(curr) 60 | curr = BatchNormalization()(curr) 61 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 62 | 63 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 64 | 65 | curr = Activation('relu')(curr) 66 | curr = BatchNormalization()(curr) 67 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 68 | 69 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 70 | 71 | curr = Activation('relu')(curr) 72 | curr = BatchNormalization()(curr) 73 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 74 | 75 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 76 | 77 | curr = Activation('relu')(curr) 78 | curr = BatchNormalization()(curr) 79 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 80 | 81 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 82 | 83 | curr = Activation('relu')(curr) 84 | curr = BatchNormalization()(curr) 85 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 86 | 87 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 88 | 89 | curr = Activation('relu')(curr) 90 | curr = BatchNormalization()(curr) 91 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 92 | 93 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 94 | 95 | curr = Activation('relu')(curr) 96 | curr = BatchNormalization()(curr) 97 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 98 | 99 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 100 | 101 | curr = Activation('relu')(curr) 102 | curr = BatchNormalization()(curr) 103 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 104 | 105 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 106 | 107 | curr = Activation('relu')(curr) 108 | curr = BatchNormalization()(curr) 109 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 110 | 111 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 112 | 113 | curr = Activation('relu')(curr) 114 | curr = BatchNormalization()(curr) 115 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 116 | 117 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 118 | 119 | curr = Activation('relu')(curr) 120 | curr = BatchNormalization()(curr) 121 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 122 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 123 | 124 | curr = Flatten()(curr) 125 | curr = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 126 | 127 | curr = Activation('relu')(curr) 128 | curr = BatchNormalization()(curr) 129 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 130 | curr = Lambda(lambda x: K.dropout(x, level=self.mc_dropout_rate))(curr) 131 | 132 | # classification head (f) 133 | curr1 = Dense(self.num_classes, activation='softmax')(curr) 134 | 135 | # selection head (g) 136 | curr2 = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 137 | curr2 = Activation('relu')(curr2) 138 | curr2 = BatchNormalization()(curr2) 139 | # this normalization is identical to initialization of batchnorm gamma to 1/10 140 | curr2 = Lambda(lambda x: x / 10)(curr2) 141 | curr2 = Dense(1, activation='sigmoid')(curr2) 142 | # auxiliary head (h) 143 | selective_output = Concatenate(axis=1, name="selective_head")([curr1, curr2]) 144 | 145 | auxiliary_output = Dense(self.num_classes, activation='softmax', name="classification_head")(curr) 146 | 147 | model = Model(inputs=input, outputs=[selective_output, auxiliary_output]) 148 | 149 | self.input = input 150 | self.model_embeding = Model(inputs=input, outputs=curr) 151 | return model 152 | 153 | def normalize(self, X_train, X_test): 154 | # this function normalize inputs for zero mean and unit variance 155 | # it is used when training a model. 156 | # Input: training set and test set 157 | # Output: normalized training set and test set according to the trianing set statistics. 158 | mean = np.mean(X_train, axis=(0, 1, 2, 3)) 159 | std = np.std(X_train, axis=(0, 1, 2, 3)) 160 | X_train = (X_train - mean) / (std + 1e-7) 161 | X_test = (X_test - mean) / (std + 1e-7) 162 | return X_train, X_test 163 | 164 | def predict(self, x=None, batch_size=128): 165 | if x is None: 166 | x = self.x_test 167 | return self.model.predict(x, batch_size) 168 | 169 | def predict_embedding(self, x=None, batch_size=128): 170 | if x is None: 171 | x = self.x_test 172 | return self.model_embeding.predict(x, batch_size) 173 | 174 | def mc_dropout(self, batch_size=1000, dropout=0.5, iter=100): 175 | K.set_value(self.mc_dropout_rate, dropout) 176 | repititions = [] 177 | for i in range(iter): 178 | _, pred = self.model.predict(self.x_test, batch_size) 179 | repititions.append(pred) 180 | K.set_value(self.mc_dropout_rate, 0) 181 | 182 | repititions = np.array(repititions) 183 | mc = np.var(repititions, 0) 184 | mc = np.mean(mc, -1) 185 | return -mc 186 | 187 | def selective_risk_at_coverage(self, coverage, mc=False): 188 | _, pred = self.predict() 189 | 190 | if mc: 191 | sr = np.max(pred, 1) 192 | else: 193 | sr = self.mc_dropout() 194 | sr_sorted = np.sort(sr) 195 | threshold = sr_sorted[pred.shape[0] - int(coverage * pred.shape[0])] 196 | covered_idx = sr > threshold 197 | selective_acc = np.mean(np.argmax(pred[covered_idx], 1) == np.argmax(self.y_test[covered_idx], 1)) 198 | return selective_acc 199 | 200 | def _load_data(self): 201 | 202 | # The data, shuffled and split between train and test sets: 203 | (x_train, y_train), (x_test, y_test_label) = load_cats_vs_dogs() 204 | x_train = x_train.astype('float32') 205 | x_test = x_test.astype('float32') 206 | self.x_train, self.x_test = self.normalize(x_train, x_test) 207 | 208 | self.y_train = keras.utils.to_categorical(y_train, self.num_classes + 1) 209 | self.y_test = keras.utils.to_categorical(y_test_label, self.num_classes + 1) 210 | 211 | def train(self, model): 212 | c = self.lamda 213 | lamda = 32 214 | 215 | def selective_loss(y_true, y_pred): 216 | loss = K.categorical_crossentropy( 217 | K.repeat_elements(y_pred[:, -1:], self.num_classes, axis=1) * y_true[:, :-1], 218 | y_pred[:, :-1]) + lamda * K.maximum(-K.mean(y_pred[:, -1]) + c, 0) ** 2 219 | return loss 220 | 221 | def selective_acc(y_true, y_pred): 222 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 223 | temp1 = K.sum( 224 | (g) * K.cast(K.equal(K.argmax(y_true[:, :-1], axis=-1), K.argmax(y_pred[:, :-1], axis=-1)), K.floatx())) 225 | temp1 = temp1 / K.sum(g) 226 | return K.cast(temp1, K.floatx()) 227 | 228 | def coverage(y_true, y_pred): 229 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 230 | return K.mean(g) 231 | 232 | # training parameters 233 | batch_size = 128 234 | maxepoches = 300 235 | learning_rate = 0.1 236 | 237 | lr_decay = 1e-6 238 | 239 | lr_drop = 25 240 | 241 | def lr_scheduler(epoch): 242 | return learning_rate * (0.5 ** (epoch // lr_drop)) 243 | 244 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 245 | 246 | # data augmentation 247 | datagen = ImageDataGenerator( 248 | featurewise_center=False, # set input mean to 0 over the dataset 249 | samplewise_center=False, # set each sample mean to 0 250 | featurewise_std_normalization=False, # divide inputs by std of the dataset 251 | samplewise_std_normalization=False, # divide each input by its std 252 | zca_whitening=False, # apply ZCA whitening 253 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 254 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 255 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 256 | horizontal_flip=True, # randomly flip images 257 | vertical_flip=False) # randomly flip images 258 | # (std, mean, and principal components if ZCA whitening is applied). 259 | datagen.fit(self.x_train) 260 | 261 | # optimization details 262 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 263 | 264 | model.compile(loss=[selective_loss, 'categorical_crossentropy'], loss_weights=[self.alpha, 1 - self.alpha], 265 | optimizer=sgd, metrics=['accuracy', selective_acc, coverage]) 266 | 267 | historytemp = model.fit_generator(my_generator(datagen.flow, self.x_train, self.y_train, 268 | batch_size=batch_size, k=self.num_classes), 269 | steps_per_epoch=self.x_train.shape[0] // batch_size, 270 | epochs=maxepoches, callbacks=[reduce_lr], 271 | validation_data=(self.x_test, [self.y_test, self.y_test[:, :-1]])) 272 | 273 | 274 | with open("checkpoints/{}_history.pkl".format(self.filename[:-3]), 'wb') as handle: 275 | pickle.dump(historytemp.history, handle, protocol=pickle.HIGHEST_PROTOCOL) 276 | 277 | model.save_weights("checkpoints/{}".format(self.filename)) 278 | 279 | return model 280 | -------------------------------------------------------------------------------- /models/cifar10_vgg_selectivenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import keras 4 | import numpy as np 5 | import pickle 6 | from keras import backend as K 7 | from keras import backend as K 8 | from keras import optimizers 9 | from keras import regularizers 10 | from keras.datasets import cifar10 11 | from keras.engine.topology import Layer 12 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Concatenate 13 | from keras.layers import Dense, Dropout, Activation, Flatten, Input 14 | from keras.layers.core import Lambda 15 | from keras.models import Model 16 | from keras.models import Sequential 17 | from keras.preprocessing.image import ImageDataGenerator 18 | 19 | from selectivnet_utils import * 20 | 21 | 22 | class cifar10vgg: 23 | def __init__(self, train=True, filename="weightsvgg.h5", coverage=0.8, alpha=0.5, baseline=False): 24 | self.lamda = coverage 25 | self.alpha = alpha 26 | self.mc_dropout_rate = K.variable(value=0) 27 | self.num_classes = 10 28 | self.weight_decay = 0.0005 29 | self._load_data() 30 | 31 | self.x_shape = self.x_train.shape[1:] 32 | self.filename = filename 33 | 34 | self.model = self.build_model() 35 | if baseline: 36 | self.alpha = 0 37 | 38 | if train: 39 | self.model = self.train(self.model) 40 | else: 41 | self.model.load_weights("checkpoints/{}".format(self.filename)) 42 | 43 | def build_model(self): 44 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 45 | weight_decay = self.weight_decay 46 | basic_dropout_rate = 0.3 47 | input = Input(shape=self.x_shape) 48 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(input) 49 | curr = Activation('relu')(curr) 50 | curr = BatchNormalization()(curr) 51 | curr = Dropout(basic_dropout_rate)(curr) 52 | 53 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 54 | 55 | curr = Activation('relu')(curr) 56 | curr = BatchNormalization()(curr) 57 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 58 | 59 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 60 | 61 | curr = Activation('relu')(curr) 62 | curr = BatchNormalization()(curr) 63 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 64 | 65 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 66 | 67 | curr = Activation('relu')(curr) 68 | curr = BatchNormalization()(curr) 69 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 70 | 71 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 72 | 73 | curr = Activation('relu')(curr) 74 | curr = BatchNormalization()(curr) 75 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 76 | 77 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 78 | 79 | curr = Activation('relu')(curr) 80 | curr = BatchNormalization()(curr) 81 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 82 | 83 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 84 | 85 | curr = Activation('relu')(curr) 86 | curr = BatchNormalization()(curr) 87 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 88 | 89 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 90 | 91 | curr = Activation('relu')(curr) 92 | curr = BatchNormalization()(curr) 93 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 94 | 95 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 96 | 97 | curr = Activation('relu')(curr) 98 | curr = BatchNormalization()(curr) 99 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 100 | 101 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 102 | 103 | curr = Activation('relu')(curr) 104 | curr = BatchNormalization()(curr) 105 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 106 | 107 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 108 | 109 | curr = Activation('relu')(curr) 110 | curr = BatchNormalization()(curr) 111 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 112 | 113 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 114 | 115 | curr = Activation('relu')(curr) 116 | curr = BatchNormalization()(curr) 117 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 118 | 119 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 120 | 121 | curr = Activation('relu')(curr) 122 | curr = BatchNormalization()(curr) 123 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 124 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 125 | 126 | curr = Flatten()(curr) 127 | curr = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 128 | 129 | curr = Activation('relu')(curr) 130 | curr = BatchNormalization()(curr) 131 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 132 | curr = Lambda(lambda x: K.dropout(x, level=self.mc_dropout_rate))(curr) 133 | 134 | # classification head (f) 135 | curr1 = Dense(self.num_classes, activation='softmax')(curr) 136 | 137 | # selection head (g) 138 | curr2 = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 139 | curr2 = Activation('relu')(curr2) 140 | curr2 = BatchNormalization()(curr2) 141 | # this normalization is identical to initialization of batchnorm gamma to 1/10 142 | curr2 = Lambda(lambda x: x / 10)(curr2) 143 | curr2 = Dense(1, activation='sigmoid')(curr2) 144 | # auxiliary head (h) 145 | selective_output = Concatenate(axis=1, name="selective_head")([curr1, curr2]) 146 | 147 | auxiliary_output = Dense(self.num_classes, activation='softmax', name="classification_head")(curr) 148 | 149 | model = Model(inputs=input, outputs=[selective_output, auxiliary_output]) 150 | 151 | self.input = input 152 | self.model_embeding = Model(inputs=input, outputs=curr) 153 | return model 154 | 155 | def normalize(self, X_train, X_test): 156 | # this function normalize inputs for zero mean and unit variance 157 | # it is used when training a model. 158 | # Input: training set and test set 159 | # Output: normalized training set and test set according to the trianing set statistics. 160 | mean = np.mean(X_train, axis=(0, 1, 2, 3)) 161 | std = np.std(X_train, axis=(0, 1, 2, 3)) 162 | X_train = (X_train - mean) / (std + 1e-7) 163 | X_test = (X_test - mean) / (std + 1e-7) 164 | return X_train, X_test 165 | 166 | def predict(self, x=None, batch_size=128): 167 | if x is None: 168 | x = self.x_test 169 | return self.model.predict(x, batch_size) 170 | 171 | def predict_embedding(self, x=None, batch_size=128): 172 | if x is None: 173 | x = self.x_test 174 | return self.model_embeding.predict(x, batch_size) 175 | 176 | def mc_dropout(self, batch_size=1000, dropout=0.5, iter=100): 177 | K.set_value(self.mc_dropout_rate, dropout) 178 | repititions = [] 179 | for i in range(iter): 180 | _, pred = self.model.predict(self.x_test, batch_size) 181 | repititions.append(pred) 182 | K.set_value(self.mc_dropout_rate, 0) 183 | 184 | repititions = np.array(repititions) 185 | mc = np.var(repititions, 0) 186 | mc = np.mean(mc, -1) 187 | return -mc 188 | 189 | def selective_risk_at_coverage(self, coverage, mc=False): 190 | _, pred = self.predict() 191 | 192 | if mc: 193 | sr = np.max(pred, 1) 194 | else: 195 | sr = self.mc_dropout() 196 | sr_sorted = np.sort(sr) 197 | threshold = sr_sorted[pred.shape[0] - int(coverage * pred.shape[0])] 198 | covered_idx = sr > threshold 199 | selective_acc = np.mean(np.argmax(pred[covered_idx], 1) == np.argmax(self.y_test[covered_idx], 1)) 200 | return selective_acc 201 | 202 | def _load_data(self): 203 | 204 | # The data, shuffled and split between train and test sets: 205 | (x_train, y_train), (x_test, y_test_label) = cifar10.load_data() 206 | x_train = x_train.astype('float32') 207 | x_test = x_test.astype('float32') 208 | self.x_train, self.x_test = self.normalize(x_train, x_test) 209 | 210 | self.y_train = keras.utils.to_categorical(y_train, self.num_classes + 1) 211 | self.y_test = keras.utils.to_categorical(y_test_label, self.num_classes + 1) 212 | 213 | def train(self, model): 214 | c = self.lamda 215 | lamda = 32 216 | 217 | def selective_loss(y_true, y_pred): 218 | loss = K.categorical_crossentropy( 219 | K.repeat_elements(y_pred[:, -1:], self.num_classes, axis=1) * y_true[:, :-1], 220 | y_pred[:, :-1]) + lamda * K.maximum(-K.mean(y_pred[:, -1]) + c, 0) ** 2 221 | return loss 222 | 223 | def selective_acc(y_true, y_pred): 224 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 225 | temp1 = K.sum( 226 | (g) * K.cast(K.equal(K.argmax(y_true[:, :-1], axis=-1), K.argmax(y_pred[:, :-1], axis=-1)), K.floatx())) 227 | temp1 = temp1 / K.sum(g) 228 | return K.cast(temp1, K.floatx()) 229 | 230 | def coverage(y_true, y_pred): 231 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 232 | return K.mean(g) 233 | 234 | 235 | 236 | # training parameters 237 | batch_size = 128 238 | maxepoches = 300 239 | learning_rate = 0.1 240 | 241 | lr_decay = 1e-6 242 | 243 | lr_drop = 25 244 | 245 | def lr_scheduler(epoch): 246 | return learning_rate * (0.5 ** (epoch // lr_drop)) 247 | 248 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 249 | 250 | # data augmentation 251 | datagen = ImageDataGenerator( 252 | featurewise_center=False, # set input mean to 0 over the dataset 253 | samplewise_center=False, # set each sample mean to 0 254 | featurewise_std_normalization=False, # divide inputs by std of the dataset 255 | samplewise_std_normalization=False, # divide each input by its std 256 | zca_whitening=False, # apply ZCA whitening 257 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 258 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 259 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 260 | horizontal_flip=True, # randomly flip images 261 | vertical_flip=False) # randomly flip images 262 | # (std, mean, and principal components if ZCA whitening is applied). 263 | datagen.fit(self.x_train) 264 | 265 | # optimization details 266 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 267 | 268 | model.compile(loss=[selective_loss, 'categorical_crossentropy'], loss_weights=[self.alpha, 1 - self.alpha], 269 | optimizer=sgd, metrics=['accuracy', selective_acc, coverage]) 270 | 271 | historytemp = model.fit_generator(my_generator(datagen.flow, self.x_train, self.y_train, 272 | batch_size=batch_size, k=self.num_classes), 273 | steps_per_epoch=self.x_train.shape[0] // batch_size, 274 | epochs=maxepoches, callbacks=[reduce_lr], 275 | validation_data=(self.x_test, [self.y_test, self.y_test[:, :-1]])) 276 | 277 | 278 | with open("checkpoints/{}_history.pkl".format(self.filename[:-3]), 'wb') as handle: 279 | pickle.dump(historytemp.history, handle, protocol=pickle.HIGHEST_PROTOCOL) 280 | 281 | model.save_weights("checkpoints/{}".format(self.filename)) 282 | 283 | return model 284 | -------------------------------------------------------------------------------- /models/svhn_vgg_selectivenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import keras 4 | import numpy as np 5 | import numpy as np 6 | import pickle 7 | import scipy.io as spio 8 | from keras import backend as K 9 | from keras import backend as K 10 | from keras import optimizers 11 | from keras import regularizers 12 | from keras.datasets import cifar10 13 | from keras.engine.topology import Layer 14 | from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Add, Subtract, Concatenate 15 | from keras.layers import Dense, Dropout, Activation, Flatten, Input 16 | from keras.layers.core import Lambda 17 | from keras.models import Model 18 | from keras.models import Sequential 19 | from keras.preprocessing.image import ImageDataGenerator 20 | 21 | from selectivnet_utils import * 22 | 23 | 24 | class SvhnVgg: 25 | def __init__(self, train=True, filename="weightsvgg.h5", coverage=0.8, alpha=0.5, baseline=False): 26 | self.lamda = coverage 27 | self.alpha = alpha 28 | self.mc_dropout_rate = K.variable(value=0) 29 | self.num_classes = 10 30 | self.weight_decay = 0.0005 31 | self._load_data() 32 | 33 | self.x_shape = self.x_train.shape[1:] 34 | self.filename = filename 35 | 36 | self.model = self.build_model() 37 | if baseline: 38 | self.alpha = 0 39 | 40 | if train: 41 | self.model = self.train(self.model) 42 | else: 43 | self.model.load_weights("checkpoints/{}".format(self.filename)) 44 | 45 | def build_model(self): 46 | # Build the network of vgg for 10 classes with massive dropout and weight decay as described in the paper. 47 | weight_decay = self.weight_decay 48 | basic_dropout_rate = 0.3 49 | input = Input(shape=self.x_shape) 50 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(input) 51 | curr = Activation('relu')(curr) 52 | curr = BatchNormalization()(curr) 53 | curr = Dropout(basic_dropout_rate)(curr) 54 | 55 | curr = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 56 | 57 | curr = Activation('relu')(curr) 58 | curr = BatchNormalization()(curr) 59 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 60 | 61 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 62 | 63 | curr = Activation('relu')(curr) 64 | curr = BatchNormalization()(curr) 65 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 66 | 67 | curr = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 68 | 69 | curr = Activation('relu')(curr) 70 | curr = BatchNormalization()(curr) 71 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 72 | 73 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 74 | 75 | curr = Activation('relu')(curr) 76 | curr = BatchNormalization()(curr) 77 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 78 | 79 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 80 | 81 | curr = Activation('relu')(curr) 82 | curr = BatchNormalization()(curr) 83 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 84 | 85 | curr = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 86 | 87 | curr = Activation('relu')(curr) 88 | curr = BatchNormalization()(curr) 89 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 90 | 91 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 92 | 93 | curr = Activation('relu')(curr) 94 | curr = BatchNormalization()(curr) 95 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 96 | 97 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 98 | 99 | curr = Activation('relu')(curr) 100 | curr = BatchNormalization()(curr) 101 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 102 | 103 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 104 | 105 | curr = Activation('relu')(curr) 106 | curr = BatchNormalization()(curr) 107 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 108 | 109 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 110 | 111 | curr = Activation('relu')(curr) 112 | curr = BatchNormalization()(curr) 113 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 114 | 115 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 116 | 117 | curr = Activation('relu')(curr) 118 | curr = BatchNormalization()(curr) 119 | curr = Dropout(basic_dropout_rate + 0.1)(curr) 120 | 121 | curr = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay))(curr) 122 | 123 | curr = Activation('relu')(curr) 124 | curr = BatchNormalization()(curr) 125 | curr = MaxPooling2D(pool_size=(2, 2))(curr) 126 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 127 | 128 | curr = Flatten()(curr) 129 | curr = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 130 | 131 | curr = Activation('relu')(curr) 132 | curr = BatchNormalization()(curr) 133 | curr = Dropout(basic_dropout_rate + 0.2)(curr) 134 | curr = Lambda(lambda x: K.dropout(x, level=self.mc_dropout_rate))(curr) 135 | 136 | # classification head (f) 137 | curr1 = Dense(self.num_classes, activation='softmax')(curr) 138 | 139 | # selection head (g) 140 | curr2 = Dense(512, kernel_regularizer=regularizers.l2(weight_decay))(curr) 141 | curr2 = Activation('relu')(curr2) 142 | curr2 = BatchNormalization()(curr2) 143 | # this normalization is identical to initialization of batchnorm gamma to 1/10 144 | curr2 = Lambda(lambda x: x / 10)(curr2) 145 | curr2 = Dense(1, activation='sigmoid')(curr2) 146 | # auxiliary head (h) 147 | selective_output = Concatenate(axis=1, name="selective_head")([curr1, curr2]) 148 | 149 | auxiliary_output = Dense(self.num_classes, activation='softmax', name="classification_head")(curr) 150 | 151 | model = Model(inputs=input, outputs=[selective_output, auxiliary_output]) 152 | 153 | self.input = input 154 | self.model_embeding = Model(inputs=input, outputs=curr) 155 | return model 156 | 157 | def normalize(self, X_train, X_test): 158 | # this function normalize inputs for zero mean and unit variance 159 | # it is used when training a model. 160 | # Input: training set and test set 161 | # Output: normalized training set and test set according to the trianing set statistics. 162 | mean = np.mean(X_train, axis=(0, 1, 2, 3)) 163 | std = np.std(X_train, axis=(0, 1, 2, 3)) 164 | X_train = (X_train - mean) / (std + 1e-7) 165 | X_test = (X_test - mean) / (std + 1e-7) 166 | return X_train, X_test 167 | 168 | def predict(self, x=None, batch_size=128): 169 | if x is None: 170 | x = self.x_test 171 | return self.model.predict(x, batch_size) 172 | 173 | def predict_embedding(self, x=None, batch_size=128): 174 | if x is None: 175 | x = self.x_test 176 | return self.model_embeding.predict(x, batch_size) 177 | 178 | def mc_dropout(self, batch_size=1000, dropout=0.5, iter=100): 179 | K.set_value(self.mc_dropout_rate, dropout) 180 | repititions = [] 181 | for i in range(iter): 182 | _, pred = self.model.predict(self.x_test, batch_size) 183 | repititions.append(pred) 184 | K.set_value(self.mc_dropout_rate, 0) 185 | 186 | repititions = np.array(repititions) 187 | mc = np.var(repititions, 0) 188 | mc = np.mean(mc, -1) 189 | return -mc 190 | 191 | def selective_risk_at_coverage(self, coverage, mc=False): 192 | _, pred = self.predict() 193 | 194 | if mc: 195 | sr = np.max(pred, 1) 196 | else: 197 | sr = self.mc_dropout() 198 | sr_sorted = np.sort(sr) 199 | threshold = sr_sorted[pred.shape[0] - int(coverage * pred.shape[0])] 200 | covered_idx = sr > threshold 201 | selective_acc = np.mean(np.argmax(pred[covered_idx], 1) == np.argmax(self.y_test[covered_idx], 1)) 202 | return selective_acc 203 | 204 | def _load_data(self): 205 | 206 | mat = spio.loadmat('datasets/train_32x32.mat', squeeze_me=True) 207 | self.x_train = mat["X"] 208 | self.y_train = mat["y"] 209 | self.x_train = np.moveaxis(self.x_train, -1, 0) 210 | del mat 211 | 212 | mat = spio.loadmat('datasets/test_32x32.mat', squeeze_me=True) 213 | self.x_test = mat["X"] 214 | self.y_test = mat["y"] 215 | del mat 216 | 217 | self.x_test = np.moveaxis(self.x_test, -1, 0) 218 | 219 | self.x_train, self.x_test = self.normalize(self.x_train, self.x_test) 220 | self.x_train = self.x_train.astype('float32') 221 | self.x_test = self.x_test.astype('float32') 222 | 223 | # 6. Preprocess class labels 224 | self.y_train = keras.utils.to_categorical(self.y_train - 1, self.num_classes+1) 225 | self.y_test = keras.utils.to_categorical(self.y_test - 1, self.num_classes+1) 226 | 227 | 228 | def train(self, model): 229 | c = self.lamda 230 | lamda = 32 231 | 232 | def selective_loss(y_true, y_pred): 233 | loss = K.categorical_crossentropy( 234 | K.repeat_elements(y_pred[:, -1:], self.num_classes, axis=1) * y_true[:, :-1], 235 | y_pred[:, :-1]) + lamda * K.maximum(-K.mean(y_pred[:, -1]) + c, 0) ** 2 236 | return loss 237 | 238 | def selective_acc(y_true, y_pred): 239 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 240 | temp1 = K.sum( 241 | (g) * K.cast(K.equal(K.argmax(y_true[:, :-1], axis=-1), K.argmax(y_pred[:, :-1], axis=-1)), K.floatx())) 242 | temp1 = temp1 / K.sum(g) 243 | return K.cast(temp1, K.floatx()) 244 | 245 | def coverage(y_true, y_pred): 246 | g = K.cast(K.greater(y_pred[:, -1], 0.5), K.floatx()) 247 | return K.mean(g) 248 | 249 | 250 | 251 | # training parameters 252 | batch_size = 128 253 | maxepoches = 300 254 | learning_rate = 0.1 255 | 256 | lr_decay = 1e-6 257 | 258 | lr_drop = 25 259 | 260 | def lr_scheduler(epoch): 261 | return learning_rate * (0.5 ** (epoch // lr_drop)) 262 | 263 | reduce_lr = keras.callbacks.LearningRateScheduler(lr_scheduler) 264 | 265 | # data augmentation 266 | datagen = ImageDataGenerator( 267 | featurewise_center=False, # set input mean to 0 over the dataset 268 | samplewise_center=False, # set each sample mean to 0 269 | featurewise_std_normalization=False, # divide inputs by std of the dataset 270 | samplewise_std_normalization=False, # divide each input by its std 271 | zca_whitening=False, # apply ZCA whitening 272 | rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) 273 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 274 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 275 | horizontal_flip=True, # randomly flip images 276 | vertical_flip=False) # randomly flip images 277 | # (std, mean, and principal components if ZCA whitening is applied). 278 | datagen.fit(self.x_train) 279 | 280 | # optimization details 281 | sgd = optimizers.SGD(lr=learning_rate, decay=lr_decay, momentum=0.9, nesterov=True) 282 | 283 | model.compile(loss=[selective_loss, 'categorical_crossentropy'], loss_weights=[self.alpha, 1 - self.alpha], 284 | optimizer=sgd, metrics=['accuracy', selective_acc, coverage]) 285 | 286 | historytemp = model.fit_generator(my_generator(datagen.flow, self.x_train, self.y_train, 287 | batch_size=batch_size, k=self.num_classes), 288 | steps_per_epoch=self.x_train.shape[0] // batch_size, 289 | epochs=maxepoches, callbacks=[reduce_lr], 290 | validation_data=(self.x_test, [self.y_test, self.y_test[:, :-1]])) 291 | 292 | with open("checkpoints/{}_history.pkl".format(self.filename[:-3]), 'wb') as handle: 293 | pickle.dump(historytemp.history, handle, protocol=pickle.HIGHEST_PROTOCOL) 294 | 295 | model.save_weights("checkpoints/{}".format(self.filename)) 296 | 297 | return model 298 | 299 | -------------------------------------------------------------------------------- /results/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /selectivnet_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | from sklearn.linear_model import LogisticRegression as LR 5 | from sklearn.metrics import log_loss 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | def to_train(filename): 10 | checkpoints = os.listdir("checkpoints/") 11 | if filename in checkpoints: 12 | return False 13 | else: 14 | return True 15 | 16 | 17 | def save_dict(filename, dict): 18 | 19 | with open(filename, 'w') as fp: 20 | json.dump(dict, fp) 21 | 22 | 23 | def calc_selective_risk(model, regression, calibrated_coverage=None): 24 | prediction, pred = model.predict() 25 | if calibrated_coverage is None: 26 | threshold = 0.5 27 | else: 28 | threshold = np.percentile(prediction[:, -1], 100 - 100 * calibrated_coverage) 29 | covered_idx = prediction[:, -1] > threshold 30 | 31 | coverage = np.mean(covered_idx) 32 | y_hat = np.argmax(prediction[:, :-1], 1) 33 | if regression: 34 | loss = np.sum(np.mean((prediction[covered_idx, :-1] - model.y_test[covered_idx, :-1]) ** 2, -1)) / np.sum( 35 | covered_idx) 36 | else: 37 | loss = np.sum(y_hat[covered_idx] != np.argmax(model.y_test[covered_idx, :], 1)) / np.sum(covered_idx) 38 | return loss, coverage 39 | 40 | 41 | def train_profile(model_name, model_cls, coverages, model_baseline=None, regression=False, alpha=0.5): 42 | results = {} 43 | for coverage_rate in coverages: 44 | print("running {}_{}.h5".format(model_name, coverage_rate)) 45 | model = model_cls(train=to_train("{}_{}.h5".format(model_name, coverage_rate)), 46 | filename="{}_{}.h5".format(model_name, coverage_rate), 47 | coverage=coverage_rate, 48 | alpha=alpha) 49 | 50 | loss, coverage = calc_selective_risk(model, regression) 51 | 52 | results[coverage] = {"lambda": coverage_rate, "selective_risk": loss} 53 | if model_baseline is not None: 54 | if regression: 55 | results[coverage]["baseline_risk"] = (model_baseline.selective_risk_at_coverage(coverage)) 56 | 57 | else: 58 | 59 | results[coverage]["baseline_risk"] = (1 - model_baseline.selective_risk_at_coverage(coverage)) 60 | results[coverage]["percentage"] = 1 - results[coverage]["selective_risk"] / results[coverage]["baseline_risk"] 61 | 62 | save_dict("results/{}.json".format(model_name), results) 63 | 64 | 65 | def post_calibration(model_name, model_cls, lamda, calibrated_coverage=None, model_baseline=None, regression=False): 66 | results = {} 67 | print("calibrating {}_{}.h5".format(model_name, lamda)) 68 | model = model_cls(train=to_train("{}_{}.h5".format(model_name, lamda)), 69 | filename="{}_{}.h5".format(model_name, lamda), coverage=lamda) 70 | loss, coverage = calc_selective_risk(model, regression, calibrated_coverage) 71 | 72 | results[coverage]={"lambda":lamda, "selective_risk":loss} 73 | if model_baseline is not None: 74 | if regression: 75 | results[coverage]["baseline_risk"] = (model_baseline.selective_risk_at_coverage(coverage)) 76 | 77 | else: 78 | results[coverage]["baseline_risk"] = (1 - model_baseline.selective_risk_at_coverage(coverage)) 79 | results[coverage]["mc_risk"] = (1 - model_baseline.selective_risk_at_coverage(coverage, mc=True)) 80 | 81 | results[coverage]["percentage"] = 1 - results[coverage]["selective_risk"] / results[coverage]["baseline_risk"] 82 | 83 | return results 84 | 85 | 86 | def my_generator(func, x_train, y_train, batch_size, k=10): 87 | while True: 88 | res = func(x_train, y_train, batch_size 89 | ).next() 90 | yield [res[0], [res[1], res[1][:, :-1]]] 91 | 92 | 93 | def create_cats_vs_dogs_npz(cats_vs_dogs_path='datasets'): 94 | labels = ['cat', 'dog'] 95 | label_to_y_dict = {l: i for i, l in enumerate(labels)} 96 | 97 | def _load_from_dir(dir_name): 98 | glob_path = os.path.join(cats_vs_dogs_path, dir_name, '*.*.jpg') 99 | imgs_paths = glob(glob_path) 100 | images = [resize_and_crop_image(p, 64) for p in imgs_paths] 101 | x = np.stack(images) 102 | y = [label_to_y_dict[os.path.split(p)[-1][:3]] for p in imgs_paths] 103 | y = np.array(y) 104 | return x, y 105 | 106 | x_train, y_train = _load_from_dir('train') 107 | x_test, y_test = _load_from_dir('test') 108 | 109 | np.savez_compressed(os.path.join(cats_vs_dogs_path, 'cats_vs_dogs.npz'), 110 | x_train=x_train, y_train=y_train, 111 | x_test=x_test, y_test=y_test) 112 | 113 | 114 | def load_cats_vs_dogs(cats_vs_dogs_path='datasets/'): 115 | npz_file = np.load(os.path.join(cats_vs_dogs_path, 'cats_vs_dogs.npz')) 116 | x_train = npz_file['x_train'] 117 | y_train = npz_file['y_train'] 118 | x_test = npz_file['x_test'] 119 | y_test = npz_file['y_test'] 120 | 121 | return (x_train, y_train), (x_test, y_test) 122 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from models.catdog_vgg_selectivenet import CatsvsDogVgg as CatsvsDogSelective 4 | from models.cifar10_vgg_selectivenet import cifar10vgg as cifar10Selective 5 | from models.svhn_vgg_selectivenet import SvhnVgg as SVHNSelective 6 | from selectivnet_utils import * 7 | 8 | MODELS = {"cifar_10": cifar10Selective, "catsdogs": CatsvsDogSelective, "SVHN": SVHNSelective} 9 | 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', type=str, default='cifar_10') 14 | 15 | parser.add_argument('--model_name', type=str, default='test') 16 | parser.add_argument('--baseline', type=str, default='none') 17 | parser.add_argument('--alpha', type=float, default=0.5) 18 | 19 | args = parser.parse_args() 20 | 21 | model_cls = MODELS[args.dataset] 22 | model_name = args.model_name 23 | baseline_name = args.baseline 24 | 25 | coverages = [0.95, 0.9, 0.85, 0.8, 0.75, 0.7] 26 | 27 | 28 | if baseline_name == "none": 29 | results = train_profile(model_name, cifar10Selective, coverages, alpha=args.alpha) 30 | else: 31 | model_baseline = model_cls(train=to_train("{}.h5".format(baseline_name)), 32 | filename="{}.h5".format(baseline_name), 33 | baseline=True) 34 | results = train_profile(model_name, model_cls, coverages, model_baseline=model_baseline, alpha=args.alpha) 35 | --------------------------------------------------------------------------------