├── Classifier ├── DenseNet │ ├── DenseNet_newJSN.py │ ├── densenet.py │ ├── subpixel.py │ └── tensorflow_backend.py ├── ResNet.py ├── VGG.py └── integration │ └── model ensembel.py ├── Example_images ├── 01E392EE-69F9-4E33-BFCE-E5C968654078-1920x1431.jpeg ├── 0cc8ac00ad4d2c3e8ea3ed4a9c776a_jumbo.jpeg ├── 11368d1bfb309b273d60a7138dae35_jumbo.jpeg ├── 1312A392-67A3-4EBF-9319-810CF6DA5EF6.jpeg ├── 1B734A89-A1BF-49A8-A1D3-66FAFA4FAC5D.jpeg ├── 1c13301604cbe667b39ca3fe335501_jumbo.jpeg ├── 201b87f9902cce6732917d2f292bd3_jumbo.jpeg ├── 23E99E2E-447C-46E5-8EB2-D35D12473C39-1920x1440.png ├── 2966893D-5DDF-4B68-9E2B-4979D5956C8E.jpeg ├── 2B8649B2-00C4-4233-85D5-1CE240CF233B.jpeg ├── 2C26F453-AF3B-4517-BB9E-802CF2179543.jpeg ├── 32a46f77ff2a5acc2168b20b974cf8_jumbo.jpeg ├── 39EE8E69-5801-48DE-B6E3-BE7D1BCF3092.jpeg ├── 4d844df58f10acb17fc50351fd9440_jumbo.jpeg ├── 58cb9263f16e94305c730685358e4e_jumbo.jpeg ├── 5A78BCA9-5B7A-440D-8A4E-AE7710EA6EAD-2048x1661.jpeg ├── 5CBC2E94-D358-401E-8928-965CCD965C5C-1920x1581.jpeg ├── 5CBC2E94-D358-401E-8928-965CCD965C5C-2048x1687.jpeg ├── 61c6828be4bb24b9e29e6ebfcfec0a_jumbo.jpeg ├── 6CB4EFC6-68FA-4CD5-940C-BEFA8DAFE9A7-1920x1239.jpeg ├── 7848bf2d6be7318bf1457253990d25_jumbo.jpeg ├── 7AF6C1AF-D249-4BD2-8C26-449304105D03.jpeg ├── 7E335538-2F86-424E-A0AB-6397783A38D0-1536x1246.jpeg ├── 7E335538-2F86-424E-A0AB-6397783A38D0-1920x1558.jpeg ├── 8549249b763152e944d3ad092a2a2d_jumbo.jpeg ├── 85E52EB3-56E9-4D67-82DA-DEA247C82886.jpeg ├── 93FE0BB1-022D-4F24-9727-987A07975FFB.jpeg ├── 9C34AF49-E589-44D5-92D3-168B3B04E4A6.jpeg ├── 9ad688b362b011bd3f7503799515ef_jumbo.jpeg ├── 9f987e36c0a19aeb1f3c9151b66317_jumbo.jpeg ├── 9fdd3c3032296fd04d2cad5d9070d4_jumbo.jpeg ├── B2D20576-00B7-4519-A415-72DE29C90C34.jpeg ├── B59DD164-51D5-40DF-A926-6A42DD52EBE8-1920x1472.jpeg ├── CD50BA96-6982-4C80-AE7B-5F67ACDBFA56.jpeg ├── CE13BB46-B19A-4B06-92CE-C479125C6CEA.jpeg ├── F051E018-DAD1-4506-AD43-BE4CA29E960B.jpeg ├── F2DE909F-E19C-4900-92F5-8F435B031AC6.jpeg ├── F4341CE7-73C9-45C6-99C8-8567A5484B63.jpeg ├── F63AB6CE-1968-4154-A70F-913AF154F53D-1920x1275.jpeg ├── a092a272b78ce7c23e6a490721b750_jumbo.jpeg ├── a36d7944927e369c90035d4fcbb7af_jumbo.jpeg ├── b1921029beb35ebb6bc80b1bd5c043_jumbo.jpeg ├── b418d50351b48ee58bcb4c2841e95b_jumbo.jpeg ├── b4da827908ad8209137382e301cc24_jumbo.jpeg ├── b81bbc0418db1202a4e8d6015afb32_jumbo.jpeg ├── bc4aafa5ad0aaa24a92afe73b06e74_jumbo.jpeg ├── c02786050656210c20eb86d3bc0d48_jumbo.jpeg ├── df1053d3e8896b53ef140773e10e26_jumbo.jpeg └── e493ebb5ce513a0ad49237f008595b_jumbo.jpeg ├── README.md ├── noteboks ├── Choosing_best_pretrained_model_with_WeightWatcher.ipynb ├── Decision_Visualization_GradCAM_LRP_ResNet18.ipynb ├── Decision_Visualization_GradCAM_LRP_VGG19.ipynb ├── ResNet-18.ipynb └── VGG-19.ipynb └── utils ├── CXR_preprocessing.py ├── class_balancing.py ├── gradcamutils.py ├── helper.py └── lossprettifier.py /Classifier/DenseNet/DenseNet_newJSN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | np.random.seed(3768) # for reproducibility 5 | from keras.preprocessing import sequence 6 | from keras.utils import np_utils 7 | from keras.models import Sequential,load_model,Model 8 | from keras.layers import Dense, Dropout, Activation, Flatten 9 | from keras.layers import * 10 | from keras.optimizers import SGD 11 | from random import shuffle 12 | import time 13 | import csv 14 | import os 15 | import densenet 16 | from keras.callbacks import CSVLogger 17 | from keras import callbacks 18 | from PIL import Image 19 | from keras.preprocessing.image import ImageDataGenerator 20 | import tensorflow as tf 21 | import keras 22 | from sklearn.metrics import classification_report 23 | import sklearn.metrics as sklm 24 | from keras.callbacks import EarlyStopping 25 | from keras.callbacks import LearningRateScheduler 26 | from keras import initializers 27 | import keras 28 | import tensorflow as tf 29 | 30 | def get_session(): 31 | config = tf.ConfigProto() 32 | config.gpu_options.allow_growth = True 33 | return tf.Session(config=config) 34 | # use this environment flag to change which GPU to use 35 | #os.environ["CUDA_VISIBLE_DEVICES"] = "" 36 | # set the modified tf session as backend in keras 37 | keras.backend.tensorflow_backend.set_session(get_session()) 38 | 39 | def dense_to_one_hot(labels_dense,num_clases=5): 40 | return np.eye(num_clases)[labels_dense] 41 | 42 | def load(): 43 | imgList=[] 44 | labelList=[] 45 | reader = open("/data/jiao/newlabel.csv") #label file path 46 | data=reader.readlines() 47 | files = os.listdir('/data/jiao/XR/ROI_resize/front/training/') #training path for ROIs/MRIs 48 | shuffle(files) 49 | for file in files: 50 | if file.endswith(".xml"):continue 51 | fi_d = os.path.join('/data/jiao/XR/ROI_resize/front/training/',file) #training path for ROIs/MRIs 52 | img=Image.open(fi_d).convert('L') 53 | im=np.array(img.resize((320,352), Image.ANTIALIAS)) 54 | patient=file.split('_')[0] 55 | direction=file.split('_')[1].split('.')[0] 56 | label="q" 57 | for row in data: 58 | if patient in row.split(",")[0]: 59 | if "L" in direction: 60 | label=row.split(",")[3] 61 | else: 62 | label=row.split(",")[6] 63 | break 64 | if "V" in file: #for dataset balance, I use Grade 3 images from other stages, they are named with stagename as V3 65 | label="3" 66 | if "8" not in label and "9" not in label and "X" not in label and '.' not in label: #in the labels, there are 8, 9 and X which are useless in our case. 67 | #if "." in label: 68 | #label='4' 69 | label= dense_to_one_hot(int(label),4) 70 | imgList.append(im) 71 | labelList.append(label) 72 | return np.array(imgList),np.array(labelList) 73 | 74 | def load_val(): 75 | imgList=[] 76 | labelList=[] 77 | reader = open("/data/jiao/newlabel.csv") #label file path 78 | data=reader.readlines() 79 | files = os.listdir('/data/jiao/XR/ROI_resize/front/validation/') #test path for ROIs/MRIs 80 | for file in files: 81 | if file.endswith(".xml"):continue 82 | fi_d = os.path.join('/data/jiao/XR/ROI_resize/front/validation/',file) #test path for ROIs/MRIs 83 | img=Image.open(fi_d).convert('L') 84 | im=np.array(img.resize((320,352), Image.ANTIALIAS)) 85 | patient=file.split('_')[0] 86 | direction=file.split('_')[1].split('.')[0] 87 | label="q" 88 | for row in data: 89 | if patient in row.split(",")[0]: 90 | if "L" in direction: 91 | label=row.split(",")[3] 92 | else: 93 | label=row.split(",")[6] 94 | break 95 | if "V" in file: 96 | label="3" 97 | if "8" not in label and "9" not in label and "X" not in label and '.' not in label: 98 | #if "." in label: 99 | #label='4' 100 | label= dense_to_one_hot(int(label),4) 101 | imgList.append(im) 102 | labelList.append(label) 103 | return np.array(imgList),np.array(labelList) 104 | 105 | def load_valY(): #load labels with decimal format 106 | imgList=[] 107 | labelList=[] 108 | reader = open("/data/jiao/newlabel.csv") 109 | data=reader.readlines() 110 | files = os.listdir('/data/jiao/XR/ROI_resize/front/validation/') 111 | for file in files: 112 | if file.endswith(".xml"):continue 113 | patient=file.split('_')[0] 114 | direction=file.split('_')[1].split('.')[0] 115 | label="q" 116 | for row in data: 117 | if patient in row.split(",")[0]: 118 | if "L" in direction: 119 | label=row.split(",")[3] 120 | else: 121 | label=row.split(",")[6] 122 | break 123 | if "V" in file: 124 | label="3" 125 | if "8" not in label and "9" not in label and "X" not in label and '.' not in label: 126 | #if "." in label: 127 | #label='4' 128 | labelList.append(int(label)) 129 | return np.array(labelList) 130 | 131 | 132 | batch_size=32 133 | model = densenet.DenseNetImageNet201(input_shape=(352,320,1),classes=4, weights=None) #here you can change Densenet for 121,161,169 and 201 or your own architectures, the detail settings are input_shape=None, bottleneck=True,reduction=0.5, dropout_rate=0.0,weight_decay=1e-6,include_top=True, weights='imagenet',input_tensor=None,classes=1000, activation='softmax' 134 | sgd = SGD(lr=0.01, decay=1e-6, momentum=0.95, nesterov=True) 135 | model.compile(optimizer=sgd, loss='mse',metrics=['accuracy']) 136 | 137 | datagen = ImageDataGenerator( 138 | featurewise_center=True, 139 | samplewise_center=False, # set each sample mean to 0 140 | featurewise_std_normalization=True, 141 | samplewise_std_normalization=False) 142 | X_train, Y_train = load() 143 | X_test, Y_test = load_val() 144 | X_train = X_train.reshape( len(X_train), len(X_train[0]), len(X_train[0][0]),1) 145 | X_test = X_test.reshape( len(X_test), len(X_test[0]), len(X_test[0][0]),1) 146 | X_train = X_train.astype('float32') 147 | X_test = X_test.astype('float32') 148 | X_train /= 255 149 | X_test /= 255 150 | datagen.fit(X_train) 151 | for i in range(len(X_test)): 152 | X_test[i] = datagen.standardize(X_test[i]) 153 | earlystop=EarlyStopping(monitor='val_acc', min_delta=0, patience=300, verbose=1, mode='auto', restore_best_weights=True) 154 | history = model.fit_generator(datagen.flow(X_train, Y_train,batch_size=batch_size),steps_per_epoch=32,epochs=4096,shuffle=True,validation_data=(X_test, Y_test), verbose=1,callbacks=[earlystop]) 155 | score, acc = model.evaluate(X_test,Y_test,batch_size=batch_size) 156 | print("Accuracy:",acc) 157 | if acc>0.6: #if the accuracy is higher than 60%, the models are saved 158 | model.save_weights("DenseNet-JSNnew-front.h5") 159 | y_pred = model.predict(X_test) 160 | Y_predict = y_pred.argmax(axis=-1) 161 | f=open('DenseNetRESULTS-JSNnew-front.txt','a') #create performance report 162 | f.write(classification_report(load_valY(), Y_predict)) 163 | f.write(str(sklm.cohen_kappa_score(load_valY(), Y_predict))+","+str(acc)+","+str(score)+"\n") 164 | print(classification_report(load_valY(), Y_predict)) 165 | -------------------------------------------------------------------------------- /Classifier/DenseNet/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet models for Keras. 2 | 3 | # Reference 4 | 5 | - [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993.pdf) 6 | 7 | - [The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation](https://arxiv.org/pdf/1611.09326.pdf) 8 | 9 | ''' 10 | 11 | from __future__ import print_function 12 | 13 | from __future__ import absolute_import 14 | 15 | from __future__ import division 16 | 17 | 18 | 19 | import warnings 20 | 21 | 22 | 23 | from keras.models import Model 24 | 25 | from keras.layers.core import Dense, Dropout, Activation, Reshape 26 | 27 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D 28 | 29 | from keras.layers.pooling import AveragePooling2D, MaxPooling2D 30 | 31 | from keras.layers.pooling import GlobalAveragePooling2D 32 | 33 | from keras.layers import Input 34 | 35 | from keras.layers.merge import concatenate 36 | 37 | from keras.layers.normalization import BatchNormalization 38 | 39 | from keras.regularizers import l2 40 | 41 | from keras.utils.layer_utils import convert_all_kernels_in_model, convert_dense_weights_data_format 42 | 43 | from keras.utils.data_utils import get_file 44 | 45 | from keras.engine.topology import get_source_inputs 46 | 47 | from keras_applications.imagenet_utils import _obtain_input_shape 48 | 49 | from keras_applications.imagenet_utils import decode_predictions 50 | 51 | import keras.backend as K 52 | 53 | from keras import backend as K 54 | 55 | from keras.engine import Layer 56 | 57 | from keras.utils.generic_utils import get_custom_objects 58 | 59 | from keras.backend import normalize_data_format 60 | 61 | 62 | class SubPixelUpscaling(Layer): 63 | 64 | """ Sub-pixel convolutional upscaling layer based on the paper "Real-Time Single Image 65 | 66 | and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" 67 | 68 | (https://arxiv.org/abs/1609.05158). 69 | 70 | This layer requires a Convolution2D prior to it, having output filters computed according to 71 | 72 | the formula : 73 | 74 | filters = k * (scale_factor * scale_factor) 75 | 76 | where k = a user defined number of filters (generally larger than 32) 77 | 78 | scale_factor = the upscaling factor (generally 2) 79 | 80 | This layer performs the depth to space operation on the convolution filters, and returns a 81 | 82 | tensor with the size as defined below. 83 | 84 | # Example : 85 | 86 | ```python 87 | 88 | # A standard subpixel upscaling block 89 | 90 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(...) 91 | 92 | u = SubPixelUpscaling(scale_factor=2)(x) 93 | 94 | [Optional] 95 | 96 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(u) 97 | 98 | ``` 99 | 100 | In practice, it is useful to have a second convolution layer after the 101 | 102 | SubPixelUpscaling layer to speed up the learning process. 103 | 104 | However, if you are stacking multiple SubPixelUpscaling blocks, it may increase 105 | 106 | the number of parameters greatly, so the Convolution layer after SubPixelUpscaling 107 | 108 | layer can be removed. 109 | 110 | # Arguments 111 | 112 | scale_factor: Upscaling factor. 113 | 114 | data_format: Can be None, 'channels_first' or 'channels_last'. 115 | 116 | # Input shape 117 | 118 | 4D tensor with shape: 119 | 120 | `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` if data_format='channels_first' 121 | 122 | or 4D tensor with shape: 123 | 124 | `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` if data_format='channels_last'. 125 | 126 | # Output shape 127 | 128 | 4D tensor with shape: 129 | 130 | `(samples, k channels, rows * scale_factor, cols * scale_factor))` if data_format='channels_first' 131 | 132 | or 4D tensor with shape: 133 | 134 | `(samples, rows * scale_factor, cols * scale_factor, k channels)` if data_format='channels_last'. 135 | 136 | """ 137 | 138 | 139 | def __init__(self, scale_factor=2, data_format=None, **kwargs): 140 | 141 | super(SubPixelUpscaling, self).__init__(**kwargs) 142 | 143 | self.scale_factor = scale_factor 144 | 145 | self.data_format = normalize_data_format(data_format) 146 | 147 | 148 | def build(self, input_shape): 149 | 150 | pass 151 | 152 | 153 | def call(self, x, mask=None): 154 | 155 | y = K_BACKEND.depth_to_space(x, self.scale_factor, self.data_format) 156 | 157 | return y 158 | 159 | 160 | def compute_output_shape(self, input_shape): 161 | 162 | if self.data_format == 'channels_first': 163 | 164 | b, k, r, c = input_shape 165 | 166 | return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor) 167 | 168 | else: 169 | 170 | b, r, c, k = input_shape 171 | 172 | return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2)) 173 | 174 | 175 | def get_config(self): 176 | 177 | config = {'scale_factor': self.scale_factor, 178 | 179 | 'data_format': self.data_format} 180 | 181 | base_config = super(SubPixelUpscaling, self).get_config() 182 | 183 | return dict(list(base_config.items()) + list(config.items())) 184 | 185 | get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling}) 186 | 187 | 188 | DENSENET_121_WEIGHTS_PATH = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-121-32.h5' 189 | 190 | DENSENET_161_WEIGHTS_PATH = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-161-48.h5' 191 | 192 | DENSENET_169_WEIGHTS_PATH = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-169-32.h5' 193 | 194 | DENSENET_121_WEIGHTS_PATH_NO_TOP = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-121-32-no-top.h5' 195 | 196 | DENSENET_161_WEIGHTS_PATH_NO_TOP = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-161-48-no-top.h5' 197 | 198 | DENSENET_169_WEIGHTS_PATH_NO_TOP = r'https://github.com/titu1994/DenseNet/releases/download/v3.0/DenseNet-BC-169-32-no-top.h5' 199 | 200 | 201 | 202 | def preprocess_input(x, data_format=None): 203 | 204 | """Preprocesses a tensor encoding a batch of images. 205 | 206 | 207 | 208 | # Arguments 209 | 210 | x: input Numpy tensor, 4D. 211 | 212 | data_format: data format of the image tensor. 213 | 214 | 215 | 216 | # Returns 217 | 218 | Preprocessed tensor. 219 | 220 | """ 221 | 222 | if data_format is None: 223 | 224 | data_format = K.image_data_format() 225 | 226 | assert data_format in {'channels_last', 'channels_first'} 227 | 228 | 229 | 230 | if data_format == 'channels_first': 231 | 232 | if x.ndim == 3: 233 | 234 | # 'RGB'->'BGR' 235 | 236 | x = x[::-1, ...] 237 | 238 | # Zero-center by mean pixel 239 | 240 | x[0, :, :] -= 103.939 241 | 242 | x[1, :, :] -= 116.779 243 | 244 | x[2, :, :] -= 123.68 245 | 246 | else: 247 | 248 | x = x[:, ::-1, ...] 249 | 250 | x[:, 0, :, :] -= 103.939 251 | 252 | x[:, 1, :, :] -= 116.779 253 | 254 | x[:, 2, :, :] -= 123.68 255 | 256 | else: 257 | 258 | # 'RGB'->'BGR' 259 | 260 | x = x[..., ::-1] 261 | 262 | # Zero-center by mean pixel 263 | 264 | x[..., 0] -= 103.939 265 | 266 | x[..., 1] -= 116.779 267 | 268 | x[..., 2] -= 123.68 269 | 270 | 271 | 272 | x *= 0.017 # scale values 273 | 274 | 275 | 276 | return x 277 | 278 | 279 | 280 | 281 | 282 | def DenseNet(input_shape=None, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1, nb_layers_per_block=-1, 283 | 284 | bottleneck=False, reduction=0.0, dropout_rate=0.0, weight_decay=1e-4, subsample_initial_block=False, 285 | 286 | include_top=True, weights=None, input_tensor=None, 287 | 288 | classes=10, activation='softmax'): 289 | 290 | 291 | 292 | if weights not in {'imagenet', None}: 293 | 294 | raise ValueError('The `weights` argument should be either ' 295 | 296 | '`None` (random initialization) or `cifar10` ' 297 | 298 | '(pre-training on CIFAR-10).') 299 | 300 | 301 | 302 | if weights == 'imagenet' and include_top and classes != 1000: 303 | 304 | raise ValueError('If using `weights` as ImageNet with `include_top`' 305 | 306 | ' as true, `classes` should be 1000') 307 | 308 | 309 | 310 | if activation not in ['softmax', 'sigmoid']: 311 | 312 | raise ValueError('activation must be one of "softmax" or "sigmoid"') 313 | 314 | 315 | 316 | if activation == 'sigmoid' and classes != 1: 317 | 318 | raise ValueError('sigmoid activation can only be used when classes = 1') 319 | 320 | 321 | 322 | # Determine proper input shape 323 | 324 | input_shape = _obtain_input_shape(input_shape, 325 | 326 | default_size=32, 327 | 328 | min_size=8, 329 | 330 | data_format=K.image_data_format(), 331 | 332 | require_flatten=include_top) 333 | 334 | 335 | 336 | if input_tensor is None: 337 | 338 | img_input = Input(shape=input_shape) 339 | 340 | else: 341 | 342 | if not K.is_keras_tensor(input_tensor): 343 | 344 | img_input = Input(tensor=input_tensor, shape=input_shape) 345 | 346 | else: 347 | 348 | img_input = input_tensor 349 | 350 | 351 | 352 | x = __create_dense_net(classes, img_input, include_top, depth, nb_dense_block, 353 | 354 | growth_rate, nb_filter, nb_layers_per_block, bottleneck, reduction, 355 | 356 | dropout_rate, weight_decay, subsample_initial_block, activation) 357 | 358 | 359 | 360 | # Ensure that the model takes into account 361 | 362 | # any potential predecessors of `input_tensor`. 363 | 364 | if input_tensor is not None: 365 | 366 | inputs = get_source_inputs(input_tensor) 367 | 368 | else: 369 | 370 | inputs = img_input 371 | 372 | # Create model. 373 | 374 | model = Model(inputs, x, name='densenet') 375 | 376 | 377 | return model 378 | 379 | 380 | 381 | 382 | def DenseNetImageNet121(input_shape=None, 383 | 384 | bottleneck=True, 385 | 386 | reduction=0.5, 387 | 388 | dropout_rate=0.0, 389 | 390 | weight_decay=1e-6, 391 | 392 | include_top=True, 393 | 394 | weights='imagenet', 395 | 396 | input_tensor=None, 397 | 398 | classes=1000, 399 | 400 | activation='softmax'): 401 | 402 | return DenseNet(input_shape, depth=121, nb_dense_block=4, growth_rate=12, nb_filter=8, 403 | 404 | nb_layers_per_block=[6, 12, 24, 16], bottleneck=bottleneck, reduction=reduction, 405 | 406 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 407 | 408 | include_top=include_top, weights=weights, input_tensor=input_tensor, 409 | 410 | classes=classes, activation=activation) 411 | 412 | 413 | 414 | 415 | 416 | def DenseNetImageNet169(input_shape=None, 417 | 418 | bottleneck=True, 419 | 420 | reduction=0.5, 421 | 422 | dropout_rate=0.0, 423 | 424 | weight_decay=1e-6, 425 | 426 | include_top=True, 427 | 428 | weights='imagenet', 429 | 430 | input_tensor=None, 431 | 432 | classes=1000, 433 | 434 | activation='softmax'): 435 | 436 | return DenseNet(input_shape, depth=169, nb_dense_block=4, growth_rate=12, nb_filter=8, 437 | 438 | nb_layers_per_block=[6, 12, 32, 32], bottleneck=bottleneck, reduction=reduction, 439 | 440 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 441 | 442 | include_top=include_top, weights=weights, input_tensor=input_tensor, 443 | 444 | classes=classes, activation=activation) 445 | 446 | 447 | 448 | 449 | 450 | def DenseNetImageNet201(input_shape=None, 451 | 452 | bottleneck=True, 453 | 454 | reduction=0.5, 455 | 456 | dropout_rate=0.0, 457 | 458 | weight_decay=1e-6, 459 | 460 | include_top=True, 461 | 462 | weights=None, 463 | 464 | input_tensor=None, 465 | 466 | classes=1000, 467 | 468 | activation='softmax'): 469 | 470 | return DenseNet(input_shape, depth=201, nb_dense_block=4, growth_rate=16, nb_filter=8, 471 | 472 | nb_layers_per_block=[6, 12, 48, 32], bottleneck=bottleneck, reduction=reduction, 473 | 474 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 475 | 476 | include_top=include_top, weights=weights, input_tensor=input_tensor, 477 | 478 | classes=classes, activation=activation) 479 | 480 | 481 | 482 | 483 | 484 | def DenseNetImageNet264(input_shape=None, 485 | 486 | bottleneck=True, 487 | 488 | reduction=0.5, 489 | 490 | dropout_rate=0.5, 491 | 492 | weight_decay=1e-4, 493 | 494 | include_top=True, 495 | 496 | weights=None, 497 | 498 | input_tensor=None, 499 | 500 | classes=1000, 501 | 502 | activation='softmax'): 503 | 504 | return DenseNet(input_shape, depth=201, nb_dense_block=4, growth_rate=16, nb_filter=8, 505 | 506 | nb_layers_per_block=[6, 12, 64, 48], bottleneck=bottleneck, reduction=reduction, 507 | 508 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 509 | 510 | include_top=include_top, weights=weights, input_tensor=input_tensor, 511 | 512 | classes=classes, activation=activation) 513 | 514 | 515 | 516 | 517 | 518 | def DenseNetImageNet161(input_shape=None, 519 | 520 | bottleneck=True, 521 | 522 | reduction=0.5, 523 | 524 | dropout_rate=0.0, 525 | 526 | weight_decay=1e-6, 527 | 528 | include_top=True, 529 | 530 | weights='imagenet', 531 | 532 | input_tensor=None, 533 | 534 | classes=1000, 535 | 536 | activation='softmax'): 537 | 538 | return DenseNet(input_shape, depth=161, nb_dense_block=4, growth_rate=12, nb_filter=8, 539 | 540 | nb_layers_per_block=[6, 12, 36, 24], bottleneck=bottleneck, reduction=reduction, 541 | 542 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 543 | 544 | include_top=include_top, weights=weights, input_tensor=input_tensor, 545 | 546 | classes=classes, activation=activation) 547 | 548 | 549 | 550 | 551 | 552 | def __conv_block(ip, nb_filter, bottleneck=False, dropout_rate=None, weight_decay=1e-4): 553 | 554 | ''' Apply BatchNorm, Relu, 3x3 Conv2D, optional bottleneck block and dropout 555 | 556 | Args: 557 | 558 | ip: Input keras tensor 559 | 560 | nb_filter: number of filters 561 | 562 | bottleneck: add bottleneck block 563 | 564 | dropout_rate: dropout rate 565 | 566 | weight_decay: weight decay factor 567 | 568 | Returns: keras tensor with batch_norm, relu and convolution2d added (optional bottleneck) 569 | 570 | ''' 571 | 572 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 573 | 574 | 575 | 576 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(ip) 577 | 578 | x = Activation('relu')(x) 579 | 580 | 581 | 582 | if bottleneck: 583 | 584 | inter_channel = nb_filter * 4 # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua 585 | 586 | 587 | 588 | x = Conv2D(inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 589 | 590 | kernel_regularizer=l2(weight_decay))(x) 591 | 592 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(x) 593 | 594 | x = Activation('relu')(x) 595 | 596 | 597 | 598 | x = Conv2D(nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False)(x) 599 | 600 | if dropout_rate: 601 | 602 | x = Dropout(dropout_rate)(x) 603 | 604 | 605 | 606 | return x 607 | 608 | 609 | 610 | 611 | 612 | def __dense_block(x, nb_layers, nb_filter, growth_rate, bottleneck=False, dropout_rate=None, weight_decay=1e-4, 613 | 614 | grow_nb_filters=True, return_concat_list=False): 615 | 616 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 617 | 618 | Args: 619 | 620 | x: keras tensor 621 | 622 | nb_layers: the number of layers of conv_block to append to the model. 623 | 624 | nb_filter: number of filters 625 | 626 | growth_rate: growth rate 627 | 628 | bottleneck: bottleneck block 629 | 630 | dropout_rate: dropout rate 631 | 632 | weight_decay: weight decay factor 633 | 634 | grow_nb_filters: flag to decide to allow number of filters to grow 635 | 636 | return_concat_list: return the list of feature maps along with the actual output 637 | 638 | Returns: keras tensor with nb_layers of conv_block appended 639 | 640 | ''' 641 | 642 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 643 | 644 | 645 | 646 | x_list = [x] 647 | 648 | 649 | 650 | for i in range(nb_layers): 651 | 652 | cb = __conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay) 653 | 654 | x_list.append(cb) 655 | 656 | 657 | 658 | x = concatenate([x, cb], axis=concat_axis) 659 | 660 | 661 | 662 | if grow_nb_filters: 663 | 664 | nb_filter += growth_rate 665 | 666 | 667 | 668 | if return_concat_list: 669 | 670 | return x, nb_filter, x_list 671 | 672 | else: 673 | 674 | return x, nb_filter 675 | 676 | 677 | 678 | 679 | 680 | def __transition_block(ip, nb_filter, compression=1.0, weight_decay=1e-4): 681 | 682 | ''' Apply BatchNorm, Relu 1x1, Conv2D, optional compression, dropout and Maxpooling2D 683 | 684 | Args: 685 | 686 | ip: keras tensor 687 | 688 | nb_filter: number of filters 689 | 690 | compression: calculated as 1 - reduction. Reduces the number of feature maps 691 | 692 | in the transition block. 693 | 694 | dropout_rate: dropout rate 695 | 696 | weight_decay: weight decay factor 697 | 698 | Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool 699 | 700 | ''' 701 | 702 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 703 | 704 | 705 | 706 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(ip) 707 | 708 | x = Activation('relu')(x) 709 | 710 | x = Conv2D(int(nb_filter * compression), (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 711 | 712 | kernel_regularizer=l2(weight_decay))(x) 713 | 714 | x = AveragePooling2D((2, 2), strides=(2, 2))(x) 715 | 716 | 717 | 718 | return x 719 | 720 | 721 | 722 | 723 | 724 | def __transition_up_block(ip, nb_filters, type='deconv', weight_decay=1E-4): 725 | 726 | ''' SubpixelConvolutional Upscaling (factor = 2) 727 | 728 | Args: 729 | 730 | ip: keras tensor 731 | 732 | nb_filters: number of layers 733 | 734 | type: can be 'upsampling', 'subpixel', 'deconv'. Determines type of upsampling performed 735 | 736 | weight_decay: weight decay factor 737 | 738 | Returns: keras tensor, after applying upsampling operation. 739 | 740 | ''' 741 | 742 | 743 | 744 | if type == 'upsampling': 745 | 746 | x = UpSampling2D()(ip) 747 | 748 | elif type == 'subpixel': 749 | 750 | x = Conv2D(nb_filters, (3, 3), activation='relu', padding='same', kernel_regularizer=l2(weight_decay), 751 | 752 | use_bias=False, kernel_initializer='he_normal')(ip) 753 | 754 | x = SubPixelUpscaling(scale_factor=2)(x) 755 | 756 | x = Conv2D(nb_filters, (3, 3), activation='relu', padding='same', kernel_regularizer=l2(weight_decay), 757 | 758 | use_bias=False, kernel_initializer='he_normal')(x) 759 | 760 | else: 761 | 762 | x = Conv2DTranspose(nb_filters, (3, 3), activation='relu', padding='same', strides=(2, 2), 763 | 764 | kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(ip) 765 | 766 | 767 | 768 | return x 769 | 770 | 771 | 772 | 773 | 774 | def __create_dense_net(nb_classes, img_input, include_top, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1, 775 | 776 | nb_layers_per_block=-1, bottleneck=False, reduction=0.0, dropout_rate=None, weight_decay=1e-4, 777 | 778 | subsample_initial_block=False, activation='softmax'): 779 | 780 | ''' Build the DenseNet model 781 | 782 | Args: 783 | 784 | nb_classes: number of classes 785 | 786 | img_input: tuple of shape (channels, rows, columns) or (rows, columns, channels) 787 | 788 | include_top: flag to include the final Dense layer 789 | 790 | depth: number or layers 791 | 792 | nb_dense_block: number of dense blocks to add to end (generally = 3) 793 | 794 | growth_rate: number of filters to add per dense block 795 | 796 | nb_filter: initial number of filters. Default -1 indicates initial number of filters is 2 * growth_rate 797 | 798 | nb_layers_per_block: number of layers in each dense block. 799 | 800 | Can be a -1, positive integer or a list. 801 | 802 | If -1, calculates nb_layer_per_block from the depth of the network. 803 | 804 | If positive integer, a set number of layers per dense block. 805 | 806 | If list, nb_layer is used as provided. Note that list size must 807 | 808 | be (nb_dense_block + 1) 809 | 810 | bottleneck: add bottleneck blocks 811 | 812 | reduction: reduction factor of transition blocks. Note : reduction value is inverted to compute compression 813 | 814 | dropout_rate: dropout rate 815 | 816 | weight_decay: weight decay rate 817 | 818 | subsample_initial_block: Set to True to subsample the initial convolution and 819 | 820 | add a MaxPool2D before the dense blocks are added. 821 | 822 | subsample_initial: 823 | 824 | activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'. 825 | 826 | Note that if sigmoid is used, classes must be 1. 827 | 828 | Returns: keras tensor with nb_layers of conv_block appended 829 | 830 | ''' 831 | 832 | 833 | 834 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 835 | 836 | 837 | 838 | if reduction != 0.0: 839 | 840 | assert reduction <= 1.0 and reduction > 0.0, 'reduction value must lie between 0.0 and 1.0' 841 | 842 | 843 | 844 | # layers in each dense block 845 | 846 | if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple: 847 | 848 | nb_layers = list(nb_layers_per_block) # Convert tuple to list 849 | 850 | final_nb_layer = nb_layers[-1] 851 | 852 | nb_layers = nb_layers[:-1] 853 | 854 | else: 855 | 856 | if nb_layers_per_block == -1: 857 | 858 | assert (depth - 4) % 3 == 0, 'Depth must be 3 N + 4 if nb_layers_per_block == -1' 859 | 860 | count = int((depth - 4) / 3) 861 | 862 | 863 | 864 | if bottleneck: 865 | 866 | count = count // 2 867 | 868 | 869 | 870 | nb_layers = [count for _ in range(nb_dense_block)] 871 | 872 | final_nb_layer = count 873 | 874 | else: 875 | 876 | final_nb_layer = nb_layers_per_block 877 | 878 | nb_layers = [nb_layers_per_block] * nb_dense_block 879 | 880 | 881 | 882 | # compute initial nb_filter if -1, else accept users initial nb_filter 883 | 884 | if nb_filter <= 0: 885 | 886 | nb_filter = 2 * growth_rate 887 | 888 | 889 | 890 | # compute compression factor 891 | 892 | compression = 1.0 - reduction 893 | 894 | 895 | 896 | # Initial convolution 897 | 898 | if subsample_initial_block: 899 | 900 | initial_kernel = (7, 7) 901 | 902 | initial_strides = (2, 2) 903 | 904 | else: 905 | 906 | initial_kernel = (3, 3) 907 | 908 | initial_strides = (1, 1) 909 | 910 | 911 | 912 | x = Conv2D(nb_filter, initial_kernel, kernel_initializer='he_normal', padding='same', 913 | 914 | strides=initial_strides, use_bias=False, kernel_regularizer=l2(weight_decay))(img_input) 915 | 916 | 917 | 918 | if subsample_initial_block: 919 | 920 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(x) 921 | 922 | x = Activation('relu')(x) 923 | 924 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 925 | 926 | 927 | 928 | # Add dense blocks 929 | 930 | for block_idx in range(nb_dense_block - 1): 931 | 932 | x, nb_filter = __dense_block(x, nb_layers[block_idx], nb_filter, growth_rate, bottleneck=bottleneck, 933 | 934 | dropout_rate=dropout_rate, weight_decay=weight_decay) 935 | 936 | # add transition_block 937 | 938 | x = __transition_block(x, nb_filter, compression=compression, weight_decay=weight_decay) 939 | 940 | nb_filter = int(nb_filter * compression) 941 | 942 | 943 | 944 | # The last dense_block does not have a transition_block 945 | 946 | x, nb_filter = __dense_block(x, final_nb_layer, nb_filter, growth_rate, bottleneck=bottleneck, 947 | 948 | dropout_rate=dropout_rate, weight_decay=weight_decay) 949 | 950 | 951 | 952 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(x) 953 | 954 | x = Activation('relu',name='feature')(x) 955 | 956 | x = GlobalAveragePooling2D()(x) 957 | 958 | 959 | 960 | if include_top: 961 | 962 | x = Dense(nb_classes, activation=activation)(x) 963 | 964 | 965 | 966 | return x 967 | 968 | 969 | 970 | 971 | 972 | def __create_fcn_dense_net(nb_classes, img_input, include_top, nb_dense_block=5, growth_rate=12, 973 | 974 | reduction=0.0, dropout_rate=None, weight_decay=1e-4, 975 | 976 | nb_layers_per_block=4, nb_upsampling_conv=128, upsampling_type='upsampling', 977 | 978 | init_conv_filters=48, input_shape=None, activation='deconv'): 979 | 980 | 981 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 982 | 983 | 984 | 985 | if concat_axis == 1: # channels_first dim ordering 986 | 987 | _, rows, cols = input_shape 988 | 989 | else: 990 | 991 | rows, cols, _ = input_shape 992 | 993 | 994 | 995 | if reduction != 0.0: 996 | 997 | assert reduction <= 1.0 and reduction > 0.0, 'reduction value must lie between 0.0 and 1.0' 998 | 999 | 1000 | 1001 | # check if upsampling_conv has minimum number of filters 1002 | 1003 | # minimum is set to 12, as at least 3 color channels are needed for correct upsampling 1004 | 1005 | 1006 | 1007 | # layers in each dense block 1008 | 1009 | if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple: 1010 | 1011 | nb_layers = list(nb_layers_per_block) # Convert tuple to list 1012 | 1013 | 1014 | bottleneck_nb_layers = nb_layers[-1] 1015 | 1016 | rev_layers = nb_layers[::-1] 1017 | 1018 | nb_layers.extend(rev_layers[1:]) 1019 | 1020 | else: 1021 | 1022 | bottleneck_nb_layers = nb_layers_per_block 1023 | 1024 | nb_layers = [nb_layers_per_block] * (2 * nb_dense_block + 1) 1025 | 1026 | 1027 | 1028 | # compute compression factor 1029 | 1030 | compression = 1.0 - reduction 1031 | 1032 | 1033 | 1034 | # Initial convolution 1035 | 1036 | x = Conv2D(init_conv_filters, (7, 7), kernel_initializer='he_normal', padding='same', name='initial_conv2D', 1037 | 1038 | use_bias=False, kernel_regularizer=l2(weight_decay))(img_input) 1039 | 1040 | x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(x) 1041 | 1042 | x = Activation('relu')(x) 1043 | 1044 | 1045 | 1046 | nb_filter = init_conv_filters 1047 | 1048 | 1049 | 1050 | skip_list = [] 1051 | 1052 | 1053 | 1054 | # Add dense blocks and transition down block 1055 | 1056 | for block_idx in range(nb_dense_block): 1057 | 1058 | x, nb_filter = __dense_block(x, nb_layers[block_idx], nb_filter, growth_rate, dropout_rate=dropout_rate, 1059 | 1060 | weight_decay=weight_decay) 1061 | 1062 | 1063 | 1064 | # Skip connection 1065 | 1066 | skip_list.append(x) 1067 | 1068 | 1069 | 1070 | # add transition_block 1071 | 1072 | x = __transition_block(x, nb_filter, compression=compression, weight_decay=weight_decay) 1073 | 1074 | 1075 | 1076 | nb_filter = int(nb_filter * compression) # this is calculated inside transition_down_block 1077 | 1078 | 1079 | 1080 | # The last dense_block does not have a transition_down_block 1081 | 1082 | # return the concatenated feature maps without the concatenation of the input 1083 | 1084 | _, nb_filter, concat_list = __dense_block(x, bottleneck_nb_layers, nb_filter, growth_rate, 1085 | 1086 | dropout_rate=dropout_rate, weight_decay=weight_decay, 1087 | 1088 | return_concat_list=True) 1089 | 1090 | 1091 | 1092 | skip_list = skip_list[::-1] # reverse the skip list 1093 | 1094 | 1095 | 1096 | # Add dense blocks and transition up block 1097 | 1098 | for block_idx in range(nb_dense_block): 1099 | 1100 | n_filters_keep = growth_rate * nb_layers[nb_dense_block + block_idx] 1101 | 1102 | 1103 | 1104 | # upsampling block must upsample only the feature maps (concat_list[1:]), 1105 | 1106 | # not the concatenation of the input with the feature maps (concat_list[0]. 1107 | 1108 | l = concatenate(concat_list[1:], axis=concat_axis) 1109 | 1110 | 1111 | 1112 | t = __transition_up_block(l, nb_filters=n_filters_keep, type=upsampling_type, weight_decay=weight_decay) 1113 | 1114 | 1115 | 1116 | # concatenate the skip connection with the transition block 1117 | 1118 | x = concatenate([t, skip_list[block_idx]], axis=concat_axis) 1119 | 1120 | 1121 | 1122 | # Dont allow the feature map size to grow in upsampling dense blocks 1123 | 1124 | x_up, nb_filter, concat_list = __dense_block(x, nb_layers[nb_dense_block + block_idx + 1], nb_filter=growth_rate, 1125 | 1126 | growth_rate=growth_rate, dropout_rate=dropout_rate, 1127 | 1128 | weight_decay=weight_decay, return_concat_list=True, 1129 | 1130 | grow_nb_filters=False) 1131 | 1132 | 1133 | 1134 | if include_top: 1135 | 1136 | x = Conv2D(nb_classes, (1, 1), activation='linear', padding='same', use_bias=False)(x_up) 1137 | 1138 | 1139 | 1140 | if K.image_data_format() == 'channels_first': 1141 | 1142 | channel, row, col = input_shape 1143 | 1144 | else: 1145 | 1146 | row, col, channel = input_shape 1147 | 1148 | 1149 | 1150 | x = Reshape((row * col, nb_classes))(x) 1151 | 1152 | x = Activation(activation)(x) 1153 | 1154 | x = Reshape((row, col, nb_classes))(x) 1155 | 1156 | else: 1157 | 1158 | x = x_up 1159 | 1160 | 1161 | 1162 | return x 1163 | 1164 | 1165 | 1166 | 1167 | 1168 | 1169 | 1170 | 1171 | 1172 | if __name__ == '__main__': 1173 | 1174 | 1175 | 1176 | from keras.utils.vis_utils import plot_model 1177 | 1178 | #model = DenseNetFCN((32, 32, 3), growth_rate=16, nb_layers_per_block=[4, 5, 7, 10, 12, 15], upsampling_type='deconv') 1179 | 1180 | model = DenseNet((32, 32, 3), depth=100, nb_dense_block=3, 1181 | 1182 | growth_rate=12, bottleneck=True, reduction=0.5, weights=None) 1183 | 1184 | model.summary() 1185 | 1186 | 1187 | 1188 | from keras.callbacks import ModelCheckpoint, TensorBoard 1189 | 1190 | #plot_model(model, 'test.png', show_shapes=True) 1191 | -------------------------------------------------------------------------------- /Classifier/DenseNet/subpixel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | 5 | from keras import backend as K 6 | 7 | from keras.engine import Layer 8 | 9 | from keras.utils.generic_utils import get_custom_objects 10 | 11 | from keras.backend import normalize_data_format 12 | 13 | 14 | 15 | if K.backend() == 'theano': 16 | 17 | import theano_backend as K_BACKEND 18 | 19 | else: 20 | 21 | import tensorflow_backend as K_BACKEND 22 | 23 | 24 | 25 | class SubPixelUpscaling(Layer): 26 | 27 | """ Sub-pixel convolutional upscaling layer based on the paper "Real-Time Single Image 28 | 29 | and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" 30 | 31 | (https://arxiv.org/abs/1609.05158). 32 | 33 | This layer requires a Convolution2D prior to it, having output filters computed according to 34 | 35 | the formula : 36 | 37 | filters = k * (scale_factor * scale_factor) 38 | 39 | where k = a user defined number of filters (generally larger than 32) 40 | 41 | scale_factor = the upscaling factor (generally 2) 42 | 43 | This layer performs the depth to space operation on the convolution filters, and returns a 44 | 45 | tensor with the size as defined below. 46 | 47 | # Example : 48 | 49 | ```python 50 | 51 | # A standard subpixel upscaling block 52 | 53 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(...) 54 | 55 | u = SubPixelUpscaling(scale_factor=2)(x) 56 | 57 | [Optional] 58 | 59 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(u) 60 | 61 | ``` 62 | 63 | In practice, it is useful to have a second convolution layer after the 64 | 65 | SubPixelUpscaling layer to speed up the learning process. 66 | 67 | However, if you are stacking multiple SubPixelUpscaling blocks, it may increase 68 | 69 | the number of parameters greatly, so the Convolution layer after SubPixelUpscaling 70 | 71 | layer can be removed. 72 | 73 | # Arguments 74 | 75 | scale_factor: Upscaling factor. 76 | 77 | data_format: Can be None, 'channels_first' or 'channels_last'. 78 | 79 | # Input shape 80 | 81 | 4D tensor with shape: 82 | 83 | `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` if data_format='channels_first' 84 | 85 | or 4D tensor with shape: 86 | 87 | `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` if data_format='channels_last'. 88 | 89 | # Output shape 90 | 91 | 4D tensor with shape: 92 | 93 | `(samples, k channels, rows * scale_factor, cols * scale_factor))` if data_format='channels_first' 94 | 95 | or 4D tensor with shape: 96 | 97 | `(samples, rows * scale_factor, cols * scale_factor, k channels)` if data_format='channels_last'. 98 | 99 | """ 100 | 101 | 102 | 103 | def __init__(self, scale_factor=2, data_format=None, **kwargs): 104 | 105 | super(SubPixelUpscaling, self).__init__(**kwargs) 106 | 107 | 108 | 109 | self.scale_factor = scale_factor 110 | 111 | self.data_format = normalize_data_format(data_format) 112 | 113 | 114 | 115 | def build(self, input_shape): 116 | 117 | pass 118 | 119 | 120 | 121 | def call(self, x, mask=None): 122 | 123 | y = K_BACKEND.depth_to_space(x, self.scale_factor, self.data_format) 124 | 125 | return y 126 | 127 | 128 | 129 | def compute_output_shape(self, input_shape): 130 | 131 | if self.data_format == 'channels_first': 132 | 133 | b, k, r, c = input_shape 134 | 135 | return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor) 136 | 137 | else: 138 | 139 | b, r, c, k = input_shape 140 | 141 | return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2)) 142 | 143 | 144 | 145 | def get_config(self): 146 | 147 | config = {'scale_factor': self.scale_factor, 148 | 149 | 'data_format': self.data_format} 150 | 151 | base_config = super(SubPixelUpscaling, self).get_config() 152 | 153 | return dict(list(base_config.items()) + list(config.items())) 154 | 155 | 156 | 157 | 158 | 159 | get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling}) 160 | -------------------------------------------------------------------------------- /Classifier/DenseNet/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | 6 | from keras.backend import tensorflow_backend as KTF 7 | 8 | from keras.backend.common import image_data_format 9 | 10 | 11 | 12 | py_all = all 13 | 14 | 15 | 16 | def depth_to_space(input, scale, data_format=None): 17 | 18 | ''' Uses phase shift algorithm to convert channels/depth for spatial resolution ''' 19 | 20 | if data_format is None: 21 | 22 | data_format = image_data_format() 23 | 24 | 25 | 26 | if data_format == 'channels_first': 27 | 28 | data_format = 'NCHW' 29 | 30 | else: 31 | 32 | data_format = 'NHWC' 33 | 34 | 35 | 36 | data_format = data_format.lower() 37 | 38 | out = tf.depth_to_space(input, scale, data_format=data_format) 39 | 40 | return out 41 | -------------------------------------------------------------------------------- /Classifier/ResNet.py: -------------------------------------------------------------------------------- 1 | """ResNet v1, v2, and segmentation models for Keras. 2 | 3 | 4 | 5 | # Reference 6 | 7 | 8 | 9 | - [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) 10 | 11 | - [Identity Mappings in Deep Residual Networks](https://arxiv.org/abs/1603.05027) 12 | 13 | 14 | 15 | Reference material for extended functionality: 16 | 17 | 18 | 19 | - [ResNeXt](https://arxiv.org/abs/1611.05431) for Tiny ImageNet support. 20 | 21 | - [Dilated Residual Networks](https://arxiv.org/pdf/1705.09914) for segmentation support 22 | 23 | - [Deep Residual Learning for Instrument Segmentation in 24 | 25 | Robotic Surgery](https://arxiv.org/abs/1703.08580) 26 | 27 | for segmentation support. 28 | 29 | 30 | 31 | Implementation Adapted from: github.com/raghakot/keras-resnet 32 | 33 | """ # pylint: disable=E501 34 | 35 | from __future__ import division 36 | 37 | 38 | 39 | import six 40 | 41 | from keras.models import Model 42 | 43 | from keras.layers import Input 44 | 45 | from keras.layers import Activation 46 | 47 | from keras.layers import Reshape 48 | 49 | from keras.layers import Dense 50 | 51 | from keras.layers import Conv2D 52 | 53 | from keras.layers import MaxPooling2D 54 | 55 | from keras.layers import GlobalMaxPooling2D 56 | 57 | from keras.layers import GlobalAveragePooling2D 58 | 59 | from keras.layers import Dropout 60 | 61 | from keras.layers.merge import add 62 | 63 | from keras.layers.normalization import BatchNormalization 64 | 65 | from keras.regularizers import l2 66 | 67 | from keras import backend as K 68 | 69 | from keras_applications.imagenet_utils import _obtain_input_shape 70 | 71 | 72 | 73 | 74 | 75 | def _bn_relu(x, bn_name=None, relu_name=None): 76 | 77 | """Helper to build a BN -> relu block 78 | 79 | """ 80 | 81 | norm = BatchNormalization(axis=CHANNEL_AXIS, name=bn_name)(x) 82 | 83 | return Activation("relu", name=relu_name)(norm) 84 | 85 | 86 | 87 | 88 | 89 | def _conv_bn_relu(**conv_params): 90 | 91 | """Helper to build a conv -> BN -> relu residual unit activation function. 92 | 93 | This is the original ResNet v1 scheme in https://arxiv.org/abs/1512.03385 94 | 95 | """ 96 | 97 | filters = conv_params["filters"] 98 | 99 | kernel_size = conv_params["kernel_size"] 100 | 101 | strides = conv_params.setdefault("strides", (1, 1)) 102 | 103 | dilation_rate = conv_params.setdefault("dilation_rate", (1, 1)) 104 | 105 | conv_name = conv_params.setdefault("conv_name", None) 106 | 107 | bn_name = conv_params.setdefault("bn_name", None) 108 | 109 | relu_name = conv_params.setdefault("relu_name", None) 110 | 111 | kernel_initializer = conv_params.setdefault("kernel_initializer", "random_normal") 112 | 113 | padding = conv_params.setdefault("padding", "same") 114 | 115 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 116 | 117 | 118 | 119 | def f(x): 120 | 121 | x = Conv2D(filters=filters, kernel_size=kernel_size, 122 | 123 | strides=strides, padding=padding, 124 | 125 | dilation_rate=dilation_rate, 126 | 127 | kernel_initializer=kernel_initializer, 128 | 129 | kernel_regularizer=kernel_regularizer, 130 | 131 | name=conv_name)(x) 132 | 133 | return _bn_relu(x, bn_name=bn_name, relu_name=relu_name) 134 | 135 | 136 | 137 | return f 138 | 139 | 140 | 141 | 142 | 143 | def _bn_relu_conv(**conv_params): 144 | 145 | """Helper to build a BN -> relu -> conv residual unit with full pre-activation 146 | 147 | function. This is the ResNet v2 scheme proposed in 148 | 149 | http://arxiv.org/pdf/1603.05027v2.pdf 150 | 151 | """ 152 | 153 | filters = conv_params["filters"] 154 | 155 | kernel_size = conv_params["kernel_size"] 156 | 157 | strides = conv_params.setdefault("strides", (1, 1)) 158 | 159 | dilation_rate = conv_params.setdefault("dilation_rate", (1, 1)) 160 | 161 | conv_name = conv_params.setdefault("conv_name", None) 162 | 163 | bn_name = conv_params.setdefault("bn_name", None) 164 | 165 | relu_name = conv_params.setdefault("relu_name", None) 166 | 167 | kernel_initializer = conv_params.setdefault("kernel_initializer", "random_normal") 168 | 169 | padding = conv_params.setdefault("padding", "same") 170 | 171 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 172 | 173 | 174 | 175 | def f(x): 176 | 177 | activation = _bn_relu(x, bn_name=bn_name, relu_name=relu_name) 178 | 179 | return Conv2D(filters=filters, kernel_size=kernel_size, 180 | 181 | strides=strides, padding=padding, 182 | 183 | dilation_rate=dilation_rate, 184 | 185 | kernel_initializer=kernel_initializer, 186 | 187 | kernel_regularizer=kernel_regularizer, 188 | 189 | name=conv_name)(activation) 190 | 191 | 192 | 193 | return f 194 | 195 | 196 | 197 | 198 | 199 | def _shortcut(input_feature, residual, conv_name_base=None, bn_name_base=None): 200 | 201 | """Adds a shortcut between input and residual block and merges them with "sum" 202 | 203 | """ 204 | 205 | # Expand channels of shortcut to match residual. 206 | 207 | # Stride appropriately to match residual (width, height) 208 | 209 | # Should be int if network architecture is correctly configured. 210 | 211 | input_shape = K.int_shape(input_feature) 212 | 213 | residual_shape = K.int_shape(residual) 214 | 215 | stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS])) 216 | 217 | stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS])) 218 | 219 | equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS] 220 | 221 | 222 | 223 | shortcut = input_feature 224 | 225 | # 1 X 1 conv if shape is different. Else identity. 226 | 227 | if stride_width > 1 or stride_height > 1 or not equal_channels: 228 | 229 | print('reshaping via a convolution...') 230 | 231 | if conv_name_base is not None: 232 | 233 | conv_name_base = conv_name_base + '1' 234 | 235 | shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS], 236 | 237 | kernel_size=(1, 1), 238 | 239 | strides=(stride_width, stride_height), 240 | 241 | padding="valid", 242 | 243 | kernel_initializer="he_normal", 244 | 245 | kernel_regularizer=l2(0.0001), 246 | 247 | name=conv_name_base)(input_feature) 248 | 249 | if bn_name_base is not None: 250 | 251 | bn_name_base = bn_name_base + '1' 252 | 253 | shortcut = BatchNormalization(axis=CHANNEL_AXIS, 254 | 255 | name=bn_name_base)(shortcut) 256 | 257 | 258 | 259 | return add([shortcut, residual]) 260 | 261 | 262 | 263 | 264 | 265 | def _residual_block(block_function, filters, blocks, stage, 266 | 267 | transition_strides=None, transition_dilation_rates=None, 268 | 269 | dilation_rates=None, is_first_layer=False, dropout=None, 270 | 271 | residual_unit=_bn_relu_conv): 272 | 273 | """Builds a residual block with repeating bottleneck blocks. 274 | 275 | 276 | 277 | stage: integer, current stage label, used for generating layer names 278 | 279 | blocks: number of blocks 'a','b'..., current block label, used for generating 280 | 281 | layer names 282 | 283 | transition_strides: a list of tuples for the strides of each transition 284 | 285 | transition_dilation_rates: a list of tuples for the dilation rate of each 286 | 287 | transition 288 | 289 | """ 290 | 291 | if transition_dilation_rates is None: 292 | 293 | transition_dilation_rates = [(1, 1)] * blocks 294 | 295 | if transition_strides is None: 296 | 297 | transition_strides = [(1, 1)] * blocks 298 | 299 | if dilation_rates is None: 300 | 301 | dilation_rates = [1] * blocks 302 | 303 | 304 | 305 | def f(x): 306 | 307 | for i in range(blocks): 308 | 309 | is_first_block = is_first_layer and i == 0 310 | 311 | x = block_function(filters=filters, stage=stage, block=i, 312 | 313 | transition_strides=transition_strides[i], 314 | 315 | dilation_rate=dilation_rates[i], 316 | 317 | is_first_block_of_first_layer=is_first_block, 318 | 319 | dropout=dropout, 320 | 321 | residual_unit=residual_unit)(x) 322 | 323 | return x 324 | 325 | 326 | 327 | return f 328 | 329 | 330 | 331 | 332 | 333 | def _block_name_base(stage, block): 334 | 335 | """Get the convolution name base and batch normalization name base defined by 336 | 337 | stage and block. 338 | 339 | 340 | 341 | If there are less than 26 blocks they will be labeled 'a', 'b', 'c' to match the 342 | 343 | paper and keras and beyond 26 blocks they will simply be numbered. 344 | 345 | """ 346 | 347 | if block < 27: 348 | 349 | block = '%c' % (block + 97) # 97 is the ascii number for lowercase 'a' 350 | 351 | conv_name_base = 'res' + str(stage) + block + '_branch' 352 | 353 | bn_name_base = 'bn' + str(stage) + block + '_branch' 354 | 355 | return conv_name_base, bn_name_base 356 | 357 | 358 | 359 | 360 | 361 | def basic_block(filters, stage, block, transition_strides=(1, 1), 362 | 363 | dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None, 364 | 365 | residual_unit=_bn_relu_conv): 366 | 367 | """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34. 368 | 369 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 370 | 371 | """ 372 | 373 | def f(input_features): 374 | 375 | conv_name_base, bn_name_base = _block_name_base(stage, block) 376 | 377 | if is_first_block_of_first_layer: 378 | 379 | # don't repeat bn->relu since we just did bn->relu->maxpool 380 | 381 | x = Conv2D(filters=filters, kernel_size=(3, 3), 382 | 383 | strides=transition_strides, 384 | 385 | dilation_rate=dilation_rate, 386 | 387 | padding="same", 388 | 389 | kernel_initializer="he_normal", 390 | 391 | kernel_regularizer=l2(1e-4), 392 | 393 | name=conv_name_base + '2a')(input_features) 394 | 395 | else: 396 | 397 | x = residual_unit(filters=filters, kernel_size=(3, 3), 398 | 399 | strides=transition_strides, 400 | 401 | dilation_rate=dilation_rate, 402 | 403 | conv_name_base=conv_name_base + '2a', 404 | 405 | bn_name_base=bn_name_base + '2a')(input_features) 406 | 407 | 408 | 409 | if dropout is not None: 410 | 411 | x = Dropout(dropout)(x) 412 | 413 | 414 | 415 | x = residual_unit(filters=filters, kernel_size=(3, 3), 416 | 417 | conv_name_base=conv_name_base + '2b', 418 | 419 | bn_name_base=bn_name_base + '2b')(x) 420 | 421 | 422 | 423 | return _shortcut(input_features, x) 424 | 425 | 426 | 427 | return f 428 | 429 | 430 | 431 | 432 | 433 | def bottleneck(filters, stage, block, transition_strides=(1, 1), 434 | 435 | dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None, 436 | 437 | residual_unit=_bn_relu_conv): 438 | 439 | """Bottleneck architecture for > 34 layer resnet. 440 | 441 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 442 | 443 | 444 | 445 | Returns: 446 | 447 | A final conv layer of filters * 4 448 | 449 | """ 450 | 451 | def f(input_feature): 452 | 453 | conv_name_base, bn_name_base = _block_name_base(stage, block) 454 | 455 | if is_first_block_of_first_layer: 456 | 457 | # don't repeat bn->relu since we just did bn->relu->maxpool 458 | 459 | x = Conv2D(filters=filters, kernel_size=(1, 1), 460 | 461 | strides=transition_strides, 462 | 463 | dilation_rate=dilation_rate, 464 | 465 | padding="same", 466 | 467 | kernel_initializer="he_normal", 468 | 469 | kernel_regularizer=l2(1e-4), 470 | 471 | name=conv_name_base + '2a')(input_feature) 472 | 473 | else: 474 | 475 | x = residual_unit(filters=filters, kernel_size=(1, 1), 476 | 477 | strides=transition_strides, 478 | 479 | dilation_rate=dilation_rate, 480 | 481 | conv_name_base=conv_name_base + '2a', 482 | 483 | bn_name_base=bn_name_base + '2a')(input_feature) 484 | 485 | 486 | 487 | if dropout is not None: 488 | 489 | x = Dropout(dropout)(x) 490 | 491 | 492 | 493 | x = residual_unit(filters=filters, kernel_size=(3, 3), 494 | 495 | conv_name_base=conv_name_base + '2b', 496 | 497 | bn_name_base=bn_name_base + '2b')(x) 498 | 499 | 500 | 501 | if dropout is not None: 502 | 503 | x = Dropout(dropout)(x) 504 | 505 | 506 | 507 | x = residual_unit(filters=filters * 4, kernel_size=(1, 1), 508 | 509 | conv_name_base=conv_name_base + '2c', 510 | 511 | bn_name_base=bn_name_base + '2c')(x) 512 | 513 | 514 | 515 | return _shortcut(input_feature, x) 516 | 517 | 518 | 519 | return f 520 | 521 | 522 | 523 | 524 | 525 | def _handle_dim_ordering(): 526 | 527 | global ROW_AXIS 528 | 529 | global COL_AXIS 530 | 531 | global CHANNEL_AXIS 532 | 533 | if K.image_data_format() == 'channels_last': 534 | 535 | ROW_AXIS = 1 536 | 537 | COL_AXIS = 2 538 | 539 | CHANNEL_AXIS = 3 540 | 541 | else: 542 | 543 | CHANNEL_AXIS = 1 544 | 545 | ROW_AXIS = 2 546 | 547 | COL_AXIS = 3 548 | 549 | 550 | 551 | 552 | 553 | def _string_to_function(identifier): 554 | 555 | if isinstance(identifier, six.string_types): 556 | 557 | res = globals().get(identifier) 558 | 559 | if not res: 560 | 561 | raise ValueError('Invalid {}'.format(identifier)) 562 | 563 | return res 564 | 565 | return identifier 566 | 567 | 568 | 569 | 570 | 571 | def ResNet(input_shape=None, classes=10, block='bottleneck', residual_unit='v2', 572 | 573 | repetitions=None, initial_filters=64, activation='softmax', include_top=True, 574 | 575 | input_tensor=None, dropout=None, transition_dilation_rate=(1, 1), 576 | 577 | initial_strides=(2, 2), initial_kernel_size=(7, 7), initial_pooling='max', 578 | 579 | final_pooling=None, top='classification'): 580 | 581 | """Builds a custom ResNet like architecture. Defaults to ResNet50 v2. 582 | 583 | 584 | 585 | Args: 586 | 587 | input_shape: optional shape tuple, only to be specified 588 | 589 | if `include_top` is False (otherwise the input shape 590 | 591 | has to be `(224, 224, 3)` (with `channels_last` dim ordering) 592 | 593 | or `(3, 224, 224)` (with `channels_first` dim ordering). 594 | 595 | It should have exactly 3 dimensions, 596 | 597 | and width and height should be no smaller than 8. 598 | 599 | E.g. `(224, 224, 3)` would be one valid value. 600 | 601 | classes: The number of outputs at final softmax layer 602 | 603 | block: The block function to use. This is either `'basic'` or `'bottleneck'`. 604 | 605 | The original paper used `basic` for layers < 50. 606 | 607 | repetitions: Number of repetitions of various block units. 608 | 609 | At each block unit, the number of filters are doubled and the input size 610 | 611 | is halved. Default of None implies the ResNet50v2 values of [3, 4, 6, 3]. 612 | 613 | residual_unit: the basic residual unit, 'v1' for conv bn relu, 'v2' for bn relu 614 | 615 | conv. See [Identity Mappings in 616 | 617 | Deep Residual Networks](https://arxiv.org/abs/1603.05027) 618 | 619 | for details. 620 | 621 | dropout: None for no dropout, otherwise rate of dropout from 0 to 1. 622 | 623 | Based on [Wide Residual Networks.(https://arxiv.org/pdf/1605.07146) paper. 624 | 625 | transition_dilation_rate: Dilation rate for transition layers. For semantic 626 | 627 | segmentation of images use a dilation rate of (2, 2). 628 | 629 | initial_strides: Stride of the very first residual unit and MaxPooling2D call, 630 | 631 | with default (2, 2), set to (1, 1) for small images like cifar. 632 | 633 | initial_kernel_size: kernel size of the very first convolution, (7, 7) for 634 | 635 | imagenet and (3, 3) for small image datasets like tiny imagenet and cifar. 636 | 637 | See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details. 638 | 639 | initial_pooling: Determine if there will be an initial pooling layer, 640 | 641 | 'max' for imagenet and None for small image datasets. 642 | 643 | See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details. 644 | 645 | final_pooling: Optional pooling mode for feature extraction at the final 646 | 647 | model layer when `include_top` is `False`. 648 | 649 | - `None` means that the output of the model 650 | 651 | will be the 4D tensor output of the 652 | 653 | last convolutional layer. 654 | 655 | - `avg` means that global average pooling 656 | 657 | will be applied to the output of the 658 | 659 | last convolutional layer, and thus 660 | 661 | the output of the model will be a 662 | 663 | 2D tensor. 664 | 665 | - `max` means that global max pooling will 666 | 667 | be applied. 668 | 669 | top: Defines final layers to evaluate based on a specific problem type. Options 670 | 671 | are 'classification' for ImageNet style problems, 'segmentation' for 672 | 673 | problems like the Pascal VOC dataset, and None to exclude these layers 674 | 675 | entirely. 676 | 677 | 678 | 679 | Returns: 680 | 681 | The keras `Model`. 682 | 683 | """ 684 | 685 | if activation not in ['softmax', 'sigmoid', None]: 686 | 687 | raise ValueError('activation must be one of "softmax", "sigmoid", or None') 688 | 689 | if activation == 'sigmoid' and classes != 1: 690 | 691 | raise ValueError('sigmoid activation can only be used when classes = 1') 692 | 693 | if repetitions is None: 694 | 695 | repetitions = [3, 4, 6, 3] 696 | 697 | # Determine proper input shape 698 | 699 | input_shape = _obtain_input_shape(input_shape, 700 | 701 | default_size=32, 702 | 703 | min_size=8, 704 | 705 | data_format=K.image_data_format(), 706 | 707 | require_flatten=include_top) 708 | 709 | _handle_dim_ordering() 710 | 711 | if len(input_shape) != 3: 712 | 713 | raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)") 714 | 715 | 716 | 717 | if block == 'basic': 718 | 719 | block_fn = basic_block 720 | 721 | elif block == 'bottleneck': 722 | 723 | block_fn = bottleneck 724 | 725 | elif isinstance(block, six.string_types): 726 | 727 | block_fn = _string_to_function(block) 728 | 729 | else: 730 | 731 | block_fn = block 732 | 733 | 734 | 735 | if residual_unit == 'v2': 736 | 737 | residual_unit = _bn_relu_conv 738 | 739 | elif residual_unit == 'v1': 740 | 741 | residual_unit = _conv_bn_relu 742 | 743 | elif isinstance(residual_unit, six.string_types): 744 | 745 | residual_unit = _string_to_function(residual_unit) 746 | 747 | else: 748 | 749 | residual_unit = residual_unit 750 | 751 | 752 | 753 | # Permute dimension order if necessary 754 | 755 | if K.image_data_format() == 'channels_first': 756 | 757 | input_shape = (input_shape[1], input_shape[2], input_shape[0]) 758 | 759 | # Determine proper input shape 760 | 761 | input_shape = _obtain_input_shape(input_shape, 762 | 763 | default_size=32, 764 | 765 | min_size=8, 766 | 767 | data_format=K.image_data_format(), 768 | 769 | require_flatten=include_top) 770 | 771 | 772 | 773 | img_input = Input(shape=input_shape, tensor=input_tensor) 774 | 775 | x = _conv_bn_relu(filters=initial_filters, kernel_size=initial_kernel_size, 776 | 777 | strides=initial_strides)(img_input) 778 | 779 | if initial_pooling == 'max': 780 | 781 | x = MaxPooling2D(pool_size=(3, 3), strides=initial_strides, padding="same")(x) 782 | 783 | 784 | 785 | block = x 786 | 787 | filters = initial_filters 788 | 789 | for i, r in enumerate(repetitions): 790 | 791 | transition_dilation_rates = [transition_dilation_rate] * r 792 | 793 | transition_strides = [(1, 1)] * r 794 | 795 | if transition_dilation_rate == (1, 1): 796 | 797 | transition_strides[0] = (2, 2) 798 | 799 | block = _residual_block(block_fn, filters=filters, 800 | 801 | stage=i, blocks=r, 802 | 803 | is_first_layer=(i == 0), 804 | 805 | dropout=dropout, 806 | 807 | transition_dilation_rates=transition_dilation_rates, 808 | 809 | transition_strides=transition_strides, 810 | 811 | residual_unit=residual_unit)(block) 812 | 813 | filters *= 2 814 | 815 | 816 | 817 | # Last activation 818 | 819 | inter = _bn_relu(block) 820 | 821 | 822 | 823 | # Classifier block 824 | 825 | if include_top and top is 'classification': 826 | 827 | x = GlobalAveragePooling2D()(inter) 828 | 829 | x = Dense(units=classes, activation=activation, 830 | 831 | kernel_initializer="he_normal")(x) 832 | 833 | elif include_top and top is 'segmentation': 834 | 835 | x = Conv2D(classes, (1, 1), activation='linear', padding='same')(inter) 836 | 837 | 838 | 839 | if K.image_data_format() == 'channels_first': 840 | 841 | channel, row, col = input_shape 842 | 843 | else: 844 | 845 | row, col, channel = input_shape 846 | 847 | 848 | 849 | x = Reshape((row * col, classes))(x) 850 | 851 | x = Activation(activation)(x) 852 | 853 | x = Reshape((row, col, classes))(x) 854 | 855 | elif final_pooling == 'avg': 856 | 857 | x = GlobalAveragePooling2D()(inter) 858 | 859 | elif final_pooling == 'max': 860 | 861 | x = GlobalMaxPooling2D()(inter) 862 | 863 | 864 | 865 | model = Model(inputs=img_input, outputs=x) 866 | 867 | return model 868 | 869 | 870 | 871 | 872 | 873 | def ResNet18(input_shape, classes): 874 | 875 | """ResNet with 18 layers and v2 residual units 876 | 877 | """ 878 | 879 | return ResNet(input_shape, classes, basic_block, repetitions=[2, 2, 2, 2]) 880 | 881 | 882 | 883 | 884 | 885 | def ResNet34(input_shape, classes): 886 | 887 | """ResNet with 34 layers and v2 residual units 888 | 889 | """ 890 | 891 | return ResNet(input_shape, classes, basic_block, repetitions=[3, 4, 6, 3]) 892 | 893 | 894 | 895 | 896 | 897 | def ResNet50(input_shape, classes): 898 | 899 | """ResNet with 50 layers and v2 residual units 900 | 901 | """ 902 | 903 | return ResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 6, 3]) 904 | 905 | 906 | 907 | 908 | 909 | def ResNet101(input_shape, classes): 910 | 911 | """ResNet with 101 layers and v2 residual units 912 | 913 | """ 914 | 915 | return ResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 23, 3]) 916 | 917 | 918 | 919 | 920 | 921 | def ResNet152(input_shape, classes): 922 | 923 | """ResNet with 152 layers and v2 residual units 924 | 925 | """ 926 | 927 | return ResNet(input_shape, classes, bottleneck, repetitions=[3, 8, 36, 3]) -------------------------------------------------------------------------------- /Classifier/VGG.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from __future__ import division 4 | 5 | from __future__ import print_function 6 | 7 | import os 8 | from keras import layers 9 | from keras import models 10 | from keras.models import * 11 | from keras.layers import * 12 | 13 | 14 | def VGG16(): 15 | 16 | model = Sequential() 17 | 18 | model.add(ZeroPadding2D((1,1),input_shape=(544,352,1))) #here to change input shape 19 | 20 | model.add(Convolution2D(2, 3, 3, activation='relu',kernel_initializer="he_normal")) 21 | 22 | model.add(ZeroPadding2D((1,1))) 23 | 24 | model.add(Convolution2D(2, 3, 3, activation='relu',kernel_initializer="he_normal")) 25 | 26 | model.add(MaxPooling2D((2,2), strides=(2,2))) 27 | 28 | 29 | 30 | model.add(ZeroPadding2D((1,1))) 31 | 32 | model.add(Convolution2D(4, 3, 3, activation='relu',kernel_initializer="he_normal")) 33 | 34 | model.add(ZeroPadding2D((1,1))) 35 | 36 | model.add(Convolution2D(4, 3, 3, activation='relu',kernel_initializer="he_normal")) 37 | 38 | model.add(MaxPooling2D((2,2), strides=(2,2))) 39 | 40 | 41 | 42 | model.add(ZeroPadding2D((1,1))) 43 | 44 | model.add(Convolution2D(8, 3, 3, activation='relu',kernel_initializer="he_normal")) 45 | 46 | model.add(ZeroPadding2D((1,1))) 47 | 48 | model.add(Convolution2D(8, 3, 3, activation='relu',kernel_initializer="he_normal")) 49 | 50 | model.add(ZeroPadding2D((1,1))) 51 | 52 | model.add(Convolution2D(8, 3, 3, activation='relu',kernel_initializer="he_normal")) 53 | 54 | model.add(MaxPooling2D((2,2), strides=(2,2))) 55 | 56 | 57 | 58 | model.add(ZeroPadding2D((1,1))) 59 | 60 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 61 | 62 | model.add(ZeroPadding2D((1,1))) 63 | 64 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 65 | 66 | model.add(ZeroPadding2D((1,1))) 67 | 68 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 69 | 70 | model.add(MaxPooling2D((2,2), strides=(2,2))) 71 | 72 | 73 | 74 | model.add(ZeroPadding2D((1,1))) 75 | 76 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 77 | 78 | model.add(ZeroPadding2D((1,1))) 79 | 80 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 81 | 82 | model.add(ZeroPadding2D((1,1))) 83 | 84 | model.add(Convolution2D(16, 3, 3, activation='relu',kernel_initializer="he_normal")) 85 | 86 | model.add(MaxPooling2D((2,2), strides=(2,2), name='inter')) 87 | 88 | 89 | 90 | model.add(Flatten()) 91 | 92 | model.add(Dense(1024, activation='relu', name='fc1')) 93 | 94 | model.add(Dropout(0.5)) 95 | 96 | model.add(Dense(256, activation='relu', name='fc2')) 97 | 98 | #model.add(Dropout(0.5)) 99 | 100 | model.add(Dense(4, activation='softmax')) 101 | 102 | 103 | return model 104 | 105 | 106 | 107 | def VGG19(input_shape=None,classes=5, use_soft=True): 108 | 109 | 110 | img_input = layers.Input(shape=input_shape) 111 | 112 | 113 | # Block 1 114 | 115 | x = layers.Conv2D(2, (3, 3), 116 | 117 | activation='relu', 118 | 119 | padding='same', 120 | 121 | name='block1_conv1',kernel_initializer="he_normal")(img_input) 122 | 123 | x = layers.Conv2D(2, (3, 3), 124 | 125 | activation='relu', 126 | 127 | padding='same', 128 | 129 | name='block1_conv2',kernel_initializer="he_normal")(x) 130 | 131 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 132 | 133 | 134 | 135 | # Block 2 136 | 137 | x = layers.Conv2D(4, (3, 3), 138 | 139 | activation='relu', 140 | 141 | padding='same', 142 | 143 | name='block2_conv1',kernel_initializer="he_normal")(x) 144 | 145 | x = layers.Conv2D(4, (3, 3), 146 | 147 | activation='relu', 148 | 149 | padding='same', 150 | 151 | name='block2_conv2',kernel_initializer="he_normal")(x) 152 | 153 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 154 | 155 | 156 | 157 | # Block 3 158 | 159 | x = layers.Conv2D(8, (3, 3), 160 | 161 | activation='relu', 162 | 163 | padding='same', 164 | 165 | name='block3_conv1',kernel_initializer="he_normal")(x) 166 | 167 | x = layers.Conv2D(8, (3, 3), 168 | 169 | activation='relu', 170 | 171 | padding='same', 172 | 173 | name='block3_conv2',kernel_initializer="he_normal")(x) 174 | 175 | x = layers.Conv2D(8, (3, 3), 176 | 177 | activation='relu', 178 | 179 | padding='same', 180 | 181 | name='block3_conv3',kernel_initializer="he_normal")(x) 182 | x = layers.Conv2D(8, (3, 3), 183 | 184 | activation='relu', 185 | 186 | padding='same', 187 | 188 | name='block3_conv4',kernel_initializer="he_normal")(x) 189 | 190 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 191 | 192 | 193 | 194 | # Block 4 195 | 196 | x = layers.Conv2D(16, (3, 3), 197 | 198 | activation='relu', 199 | 200 | padding='same', 201 | 202 | name='block4_conv1',kernel_initializer="he_normal")(x) 203 | 204 | x = layers.Conv2D(16, (3, 3), 205 | 206 | activation='relu', 207 | 208 | padding='same', 209 | 210 | name='block4_conv2',kernel_initializer="he_normal")(x) 211 | 212 | x = layers.Conv2D(16, (3,3), 213 | 214 | activation='relu', 215 | 216 | padding='same', 217 | 218 | name='block4_conv3',kernel_initializer="he_normal")(x) 219 | x = layers.Conv2D(16, (3,3), 220 | 221 | activation='relu', 222 | 223 | padding='same', 224 | 225 | name='block4_conv4',kernel_initializer="he_normal")(x) 226 | 227 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 228 | 229 | 230 | 231 | # Block 5 232 | 233 | x = layers.Conv2D(16, (3, 3), 234 | 235 | activation='relu', 236 | 237 | padding='same', 238 | 239 | name='block5_conv1',kernel_initializer="he_normal")(x) 240 | 241 | x = layers.Conv2D(16, (3, 3), 242 | 243 | activation='relu', 244 | 245 | padding='same', 246 | 247 | name='block5_conv2',kernel_initializer="he_normal")(x) 248 | 249 | x = layers.Conv2D(16, (3, 3), 250 | 251 | activation='relu', 252 | 253 | padding='same', 254 | 255 | name='block5_conv3',kernel_initializer="he_normal")(x) 256 | x = layers.Conv2D(16, (3, 3), 257 | 258 | activation='relu', 259 | 260 | padding='same', 261 | 262 | name='block5_conv4',kernel_initializer="he_normal")(x) 263 | 264 | x = layers.MaxPooling2D((2, 2), strides=(2,2), name='block5_pool')(x) 265 | 266 | 267 | # Classification block 268 | 269 | x = layers.Flatten(name='flatten')(x) 270 | 271 | x = layers.Dense(512, activation='relu', name='fc1')(x) 272 | x=layers.Dropout(0.5)(x) 273 | 274 | x = layers.Dense(128, activation='relu', name='fc2')(x) 275 | #x=layers.Dropout(0.8)(x) 276 | 277 | if use_soft: 278 | x = Dense(classes, activation = "softmax", name='predictions')(x) 279 | else: 280 | x = Dense(classes, activation = "linear", name = "Z_4")(x) 281 | 282 | model = models.Model(img_input, x, name='vgg19') 283 | 284 | return model 285 | 286 | def VGG19_dense(input_shape=None,classes=4): 287 | 288 | 289 | img_input = layers.Input(shape=input_shape) 290 | 291 | 292 | # Block 1 293 | 294 | x = layers.Conv2D(2, (3, 3), 295 | 296 | activation='relu', 297 | 298 | padding='same', 299 | 300 | name='block1_conv1',kernel_initializer="he_normal")(img_input) 301 | 302 | x = layers.Conv2D(2, (3, 3), 303 | 304 | activation='relu', 305 | 306 | padding='same', 307 | 308 | name='block1_conv2',kernel_initializer="he_normal")(x) 309 | 310 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 311 | 312 | 313 | 314 | # Block 2 315 | 316 | x = layers.Conv2D(4, (3, 3), 317 | 318 | activation='relu', 319 | 320 | padding='same', 321 | 322 | name='block2_conv1',kernel_initializer="he_normal")(x) 323 | 324 | x = layers.Conv2D(4, (3, 3), 325 | 326 | activation='relu', 327 | 328 | padding='same', 329 | 330 | name='block2_conv2',kernel_initializer="he_normal")(x) 331 | 332 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 333 | 334 | 335 | 336 | # Block 3 337 | 338 | x = layers.Conv2D(8, (3, 3), 339 | 340 | activation='relu', 341 | 342 | padding='same', 343 | 344 | name='block3_conv1',kernel_initializer="he_normal")(x) 345 | 346 | x = layers.Conv2D(8, (3, 3), 347 | 348 | activation='relu', 349 | 350 | padding='same', 351 | 352 | name='block3_conv2',kernel_initializer="he_normal")(x) 353 | 354 | x = layers.Conv2D(8, (3, 3), 355 | 356 | activation='relu', 357 | 358 | padding='same', 359 | 360 | name='block3_conv3',kernel_initializer="he_normal")(x) 361 | x = layers.Conv2D(8, (3, 3), 362 | 363 | activation='relu', 364 | 365 | padding='same', 366 | 367 | name='block3_conv4',kernel_initializer="he_normal")(x) 368 | 369 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 370 | 371 | 372 | 373 | # Block 4 374 | 375 | x = layers.Conv2D(16, (3, 3), 376 | 377 | activation='relu', 378 | 379 | padding='same', 380 | 381 | name='block4_conv1',kernel_initializer="he_normal")(x) 382 | 383 | x = layers.Conv2D(16, (3, 3), 384 | 385 | activation='relu', 386 | 387 | padding='same', 388 | 389 | name='block4_conv2',kernel_initializer="he_normal")(x) 390 | 391 | x = layers.Conv2D(16, (3,3), 392 | 393 | activation='relu', 394 | 395 | padding='same', 396 | 397 | name='block4_conv3',kernel_initializer="he_normal")(x) 398 | x = layers.Conv2D(16, (3,3), 399 | 400 | activation='relu', 401 | 402 | padding='same', 403 | 404 | name='block4_conv4',kernel_initializer="he_normal")(x) 405 | 406 | x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 407 | 408 | 409 | 410 | # Block 5 411 | 412 | x = layers.Conv2D(16, (3, 3), 413 | 414 | activation='relu', 415 | 416 | padding='same', 417 | 418 | name='block5_conv1',kernel_initializer="he_normal")(x) 419 | 420 | x = layers.Conv2D(16, (3, 3), 421 | 422 | activation='relu', 423 | 424 | padding='same', 425 | 426 | name='block5_conv2',kernel_initializer="he_normal")(x) 427 | 428 | x = layers.Conv2D(16, (3, 3), 429 | 430 | activation='relu', 431 | 432 | padding='same', 433 | 434 | name='block5_conv3',kernel_initializer="he_normal")(x) 435 | x = layers.Conv2D(16, (3, 3), 436 | 437 | activation='relu', 438 | 439 | padding='same', 440 | 441 | name='block5_conv4',kernel_initializer="he_normal")(x) 442 | 443 | x = layers.MaxPooling2D((2, 2), strides=(2,2), name='block5_pool')(x) 444 | 445 | 446 | x = layers.Conv2D(32, (3, 3), activation='relu',name="feature")(x) 447 | x = layers.Conv2D(32, (1, 1), activation='relu')(x) 448 | x = layers.Conv2D(classes, (1, 1), activation='softmax')(x) 449 | x = layers.AveragePooling2D((9,9))(x) 450 | 451 | model = models.Model(img_input, x, name='vgg19') 452 | model.summary() 453 | 454 | 455 | 456 | return model 457 | 458 | -------------------------------------------------------------------------------- /Classifier/integration/model ensembel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import time 4 | import csv 5 | import os 6 | from sklearn.metrics import classification_report 7 | import sklearn.metrics as sklm 8 | import math 9 | from scipy import stats 10 | 11 | def dense_to_one_hot(labels_dense,num_clases=5): 12 | return np.eye(num_clases)[labels_dense] 13 | 14 | def load_val(): #average prediction class 15 | labelList=[] 16 | reader = open("/data/jiao/singlefront.csv") #prediction file path 17 | data=reader.readlines() 18 | label="q" 19 | for row in data: 20 | label=max(int(row.split(",")[0])) 21 | #label=math.ceil((int(row.split(",")[0])+int(row.split(",")[2]))/2) 22 | labelList.append(label) 23 | return np.array(labelList) 24 | 25 | def load_v(): #prediction maximum 26 | labelList=[] 27 | imlist=[] 28 | reader = open("/data/jiao/DenseNet-cls.csv")#prediction file path 29 | data=reader.readlines() 30 | label="q" 31 | for row in data: 32 | for i in range(4): 33 | imlist.append(int(row.split(",")[i])) 34 | label=np.bincount(imlist) 35 | imlist=[] 36 | labelList.append(np.argmax(label)) 37 | return np.array(labelList) 38 | 39 | 40 | def load_va():#average softmax possibilities 41 | imlist=[] 42 | labelList=[] 43 | reader = open("/data/jiao/DenseNet-posfront.csv")#prediction file path 44 | data=reader.readlines() 45 | label="q" 46 | for row in data: 47 | for i in range(4): 48 | imlist.append(float(row.split(",")[i+12])) 49 | label=imlist.index(max(imlist)) 50 | imlist=[] 51 | labelList.append(label) 52 | return np.array(labelList) 53 | 54 | def load_valY(): 55 | imgList=[] 56 | labelList=[] 57 | reader = open("/data/jiao/newlabel.csv") #original label file path 58 | data=reader.readlines() 59 | files = os.listdir('/data/jiao/XR/front/validation/')#test set image path 60 | for file in files: 61 | if file.endswith(".xml"):continue 62 | patient=file.split('_')[0] 63 | direction=file.split('_')[1].split('.')[0] 64 | label="q" 65 | for row in data: 66 | if patient in row.split(",")[0]: 67 | if "L" in direction: 68 | label=row.split(",")[3] 69 | else: 70 | label=row.split(",")[6] 71 | break 72 | if "V" in file: 73 | label="3" 74 | if "8" not in label and "9" not in label and "X" not in label and '.' not in label: 75 | #if "." in label: 76 | #label='4' 77 | labelList.append(int(label)) 78 | return np.array(labelList) 79 | 80 | print(sklm.accuracy_score(load_valY(),load_va())) 81 | print(sklm.classification_report(load_valY(), load_va())) 82 | print(sklm.confusion_matrix(load_valY(), load_va())) 83 | print(sklm.mean_squared_error(load_valY(), load_va())) 84 | print(sklm.mean_absolute_error(load_valY(), load_va())) -------------------------------------------------------------------------------- /Example_images/01E392EE-69F9-4E33-BFCE-E5C968654078-1920x1431.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/01E392EE-69F9-4E33-BFCE-E5C968654078-1920x1431.jpeg -------------------------------------------------------------------------------- /Example_images/0cc8ac00ad4d2c3e8ea3ed4a9c776a_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/0cc8ac00ad4d2c3e8ea3ed4a9c776a_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/11368d1bfb309b273d60a7138dae35_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/11368d1bfb309b273d60a7138dae35_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/1312A392-67A3-4EBF-9319-810CF6DA5EF6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/1312A392-67A3-4EBF-9319-810CF6DA5EF6.jpeg -------------------------------------------------------------------------------- /Example_images/1B734A89-A1BF-49A8-A1D3-66FAFA4FAC5D.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/1B734A89-A1BF-49A8-A1D3-66FAFA4FAC5D.jpeg -------------------------------------------------------------------------------- /Example_images/1c13301604cbe667b39ca3fe335501_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/1c13301604cbe667b39ca3fe335501_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/201b87f9902cce6732917d2f292bd3_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/201b87f9902cce6732917d2f292bd3_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/23E99E2E-447C-46E5-8EB2-D35D12473C39-1920x1440.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/23E99E2E-447C-46E5-8EB2-D35D12473C39-1920x1440.png -------------------------------------------------------------------------------- /Example_images/2966893D-5DDF-4B68-9E2B-4979D5956C8E.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/2966893D-5DDF-4B68-9E2B-4979D5956C8E.jpeg -------------------------------------------------------------------------------- /Example_images/2B8649B2-00C4-4233-85D5-1CE240CF233B.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/2B8649B2-00C4-4233-85D5-1CE240CF233B.jpeg -------------------------------------------------------------------------------- /Example_images/2C26F453-AF3B-4517-BB9E-802CF2179543.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/2C26F453-AF3B-4517-BB9E-802CF2179543.jpeg -------------------------------------------------------------------------------- /Example_images/32a46f77ff2a5acc2168b20b974cf8_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/32a46f77ff2a5acc2168b20b974cf8_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/39EE8E69-5801-48DE-B6E3-BE7D1BCF3092.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/39EE8E69-5801-48DE-B6E3-BE7D1BCF3092.jpeg -------------------------------------------------------------------------------- /Example_images/4d844df58f10acb17fc50351fd9440_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/4d844df58f10acb17fc50351fd9440_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/58cb9263f16e94305c730685358e4e_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/58cb9263f16e94305c730685358e4e_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/5A78BCA9-5B7A-440D-8A4E-AE7710EA6EAD-2048x1661.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/5A78BCA9-5B7A-440D-8A4E-AE7710EA6EAD-2048x1661.jpeg -------------------------------------------------------------------------------- /Example_images/5CBC2E94-D358-401E-8928-965CCD965C5C-1920x1581.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/5CBC2E94-D358-401E-8928-965CCD965C5C-1920x1581.jpeg -------------------------------------------------------------------------------- /Example_images/5CBC2E94-D358-401E-8928-965CCD965C5C-2048x1687.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/5CBC2E94-D358-401E-8928-965CCD965C5C-2048x1687.jpeg -------------------------------------------------------------------------------- /Example_images/61c6828be4bb24b9e29e6ebfcfec0a_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/61c6828be4bb24b9e29e6ebfcfec0a_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/6CB4EFC6-68FA-4CD5-940C-BEFA8DAFE9A7-1920x1239.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/6CB4EFC6-68FA-4CD5-940C-BEFA8DAFE9A7-1920x1239.jpeg -------------------------------------------------------------------------------- /Example_images/7848bf2d6be7318bf1457253990d25_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/7848bf2d6be7318bf1457253990d25_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/7AF6C1AF-D249-4BD2-8C26-449304105D03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/7AF6C1AF-D249-4BD2-8C26-449304105D03.jpeg -------------------------------------------------------------------------------- /Example_images/7E335538-2F86-424E-A0AB-6397783A38D0-1536x1246.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/7E335538-2F86-424E-A0AB-6397783A38D0-1536x1246.jpeg -------------------------------------------------------------------------------- /Example_images/7E335538-2F86-424E-A0AB-6397783A38D0-1920x1558.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/7E335538-2F86-424E-A0AB-6397783A38D0-1920x1558.jpeg -------------------------------------------------------------------------------- /Example_images/8549249b763152e944d3ad092a2a2d_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/8549249b763152e944d3ad092a2a2d_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/85E52EB3-56E9-4D67-82DA-DEA247C82886.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/85E52EB3-56E9-4D67-82DA-DEA247C82886.jpeg -------------------------------------------------------------------------------- /Example_images/93FE0BB1-022D-4F24-9727-987A07975FFB.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/93FE0BB1-022D-4F24-9727-987A07975FFB.jpeg -------------------------------------------------------------------------------- /Example_images/9C34AF49-E589-44D5-92D3-168B3B04E4A6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/9C34AF49-E589-44D5-92D3-168B3B04E4A6.jpeg -------------------------------------------------------------------------------- /Example_images/9ad688b362b011bd3f7503799515ef_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/9ad688b362b011bd3f7503799515ef_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/9f987e36c0a19aeb1f3c9151b66317_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/9f987e36c0a19aeb1f3c9151b66317_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/9fdd3c3032296fd04d2cad5d9070d4_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/9fdd3c3032296fd04d2cad5d9070d4_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/B2D20576-00B7-4519-A415-72DE29C90C34.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/B2D20576-00B7-4519-A415-72DE29C90C34.jpeg -------------------------------------------------------------------------------- /Example_images/B59DD164-51D5-40DF-A926-6A42DD52EBE8-1920x1472.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/B59DD164-51D5-40DF-A926-6A42DD52EBE8-1920x1472.jpeg -------------------------------------------------------------------------------- /Example_images/CD50BA96-6982-4C80-AE7B-5F67ACDBFA56.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/CD50BA96-6982-4C80-AE7B-5F67ACDBFA56.jpeg -------------------------------------------------------------------------------- /Example_images/CE13BB46-B19A-4B06-92CE-C479125C6CEA.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/CE13BB46-B19A-4B06-92CE-C479125C6CEA.jpeg -------------------------------------------------------------------------------- /Example_images/F051E018-DAD1-4506-AD43-BE4CA29E960B.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/F051E018-DAD1-4506-AD43-BE4CA29E960B.jpeg -------------------------------------------------------------------------------- /Example_images/F2DE909F-E19C-4900-92F5-8F435B031AC6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/F2DE909F-E19C-4900-92F5-8F435B031AC6.jpeg -------------------------------------------------------------------------------- /Example_images/F4341CE7-73C9-45C6-99C8-8567A5484B63.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/F4341CE7-73C9-45C6-99C8-8567A5484B63.jpeg -------------------------------------------------------------------------------- /Example_images/F63AB6CE-1968-4154-A70F-913AF154F53D-1920x1275.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/F63AB6CE-1968-4154-A70F-913AF154F53D-1920x1275.jpeg -------------------------------------------------------------------------------- /Example_images/a092a272b78ce7c23e6a490721b750_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/a092a272b78ce7c23e6a490721b750_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/a36d7944927e369c90035d4fcbb7af_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/a36d7944927e369c90035d4fcbb7af_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/b1921029beb35ebb6bc80b1bd5c043_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/b1921029beb35ebb6bc80b1bd5c043_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/b418d50351b48ee58bcb4c2841e95b_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/b418d50351b48ee58bcb4c2841e95b_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/b4da827908ad8209137382e301cc24_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/b4da827908ad8209137382e301cc24_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/b81bbc0418db1202a4e8d6015afb32_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/b81bbc0418db1202a4e8d6015afb32_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/bc4aafa5ad0aaa24a92afe73b06e74_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/bc4aafa5ad0aaa24a92afe73b06e74_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/c02786050656210c20eb86d3bc0d48_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/c02786050656210c20eb86d3bc0d48_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/df1053d3e8896b53ef140773e10e26_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/df1053d3e8896b53ef140773e10e26_jumbo.jpeg -------------------------------------------------------------------------------- /Example_images/e493ebb5ce513a0ad49237f008595b_jumbo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezacsedu/DeepCOVIDExplainer/5d77c295c9d8db678d8481532b7f01336b696777/Example_images/e493ebb5ce513a0ad49237f008595b_jumbo.jpeg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DeepCOVIDExplainer: Explainable COVID-19 Diagnosis from Chest X-rays 2 | Supplementary materials for "DeepCOVIDExplainer: Explainable COVID-19 Diagnosis from Chest Radiography Images" accepted at IEEE International Conference on Bioinformatics and Biomedicine (BIBM'2020), to be held in Seoul, South Korea. We provide details of dataset, preprocessing, network architectures, and some additional results. Nevertheless, we'll provide trained models, preprocessed data, interactive Python notebooks, and a web application showing live demo. As planned, we keep this repo updated. 3 | 4 | ### Methods 5 | The pipeline of "DeepCOVIDExplainer" consist of preprocessing, classification, snapshot neural ensemble, and decision visualizations. After necessary preprocessing of CXR images, DenseNet, ResNets, and VGGNets are trained in a transfer learning setting, creating their model snapshots, followed by neural snapshot ensemble based on averaging Softmax class posterior and the prediction maximization of best performing models. Finally, class-discriminating attention maps are generated using gradient-guided class activation maps (Grad-CAM++) and layer-wise relevance propagation (LRP) to provide explanations of the predictions and to identify the critical regions on patients chest. 6 | 7 | #### Erratum 8 | Although we mentioned in the paper mistakenly that network's weights were initialized with ImageNet, we rather trained them from scratch (please refer to the Jupyter notebooks). The reason is that ImageNet contains photos of general objects, which would activate the internal representation of network's hidden layers with geometrical forms, colorful patterns, or irrelevent shapes that are usually not present in biomedical images, e.g., x-ray, MRIs, or CT. 9 | 10 | ### Datasets 11 | We choose 3 different versions of COVIDx dataset to train and evaluate the model. The COVIDx v1.0 had a total of 5,941 CXR images from 2,839 patients. It is based on COVID-19 image dataset curated by Joseph P. C., et al. and Kaggle CXR Pneumonia dataset (https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia) by Paul Mooney. COVIDx v1.0 is already used in a literature (https://arxiv.org/abs/2003.09871). However, Kaggle CXR images are of children. Therefore, to avoid possible prediction bias~(e.g., the model might predict based on the chest size itself), we consider using the CXR of adults with pneumonia by augmenting more CXR images from COVID-19 confirmed cases. 12 | 13 | COVIDx v2.0 dataset is also based on COVID-19 image dataset, but come with RSNA Pneumonia Detection Challenge dataset (https://www.kaggle.com/c/rsna-pneumonia-detection-challenge) provided by the Radiological Society of North America. On the other hand, COVIDx v3.0 is also based on COVID-19 image dataset and RSNA Pneumonia dataset, we enriched it with additional 49 COVID-19 CXR from: i) Italian Radiological Case CASE (https://radiopaedia.org/articles/covid-19-3?lang=us), and ii) Radiopaedia.org (provided by Dr. Fabio Macori)(https://www.sirm.org/category/senza-categoria/covid-19/). COVIDx v1.0 CXR images are categorized with normal, bacterial, non-COVID19 viral and COVID19 viral, whereas COVIDx v2.0 and v3.0 CXR images are categorized as normal, pneumonia and COVID19 viral. The distribution of images and patient cases amongst the different infection types are as follows: 14 | 15 | ### Data availability 16 | We will open source the preprocessed data in npy file format to ease the community to build the model with ease. However, it'll take a few more days. 17 | 18 | ### Availability of pretrained models 19 | We plan to make public all the pretrained models and some computational resources available, but it will take time. For the time being, we made only VGG-19 and ResNet-18 upon reasonable request. 20 | 21 | ### A quick instructions 22 | A quick example on a small dataset can be performed as follows: 23 | 24 | ### Citation request 25 | If you use the code of this repository in your research, please consider citing the folowing papers: 26 | 27 | @inproceedings{DeepCOVIDExplainer, 28 | title={DeepCOVIDExplainer: Explainable COVID-19 Diagnosis from Chest X-ray Images}, 29 | author={Karim, Md Rezaul and Döhmen, Till and Rebholz-Schuhmann, Dietrich and Decker, Stefan and Cochez, Michael and Beyan, Oya}, 30 | conference={IEEE International Conference on Bioinformatics and Biomedicine (BIBM'2020)}, 31 | publisher={IEEE}, 32 | year={2020} 33 | } 34 | 35 | ### Contributing 36 | In future, we'll provide an email address, in case readers have any questions. 37 | -------------------------------------------------------------------------------- /noteboks/ResNet-18.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 41, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import cv2\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import os" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 42, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "x_train = np.load('data/x_train.npy')\n", 22 | "y_train = np.load('data/y_train.npy')\n", 23 | "x_test = np.load('data/x_test.npy')\n", 24 | "y_test = np.load('data/y_test.npy')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 43, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/plain": [ 35 | "(654, 224, 224, 3)" 36 | ] 37 | }, 38 | "execution_count": 43, 39 | "metadata": {}, 40 | "output_type": "execute_result" 41 | } 42 | ], 43 | "source": [ 44 | "x_test.shape" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 44, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "#x_train /= 255????" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 45, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "from __future__ import print_function\n", 63 | "\n", 64 | "import numpy as np\n", 65 | "from random import shuffle\n", 66 | "import time\n", 67 | "import csv\n", 68 | "from PIL import Image\n", 69 | "import os\n", 70 | "import tensorflow as tf\n", 71 | "import keras\n", 72 | "from keras.callbacks import EarlyStopping, LearningRateScheduler\n", 73 | "from keras import initializers\n", 74 | "from keras.optimizers import SGD\n", 75 | "from keras.preprocessing import sequence\n", 76 | "from keras.utils import np_utils\n", 77 | "from keras.models import Sequential,load_model,Model\n", 78 | "from keras.layers import Dense, Dropout, Activation, Flatten\n", 79 | "from keras.layers import *\n", 80 | "from keras.callbacks import CSVLogger\n", 81 | "from keras import callbacks\n", 82 | "from keras.preprocessing.image import ImageDataGenerator\n", 83 | "\n", 84 | "from sklearn.metrics import classification_report\n", 85 | "from sklearn.model_selection import train_test_split\n", 86 | "import sklearn.metrics as sklm\n", 87 | "import lossprettifier\n", 88 | "from ResNet import *" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 46, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# for reproducibility\n", 98 | "np.random.seed(3768)\n", 99 | "\n", 100 | "# use this environment flag to change which GPU to use \n", 101 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", 102 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" # specify which GPU(s) to be used\n", 103 | "\n", 104 | "#Get TensorFlow session\n", 105 | "def get_session(): \n", 106 | " config = tf.ConfigProto() \n", 107 | " config.gpu_options.allow_growth = True \n", 108 | " return tf.Session(config=config) \n", 109 | " \n", 110 | "# One hot encoding of labels \n", 111 | "def dense_to_one_hot(labels_dense,num_clases=4):\n", 112 | " return np.eye(num_clases)[labels_dense]" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 47, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# Preparing training and test sets\n", 122 | "x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.10, random_state=42)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 48, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "y_train = dense_to_one_hot(y_train,num_clases=4)\n", 132 | "y_valid= dense_to_one_hot(y_valid,num_clases=4)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 49, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "/home/reza/.local/lib/python3.6/site-packages/keras_preprocessing/image/image_data_generator.py:348: UserWarning: This ImageDataGenerator specifies `featurewise_std_normalization`, which overrides setting of `featurewise_center`.\n", 145 | " warnings.warn('This ImageDataGenerator specifies '\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "#Image data generation for the training \n", 151 | "datagen = ImageDataGenerator(\n", 152 | " featurewise_center = False, \n", 153 | " samplewise_center = False, # set each sample mean to 0\n", 154 | " featurewise_std_normalization = True, \n", 155 | " samplewise_std_normalization = False) \n", 156 | "\n", 157 | "datagen.fit(x_train) \n", 158 | "for i in range(len(x_test)):\n", 159 | " x_test[i] = datagen.standardize(x_test[i])" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "reshaping via a convolution...\n", 172 | "reshaping via a convolution...\n", 173 | "reshaping via a convolution...\n", 174 | "reshaping via a convolution...\n", 175 | "Epoch 1/50\n", 176 | "32/32 [==============================] - 78s 2s/step - loss: 3.8593 - accuracy: 0.6270 - val_loss: 1130.2300 - val_accuracy: 0.2539\n", 177 | "Epoch 2/50\n", 178 | "32/32 [==============================] - 73s 2s/step - loss: 3.6147 - accuracy: 0.7143 - val_loss: 286.7065 - val_accuracy: 0.2704\n", 179 | "Epoch 3/50\n", 180 | "32/32 [==============================] - 74s 2s/step - loss: 3.5902 - accuracy: 0.7178 - val_loss: 102.1051 - val_accuracy: 0.2684\n", 181 | "Epoch 4/50\n", 182 | "32/32 [==============================] - 74s 2s/step - loss: 3.5250 - accuracy: 0.7393 - val_loss: 26.7099 - val_accuracy: 0.2505\n", 183 | "Epoch 5/50\n", 184 | "32/32 [==============================] - 75s 2s/step - loss: 3.4561 - accuracy: 0.7832 - val_loss: 14.2865 - val_accuracy: 0.2744\n", 185 | "Epoch 6/50\n", 186 | "32/32 [==============================] - 73s 2s/step - loss: 3.4043 - accuracy: 0.7695 - val_loss: 9.3615 - val_accuracy: 0.2704\n", 187 | "Epoch 7/50\n", 188 | "32/32 [==============================] - 73s 2s/step - loss: 3.4066 - accuracy: 0.7871 - val_loss: 13.1856 - val_accuracy: 0.2584\n", 189 | "Epoch 8/50\n", 190 | "32/32 [==============================] - 73s 2s/step - loss: 3.4122 - accuracy: 0.7852 - val_loss: 5.2803 - val_accuracy: 0.3579\n", 191 | "Epoch 9/50\n", 192 | "32/32 [==============================] - 74s 2s/step - loss: 3.3905 - accuracy: 0.7861 - val_loss: 3.5473 - val_accuracy: 0.7276\n", 193 | "Epoch 10/50\n", 194 | "32/32 [==============================] - 73s 2s/step - loss: 3.3387 - accuracy: 0.7959 - val_loss: 3.8871 - val_accuracy: 0.6044\n", 195 | "Epoch 11/50\n", 196 | "32/32 [==============================] - 75s 2s/step - loss: 3.3018 - accuracy: 0.8018 - val_loss: 3.6547 - val_accuracy: 0.7038\n", 197 | "Epoch 12/50\n", 198 | "32/32 [==============================] - 74s 2s/step - loss: 3.3143 - accuracy: 0.8086 - val_loss: 3.3349 - val_accuracy: 0.7555\n", 199 | "Epoch 13/50\n", 200 | "32/32 [==============================] - 73s 2s/step - loss: 3.2848 - accuracy: 0.8132 - val_loss: 3.4284 - val_accuracy: 0.6859\n", 201 | "Epoch 14/50\n", 202 | "32/32 [==============================] - 74s 2s/step - loss: 3.3172 - accuracy: 0.8018 - val_loss: 3.8279 - val_accuracy: 0.5209\n", 203 | "Epoch 15/50\n", 204 | "32/32 [==============================] - 74s 2s/step - loss: 3.2460 - accuracy: 0.8340 - val_loss: 3.1901 - val_accuracy: 0.7694\n", 205 | "Epoch 16/50\n", 206 | "32/32 [==============================] - 74s 2s/step - loss: 3.2157 - accuracy: 0.8412 - val_loss: 3.4472 - val_accuracy: 0.7535\n", 207 | "Epoch 17/50\n", 208 | "32/32 [==============================] - 75s 2s/step - loss: 3.2261 - accuracy: 0.8232 - val_loss: 3.3914 - val_accuracy: 0.7475\n", 209 | "Epoch 18/50\n", 210 | "32/32 [==============================] - 74s 2s/step - loss: 3.2335 - accuracy: 0.8320 - val_loss: 3.5598 - val_accuracy: 0.7383\n", 211 | "Epoch 19/50\n", 212 | "32/32 [==============================] - 74s 2s/step - loss: 3.2080 - accuracy: 0.8291 - val_loss: 3.3610 - val_accuracy: 0.7455\n", 213 | "Epoch 20/50\n", 214 | "32/32 [==============================] - 74s 2s/step - loss: 3.1325 - accuracy: 0.8760 - val_loss: 3.7209 - val_accuracy: 0.7217\n", 215 | "Epoch 21/50\n", 216 | "32/32 [==============================] - 74s 2s/step - loss: 3.1284 - accuracy: 0.8631 - val_loss: 3.4455 - val_accuracy: 0.7495\n", 217 | "Epoch 22/50\n", 218 | "32/32 [==============================] - 75s 2s/step - loss: 3.1604 - accuracy: 0.8496 - val_loss: 3.0686 - val_accuracy: 0.7714\n", 219 | "Epoch 23/50\n", 220 | "32/32 [==============================] - 74s 2s/step - loss: 3.1449 - accuracy: 0.8369 - val_loss: 3.9068 - val_accuracy: 0.7515\n", 221 | "Epoch 24/50\n", 222 | "32/32 [==============================] - 70s 2s/step - loss: 3.1027 - accuracy: 0.8682 - val_loss: 3.1616 - val_accuracy: 0.7594\n", 223 | "Epoch 25/50\n", 224 | "32/32 [==============================] - 70s 2s/step - loss: 3.0319 - accuracy: 0.8955 - val_loss: 8.5852 - val_accuracy: 0.3161\n", 225 | "Epoch 26/50\n", 226 | "32/32 [==============================] - 69s 2s/step - loss: 3.0787 - accuracy: 0.8838 - val_loss: 3.2822 - val_accuracy: 0.7634\n", 227 | "Epoch 27/50\n", 228 | "32/32 [==============================] - 70s 2s/step - loss: 3.0778 - accuracy: 0.8741 - val_loss: 3.7060 - val_accuracy: 0.7575\n", 229 | "Epoch 28/50\n", 230 | "32/32 [==============================] - 72s 2s/step - loss: 3.0845 - accuracy: 0.8721 - val_loss: 3.4433 - val_accuracy: 0.7217\n", 231 | "Epoch 29/50\n", 232 | "32/32 [==============================] - 73s 2s/step - loss: 3.0260 - accuracy: 0.8906 - val_loss: 6.7136 - val_accuracy: 0.3579\n", 233 | "Epoch 30/50\n", 234 | "32/32 [==============================] - 72s 2s/step - loss: 2.9881 - accuracy: 0.8945 - val_loss: 3.2591 - val_accuracy: 0.7495\n", 235 | "Epoch 31/50\n", 236 | "32/32 [==============================] - 73s 2s/step - loss: 2.9370 - accuracy: 0.9258 - val_loss: 3.5817 - val_accuracy: 0.7217\n", 237 | "Epoch 32/50\n", 238 | "32/32 [==============================] - 72s 2s/step - loss: 2.9795 - accuracy: 0.8955 - val_loss: 3.2548 - val_accuracy: 0.7416\n", 239 | "Epoch 33/50\n", 240 | "32/32 [==============================] - 72s 2s/step - loss: 3.0388 - accuracy: 0.8871 - val_loss: 3.3379 - val_accuracy: 0.7455\n", 241 | "Epoch 34/50\n", 242 | "32/32 [==============================] - 72s 2s/step - loss: 2.9381 - accuracy: 0.9141 - val_loss: 3.5816 - val_accuracy: 0.7694\n", 243 | "Epoch 35/50\n", 244 | "32/32 [==============================] - 74s 2s/step - loss: 2.8870 - accuracy: 0.9355 - val_loss: 3.3679 - val_accuracy: 0.7695\n", 245 | "Epoch 36/50\n", 246 | "32/32 [==============================] - 75s 2s/step - loss: 2.9027 - accuracy: 0.9248 - val_loss: 4.0747 - val_accuracy: 0.7992\n", 247 | "Epoch 37/50\n", 248 | "32/32 [==============================] - 72s 2s/step - loss: 2.8682 - accuracy: 0.9401 - val_loss: 3.4795 - val_accuracy: 0.7833\n", 249 | "Epoch 38/50\n", 250 | "32/32 [==============================] - 74s 2s/step - loss: 2.8957 - accuracy: 0.9229 - val_loss: 4.4265 - val_accuracy: 0.6441\n", 251 | "Epoch 39/50\n", 252 | "32/32 [==============================] - 74s 2s/step - loss: 2.8088 - accuracy: 0.9551 - val_loss: 4.0555 - val_accuracy: 0.7336\n", 253 | "Epoch 40/50\n", 254 | "32/32 [==============================] - 73s 2s/step - loss: 2.9258 - accuracy: 0.9141 - val_loss: 5.3826 - val_accuracy: 0.4692\n", 255 | "Epoch 41/50\n", 256 | "32/32 [==============================] - 74s 2s/step - loss: 2.9288 - accuracy: 0.9033 - val_loss: 6.1204 - val_accuracy: 0.4970\n", 257 | "Epoch 42/50\n", 258 | "32/32 [==============================] - 73s 2s/step - loss: 2.9852 - accuracy: 0.8730 - val_loss: 3.9059 - val_accuracy: 0.7256\n", 259 | "Epoch 43/50\n", 260 | "32/32 [==============================] - 72s 2s/step - loss: 2.8609 - accuracy: 0.9287 - val_loss: 3.8044 - val_accuracy: 0.7555\n", 261 | "Epoch 44/50\n", 262 | "32/32 [==============================] - 72s 2s/step - loss: 2.7896 - accuracy: 0.9540 - val_loss: 3.2766 - val_accuracy: 0.7654\n", 263 | "Epoch 45/50\n", 264 | "11/32 [=========>....................] - ETA: 45s - loss: 2.7915 - accuracy: 0.9489" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "#Defining hyperparameters\n", 270 | "batch_Size = 32\n", 271 | "steps_Per_Epoch = 32\n", 272 | "numEpochs = 50\n", 273 | "\n", 274 | "#Instantating ResNet18 model\n", 275 | "model = ResNet18((224, 224, 3), 4) \n", 276 | "\n", 277 | "#Creating an optimizers\n", 278 | "adaDelta = keras.optimizers.Adadelta(lr=1.0, rho=0.95)\n", 279 | "sgd = SGD(lr=0.01, decay=1e-6, momentum=0.95, nesterov=True)\n", 280 | "model.compile(optimizer = sgd , loss = 'categorical_crossentropy', metrics = ['accuracy'])\n", 281 | "\n", 282 | "#Creating early stopping \n", 283 | "earlystop = EarlyStopping(monitor = 'val_accuracy', min_delta = 0, patience = 50, verbose = 1, mode = 'auto', restore_best_weights = True) \n", 284 | "\n", 285 | "train_generator = datagen.flow(x_train, y_train, batch_size = batch_Size)\n", 286 | "validation_generator = datagen.flow(x_valid, y_valid, batch_size = batch_Size)\n", 287 | "\n", 288 | "# Model training\n", 289 | "history = model.fit_generator(\n", 290 | " train_generator,\n", 291 | " steps_per_epoch = steps_Per_Epoch,\n", 292 | " validation_data = validation_generator, \n", 293 | " validation_steps = 16,\n", 294 | " epochs = numEpochs,\n", 295 | " shuffle = True, \n", 296 | " verbose = 1)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 11, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "modelPath = \"ResNet18_COVID19.h5\"\n", 306 | "resultPath = 'ResNet18_COVID19.txt'" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 29, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "Epoch 0 | LossA: 3.82(+0.00%) \u001b[0m\t| LossAB: 1515.01(+0.00%) \u001b[0m\t\n", 319 | "Epoch 1 | LossA: \u001b[32m3.61(-5.36%) ▼\u001b[0m\t| LossAB: \u001b[32m293.95(-80.60%) ▼\u001b[0m\t\n", 320 | "Epoch 2 | LossA: \u001b[32m3.59(-0.76%) ▼\u001b[0m\t| LossAB: \u001b[32m66.91(-77.24%) ▼\u001b[0m\t\n", 321 | "Epoch 3 | LossA: \u001b[32m3.49(-2.55%) ▼\u001b[0m\t| LossAB: \u001b[32m29.26(-56.27%) ▼\u001b[0m\t\n", 322 | "Epoch 4 | LossA: \u001b[32m3.48(-0.42%) ▼\u001b[0m\t| LossAB: \u001b[32m14.31(-51.10%) ▼\u001b[0m\t\n", 323 | "Epoch 5 | LossA: \u001b[32m3.40(-2.45%) ▼\u001b[0m\t| LossAB: \u001b[32m5.35(-62.63%) ▼\u001b[0m\t\n", 324 | "Epoch 6 | LossA: \u001b[91m3.42(+0.86%) ▲\u001b[0m\t| LossAB: \u001b[32m5.03(-5.97%) ▼\u001b[0m\t\n", 325 | "Epoch 7 | LossA: \u001b[91m3.43(+0.21%) ▲\u001b[0m\t| LossAB: \u001b[91m7.25(+44.15%) ▲\u001b[0m\t\n", 326 | "Epoch 8 | LossA: \u001b[32m3.41(-0.69%) ▼\u001b[0m\t| LossAB: \u001b[91m8.45(+16.64%) ▲\u001b[0m\t\n", 327 | "Epoch 9 | LossA: \u001b[32m3.35(-1.62%) ▼\u001b[0m\t| LossAB: \u001b[32m3.71(-56.15%) ▼\u001b[0m\t\n", 328 | "Epoch 10 | LossA: \u001b[32m3.32(-1.03%) ▼\u001b[0m\t| LossAB: \u001b[91m4.33(+16.90%) ▲\u001b[0m\t\n", 329 | "Epoch 11 | LossA: \u001b[32m3.31(-0.14%) ▼\u001b[0m\t| LossAB: \u001b[32m4.06(-6.24%) ▼\u001b[0m\t\n", 330 | "Epoch 12 | LossA: \u001b[32m3.31(-0.17%) ▼\u001b[0m\t| LossAB: \u001b[32m3.67(-9.78%) ▼\u001b[0m\t\n", 331 | "Epoch 13 | LossA: \u001b[91m3.32(+0.32%) ▲\u001b[0m\t| LossAB: \u001b[91m3.74(+1.93%) ▲\u001b[0m\t\n", 332 | "Epoch 14 | LossA: \u001b[32m3.21(-3.14%) ▼\u001b[0m\t| LossAB: \u001b[32m3.45(-7.80%) ▼\u001b[0m\t\n", 333 | "Epoch 15 | LossA: \u001b[91m3.22(+0.32%) ▲\u001b[0m\t| LossAB: \u001b[32m3.31(-3.85%) ▼\u001b[0m\t\n", 334 | "Epoch 16 | LossA: \u001b[91m3.23(+0.07%) ▲\u001b[0m\t| LossAB: \u001b[91m3.36(+1.49%) ▲\u001b[0m\t\n", 335 | "Epoch 17 | LossA: \u001b[91m3.23(+0.18%) ▲\u001b[0m\t| LossAB: \u001b[91m3.88(+15.33%) ▲\u001b[0m\t\n", 336 | "Epoch 18 | LossA: \u001b[32m3.21(-0.67%) ▼\u001b[0m\t| LossAB: \u001b[91m4.27(+10.18%) ▲\u001b[0m\t\n", 337 | "Epoch 19 | LossA: \u001b[32m3.12(-2.97%) ▼\u001b[0m\t| LossAB: \u001b[32m3.33(-22.06%) ▼\u001b[0m\t\n", 338 | "Epoch 20 | LossA: \u001b[91m3.13(+0.54%) ▲\u001b[0m\t| LossAB: \u001b[91m3.72(+11.81%) ▲\u001b[0m\t\n", 339 | "Epoch 21 | LossA: \u001b[32m3.10(-0.88%) ▼\u001b[0m\t| LossAB: \u001b[91m4.98(+33.81%) ▲\u001b[0m\t\n", 340 | "Epoch 22 | LossA: \u001b[91m3.16(+1.81%) ▲\u001b[0m\t| LossAB: \u001b[32m4.05(-18.67%) ▼\u001b[0m\t\n", 341 | "Epoch 23 | LossA: \u001b[32m3.09(-2.11%) ▼\u001b[0m\t| LossAB: \u001b[32m3.81(-5.93%) ▼\u001b[0m\t\n", 342 | "Epoch 24 | LossA: \u001b[32m3.05(-1.53%) ▼\u001b[0m\t| LossAB: \u001b[32m3.80(-0.33%) ▼\u001b[0m\t\n", 343 | "Epoch 25 | LossA: \u001b[91m3.05(+0.07%) ▲\u001b[0m\t| LossAB: \u001b[32m3.14(-17.29%) ▼\u001b[0m\t\n", 344 | "Epoch 26 | LossA: \u001b[91m3.12(+2.29%) ▲\u001b[0m\t| LossAB: \u001b[91m7.82(+149.01%) ▲\u001b[0m\t\n", 345 | "Epoch 27 | LossA: \u001b[32m3.11(-0.40%) ▼\u001b[0m\t| LossAB: \u001b[32m3.29(-57.95%) ▼\u001b[0m\t\n", 346 | "Epoch 28 | LossA: \u001b[32m3.06(-1.60%) ▼\u001b[0m\t| LossAB: \u001b[91m3.49(+6.19%) ▲\u001b[0m\t\n", 347 | "Epoch 29 | LossA: \u001b[32m2.95(-3.33%) ▼\u001b[0m\t| LossAB: \u001b[91m3.87(+10.78%) ▲\u001b[0m\t\n", 348 | "Epoch 30 | LossA: \u001b[91m2.96(+0.08%) ▲\u001b[0m\t| LossAB: \u001b[91m4.56(+17.81%) ▲\u001b[0m\t\n", 349 | "Epoch 31 | LossA: \u001b[91m3.00(+1.32%) ▲\u001b[0m\t| LossAB: \u001b[91m9.76(+113.97%) ▲\u001b[0m\t\n", 350 | "Epoch 32 | LossA: \u001b[91m3.03(+1.06%) ▲\u001b[0m\t| LossAB: \u001b[32m4.85(-50.29%) ▼\u001b[0m\t\n", 351 | "Epoch 33 | LossA: \u001b[32m2.97(-1.97%) ▼\u001b[0m\t| LossAB: \u001b[32m3.52(-27.45%) ▼\u001b[0m\t\n", 352 | "Epoch 34 | LossA: \u001b[32m2.93(-1.31%) ▼\u001b[0m\t| LossAB: \u001b[32m3.31(-5.88%) ▼\u001b[0m\t\n", 353 | "Epoch 35 | LossA: \u001b[91m2.94(+0.22%) ▲\u001b[0m\t| LossAB: \u001b[91m3.69(+11.55%) ▲\u001b[0m\t\n", 354 | "Epoch 36 | LossA: \u001b[91m3.02(+2.93%) ▲\u001b[0m\t| LossAB: \u001b[32m3.44(-6.77%) ▼\u001b[0m\t\n", 355 | "Epoch 37 | LossA: \u001b[32m2.89(-4.23%) ▼\u001b[0m\t| LossAB: \u001b[91m3.50(+1.54%) ▲\u001b[0m\t\n", 356 | "Epoch 38 | LossA: \u001b[32m2.84(-1.83%) ▼\u001b[0m\t| LossAB: \u001b[91m4.12(+17.77%) ▲\u001b[0m\t\n", 357 | "Epoch 39 | LossA: \u001b[32m2.84(-0.19%) ▼\u001b[0m\t| LossAB: \u001b[91m4.78(+16.15%) ▲\u001b[0m\t\n", 358 | "Epoch 40 | LossA: \u001b[32m2.82(-0.63%) ▼\u001b[0m\t| LossAB: \u001b[32m3.61(-24.54%) ▼\u001b[0m\t\n", 359 | "Epoch 41 | LossA: \u001b[91m2.84(+0.70%) ▲\u001b[0m\t| LossAB: \u001b[32m3.32(-7.92%) ▼\u001b[0m\t\n", 360 | "Epoch 42 | LossA: \u001b[32m2.78(-2.10%) ▼\u001b[0m\t| LossAB: \u001b[91m3.41(+2.56%) ▲\u001b[0m\t\n", 361 | "Epoch 43 | LossA: \u001b[91m2.78(+0.21%) ▲\u001b[0m\t| LossAB: \u001b[91m5.48(+60.87%) ▲\u001b[0m\t\n", 362 | "Epoch 44 | LossA: \u001b[32m2.77(-0.52%) ▼\u001b[0m\t| LossAB: \u001b[32m3.44(-37.22%) ▼\u001b[0m\t\n", 363 | "Epoch 45 | LossA: \u001b[91m2.80(+1.26%) ▲\u001b[0m\t| LossAB: \u001b[91m3.45(+0.32%) ▲\u001b[0m\t\n", 364 | "Epoch 46 | LossA: \u001b[91m2.81(+0.26%) ▲\u001b[0m\t| LossAB: \u001b[32m3.37(-2.38%) ▼\u001b[0m\t\n", 365 | "Epoch 47 | LossA: \u001b[32m2.74(-2.48%) ▼\u001b[0m\t| LossAB: \u001b[91m3.98(+18.11%) ▲\u001b[0m\t\n", 366 | "Epoch 48 | LossA: \u001b[32m2.73(-0.48%) ▼\u001b[0m\t| LossAB: \u001b[32m3.23(-19.01%) ▼\u001b[0m\t\n", 367 | "Epoch 49 | LossA: \u001b[91m2.75(+0.81%) ▲\u001b[0m\t| LossAB: \u001b[91m3.71(+15.05%) ▲\u001b[0m\t\n", 368 | "Epoch 50 | LossA: \u001b[32m2.72(-1.12%) ▼\u001b[0m\t| LossAB: \u001b[91m4.69(+26.48%) ▲\u001b[0m\t\n", 369 | "Epoch 51 | LossA: \u001b[91m2.74(+0.72%) ▲\u001b[0m\t| LossAB: \u001b[32m3.51(-25.30%) ▼\u001b[0m\t\n", 370 | "Epoch 52 | LossA: \u001b[32m2.67(-2.47%) ▼\u001b[0m\t| LossAB: \u001b[91m3.52(+0.32%) ▲\u001b[0m\t\n", 371 | "Epoch 53 | LossA: \u001b[91m2.68(+0.22%) ▲\u001b[0m\t| LossAB: \u001b[32m3.35(-4.84%) ▼\u001b[0m\t\n", 372 | "Epoch 54 | LossA: \u001b[32m2.67(-0.44%) ▼\u001b[0m\t| LossAB: \u001b[91m3.70(+10.65%) ▲\u001b[0m\t\n", 373 | "Epoch 55 | LossA: \u001b[32m2.66(-0.13%) ▼\u001b[0m\t| LossAB: \u001b[32m3.44(-7.20%) ▼\u001b[0m\t\n", 374 | "Epoch 56 | LossA: \u001b[32m2.65(-0.47%) ▼\u001b[0m\t| LossAB: \u001b[32m3.12(-9.08%) ▼\u001b[0m\t\n", 375 | "Epoch 57 | LossA: \u001b[32m2.62(-0.97%) ▼\u001b[0m\t| LossAB: \u001b[91m6.54(+109.42%) ▲\u001b[0m\t\n", 376 | "Epoch 58 | LossA: \u001b[91m2.65(+1.13%) ▲\u001b[0m\t| LossAB: \u001b[32m3.99(-39.06%) ▼\u001b[0m\t\n", 377 | "Epoch 59 | LossA: \u001b[91m2.68(+1.08%) ▲\u001b[0m\t| LossAB: \u001b[91m4.03(+0.96%) ▲\u001b[0m\t\n", 378 | "Epoch 60 | LossA: \u001b[32m2.67(-0.33%) ▼\u001b[0m\t| LossAB: \u001b[91m6.35(+57.78%) ▲\u001b[0m\t\n", 379 | "Epoch 61 | LossA: \u001b[32m2.64(-1.40%) ▼\u001b[0m\t| LossAB: \u001b[32m3.07(-51.59%) ▼\u001b[0m\t\n", 380 | "Epoch 62 | LossA: \u001b[91m2.64(+0.09%) ▲\u001b[0m\t| LossAB: \u001b[91m4.74(+54.00%) ▲\u001b[0m\t\n", 381 | "Epoch 63 | LossA: \u001b[32m2.62(-0.67%) ▼\u001b[0m\t| LossAB: \u001b[32m3.75(-20.81%) ▼\u001b[0m\t\n", 382 | "Epoch 64 | LossA: \u001b[91m2.64(+0.69%) ▲\u001b[0m\t| LossAB: \u001b[32m3.18(-15.07%) ▼\u001b[0m\t\n", 383 | "Epoch 65 | LossA: \u001b[32m2.62(-0.80%) ▼\u001b[0m\t| LossAB: \u001b[91m3.21(+0.88%) ▲\u001b[0m\t\n", 384 | "Epoch 66 | LossA: \u001b[32m2.59(-0.89%) ▼\u001b[0m\t| LossAB: \u001b[91m4.46(+38.81%) ▲\u001b[0m\t\n", 385 | "Epoch 67 | LossA: \u001b[32m2.57(-0.97%) ▼\u001b[0m\t| LossAB: \u001b[32m4.21(-5.53%) ▼\u001b[0m\t\n", 386 | "Epoch 68 | LossA: \u001b[91m2.58(+0.37%) ▲\u001b[0m\t| LossAB: \u001b[91m4.51(+6.95%) ▲\u001b[0m\t\n", 387 | "Epoch 69 | LossA: \u001b[32m2.57(-0.17%) ▼\u001b[0m\t| LossAB: \u001b[32m3.42(-24.02%) ▼\u001b[0m\t\n", 388 | "Epoch 70 | LossA: \u001b[32m2.55(-0.99%) ▼\u001b[0m\t| LossAB: \u001b[91m3.76(+9.72%) ▲\u001b[0m\t\n", 389 | "Epoch 71 | LossA: \u001b[91m2.57(+1.02%) ▲\u001b[0m\t| LossAB: \u001b[91m4.42(+17.55%) ▲\u001b[0m\t\n", 390 | "Epoch 72 | LossA: \u001b[32m2.57(-0.12%) ▼\u001b[0m\t| LossAB: \u001b[32m4.04(-8.55%) ▼\u001b[0m\t\n", 391 | "Epoch 73 | LossA: \u001b[91m2.59(+0.69%) ▲\u001b[0m\t| LossAB: \u001b[32m2.86(-29.15%) ▼\u001b[0m\t\n", 392 | "Epoch 74 | LossA: \u001b[32m2.55(-1.48%) ▼\u001b[0m\t| LossAB: \u001b[91m4.66(+62.79%) ▲\u001b[0m\t\n", 393 | "Epoch 75 | LossA: \u001b[32m2.53(-0.79%) ▼\u001b[0m\t| LossAB: \u001b[32m3.42(-26.57%) ▼\u001b[0m\t\n", 394 | "Epoch 76 | LossA: \u001b[91m2.54(+0.18%) ▲\u001b[0m\t| LossAB: \u001b[91m4.01(+17.38%) ▲\u001b[0m\t\n", 395 | "Epoch 77 | LossA: \u001b[32m2.52(-0.70%) ▼\u001b[0m\t| LossAB: \u001b[32m3.28(-18.35%) ▼\u001b[0m\t\n", 396 | "Epoch 78 | LossA: \u001b[91m2.52(+0.09%) ▲\u001b[0m\t| LossAB: \u001b[91m3.29(+0.35%) ▲\u001b[0m\t\n", 397 | "Epoch 79 | LossA: \u001b[32m2.50(-0.90%) ▼\u001b[0m\t| LossAB: \u001b[91m3.88(+17.82%) ▲\u001b[0m\t\n", 398 | "Epoch 80 | LossA: \u001b[32m2.49(-0.39%) ▼\u001b[0m\t| LossAB: \u001b[91m4.00(+3.28%) ▲\u001b[0m\t\n", 399 | "Epoch 81 | LossA: \u001b[32m2.47(-0.64%) ▼\u001b[0m\t| LossAB: \u001b[32m2.85(-28.78%) ▼\u001b[0m\t\n", 400 | "Epoch 82 | LossA: \u001b[32m2.47(-0.21%) ▼\u001b[0m\t| LossAB: \u001b[91m3.26(+14.44%) ▲\u001b[0m\t\n", 401 | "Epoch 83 | LossA: \u001b[91m2.47(+0.27%) ▲\u001b[0m\t| LossAB: \u001b[91m3.94(+20.65%) ▲\u001b[0m\t\n", 402 | "Epoch 84 | LossA: \u001b[32m2.46(-0.73%) ▼\u001b[0m\t| LossAB: \u001b[91m3.98(+1.18%) ▲\u001b[0m\t\n", 403 | "Epoch 85 | LossA: \u001b[32m2.45(-0.03%) ▼\u001b[0m\t| LossAB: \u001b[32m3.62(-9.00%) ▼\u001b[0m\t\n", 404 | "Epoch 86 | LossA: \u001b[32m2.44(-0.57%) ▼\u001b[0m\t| LossAB: \u001b[32m3.44(-5.09%) ▼\u001b[0m\t\n", 405 | "Epoch 87 | LossA: \u001b[91m2.45(+0.29%) ▲\u001b[0m\t| LossAB: \u001b[91m4.96(+44.28%) ▲\u001b[0m\t\n", 406 | "Epoch 88 | LossA: \u001b[32m2.44(-0.19%) ▼\u001b[0m\t| LossAB: \u001b[32m3.58(-27.96%) ▼\u001b[0m\t\n", 407 | "Epoch 89 | LossA: \u001b[32m2.43(-0.38%) ▼\u001b[0m\t| LossAB: \u001b[91m4.10(+14.55%) ▲\u001b[0m\t\n", 408 | "Epoch 90 | LossA: \u001b[91m2.44(+0.35%) ▲\u001b[0m\t| LossAB: \u001b[32m3.65(-10.96%) ▼\u001b[0m\t\n", 409 | "Epoch 91 | LossA: \u001b[91m2.45(+0.21%) ▲\u001b[0m\t| LossAB: \u001b[32m3.09(-15.24%) ▼\u001b[0m\t\n", 410 | "Epoch 92 | LossA: \u001b[91m2.46(+0.45%) ▲\u001b[0m\t| LossAB: \u001b[91m3.31(+6.99%) ▲\u001b[0m\t\n", 411 | "Epoch 93 | LossA: \u001b[32m2.44(-0.90%) ▼\u001b[0m\t| LossAB: \u001b[32m3.17(-4.00%) ▼\u001b[0m\t\n", 412 | "Epoch 94 | LossA: \u001b[91m2.48(+1.66%) ▲\u001b[0m\t| LossAB: \u001b[91m4.07(+28.08%) ▲\u001b[0m\t\n", 413 | "Epoch 95 | LossA: \u001b[32m2.47(-0.29%) ▼\u001b[0m\t| LossAB: \u001b[91m7.56(+86.01%) ▲\u001b[0m\t\n", 414 | "Epoch 96 | LossA: \u001b[32m2.46(-0.30%) ▼\u001b[0m\t| LossAB: \u001b[32m5.02(-33.60%) ▼\u001b[0m\t\n", 415 | "Epoch 97 | LossA: \u001b[32m2.43(-1.13%) ▼\u001b[0m\t| LossAB: \u001b[32m3.67(-26.98%) ▼\u001b[0m\t\n", 416 | "Epoch 98 | LossA: \u001b[91m2.44(+0.40%) ▲\u001b[0m\t| LossAB: \u001b[32m2.78(-24.13%) ▼\u001b[0m\t\n", 417 | "Epoch 99 | LossA: \u001b[32m2.40(-1.65%) ▼\u001b[0m\t| LossAB: \u001b[91m3.61(+29.67%) ▲\u001b[0m\t\n", 418 | "Epoch 100 | LossA: \u001b[32m2.38(-0.90%) ▼\u001b[0m\t| LossAB: \u001b[32m2.78(-23.04%) ▼\u001b[0m\t\n", 419 | "Epoch 101 | LossA: \u001b[91m2.38(+0.05%) ▲\u001b[0m\t| LossAB: \u001b[91m3.51(+26.40%) ▲\u001b[0m\t\n", 420 | "Epoch 102 | LossA: \u001b[91m2.48(+3.90%) ▲\u001b[0m\t| LossAB: \u001b[32m3.33(-5.17%) ▼\u001b[0m\t\n", 421 | "Epoch 103 | LossA: \u001b[32m2.39(-3.45%) ▼\u001b[0m\t| LossAB: \u001b[91m6.08(+82.79%) ▲\u001b[0m\t\n", 422 | "Epoch 104 | LossA: \u001b[32m2.38(-0.33%) ▼\u001b[0m\t| LossAB: \u001b[32m3.24(-46.72%) ▼\u001b[0m\t\n", 423 | "Epoch 105 | LossA: \u001b[91m2.39(+0.46%) ▲\u001b[0m\t| LossAB: \u001b[91m4.13(+27.56%) ▲\u001b[0m\t\n", 424 | "Epoch 106 | LossA: \u001b[32m2.36(-1.54%) ▼\u001b[0m\t| LossAB: \u001b[32m4.04(-2.24%) ▼\u001b[0m\t\n", 425 | "Epoch 107 | LossA: \u001b[91m2.37(+0.65%) ▲\u001b[0m\t| LossAB: \u001b[32m4.01(-0.68%) ▼\u001b[0m\t\n", 426 | "Epoch 108 | LossA: \u001b[32m2.33(-1.75%) ▼\u001b[0m\t| LossAB: \u001b[32m3.06(-23.68%) ▼\u001b[0m\t\n", 427 | "Epoch 109 | LossA: \u001b[91m2.34(+0.35%) ▲\u001b[0m\t| LossAB: \u001b[91m3.41(+11.29%) ▲\u001b[0m\t\n", 428 | "Epoch 110 | LossA: \u001b[91m2.34(+0.09%) ▲\u001b[0m\t| LossAB: \u001b[91m3.93(+15.15%) ▲\u001b[0m\t\n", 429 | "Epoch 111 | LossA: \u001b[91m2.38(+1.87%) ▲\u001b[0m\t| LossAB: \u001b[91m4.78(+21.77%) ▲\u001b[0m\t\n", 430 | "Epoch 112 | LossA: \u001b[32m2.35(-1.25%) ▼\u001b[0m\t| LossAB: \u001b[32m2.91(-39.14%) ▼\u001b[0m\t\n", 431 | "Epoch 113 | LossA: \u001b[32m2.32(-1.27%) ▼\u001b[0m\t| LossAB: \u001b[91m3.15(+8.31%) ▲\u001b[0m\t\n", 432 | "Epoch 114 | LossA: \u001b[32m2.30(-1.20%) ▼\u001b[0m\t| LossAB: \u001b[91m3.60(+14.34%) ▲\u001b[0m\t\n", 433 | "Epoch 115 | LossA: \u001b[91m2.31(+0.60%) ▲\u001b[0m\t| LossAB: \u001b[32m3.41(-5.42%) ▼\u001b[0m\t\n", 434 | "Epoch 116 | LossA: \u001b[32m2.29(-0.68%) ▼\u001b[0m\t| LossAB: \u001b[91m3.41(+0.16%) ▲\u001b[0m\t\n", 435 | "Epoch 117 | LossA: \u001b[32m2.28(-0.64%) ▼\u001b[0m\t| LossAB: \u001b[32m3.22(-5.53%) ▼\u001b[0m\t\n", 436 | "Epoch 118 | LossA: \u001b[32m2.27(-0.51%) ▼\u001b[0m\t| LossAB: \u001b[91m3.54(+9.72%) ▲\u001b[0m\t\n", 437 | "Epoch 119 | LossA: \u001b[32m2.26(-0.38%) ▼\u001b[0m\t| LossAB: \u001b[91m3.78(+6.87%) ▲\u001b[0m\t\n", 438 | "Epoch 120 | LossA: \u001b[91m2.26(+0.11%) ▲\u001b[0m\t| LossAB: \u001b[91m4.40(+16.33%) ▲\u001b[0m\t\n", 439 | "Epoch 121 | LossA: \u001b[32m2.26(-0.00%) ▼\u001b[0m\t| LossAB: \u001b[32m3.16(-28.24%) ▼\u001b[0m\t\n", 440 | "Epoch 122 | LossA: \u001b[32m2.24(-0.77%) ▼\u001b[0m\t| LossAB: \u001b[91m4.18(+32.32%) ▲\u001b[0m\t\n", 441 | "Epoch 123 | LossA: \u001b[32m2.24(-0.41%) ▼\u001b[0m\t| LossAB: \u001b[32m2.94(-29.50%) ▼\u001b[0m\t\n", 442 | "Epoch 124 | LossA: \u001b[91m2.24(+0.07%) ▲\u001b[0m\t| LossAB: \u001b[91m3.50(+18.91%) ▲\u001b[0m\t\n", 443 | "Epoch 125 | LossA: \u001b[32m2.23(-0.43%) ▼\u001b[0m\t| LossAB: \u001b[91m3.72(+6.15%) ▲\u001b[0m\t\n", 444 | "Epoch 126 | LossA: \u001b[32m2.23(-0.10%) ▼\u001b[0m\t| LossAB: \u001b[91m4.22(+13.42%) ▲\u001b[0m\t\n", 445 | "Epoch 127 | LossA: \u001b[91m2.28(+2.64%) ▲\u001b[0m\t| LossAB: \u001b[32m2.83(-32.83%) ▼\u001b[0m\t\n", 446 | "Epoch 128 | LossA: \u001b[32m2.25(-1.37%) ▼\u001b[0m\t| LossAB: \u001b[91m3.73(+31.88%) ▲\u001b[0m\t\n", 447 | "Epoch 129 | LossA: \u001b[32m2.23(-1.11%) ▼\u001b[0m\t| LossAB: \u001b[32m3.25(-12.96%) ▼\u001b[0m\t\n", 448 | "Epoch 130 | LossA: \u001b[32m2.21(-0.77%) ▼\u001b[0m\t| LossAB: \u001b[91m3.44(+5.70%) ▲\u001b[0m\t\n", 449 | "Epoch 131 | LossA: \u001b[32m2.20(-0.29%) ▼\u001b[0m\t| LossAB: \u001b[32m2.62(-23.79%) ▼\u001b[0m\t\n", 450 | "Epoch 132 | LossA: \u001b[32m2.20(-0.13%) ▼\u001b[0m\t| LossAB: \u001b[91m2.80(+6.84%) ▲\u001b[0m\t\n", 451 | "Epoch 133 | LossA: \u001b[91m2.22(+0.71%) ▲\u001b[0m\t| LossAB: \u001b[91m3.82(+36.42%) ▲\u001b[0m\t\n", 452 | "Epoch 134 | LossA: \u001b[32m2.19(-0.97%) ▼\u001b[0m\t| LossAB: \u001b[91m4.07(+6.58%) ▲\u001b[0m\t\n", 453 | "Epoch 135 | LossA: \u001b[32m2.19(-0.03%) ▼\u001b[0m\t| LossAB: \u001b[32m2.81(-30.90%) ▼\u001b[0m\t\n", 454 | "Epoch 136 | LossA: \u001b[32m2.17(-1.26%) ▼\u001b[0m\t| LossAB: \u001b[91m3.67(+30.60%) ▲\u001b[0m\t\n", 455 | "Epoch 137 | LossA: \u001b[32m2.16(-0.27%) ▼\u001b[0m\t| LossAB: \u001b[32m3.58(-2.34%) ▼\u001b[0m\t\n", 456 | "Epoch 138 | LossA: \u001b[32m2.16(-0.05%) ▼\u001b[0m\t| LossAB: \u001b[91m3.73(+4.18%) ▲\u001b[0m\t\n", 457 | "Epoch 139 | LossA: \u001b[32m2.15(-0.39%) ▼\u001b[0m\t| LossAB: \u001b[32m2.33(-37.67%) ▼\u001b[0m\t\n", 458 | "Epoch 140 | LossA: \u001b[32m2.15(-0.29%) ▼\u001b[0m\t| LossAB: \u001b[91m2.97(+27.61%) ▲\u001b[0m\t\n", 459 | "Epoch 141 | LossA: \u001b[32m2.13(-0.52%) ▼\u001b[0m\t| LossAB: \u001b[91m4.36(+46.84%) ▲\u001b[0m\t\n", 460 | "Epoch 142 | LossA: \u001b[91m2.14(+0.12%) ▲\u001b[0m\t| LossAB: \u001b[32m2.80(-35.81%) ▼\u001b[0m\t\n", 461 | "Epoch 143 | LossA: \u001b[32m2.13(-0.40%) ▼\u001b[0m\t| LossAB: \u001b[91m3.21(+14.74%) ▲\u001b[0m\t\n", 462 | "Epoch 144 | LossA: \u001b[91m2.13(+0.02%) ▲\u001b[0m\t| LossAB: \u001b[32m2.73(-15.12%) ▼\u001b[0m\t\n", 463 | "Epoch 145 | LossA: \u001b[32m2.11(-0.72%) ▼\u001b[0m\t| LossAB: \u001b[91m2.92(+6.93%) ▲\u001b[0m\t\n", 464 | "Epoch 146 | LossA: \u001b[32m2.11(-0.25%) ▼\u001b[0m\t| LossAB: \u001b[91m3.89(+33.44%) ▲\u001b[0m\t\n", 465 | "Epoch 147 | LossA: \u001b[32m2.10(-0.24%) ▼\u001b[0m\t| LossAB: \u001b[32m3.63(-6.77%) ▼\u001b[0m\t\n", 466 | "Epoch 148 | LossA: \u001b[32m2.10(-0.30%) ▼\u001b[0m\t| LossAB: \u001b[32m3.51(-3.31%) ▼\u001b[0m\t\n", 467 | "Epoch 149 | LossA: \u001b[32m2.09(-0.22%) ▼\u001b[0m\t| LossAB: \u001b[91m4.49(+28.10%) ▲\u001b[0m\t\n", 468 | "Epoch 150 | LossA: \u001b[32m2.09(-0.32%) ▼\u001b[0m\t| LossAB: \u001b[32m2.62(-41.76%) ▼\u001b[0m\t\n", 469 | "Epoch 151 | LossA: \u001b[32m2.08(-0.15%) ▼\u001b[0m\t| LossAB: \u001b[91m3.48(+32.87%) ▲\u001b[0m\t\n", 470 | "Epoch 152 | LossA: \u001b[32m2.08(-0.17%) ▼\u001b[0m\t| LossAB: \u001b[91m4.25(+22.37%) ▲\u001b[0m\t\n", 471 | "Epoch 153 | LossA: \u001b[32m2.07(-0.31%) ▼\u001b[0m\t| LossAB: \u001b[32m2.48(-41.78%) ▼\u001b[0m\t\n", 472 | "Epoch 154 | LossA: \u001b[32m2.07(-0.26%) ▼\u001b[0m\t| LossAB: \u001b[91m3.16(+27.51%) ▲\u001b[0m\t\n", 473 | "Epoch 155 | LossA: \u001b[32m2.06(-0.30%) ▼\u001b[0m\t| LossAB: \u001b[91m3.25(+2.89%) ▲\u001b[0m\t\n", 474 | "Epoch 156 | LossA: \u001b[32m2.06(-0.14%) ▼\u001b[0m\t| LossAB: \u001b[32m2.71(-16.69%) ▼\u001b[0m\t\n", 475 | "Epoch 157 | LossA: \u001b[32m2.05(-0.41%) ▼\u001b[0m\t| LossAB: \u001b[91m3.07(+13.53%) ▲\u001b[0m\t\n", 476 | "Epoch 158 | LossA: \u001b[32m2.04(-0.25%) ▼\u001b[0m\t| LossAB: \u001b[32m2.37(-22.97%) ▼\u001b[0m\t\n", 477 | "Epoch 159 | LossA: \u001b[32m2.04(-0.28%) ▼\u001b[0m\t| LossAB: \u001b[91m2.78(+17.40%) ▲\u001b[0m\t\n", 478 | "Epoch 160 | LossA: \u001b[32m2.03(-0.27%) ▼\u001b[0m\t| LossAB: \u001b[91m2.88(+3.50%) ▲\u001b[0m\t\n", 479 | "Epoch 161 | LossA: \u001b[91m2.04(+0.15%) ▲\u001b[0m\t| LossAB: \u001b[91m3.22(+12.04%) ▲\u001b[0m\t\n", 480 | "Epoch 162 | LossA: \u001b[32m2.03(-0.34%) ▼\u001b[0m\t| LossAB: \u001b[91m3.45(+6.93%) ▲\u001b[0m\t\n", 481 | "Epoch 163 | LossA: \u001b[32m2.02(-0.55%) ▼\u001b[0m\t| LossAB: \u001b[32m3.22(-6.60%) ▼\u001b[0m\t\n", 482 | "Epoch 164 | LossA: \u001b[32m2.02(-0.07%) ▼\u001b[0m\t| LossAB: \u001b[91m3.66(+13.68%) ▲\u001b[0m\t\n", 483 | "Epoch 165 | LossA: \u001b[91m2.03(+0.48%) ▲\u001b[0m\t| LossAB: \u001b[91m4.09(+11.80%) ▲\u001b[0m\t\n", 484 | "Epoch 166 | LossA: \u001b[32m2.01(-0.67%) ▼\u001b[0m\t| LossAB: \u001b[32m3.47(-15.15%) ▼\u001b[0m\t\n", 485 | "Epoch 167 | LossA: \u001b[32m2.01(-0.15%) ▼\u001b[0m\t| LossAB: \u001b[91m3.62(+4.26%) ▲\u001b[0m\t\n", 486 | "Epoch 168 | LossA: \u001b[32m2.00(-0.51%) ▼\u001b[0m\t| LossAB: \u001b[91m3.86(+6.65%) ▲\u001b[0m\t\n", 487 | "Epoch 169 | LossA: \u001b[32m1.99(-0.52%) ▼\u001b[0m\t| LossAB: \u001b[32m3.52(-8.80%) ▼\u001b[0m\t\n", 488 | "Epoch 170 | LossA: \u001b[32m1.99(-0.18%) ▼\u001b[0m\t| LossAB: \u001b[32m2.68(-23.80%) ▼\u001b[0m\t\n", 489 | "Epoch 171 | LossA: \u001b[32m1.98(-0.16%) ▼\u001b[0m\t| LossAB: \u001b[91m3.16(+17.69%) ▲\u001b[0m\t\n", 490 | "Epoch 172 | LossA: \u001b[32m1.98(-0.31%) ▼\u001b[0m\t| LossAB: \u001b[91m4.80(+52.18%) ▲\u001b[0m\t\n", 491 | "Epoch 173 | LossA: \u001b[32m1.97(-0.21%) ▼\u001b[0m\t| LossAB: \u001b[32m3.03(-36.88%) ▼\u001b[0m\t\n", 492 | "Epoch 174 | LossA: \u001b[32m1.97(-0.19%) ▼\u001b[0m\t| LossAB: \u001b[91m3.37(+10.96%) ▲\u001b[0m\t\n", 493 | "Epoch 175 | LossA: \u001b[32m1.96(-0.23%) ▼\u001b[0m\t| LossAB: \u001b[32m2.75(-18.40%) ▼\u001b[0m\t\n", 494 | "Epoch 176 | LossA: \u001b[32m1.96(-0.38%) ▼\u001b[0m\t| LossAB: \u001b[91m2.91(+5.81%) ▲\u001b[0m\t\n", 495 | "Epoch 177 | LossA: \u001b[32m1.95(-0.28%) ▼\u001b[0m\t| LossAB: \u001b[32m2.38(-18.23%) ▼\u001b[0m\t\n", 496 | "Epoch 178 | LossA: \u001b[32m1.95(-0.16%) ▼\u001b[0m\t| LossAB: \u001b[91m3.72(+56.50%) ▲\u001b[0m\t\n", 497 | "Epoch 179 | LossA: \u001b[32m1.94(-0.43%) ▼\u001b[0m\t| LossAB: \u001b[91m4.21(+13.11%) ▲\u001b[0m\t\n", 498 | "Epoch 180 | LossA: \u001b[32m1.93(-0.30%) ▼\u001b[0m\t| LossAB: \u001b[32m2.72(-35.25%) ▼\u001b[0m\t\n", 499 | "Epoch 181 | LossA: \u001b[32m1.93(-0.21%) ▼\u001b[0m\t| LossAB: \u001b[32m2.24(-17.87%) ▼\u001b[0m\t\n", 500 | "Epoch 182 | LossA: \u001b[32m1.92(-0.29%) ▼\u001b[0m\t| LossAB: \u001b[91m2.71(+21.13%) ▲\u001b[0m\t\n", 501 | "Epoch 183 | LossA: \u001b[32m1.92(-0.27%) ▼\u001b[0m\t| LossAB: \u001b[91m2.94(+8.38%) ▲\u001b[0m\t\n", 502 | "Epoch 184 | LossA: \u001b[32m1.92(-0.14%) ▼\u001b[0m\t| LossAB: \u001b[91m3.55(+21.03%) ▲\u001b[0m\t\n", 503 | "Epoch 185 | LossA: \u001b[32m1.91(-0.24%) ▼\u001b[0m\t| LossAB: \u001b[91m3.57(+0.45%) ▲\u001b[0m\t\n", 504 | "Epoch 186 | LossA: \u001b[32m1.90(-0.35%) ▼\u001b[0m\t| LossAB: \u001b[91m3.66(+2.47%) ▲\u001b[0m\t\n", 505 | "Epoch 187 | LossA: \u001b[32m1.90(-0.21%) ▼\u001b[0m\t| LossAB: \u001b[32m3.66(-0.03%) ▼\u001b[0m\t\n", 506 | "Epoch 188 | LossA: \u001b[32m1.90(-0.17%) ▼\u001b[0m\t| LossAB: \u001b[32m3.53(-3.50%) ▼\u001b[0m\t\n", 507 | "Epoch 189 | LossA: \u001b[32m1.89(-0.25%) ▼\u001b[0m\t| LossAB: \u001b[91m4.72(+33.86%) ▲\u001b[0m\t\n", 508 | "Epoch 190 | LossA: \u001b[32m1.89(-0.09%) ▼\u001b[0m\t| LossAB: \u001b[32m2.46(-47.89%) ▼\u001b[0m\t\n", 509 | "Epoch 191 | LossA: \u001b[32m1.88(-0.36%) ▼\u001b[0m\t| LossAB: \u001b[91m4.57(+85.60%) ▲\u001b[0m\t\n", 510 | "Epoch 192 | LossA: \u001b[32m1.88(-0.41%) ▼\u001b[0m\t| LossAB: \u001b[32m4.24(-7.13%) ▼\u001b[0m\t\n", 511 | "Epoch 193 | LossA: \u001b[32m1.88(-0.01%) ▼\u001b[0m\t| LossAB: \u001b[32m3.10(-26.86%) ▼\u001b[0m\t\n", 512 | "Epoch 194 | LossA: \u001b[32m1.87(-0.41%) ▼\u001b[0m\t| LossAB: \u001b[91m3.29(+6.07%) ▲\u001b[0m\t\n", 513 | "Epoch 195 | LossA: \u001b[32m1.87(-0.08%) ▼\u001b[0m\t| LossAB: \u001b[91m3.30(+0.26%) ▲\u001b[0m\t\n", 514 | "Epoch 196 | LossA: \u001b[32m1.86(-0.29%) ▼\u001b[0m\t| LossAB: \u001b[32m2.27(-31.29%) ▼\u001b[0m\t\n", 515 | "Epoch 197 | LossA: \u001b[32m1.85(-0.37%) ▼\u001b[0m\t| LossAB: \u001b[91m2.80(+23.47%) ▲\u001b[0m\t\n", 516 | "Epoch 198 | LossA: \u001b[32m1.85(-0.23%) ▼\u001b[0m\t| LossAB: \u001b[91m3.74(+33.43%) ▲\u001b[0m\t\n", 517 | "637/637 [==============================] - 7s 11ms/step\n", 518 | "Accuracy: 0.7503924369812012\n" 519 | ] 520 | } 521 | ], 522 | "source": [ 523 | "# visualizing losses and accuracy\n", 524 | "train_loss = history.history['loss']\n", 525 | "val_loss = history.history['val_loss']\n", 526 | "\n", 527 | "y_test_oh = dense_to_one_hot(y_test,num_clases=4)\n", 528 | "\n", 529 | "#Observing the losses but can be commented out as it's not mandatory \n", 530 | "reporter = lossprettifier.LossPrettifier(show_percentage=True)\n", 531 | "\n", 532 | "for i in range(numEpochs-1):\n", 533 | " reporter(epoch=i, LossA = train_loss[i], LossAB = val_loss[i])\n", 534 | "\n", 535 | "# Model evaluation \n", 536 | "score, acc = model.evaluate(x_test, y_test_oh, batch_size=batch_Size)\n", 537 | "print(\"Accuracy:\", acc)\n", 538 | "\n", 539 | "#if acc>0.675:\n", 540 | "model.save_weights(modelPath)" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 31, 546 | "metadata": {}, 547 | "outputs": [ 548 | { 549 | "name": "stdout", 550 | "output_type": "stream", 551 | "text": [ 552 | " precision recall f1-score support\n", 553 | "\n", 554 | " 0 0.95 0.59 0.73 234\n", 555 | " 1 0.76 0.93 0.84 246\n", 556 | " 2 0.59 0.70 0.64 149\n", 557 | " 3 0.45 0.62 0.53 8\n", 558 | "\n", 559 | " accuracy 0.75 637\n", 560 | " macro avg 0.69 0.71 0.68 637\n", 561 | "weighted avg 0.79 0.75 0.75 637\n", 562 | "\n" 563 | ] 564 | } 565 | ], 566 | "source": [ 567 | "y_pred = model.predict(x_test)\n", 568 | "#y_pred = y_pred.reshape(len(y_test), 4)\n", 569 | "y_pred = np.argmax(y_pred, axis=1)\n", 570 | "\n", 571 | "# Writing results on file\n", 572 | "f = open(resultPath,'a') #create classification report\n", 573 | "f.write(classification_report(y_test, y_pred))\n", 574 | "f.write(str(sklm.cohen_kappa_score(y_test, y_pred))+\",\"+str(acc)+\",\"+str(score)+\"\\n\")\n", 575 | "\n", 576 | "#Print class-wise classification metrics\n", 577 | "print(classification_report(y_test, y_pred))" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [] 586 | } 587 | ], 588 | "metadata": { 589 | "kernelspec": { 590 | "display_name": "Python 3", 591 | "language": "python", 592 | "name": "python3" 593 | }, 594 | "language_info": { 595 | "codemirror_mode": { 596 | "name": "ipython", 597 | "version": 3 598 | }, 599 | "file_extension": ".py", 600 | "mimetype": "text/x-python", 601 | "name": "python", 602 | "nbconvert_exporter": "python", 603 | "pygments_lexer": "ipython3", 604 | "version": "3.6.9" 605 | } 606 | }, 607 | "nbformat": 4, 608 | "nbformat_minor": 2 609 | } 610 | -------------------------------------------------------------------------------- /noteboks/VGG-19.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 44, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import cv2\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import os" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 45, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "x_train = np.load('data/x_train.npy')\n", 22 | "y_train = np.load('data/y_train.npy')\n", 23 | "x_test = np.load('data/x_test.npy')\n", 24 | "y_test = np.load('data/y_test.npy')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 46, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "(5344, 224, 224, 3)\n", 37 | "(654, 224, 224, 3)\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "print(x_train.shape)\n", 43 | "print(x_test.shape)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 47, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "(5344,)\n", 56 | "(654,)\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "print(y_train.shape)\n", 62 | "print(y_test.shape)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "#x_train /= 255????" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 48, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from __future__ import print_function\n", 81 | "\n", 82 | "import numpy as np\n", 83 | "from random import shuffle\n", 84 | "import time\n", 85 | "import csv\n", 86 | "from PIL import Image\n", 87 | "import os\n", 88 | "import tensorflow as tf\n", 89 | "import keras\n", 90 | "from keras.callbacks import EarlyStopping, LearningRateScheduler\n", 91 | "from keras import initializers\n", 92 | "from keras.optimizers import SGD\n", 93 | "from keras.preprocessing import sequence\n", 94 | "from keras.utils import np_utils\n", 95 | "from keras.models import Sequential,load_model,Model\n", 96 | "from keras.layers import Dense, Dropout, Activation, Flatten\n", 97 | "from keras.layers import *\n", 98 | "from keras.callbacks import CSVLogger\n", 99 | "from keras import callbacks\n", 100 | "from keras.preprocessing.image import ImageDataGenerator\n", 101 | "\n", 102 | "from sklearn.metrics import classification_report\n", 103 | "from sklearn.model_selection import train_test_split\n", 104 | "import sklearn.metrics as sklm\n", 105 | "import lossprettifier\n", 106 | "from VGG import *" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 49, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# for reproducibility\n", 116 | "np.random.seed(3768)\n", 117 | "\n", 118 | "# use this environment flag to change which GPU to use \n", 119 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", 120 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" # specify which GPU(s) to be used\n", 121 | "\n", 122 | "#Get TensorFlow session\n", 123 | "def get_session(): \n", 124 | " config = tf.ConfigProto() \n", 125 | " config.gpu_options.allow_growth = True \n", 126 | " return tf.Session(config=config) \n", 127 | " \n", 128 | "# One hot encoding of labels \n", 129 | "def dense_to_one_hot(labels_dense,num_clases=4):\n", 130 | " return np.eye(num_clases)[labels_dense]" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 50, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "# Preparing training and test sets\n", 140 | "x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.10, random_state=42)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 51, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "y_train = dense_to_one_hot(y_train,num_clases=4)\n", 150 | "y_valid= dense_to_one_hot(y_valid,num_clases=4)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 52, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stderr", 160 | "output_type": "stream", 161 | "text": [ 162 | "/home/reza/.local/lib/python3.6/site-packages/keras_preprocessing/image/image_data_generator.py:348: UserWarning: This ImageDataGenerator specifies `featurewise_std_normalization`, which overrides setting of `featurewise_center`.\n", 163 | " warnings.warn('This ImageDataGenerator specifies '\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "#Image data generation for the training \n", 169 | "datagen = ImageDataGenerator(\n", 170 | " featurewise_center = False, \n", 171 | " samplewise_center = False, # set each sample mean to 0\n", 172 | " featurewise_std_normalization = True, \n", 173 | " samplewise_std_normalization = False) \n", 174 | "\n", 175 | "datagen.fit(x_train) \n", 176 | "for i in range(len(x_test)):\n", 177 | " x_test[i] = datagen.standardize(x_test[i])" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 58, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Epoch 1/50\n", 190 | "32/32 [==============================] - 20s 627ms/step - loss: 1.1260 - accuracy: 0.4854 - val_loss: 0.8739 - val_accuracy: 0.5586\n", 191 | "Epoch 2/50\n", 192 | "32/32 [==============================] - 19s 602ms/step - loss: 0.9745 - accuracy: 0.5459 - val_loss: 0.7980 - val_accuracy: 0.6620\n", 193 | "Epoch 3/50\n", 194 | "32/32 [==============================] - 19s 591ms/step - loss: 0.8714 - accuracy: 0.6054 - val_loss: 0.6983 - val_accuracy: 0.6700\n", 195 | "Epoch 4/50\n", 196 | "32/32 [==============================] - 20s 609ms/step - loss: 0.7564 - accuracy: 0.6807 - val_loss: 0.4598 - val_accuracy: 0.7654\n", 197 | "Epoch 5/50\n", 198 | "32/32 [==============================] - 19s 604ms/step - loss: 0.7233 - accuracy: 0.7168 - val_loss: 0.9107 - val_accuracy: 0.7256\n", 199 | "Epoch 6/50\n", 200 | "32/32 [==============================] - 19s 592ms/step - loss: 0.6709 - accuracy: 0.7083 - val_loss: 0.6225 - val_accuracy: 0.7316\n", 201 | "Epoch 7/50\n", 202 | "32/32 [==============================] - 19s 593ms/step - loss: 0.6825 - accuracy: 0.7139 - val_loss: 0.4330 - val_accuracy: 0.7654\n", 203 | "Epoch 8/50\n", 204 | "32/32 [==============================] - 19s 605ms/step - loss: 0.7340 - accuracy: 0.7041 - val_loss: 0.6042 - val_accuracy: 0.7097\n", 205 | "Epoch 9/50\n", 206 | "32/32 [==============================] - 19s 604ms/step - loss: 0.7123 - accuracy: 0.6992 - val_loss: 0.3131 - val_accuracy: 0.7416\n", 207 | "Epoch 10/50\n", 208 | "32/32 [==============================] - 20s 621ms/step - loss: 0.6631 - accuracy: 0.7266 - val_loss: 0.7513 - val_accuracy: 0.7555\n", 209 | "Epoch 11/50\n", 210 | "32/32 [==============================] - 20s 612ms/step - loss: 0.6787 - accuracy: 0.7041 - val_loss: 0.5781 - val_accuracy: 0.7575\n", 211 | "Epoch 12/50\n", 212 | "32/32 [==============================] - 20s 612ms/step - loss: 0.6753 - accuracy: 0.7285 - val_loss: 0.6944 - val_accuracy: 0.7654\n", 213 | "Epoch 13/50\n", 214 | "32/32 [==============================] - 19s 608ms/step - loss: 0.6375 - accuracy: 0.7413 - val_loss: 0.4953 - val_accuracy: 0.7674\n", 215 | "Epoch 14/50\n", 216 | "32/32 [==============================] - 19s 601ms/step - loss: 0.6257 - accuracy: 0.7314 - val_loss: 0.4217 - val_accuracy: 0.7614\n", 217 | "Epoch 15/50\n", 218 | "32/32 [==============================] - 19s 592ms/step - loss: 0.5966 - accuracy: 0.7532 - val_loss: 0.5771 - val_accuracy: 0.7694\n", 219 | "Epoch 16/50\n", 220 | "32/32 [==============================] - 19s 606ms/step - loss: 0.5862 - accuracy: 0.7588 - val_loss: 0.5979 - val_accuracy: 0.7773\n", 221 | "Epoch 17/50\n", 222 | "32/32 [==============================] - 19s 608ms/step - loss: 0.6080 - accuracy: 0.7471 - val_loss: 0.5961 - val_accuracy: 0.7594\n", 223 | "Epoch 18/50\n", 224 | "32/32 [==============================] - 19s 600ms/step - loss: 0.6193 - accuracy: 0.7266 - val_loss: 0.6281 - val_accuracy: 0.7617\n", 225 | "Epoch 19/50\n", 226 | "32/32 [==============================] - 20s 615ms/step - loss: 0.6069 - accuracy: 0.7363 - val_loss: 0.4915 - val_accuracy: 0.7256\n", 227 | "Epoch 20/50\n", 228 | "32/32 [==============================] - 20s 615ms/step - loss: 0.5706 - accuracy: 0.7588 - val_loss: 0.4789 - val_accuracy: 0.7873\n", 229 | "Epoch 21/50\n", 230 | "32/32 [==============================] - 19s 587ms/step - loss: 0.5736 - accuracy: 0.7646 - val_loss: 0.5547 - val_accuracy: 0.7694\n", 231 | "Epoch 22/50\n", 232 | "32/32 [==============================] - 19s 602ms/step - loss: 0.6193 - accuracy: 0.7373 - val_loss: 0.6269 - val_accuracy: 0.7952\n", 233 | "Epoch 23/50\n", 234 | "32/32 [==============================] - 19s 601ms/step - loss: 0.5919 - accuracy: 0.7588 - val_loss: 0.7170 - val_accuracy: 0.7893\n", 235 | "Epoch 24/50\n", 236 | "32/32 [==============================] - 19s 592ms/step - loss: 0.5902 - accuracy: 0.7532 - val_loss: 0.4764 - val_accuracy: 0.7674\n", 237 | "Epoch 25/50\n", 238 | "32/32 [==============================] - 20s 620ms/step - loss: 0.5882 - accuracy: 0.7559 - val_loss: 0.6773 - val_accuracy: 0.8131\n", 239 | "Epoch 26/50\n", 240 | "32/32 [==============================] - 19s 597ms/step - loss: 0.5588 - accuracy: 0.7695 - val_loss: 0.6876 - val_accuracy: 0.7873\n", 241 | "Epoch 27/50\n", 242 | "32/32 [==============================] - 19s 606ms/step - loss: 0.5676 - accuracy: 0.7656 - val_loss: 0.3456 - val_accuracy: 0.7694\n", 243 | "Epoch 28/50\n", 244 | "32/32 [==============================] - 20s 617ms/step - loss: 0.6506 - accuracy: 0.7213 - val_loss: 0.6469 - val_accuracy: 0.7515\n", 245 | "Epoch 29/50\n", 246 | "32/32 [==============================] - 19s 603ms/step - loss: 0.6499 - accuracy: 0.7113 - val_loss: 0.3253 - val_accuracy: 0.7753\n", 247 | "Epoch 30/50\n", 248 | "32/32 [==============================] - 19s 597ms/step - loss: 0.5674 - accuracy: 0.7656 - val_loss: 0.6697 - val_accuracy: 0.7674\n", 249 | "Epoch 31/50\n", 250 | "32/32 [==============================] - 20s 612ms/step - loss: 0.6147 - accuracy: 0.7559 - val_loss: 0.6267 - val_accuracy: 0.7654\n", 251 | "Epoch 32/50\n", 252 | "32/32 [==============================] - 20s 610ms/step - loss: 0.5915 - accuracy: 0.7295 - val_loss: 0.6370 - val_accuracy: 0.7654\n", 253 | "Epoch 33/50\n", 254 | "32/32 [==============================] - 19s 588ms/step - loss: 0.5373 - accuracy: 0.7666 - val_loss: 0.5714 - val_accuracy: 0.7714\n", 255 | "Epoch 34/50\n", 256 | "32/32 [==============================] - 20s 610ms/step - loss: 0.5659 - accuracy: 0.7637 - val_loss: 0.6328 - val_accuracy: 0.7674\n", 257 | "Epoch 35/50\n", 258 | "32/32 [==============================] - 20s 613ms/step - loss: 0.5549 - accuracy: 0.7562 - val_loss: 1.1637 - val_accuracy: 0.7734\n", 259 | "Epoch 36/50\n", 260 | "32/32 [==============================] - 20s 620ms/step - loss: 0.5715 - accuracy: 0.7412 - val_loss: 0.3932 - val_accuracy: 0.7932\n", 261 | "Epoch 37/50\n", 262 | "32/32 [==============================] - 20s 614ms/step - loss: 0.5887 - accuracy: 0.7559 - val_loss: 0.5093 - val_accuracy: 0.7674\n", 263 | "Epoch 38/50\n", 264 | "32/32 [==============================] - 19s 602ms/step - loss: 0.5466 - accuracy: 0.7539 - val_loss: 0.5148 - val_accuracy: 0.7495\n", 265 | "Epoch 39/50\n", 266 | "32/32 [==============================] - 20s 614ms/step - loss: 0.4989 - accuracy: 0.7959 - val_loss: 0.4625 - val_accuracy: 0.7873\n", 267 | "Epoch 40/50\n", 268 | "32/32 [==============================] - 20s 618ms/step - loss: 0.5473 - accuracy: 0.7773 - val_loss: 0.7144 - val_accuracy: 0.7853\n", 269 | "Epoch 41/50\n", 270 | "32/32 [==============================] - 19s 584ms/step - loss: 0.5425 - accuracy: 0.7637 - val_loss: 0.7769 - val_accuracy: 0.7495\n", 271 | "Epoch 42/50\n", 272 | "32/32 [==============================] - 19s 609ms/step - loss: 0.5811 - accuracy: 0.7490 - val_loss: 0.6224 - val_accuracy: 0.7833\n", 273 | "Epoch 43/50\n", 274 | "32/32 [==============================] - 19s 587ms/step - loss: 0.5636 - accuracy: 0.7453 - val_loss: 0.7475 - val_accuracy: 0.7416\n", 275 | "Epoch 44/50\n", 276 | "32/32 [==============================] - 20s 619ms/step - loss: 0.5788 - accuracy: 0.7529 - val_loss: 0.5530 - val_accuracy: 0.7893\n", 277 | "Epoch 45/50\n", 278 | "32/32 [==============================] - 19s 607ms/step - loss: 0.5080 - accuracy: 0.7812 - val_loss: 0.6491 - val_accuracy: 0.7813\n", 279 | "Epoch 46/50\n", 280 | "32/32 [==============================] - 20s 613ms/step - loss: 0.5433 - accuracy: 0.7764 - val_loss: 0.5392 - val_accuracy: 0.7893\n", 281 | "Epoch 47/50\n", 282 | "32/32 [==============================] - 20s 612ms/step - loss: 0.4825 - accuracy: 0.7783 - val_loss: 0.6117 - val_accuracy: 0.7913\n", 283 | "Epoch 48/50\n", 284 | "32/32 [==============================] - 20s 618ms/step - loss: 0.4966 - accuracy: 0.7705 - val_loss: 0.4517 - val_accuracy: 0.7972\n", 285 | "Epoch 49/50\n", 286 | "32/32 [==============================] - 19s 602ms/step - loss: 0.5399 - accuracy: 0.7764 - val_loss: 0.5197 - val_accuracy: 0.8012\n", 287 | "Epoch 50/50\n", 288 | "32/32 [==============================] - 19s 593ms/step - loss: 0.5231 - accuracy: 0.7832 - val_loss: 0.5512 - val_accuracy: 0.7773\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "#Defining hyperparameters\n", 294 | "batch_Size = 32\n", 295 | "steps_Per_Epoch = 32\n", 296 | "numEpochs = 50\n", 297 | "\n", 298 | "#Instantating VGG19 model\n", 299 | "model = VGG19((224,224,3),4) #VGG19_dense for revised VGG19, VGG19 for VGG19. Please pay attention to VGG16(), chnage the input shape and class number in VGG.py.\n", 300 | "\n", 301 | "#Creating an optimizers\n", 302 | "adaDelta = keras.optimizers.Adadelta(lr=1.0, rho=0.95)\n", 303 | "sgd = SGD(lr=0.01, decay=1e-6, momentum=0.95, nesterov=True)\n", 304 | "model.compile(optimizer = sgd , loss = 'categorical_crossentropy', metrics = ['accuracy'])\n", 305 | "\n", 306 | "#Creating early stopping \n", 307 | "earlystop = EarlyStopping(monitor = 'val_accuracy', min_delta = 0, patience = 50, verbose = 1, mode = 'auto', restore_best_weights = True) \n", 308 | "\n", 309 | "train_generator = datagen.flow(x_train, y_train, batch_size = batch_Size)\n", 310 | "validation_generator = datagen.flow(x_valid, y_valid, batch_size = batch_Size)\n", 311 | "\n", 312 | "# Model training\n", 313 | "history = model.fit_generator(\n", 314 | " train_generator,\n", 315 | " steps_per_epoch = steps_Per_Epoch,\n", 316 | " validation_data = validation_generator, \n", 317 | " validation_steps = 16,\n", 318 | " epochs = numEpochs,\n", 319 | " shuffle = True, \n", 320 | " verbose = 1)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 59, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "modelPath = \"VGG19_COVID19.h5\"\n", 330 | "resultPath = 'VGG19_COVID19.txt'" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 60, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "Epoch 0 | LossA: 1.13(+0.00%) \u001b[0m\t| LossAB: 0.87(+0.00%) \u001b[0m\t\n", 343 | "Epoch 1 | LossA: \u001b[32m0.97(-13.46%) ▼\u001b[0m\t| LossAB: \u001b[32m0.80(-8.68%) ▼\u001b[0m\t\n", 344 | "Epoch 2 | LossA: \u001b[32m0.88(-9.63%) ▼\u001b[0m\t| LossAB: \u001b[32m0.70(-12.50%) ▼\u001b[0m\t\n", 345 | "Epoch 3 | LossA: \u001b[32m0.76(-14.12%) ▼\u001b[0m\t| LossAB: \u001b[32m0.46(-34.15%) ▼\u001b[0m\t\n", 346 | "Epoch 4 | LossA: \u001b[32m0.72(-4.37%) ▼\u001b[0m\t| LossAB: \u001b[91m0.91(+98.05%) ▲\u001b[0m\t\n", 347 | "Epoch 5 | LossA: \u001b[32m0.68(-5.74%) ▼\u001b[0m\t| LossAB: \u001b[32m0.62(-31.65%) ▼\u001b[0m\t\n", 348 | "Epoch 6 | LossA: \u001b[91m0.68(+0.10%) ▲\u001b[0m\t| LossAB: \u001b[32m0.43(-30.45%) ▼\u001b[0m\t\n", 349 | "Epoch 7 | LossA: \u001b[91m0.73(+7.55%) ▲\u001b[0m\t| LossAB: \u001b[91m0.60(+39.55%) ▲\u001b[0m\t\n", 350 | "Epoch 8 | LossA: \u001b[32m0.71(-2.96%) ▼\u001b[0m\t| LossAB: \u001b[32m0.31(-48.18%) ▼\u001b[0m\t\n", 351 | "Epoch 9 | LossA: \u001b[32m0.66(-6.91%) ▼\u001b[0m\t| LossAB: \u001b[91m0.75(+139.94%) ▲\u001b[0m\t\n", 352 | "Epoch 10 | LossA: \u001b[91m0.68(+2.35%) ▲\u001b[0m\t| LossAB: \u001b[32m0.58(-23.05%) ▼\u001b[0m\t\n", 353 | "Epoch 11 | LossA: \u001b[32m0.68(-0.49%) ▼\u001b[0m\t| LossAB: \u001b[91m0.69(+20.11%) ▲\u001b[0m\t\n", 354 | "Epoch 12 | LossA: \u001b[32m0.64(-5.25%) ▼\u001b[0m\t| LossAB: \u001b[32m0.50(-28.67%) ▼\u001b[0m\t\n", 355 | "Epoch 13 | LossA: \u001b[32m0.63(-2.21%) ▼\u001b[0m\t| LossAB: \u001b[32m0.42(-14.85%) ▼\u001b[0m\t\n", 356 | "Epoch 14 | LossA: \u001b[32m0.60(-4.87%) ▼\u001b[0m\t| LossAB: \u001b[91m0.58(+36.84%) ▲\u001b[0m\t\n", 357 | "Epoch 15 | LossA: \u001b[32m0.59(-1.51%) ▼\u001b[0m\t| LossAB: \u001b[91m0.60(+3.61%) ▲\u001b[0m\t\n", 358 | "Epoch 16 | LossA: \u001b[91m0.61(+3.71%) ▲\u001b[0m\t| LossAB: \u001b[32m0.60(-0.29%) ▼\u001b[0m\t\n", 359 | "Epoch 17 | LossA: \u001b[91m0.62(+1.86%) ▲\u001b[0m\t| LossAB: \u001b[91m0.63(+5.36%) ▲\u001b[0m\t\n", 360 | "Epoch 18 | LossA: \u001b[32m0.61(-2.00%) ▼\u001b[0m\t| LossAB: \u001b[32m0.49(-21.74%) ▼\u001b[0m\t\n", 361 | "Epoch 19 | LossA: \u001b[32m0.57(-5.98%) ▼\u001b[0m\t| LossAB: \u001b[32m0.48(-2.56%) ▼\u001b[0m\t\n", 362 | "Epoch 20 | LossA: \u001b[91m0.57(+0.54%) ▲\u001b[0m\t| LossAB: \u001b[91m0.55(+15.82%) ▲\u001b[0m\t\n", 363 | "Epoch 21 | LossA: \u001b[91m0.62(+7.96%) ▲\u001b[0m\t| LossAB: \u001b[91m0.63(+13.00%) ▲\u001b[0m\t\n", 364 | "Epoch 22 | LossA: \u001b[32m0.59(-4.43%) ▼\u001b[0m\t| LossAB: \u001b[91m0.72(+14.39%) ▲\u001b[0m\t\n", 365 | "Epoch 23 | LossA: \u001b[32m0.59(-0.60%) ▼\u001b[0m\t| LossAB: \u001b[32m0.48(-33.56%) ▼\u001b[0m\t\n", 366 | "Epoch 24 | LossA: \u001b[32m0.59(-0.02%) ▼\u001b[0m\t| LossAB: \u001b[91m0.68(+42.17%) ▲\u001b[0m\t\n", 367 | "Epoch 25 | LossA: \u001b[32m0.56(-5.00%) ▼\u001b[0m\t| LossAB: \u001b[91m0.69(+1.53%) ▲\u001b[0m\t\n", 368 | "Epoch 26 | LossA: \u001b[91m0.57(+1.58%) ▲\u001b[0m\t| LossAB: \u001b[32m0.35(-49.74%) ▼\u001b[0m\t\n", 369 | "Epoch 27 | LossA: \u001b[91m0.65(+14.08%) ▲\u001b[0m\t| LossAB: \u001b[91m0.65(+87.21%) ▲\u001b[0m\t\n", 370 | "Epoch 28 | LossA: \u001b[91m0.65(+0.16%) ▲\u001b[0m\t| LossAB: \u001b[32m0.33(-49.71%) ▼\u001b[0m\t\n", 371 | "Epoch 29 | LossA: \u001b[32m0.57(-12.52%) ▼\u001b[0m\t| LossAB: \u001b[91m0.67(+105.85%) ▲\u001b[0m\t\n", 372 | "Epoch 30 | LossA: \u001b[91m0.61(+8.35%) ▲\u001b[0m\t| LossAB: \u001b[32m0.63(-6.41%) ▼\u001b[0m\t\n", 373 | "Epoch 31 | LossA: \u001b[32m0.59(-3.78%) ▼\u001b[0m\t| LossAB: \u001b[91m0.64(+1.64%) ▲\u001b[0m\t\n", 374 | "Epoch 32 | LossA: \u001b[32m0.54(-9.16%) ▼\u001b[0m\t| LossAB: \u001b[32m0.57(-10.30%) ▼\u001b[0m\t\n", 375 | "Epoch 33 | LossA: \u001b[91m0.57(+5.32%) ▲\u001b[0m\t| LossAB: \u001b[91m0.63(+10.75%) ▲\u001b[0m\t\n", 376 | "Epoch 34 | LossA: \u001b[32m0.56(-1.39%) ▼\u001b[0m\t| LossAB: \u001b[91m1.16(+83.89%) ▲\u001b[0m\t\n", 377 | "Epoch 35 | LossA: \u001b[91m0.57(+2.40%) ▲\u001b[0m\t| LossAB: \u001b[32m0.39(-66.21%) ▼\u001b[0m\t\n", 378 | "Epoch 36 | LossA: \u001b[91m0.59(+3.01%) ▲\u001b[0m\t| LossAB: \u001b[91m0.51(+29.52%) ▲\u001b[0m\t\n", 379 | "Epoch 37 | LossA: \u001b[32m0.55(-7.14%) ▼\u001b[0m\t| LossAB: \u001b[91m0.51(+1.09%) ▲\u001b[0m\t\n", 380 | "Epoch 38 | LossA: \u001b[32m0.50(-8.72%) ▼\u001b[0m\t| LossAB: \u001b[32m0.46(-10.17%) ▼\u001b[0m\t\n", 381 | "Epoch 39 | LossA: \u001b[91m0.55(+9.69%) ▲\u001b[0m\t| LossAB: \u001b[91m0.71(+54.47%) ▲\u001b[0m\t\n", 382 | "Epoch 40 | LossA: \u001b[32m0.54(-0.88%) ▼\u001b[0m\t| LossAB: \u001b[91m0.78(+8.76%) ▲\u001b[0m\t\n", 383 | "Epoch 41 | LossA: \u001b[91m0.58(+7.11%) ▲\u001b[0m\t| LossAB: \u001b[32m0.62(-19.89%) ▼\u001b[0m\t\n", 384 | "Epoch 42 | LossA: \u001b[32m0.57(-2.44%) ▼\u001b[0m\t| LossAB: \u001b[91m0.75(+20.10%) ▲\u001b[0m\t\n", 385 | "Epoch 43 | LossA: \u001b[91m0.58(+2.09%) ▲\u001b[0m\t| LossAB: \u001b[32m0.55(-26.02%) ▼\u001b[0m\t\n", 386 | "Epoch 44 | LossA: \u001b[32m0.51(-11.66%) ▼\u001b[0m\t| LossAB: \u001b[91m0.65(+17.37%) ▲\u001b[0m\t\n", 387 | "Epoch 45 | LossA: \u001b[91m0.54(+6.27%) ▲\u001b[0m\t| LossAB: \u001b[32m0.54(-16.94%) ▼\u001b[0m\t\n", 388 | "Epoch 46 | LossA: \u001b[32m0.48(-11.19%) ▼\u001b[0m\t| LossAB: \u001b[91m0.61(+13.45%) ▲\u001b[0m\t\n", 389 | "Epoch 47 | LossA: \u001b[91m0.50(+2.92%) ▲\u001b[0m\t| LossAB: \u001b[32m0.45(-26.15%) ▼\u001b[0m\t\n", 390 | "Epoch 48 | LossA: \u001b[91m0.54(+8.72%) ▲\u001b[0m\t| LossAB: \u001b[91m0.52(+15.04%) ▲\u001b[0m\t\n", 391 | "654/654 [==============================] - 5s 8ms/step\n", 392 | "Accuracy: 0.6727828979492188\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "#y_test_oh = dense_to_one_hot(y_test, num_clases=4)\n", 398 | "\n", 399 | "# visualizing losses and accuracy\n", 400 | "train_loss = history.history['loss']\n", 401 | "val_loss = history.history['val_loss']\n", 402 | "\n", 403 | "#Observing the losses but can be commented out as it's not mandatory \n", 404 | "reporter = lossprettifier.LossPrettifier(show_percentage=True)\n", 405 | "\n", 406 | "for i in range(numEpochs-1):\n", 407 | " reporter(epoch=i, LossA = train_loss[i], LossAB = val_loss[i])\n", 408 | "\n", 409 | "# Model evaluation \n", 410 | "score, acc = model.evaluate(x_test, y_test_oh, batch_size=batch_Size)\n", 411 | "print(\"Accuracy:\", acc)\n", 412 | "\n", 413 | "#if acc>0.675:\n", 414 | "model.save_weights(modelPath)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 61, 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | " precision recall f1-score support\n", 427 | "\n", 428 | " 0 0.91 0.35 0.51 234\n", 429 | " 1 0.62 0.96 0.75 246\n", 430 | " 2 0.74 0.73 0.73 149\n", 431 | " 3 0.36 0.56 0.44 25\n", 432 | "\n", 433 | " accuracy 0.67 654\n", 434 | " macro avg 0.66 0.65 0.61 654\n", 435 | "weighted avg 0.74 0.67 0.65 654\n", 436 | "\n" 437 | ] 438 | } 439 | ], 440 | "source": [ 441 | "y_pred = model.predict(x_test)\n", 442 | "y_pred = y_pred.reshape(len(y_test), 4)\n", 443 | "y_pred = np.argmax(y_pred, axis=1)\n", 444 | "\n", 445 | "# Writing results on file\n", 446 | "f = open(resultPath,'a') #create classification report\n", 447 | "f.write(classification_report(y_test, y_pred))\n", 448 | "f.write(str(sklm.cohen_kappa_score(y_test, y_pred))+\",\"+str(acc)+\",\"+str(score)+\"\\n\")\n", 449 | "\n", 450 | "#Print class-wise classification metrics\n", 451 | "print(classification_report(y_test, y_pred))" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [] 460 | } 461 | ], 462 | "metadata": { 463 | "kernelspec": { 464 | "display_name": "Python 3", 465 | "language": "python", 466 | "name": "python3" 467 | }, 468 | "language_info": { 469 | "codemirror_mode": { 470 | "name": "ipython", 471 | "version": 3 472 | }, 473 | "file_extension": ".py", 474 | "mimetype": "text/x-python", 475 | "name": "python", 476 | "nbconvert_exporter": "python", 477 | "pygments_lexer": "ipython3", 478 | "version": "3.6.9" 479 | } 480 | }, 481 | "nbformat": 4, 482 | "nbformat_minor": 2 483 | } 484 | -------------------------------------------------------------------------------- /utils/CXR_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import * 3 | import os 4 | import optparse 5 | from PIL import Image 6 | import png 7 | import pylab 8 | from scipy import misc, ndimage 9 | from medpy.filter.smoothing import anisotropic_diffusion 10 | from sklearn import preprocessing 11 | 12 | def histogram_t(tb): 13 | ''' 14 | Histogram equalization 15 | ''' 16 | totalpixel=0 17 | maptb=[] 18 | count=len(tb) 19 | for i in range(count): 20 | totalpixel+=tb[i] 21 | maptb.append(totalpixel) 22 | 23 | for i in range(count): 24 | maptb[i]=int(round((maptb[i]*(count-1))/totalpixel)) 25 | 26 | def histogram(light): 27 | return maptb[light] 28 | return histogram 29 | 30 | def remove_annotation(img): 31 | ''' 32 | Remove textual artifacts from X-ray imagese.g., a large number of images indicate the right side of the chest with a white `R' character, `L' for left. 33 | :param img: Numpy array of image 34 | :return: Array of image with possible characters removed and inpainted. 35 | ''' 36 | mask = cv2.threshold(img, 224, 224, cv2.THRESH_BINARY)[1][:, :, 0].astype(np.uint8) 37 | img = img.astype(np.uint8) 38 | result = cv2.inpaint(img, mask, 10, cv2.INPAINT_NS).astype(np.float32) 39 | return result 40 | 41 | imagepath='cleanedCXR/' #image folder path, please pay attention to here images are already renamed with format "patient_direction". 42 | files = os.listdir(imagepath) 43 | 44 | for fi in files: 45 | fi_d = os.path.join(imagepath,fi) 46 | img=Image.open(fi_d) 47 | 48 | if 'P' in fi: #for chest x-rays 49 | (tempx1,tempy1)=img.size #original image separation 50 | width=tempx1//2 51 | left=img.crop((0,0,width,tempy1)) 52 | right=img.crop((width,0,tempx1,tempy1)) 53 | imgl = Image.new('RGBA',(width,tempy1)) 54 | imgr = Image.new('RGBA',(width,tempy1)) 55 | imgl.paste(left) 56 | imgr.paste(right) 57 | imgr=imgr.transpose(Image.FLIP_LEFT_RIGHT) #image flipping 58 | outl=imgl.resize((1023,2047),Image.ANTIALIAS) #image resize 59 | outr=imgr.resize((1023,2047),Image.ANTIALIAS) 60 | hisl=outl.histogram() #histogram 61 | hisfuncl=histogram_t(hisl) 62 | iml=outl.point(hisfuncl) 63 | 64 | hisr=outr.histogram() 65 | hisfuncr=histogram_t(hisr) 66 | imr=outr.point(hisfuncr) 67 | 68 | ir = anisotropic_diffusion(np.array(imr)) #noise removal 69 | ir = remove_annotation(ir) 70 | 71 | il = anisotropic_diffusion(np.array(iml)) 72 | il = remove_annotation(il) 73 | 74 | temp='cleanedCXR/temp/'+fi[0:4] #processed image path for saving 75 | imagel = Image.fromarray(il.astype('uint8')).convert("L") 76 | imager = Image.fromarray(ir.astype('uint8')).convert("L") 77 | 78 | imager.save(temp+'_PR.png') 79 | -------------------------------------------------------------------------------- /utils/class_balancing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import yaml 3 | import os 4 | import datetime 5 | import random 6 | import dill 7 | import numpy as np 8 | from imblearn.over_sampling import RandomOverSampler 9 | from math import ceil 10 | from Preprocessing.CXR_preprocessing import remove_annotation 11 | 12 | def get_class_weights(histogram, class_multiplier=None): 13 | ''' 14 | Computes weights for each class to be applied in the loss function during training. 15 | :param histogram: A list depicting the number of each item in different class 16 | :param class_multiplier: List of values to multiply the calculated class weights by. For further control of class weighting. 17 | :return: A dictionary containing weights for each class 18 | ''' 19 | weights = [None] * len(histogram) 20 | for i in range(len(histogram)): 21 | weights[i] = (1.0 / len(histogram)) * sum(histogram) / histogram[i] 22 | class_weight = {i: weights[i] for i in range(len(histogram))} 23 | if class_multiplier is not None: 24 | class_weight = [class_weight[i] * class_multiplier[i] for i in range(len(histogram))] 25 | print("Class weights: ", class_weight) 26 | return class_weight 27 | 28 | 29 | def random_minority_oversample(train_set): 30 | ''' 31 | Oversample the minority class using the specified algorithm 32 | :param train_set: Training set image file names and labels 33 | :return: A new training set containing oversampled examples 34 | ''' 35 | X_train = train_set[[x for x in train_set.columns if x != 'label']].to_numpy() 36 | if X_train.shape[1] == 1: 37 | X_train = np.expand_dims(X_train, axis=-1) 38 | Y_train = train_set['label'].to_numpy() 39 | sampler = RandomOverSampler(random_state=np.random.randint(0, high=1000)) 40 | X_resampled, Y_resampled = sampler.fit_resample(X_train, Y_train) 41 | filenames = X_resampled[:, 1] # Filename is in second column 42 | label_strs = X_resampled[:, 2] # Class name is in second column 43 | print("Train set shape before oversampling: ", X_train.shape, " Train set shape after resampling: ", X_resampled.shape) 44 | train_set_resampled = pd.DataFrame({'filename': filenames, 'label': Y_resampled, 'label_str': label_strs}) 45 | return train_set_resampled 46 | -------------------------------------------------------------------------------- /utils/gradcamutils.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.interpolation import zoom 2 | import numpy as np 3 | 4 | from keras.backend import tensorflow_backend 5 | from keras import backend as K 6 | from keras.preprocessing.image import load_img, img_to_array 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | def grad_cam(input_model, image, layer_name,H=360,W=360): 11 | cls = np.argmax(input_model.predict(image)) 12 | print(cls) 13 | def normalize(x): 14 | """Utility function to normalize a tensor by its L2 norm""" 15 | return (x + 1e-10) / (K.sqrt(K.mean(K.square(x))) + 1e-10) 16 | """GradCAM method for visualizing input saliency.""" 17 | y_c = input_model.output[0, cls] 18 | conv_output = input_model.get_layer(layer_name).output 19 | grads = K.gradients(y_c, conv_output)[0] 20 | grads = normalize(grads) 21 | gradient_function = K.function([input_model.input], [conv_output, grads]) 22 | 23 | output, grads_val = gradient_function([image]) 24 | output, grads_val = output[0, :], grads_val[0, :, :, :] 25 | 26 | weights = np.mean(grads_val, axis=(0, 1)) 27 | cam = np.dot(output, weights) 28 | 29 | cam = np.maximum(cam, 0) 30 | #cam = resize(cam, (H, W)) 31 | cam = zoom(cam,H/cam.shape[0]) 32 | #cam = np.maximum(cam, 0) 33 | cam = cam / cam.max() 34 | return cam 35 | 36 | def grad_cam_plus(input_model, img, layer_name,H=360,W=360): 37 | cls = np.argmax(input_model.predict(img)) 38 | print(cls) 39 | def normalize(x): 40 | """Utility function to normalize a tensor by its L2 norm""" 41 | return (x + 1e-10) / (K.sqrt(K.mean(K.square(x))) + 1e-10) 42 | """GradCAM method for visualizing input saliency.""" 43 | y_c = input_model.output[0, cls] 44 | conv_output = input_model.get_layer(layer_name).output 45 | grads = K.gradients(y_c, conv_output)[0] 46 | grads = normalize(grads) 47 | 48 | first = K.exp(y_c)*grads 49 | second = K.exp(y_c)*grads*grads 50 | third = K.exp(y_c)*grads*grads*grads 51 | 52 | gradient_function = K.function([input_model.input], [y_c,first,second,third, conv_output, grads]) 53 | y_c, conv_first_grad, conv_second_grad,conv_third_grad, conv_output, grads_val = gradient_function([img]) 54 | global_sum = np.sum(conv_output[0].reshape((-1,conv_first_grad[0].shape[2])), axis=0) 55 | 56 | alpha_num = conv_second_grad[0] 57 | alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum.reshape((1,1,conv_first_grad[0].shape[2])) 58 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape)) 59 | alphas = alpha_num/alpha_denom 60 | 61 | weights = np.maximum(conv_first_grad[0], 0.0) 62 | 63 | alpha_normalization_constant = np.sum(np.sum(alphas, axis=0),axis=0) 64 | 65 | alphas /= alpha_normalization_constant.reshape((1,1,conv_first_grad[0].shape[2])) 66 | 67 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,conv_first_grad[0].shape[2])),axis=0) 68 | #print deep_linearization_weights 69 | grad_CAM_map = np.sum(deep_linearization_weights*conv_output[0], axis=2) 70 | 71 | # Passing through ReLU 72 | cam = np.maximum(grad_CAM_map, 0) 73 | cam = zoom(cam,H/cam.shape[0]) 74 | cam = cam / np.max(cam) # scale 0 to 1.0 75 | #cam = resize(cam, (224,224)) 76 | 77 | return cam 78 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def heatmap(heatmap, cmap="seismic", interpolation="none", colorbar=False, M=None): 5 | if M is None: 6 | M = np.abs(heatmap).max() 7 | if M == 0: 8 | M = 1 9 | plt.imshow(heatmap, cmap=cmap, vmax=M, vmin=-M, interpolation=interpolation) 10 | plt.xticks([]) 11 | plt.yticks([]) 12 | plt.tight_layout() 13 | if colorbar: 14 | plt.colorbar() 15 | -------------------------------------------------------------------------------- /utils/lossprettifier.py: -------------------------------------------------------------------------------- 1 | class LossPrettifier(object): 2 | 3 | STYLE = { 4 | 'green' : '\033[32m', 5 | 'red' : '\033[91m', 6 | 'bold' : '\033[1m', 7 | } 8 | STYLE_END = '\033[0m' 9 | 10 | def __init__(self, show_percentage=False): 11 | 12 | self.show_percentage = show_percentage 13 | self.color_up = 'red' 14 | self.color_down = 'green' 15 | self.loss_terms = {} 16 | 17 | def __call__(self, epoch=None, **kwargs): 18 | 19 | if epoch is not None: 20 | print_string = f'Epoch {epoch: 5d} ' 21 | else: 22 | print_string = '' 23 | 24 | for key, value in kwargs.items(): 25 | 26 | pre_value = self.loss_terms.get(key, value) 27 | 28 | if value > pre_value: 29 | indicator = '▲' 30 | show_color = self.STYLE[self.color_up] 31 | elif value == pre_value: 32 | indicator = '' 33 | show_color = '' 34 | else: 35 | indicator = '▼' 36 | show_color = self.STYLE[self.color_down] 37 | 38 | if self.show_percentage: 39 | show_value = 0 if pre_value == 0 \ 40 | else (value - pre_value) / float(pre_value) 41 | key_string = f'| {key}: {show_color}{value:3.2f}({show_value:+3.2%}) {indicator}' 42 | else: 43 | key_string = f'| {key}: {show_color}{value:.4f} {indicator}' 44 | 45 | # Trim some long outputs 46 | key_string_part = key_string[:32] 47 | print_string += key_string_part+f'{self.STYLE_END}\t' 48 | 49 | self.loss_terms[key] = value 50 | 51 | print(print_string) 52 | --------------------------------------------------------------------------------