├── README.md ├── config.py ├── datasets.py ├── losses.py ├── main.py ├── sample_xrays ├── Atelectasis.png ├── Cardiomegaly_Edema_Effusion.png ├── Effusion.png ├── Fibrosis.png └── No Finding.png └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # NIH-Chest-X-rays-Multi-Label-Image-Classification-In-Pytorch 2 | Multi-Label Image Classification of the Chest X-Rays In Pytorch 3 | 4 | # Requirements 5 | * torch >= 0.4 6 | * torchvision >= 0.2.2 7 | * opencv-python 8 | * numpy >= 1.7.3 9 | * matplotlib 10 | * tqdm 11 | 12 | # Dataset 13 | [NIH Chest X-ray Dataset](https://www.kaggle.com/nih-chest-xrays/data#Data_Entry_2017.csv) is used for Multi-Label Disease Classification of of the Chest X-Rays. 14 | There are a total of 15 classes (14 diseases, and one for 'No findings') 15 | Images can be classified as "No findings" or one or more disease classes: 16 | * Atelectasis 17 | * Consolidation 18 | * Infiltration 19 | * Pneumothorax 20 | * Edema 21 | * Emphysema 22 | * Fibrosis 23 | * Effusion 24 | * Pneumonia 25 | * Pleural_thickening 26 | * Cardiomegaly 27 | * Nodule Mass 28 | * Hernia 29 | 30 | There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing. 31 | 32 | # Sample X-Ray Images 33 |
34 |
35 | Atelectasis 36 |
37 |
38 | Cardiomegaly | Edema | Effusion 39 |
40 |
41 | No Finding 42 |
43 |
44 | 45 | # Model 46 | Pretrained Resnet50 model is used for Transfer Learning on this new image dataset. 47 | 48 | # Loss Function 49 | There is a choice of loss function 50 | * Focal Loss (default) 51 | * Binary Cross Entropy Loss or BCE Loss 52 | 53 | # Training 54 | * ### From Scratch 55 | Following are the layers which are set to trainable- 56 | * layer2 57 | * layer3 58 | * layer4 59 | * fc 60 | 61 | Terminal Code: 62 | ``` 63 | python main.py 64 | ``` 65 | 66 | * ### Resuming From a Saved Checkpoint 67 | A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the 68 | * epochs (number of epochs the model has been trained till that time) 69 | * model (architecture and the learnt weights of the model) 70 | * lr_scheduler_state_dict (state_dict of the lr_scheduler) 71 | * losses_dict (a dictionary containing the following loses) 72 | 73 | * mean train epoch losses for all the epochs 74 | * mean val epoch losses for all the epochs 75 | * batch train loss for all the training batches 76 | * batch train loss for all the val batches 77 | 78 | Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The 'stage' parameter can be passed from the terminal using the argument --stage STAGE 79 | 80 | Terminal Code: 81 | ``` 82 | python main.py --resume --ckpt checkpoint_file.pth --stage 2 83 | ``` 84 | 85 | Training the model will create a **models** directory and will save the checkpoints in there. 86 | 87 | # Testing 88 | A Saved Checkpoint needs to be loaded using the **--ckpt** argument and **--test** argument needs to be passed for activating the Test Mode 89 | 90 | Terminal Code: 91 | ``` 92 | python main.py --test --ckpt checkpoint_file.pth 93 | ``` 94 | 95 | # Result 96 | The model achieved the average **ROC AUC Score** of **0.73241** on all classes(excluding "No findings" class) after training in the following stages- 97 | 98 | #### STAGE 1 99 | * Loss Function: FocalLoss 100 | * lr: 1e-5 101 | * Training Layers: layer2, layer3, layer4, fc 102 | * Epochs: 2 103 | 104 | #### STAGE 2 105 | * Loss Function: FocalLoss 106 | * lr: 3e-4 107 | * Training Layers: layer3, layer4, fc 108 | * Epochs: 1 109 | 110 | #### STAGE 3 111 | * Loss Function: FocalLoss 112 | * lr: 1e-3 113 | * Training Layers: layer4, fc 114 | * Epochs: 3 115 | 116 | #### STAGE 4 117 | * Loss Function: FocalLoss 118 | * lr: 1e-3 119 | * Training Layers: fc 120 | * Epochs: 2 121 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | pkl_dir_path = 'pickles' 2 | train_val_df_pkl_path = 'train_val_df.pickle' 3 | test_df_pkl_path = 'test_df.pickle' 4 | disease_classes_pkl_path = 'disease_classes.pickle' 5 | models_dir = 'models' 6 | 7 | from torchvision import transforms 8 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 9 | std=[0.229, 0.224, 0.225]) 10 | 11 | # transforms.RandomHorizontalFlip() not used because some disease might be more likely to the present in a specific lung (lelf/rigth) 12 | transform = transforms.Compose([transforms.ToPILImage(), 13 | transforms.Resize(224), 14 | transforms.ToTensor(), 15 | normalize]) 16 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob, os, sys, pdb, time 2 | import pandas as pd 3 | import numpy as np 4 | import cv2 5 | import pickle 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | import torch 9 | 10 | import config 11 | 12 | def q(text = ''): # easy way to exiting the script. useful while debugging 13 | print('> ', text) 14 | sys.exit() 15 | 16 | class XRaysTrainDataset(Dataset): 17 | def __init__(self, data_dir, transform = None): 18 | self.data_dir = data_dir 19 | 20 | self.transform = transform 21 | # print('self.data_dir: ', self.data_dir) 22 | 23 | # full dataframe including train_val and test set 24 | self.df = self.get_df() 25 | print('self.df.shape: {}'.format(self.df.shape)) 26 | 27 | self.make_pkl_dir(config.pkl_dir_path) 28 | 29 | # get train_val_df 30 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path)): 31 | 32 | self.train_val_df = self.get_train_val_df() 33 | print('\nself.train_val_df.shape: {}'.format(self.train_val_df.shape)) 34 | 35 | # pickle dump the train_val_df 36 | with open(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path), 'wb') as handle: 37 | pickle.dump(self.train_val_df, handle, protocol = pickle.HIGHEST_PROTOCOL) 38 | print('{}: dumped'.format(config.train_val_df_pkl_path)) 39 | 40 | else: 41 | # pickle load the train_val_df 42 | with open(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path), 'rb') as handle: 43 | self.train_val_df = pickle.load(handle) 44 | print('\n{}: loaded'.format(config.train_val_df_pkl_path)) 45 | print('self.train_val_df.shape: {}'.format(self.train_val_df.shape)) 46 | 47 | self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices() 48 | 49 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path)): 50 | # pickle dump the classes list 51 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'wb') as handle: 52 | pickle.dump(self.all_classes, handle, protocol = pickle.HIGHEST_PROTOCOL) 53 | print('\n{}: dumped'.format(config.disease_classes_pkl_path)) 54 | else: 55 | print('\n{}: already exists'.format(config.disease_classes_pkl_path)) 56 | 57 | self.new_df = self.train_val_df.iloc[self.the_chosen, :] # this is the sampled train_val data 58 | print('\nself.all_classes_dict: {}'.format(self.all_classes_dict)) 59 | 60 | def resample(self): 61 | self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices() 62 | self.new_df = self.train_val_df.iloc[self.the_chosen, :] 63 | print('\nself.all_classes_dict: {}'.format(self.all_classes_dict)) 64 | 65 | def make_pkl_dir(self, pkl_dir_path): 66 | if not os.path.exists(pkl_dir_path): 67 | os.mkdir(pkl_dir_path) 68 | 69 | def get_train_val_df(self): 70 | 71 | # get the list of train_val data 72 | train_val_list = self.get_train_val_list() 73 | 74 | train_val_df = pd.DataFrame() 75 | print('\nbuilding train_val_df...') 76 | for i in tqdm(range(self.df.shape[0])): 77 | filename = os.path.basename(self.df.iloc[i,0]) 78 | # print('filename: ', filename) 79 | if filename in train_val_list: 80 | train_val_df = train_val_df.append(self.df.iloc[i:i+1, :]) 81 | 82 | # print('train_val_df.shape: {}'.format(train_val_df.shape)) 83 | 84 | return train_val_df 85 | 86 | def __getitem__(self, index): 87 | row = self.new_df.iloc[index, :] 88 | 89 | img = cv2.imread(row['image_links']) 90 | labels = str.split(row['Finding Labels'], '|') 91 | 92 | target = torch.zeros(len(self.all_classes)) 93 | for lab in labels: 94 | lab_idx = self.all_classes.index(lab) 95 | target[lab_idx] = 1 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | return img, target 101 | 102 | def choose_the_indices(self): 103 | 104 | max_examples_per_class = 10000 # its the maximum number of examples that would be sampled in the training set for any class 105 | the_chosen = [] 106 | all_classes = {} 107 | length = len(self.train_val_df) 108 | # for i in tqdm(range(len(merged_df))): 109 | print('\nSampling the huuuge training dataset') 110 | for i in tqdm(list(np.random.choice(range(length),length, replace = False))): 111 | 112 | temp = str.split(self.train_val_df.iloc[i, :]['Finding Labels'], '|') 113 | 114 | # special case of ultra minority hernia. we will use all the images with 'Hernia' tagged in them. 115 | if 'Hernia' in temp: 116 | the_chosen.append(i) 117 | for t in temp: 118 | if t not in all_classes: 119 | all_classes[t] = 1 120 | else: 121 | all_classes[t] += 1 122 | continue 123 | 124 | # choose if multiple labels 125 | if len(temp) > 1: 126 | bool_lis = [False]*len(temp) 127 | # check if any label crosses the upper limit 128 | for idx, t in enumerate(temp): 129 | if t in all_classes: 130 | if all_classes[t]< max_examples_per_class: # 500 131 | bool_lis[idx] = True 132 | else: 133 | bool_lis[idx] = True 134 | # if all lables under upper limit, append 135 | if sum(bool_lis) == len(temp): 136 | the_chosen.append(i) 137 | # maintain count 138 | for t in temp: 139 | if t not in all_classes: 140 | all_classes[t] = 1 141 | else: 142 | all_classes[t] += 1 143 | else: # these are single label images 144 | for t in temp: 145 | if t not in all_classes: 146 | all_classes[t] = 1 147 | else: 148 | if all_classes[t] < max_examples_per_class: # 500 149 | all_classes[t] += 1 150 | the_chosen.append(i) 151 | 152 | # print('len(all_classes): ', len(all_classes)) 153 | # print('all_classes: ', all_classes) 154 | # print('len(the_chosen): ', len(the_chosen)) 155 | 156 | ''' 157 | if len(the_chosen) != len(set(the_chosen)): 158 | print('\nGadbad !!!') 159 | print('and the difference is: ', len(the_chosen) - len(set(the_chosen))) 160 | else: 161 | print('\nGood') 162 | ''' 163 | 164 | return the_chosen, sorted(list(all_classes)), all_classes 165 | 166 | def get_df(self): 167 | csv_path = os.path.join(self.data_dir, 'Data_Entry_2017.csv') 168 | print('\n{} found: {}'.format(csv_path, os.path.exists(csv_path))) 169 | 170 | all_xray_df = pd.read_csv(csv_path) 171 | 172 | df = pd.DataFrame() 173 | df['image_links'] = [x for x in glob.glob(os.path.join(self.data_dir, 'images*', '*', '*.png'))] 174 | 175 | df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)]) 176 | merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index']) 177 | merged_df = merged_df[['image_links','Finding Labels']] 178 | return merged_df 179 | 180 | def get_train_val_list(self): 181 | f = open(os.path.join('data', 'NIH Chest X-rays', 'train_val_list.txt'), 'r') 182 | train_val_list = str.split(f.read(), '\n') 183 | return train_val_list 184 | 185 | def __len__(self): 186 | return len(self.new_df) 187 | 188 | 189 | # prepare the test dataset 190 | class XRaysTestDataset(Dataset): 191 | def __init__(self, data_dir, transform = None): 192 | self.data_dir = data_dir 193 | self.transform = transform 194 | # print('self.data_dir: ', self.data_dir) 195 | 196 | # full dataframe including train_val and test set 197 | self.df = self.get_df() 198 | print('\nself.df.shape: {}'.format(self.df.shape)) 199 | 200 | self.make_pkl_dir(config.pkl_dir_path) 201 | 202 | # loading the classes list 203 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'rb') as handle: 204 | self.all_classes = pickle.load(handle) 205 | 206 | # get test_df 207 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.test_df_pkl_path)): 208 | 209 | self.test_df = self.get_test_df() 210 | print('self.test_df.shape: ', self.test_df.shape) 211 | 212 | # pickle dump the test_df 213 | with open(os.path.join(config.pkl_dir_path, config.test_df_pkl_path), 'wb') as handle: 214 | pickle.dump(self.test_df, handle, protocol = pickle.HIGHEST_PROTOCOL) 215 | print('\n{}: dumped'.format(config.test_df_pkl_path)) 216 | else: 217 | # pickle load the test_df 218 | with open(os.path.join(config.pkl_dir_path, config.test_df_pkl_path), 'rb') as handle: 219 | self.test_df = pickle.load(handle) 220 | print('\n{}: loaded'.format(config.test_df_pkl_path)) 221 | print('self.test_df.shape: {}'.format(self.test_df.shape)) 222 | 223 | def __getitem__(self, index): 224 | row = self.test_df.iloc[index, :] 225 | 226 | img = cv2.imread(row['image_links']) 227 | labels = str.split(row['Finding Labels'], '|') 228 | 229 | target = torch.zeros(len(self.all_classes)) 230 | for lab in labels: 231 | lab_idx = self.all_classes.index(lab) 232 | target[lab_idx] = 1 233 | 234 | if self.transform is not None: 235 | img = self.transform(img) 236 | 237 | return img, target 238 | 239 | def make_pkl_dir(self, pkl_dir_path): 240 | if not os.path.exists(pkl_dir_path): 241 | os.mkdir(pkl_dir_path) 242 | 243 | def get_df(self): 244 | csv_path = os.path.join(self.data_dir, 'Data_Entry_2017.csv') 245 | 246 | all_xray_df = pd.read_csv(csv_path) 247 | 248 | df = pd.DataFrame() 249 | df['image_links'] = [x for x in glob.glob(os.path.join(self.data_dir, 'images*', '*', '*.png'))] 250 | 251 | df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)]) 252 | merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index']) 253 | merged_df = merged_df[['image_links','Finding Labels']] 254 | return merged_df 255 | 256 | def get_test_df(self): 257 | 258 | # get the list of test data 259 | test_list = self.get_test_list() 260 | 261 | test_df = pd.DataFrame() 262 | print('\nbuilding test_df...') 263 | for i in tqdm(range(self.df.shape[0])): 264 | filename = os.path.basename(self.df.iloc[i,0]) 265 | # print('filename: ', filename) 266 | if filename in test_list: 267 | test_df = test_df.append(self.df.iloc[i:i+1, :]) 268 | 269 | print('test_df.shape: ', test_df.shape) 270 | 271 | return test_df 272 | 273 | def get_test_list(self): 274 | f = open( os.path.join('data', 'NIH Chest X-rays', 'test_list.txt'), 'r') 275 | test_list = str.split(f.read(), '\n') 276 | return test_list 277 | 278 | def __len__(self): 279 | return len(self.test_df) 280 | 281 | 282 | 283 | 284 | 285 | 286 | ''' 287 | # prepare the test dataset 288 | import random 289 | class XRaysTestDataset2(Dataset): 290 | def __init__(self, test_data_dir, transform = None): 291 | self.test_data_dir = test_data_dir 292 | self.transform = transform 293 | self.data_list = self.get_data_list(self.test_data_dir) 294 | 295 | self.subset = self.data_list[:1000] 296 | 297 | def __getitem__(self, index): 298 | img_path = self.data_list[index] 299 | img = cv2.imread(img_path) 300 | 301 | if self.transform is not None: 302 | img = self.transform(img) 303 | 304 | return img_path 305 | 306 | def sample(self): 307 | 308 | random.shuffle(self.data_list) 309 | 310 | self.subset = self.data_list[:np.random.randint(500,700)] 311 | 312 | def __len__(self): 313 | return len(self.subset) 314 | 315 | def get_data_list(self, data_dir): 316 | data_list = [] 317 | for path in glob.glob(data_dir + os.sep + '*'): 318 | data_list.append(path) 319 | return data_list 320 | ''' 321 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch, sys, os, pdb 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): 6 | 7 | def __init__(self, device, gamma = 1.0): 8 | super(FocalLoss, self).__init__() 9 | self.device = device 10 | self.gamma = torch.tensor(gamma, dtype = torch.float32).to(device) 11 | self.eps = 1e-6 12 | 13 | # self.BCE_loss = nn.BCEWithLogitsLoss(reduction='none').to(device) 14 | 15 | def forward(self, input, target): 16 | 17 | BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none').to(self.device) 18 | # BCE_loss = self.BCE_loss(input, target) 19 | pt = torch.exp(-BCE_loss) # prevents nans when probability 0 20 | F_loss = (1-pt)**self.gamma * BCE_loss 21 | 22 | return F_loss.mean() 23 | 24 | # def forward(self, input, target): 25 | 26 | # # input are not the probabilities, they are just the cnn out vector 27 | # # input and target shape: (bs, n_classes) 28 | # # sigmoid 29 | # probs = torch.sigmoid(input) 30 | # log_probs = -torch.log(probs) 31 | 32 | # focal_loss = torch.sum( torch.pow(1-probs + self.eps, self.gamma).mul(log_probs).mul(target) , dim=1) 33 | # # bce_loss = torch.sum(log_probs.mul(target), dim = 1) 34 | 35 | # return focal_loss.mean() #, bce_loss 36 | 37 | if __name__ == '__main__': 38 | inp = torch.tensor([[1., 0.95], 39 | [.9, 0.3], 40 | [0.6, 0.4]], requires_grad = True) 41 | target = torch.tensor([[1., 1], 42 | [1, 0], 43 | [0, 0]]) 44 | 45 | print('inp\n',inp, '\n') 46 | print('target\n',target, '\n') 47 | 48 | print('inp.requires_grad:', inp.requires_grad, inp.shape) 49 | print('target.requires_grad:', target.requires_grad, target.shape) 50 | 51 | 52 | loss = FocalLoss(gamma = 2) 53 | 54 | focal_loss, bce_loss = loss(inp ,target) 55 | print('\nbce_loss',bce_loss, '\n') 56 | print('\nfocal_loss',focal_loss, '\n') 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, pdb, sys, glob, time 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import cv2 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.models as models 11 | 12 | # import custom dataset classes 13 | from datasets import XRaysTrainDataset 14 | from datasets import XRaysTestDataset 15 | 16 | # import neccesary libraries for defining the optimizers 17 | import torch.optim as optim 18 | 19 | from trainer import fit 20 | import config 21 | 22 | def q(text = ''): # easy way to exiting the script. useful while debugging 23 | print('> ', text) 24 | sys.exit() 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | print(f'\ndevice: {device}') 28 | 29 | parser = argparse.ArgumentParser(description='Following are the arguments that can be passed form the terminal itself ! Cool huh ? :D') 30 | parser.add_argument('--data_path', type = str, default = 'NIH Chest X-rays', help = 'This is the path of the training data') 31 | parser.add_argument('--bs', type = int, default = 128, help = 'batch size') 32 | parser.add_argument('--lr', type = float, default = 1e-5, help = 'Learning Rate for the optimizer') 33 | parser.add_argument('--stage', type = int, default = 1, help = 'Stage, it decides which layers of the Neural Net to train') 34 | parser.add_argument('--loss_func', type = str, default = 'FocalLoss', choices = {'BCE', 'FocalLoss'}, help = 'loss function') 35 | parser.add_argument('-r','--resume', action = 'store_true') # args.resume will return True if -r or --resume is used in the terminal 36 | parser.add_argument('--ckpt', type = str, help = 'Path of the ckeckpoint that you wnat to load') 37 | parser.add_argument('-t','--test', action = 'store_true') # args.test will return True if -t or --test is used in the terminal 38 | args = parser.parse_args() 39 | 40 | if args.resume and args.test: # what if --test is not defiend at all ? test case hai ye ek 41 | q('The flow of this code has been designed either to train the model or to test it.\nPlease choose either --resume or --test') 42 | 43 | stage = args.stage 44 | if not args.resume: 45 | print(f'\nOverwriting stage to 1, as the model training is being done from scratch') 46 | stage = 1 47 | 48 | if args.test: 49 | print('TESTING THE MODEL') 50 | else: 51 | if args.resume: 52 | print('RESUMING THE MODEL TRAINING') 53 | else: 54 | print('TRAINING THE MODEL FROM SCRATCH') 55 | 56 | script_start_time = time.time() # tells the total run time of this script 57 | 58 | # mention the path of the data 59 | data_dir = os.path.join('data',args.data_path) # Data_Entry_2017.csv should be present in the mentioned path 60 | 61 | # define a function to count the total number of trainable parameters 62 | def count_parameters(model): 63 | num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 64 | return num_parameters/1e6 # in terms of millions 65 | 66 | # make the datasets 67 | XRayTrain_dataset = XRaysTrainDataset(data_dir, transform = config.transform) 68 | train_percentage = 0.8 69 | train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)]) 70 | 71 | XRayTest_dataset = XRaysTestDataset(data_dir, transform = config.transform) 72 | 73 | print('\n-----Initial Dataset Information-----') 74 | print('num images in train_dataset : {}'.format(len(train_dataset))) 75 | print('num images in val_dataset : {}'.format(len(val_dataset))) 76 | print('num images in XRayTest_dataset: {}'.format(len(XRayTest_dataset))) 77 | print('-------------------------------------') 78 | 79 | # make the dataloaders 80 | batch_size = args.bs # 128 by default 81 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True) 82 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True) 83 | test_loader = torch.utils.data.DataLoader(XRayTest_dataset, batch_size = batch_size, shuffle = not True) 84 | 85 | print('\n-----Initial Batchloaders Information -----') 86 | print('num batches in train_loader: {}'.format(len(train_loader))) 87 | print('num batches in val_loader : {}'.format(len(val_loader))) 88 | print('num batches in test_loader : {}'.format(len(test_loader))) 89 | print('-------------------------------------------') 90 | 91 | # sanity check 92 | if len(XRayTrain_dataset.all_classes) != 15: # 15 is the unique number of diseases in this dataset 93 | q('\nnumber of classes not equal to 15 !') 94 | 95 | a,b = train_dataset[0] 96 | print('\nwe are working with \nImages shape: {} and \nTarget shape: {}'.format( a.shape, b.shape)) 97 | 98 | # make models directory, where the models and the loss plots will be saved 99 | if not os.path.exists(config.models_dir): 100 | os.mkdir(config.models_dir) 101 | 102 | # define the loss function 103 | if args.loss_func == 'FocalLoss': # by default 104 | from losses import FocalLoss 105 | loss_fn = FocalLoss(device = device, gamma = 2.).to(device) 106 | elif args.loss_func == 'BCE': 107 | loss_fn = nn.BCEWithLogitsLoss().to(device) 108 | 109 | # define the learning rate 110 | lr = args.lr 111 | 112 | if not args.test: # training 113 | 114 | # initialize the model if not args.resume 115 | if not args.resume: 116 | print('\ntraining from scratch') 117 | # import pretrained model 118 | model = models.resnet50(pretrained=True) # pretrained = False bydefault 119 | # change the last linear layer 120 | num_ftrs = model.fc.in_features 121 | model.fc = nn.Linear(num_ftrs, len(XRayTrain_dataset.all_classes)) # 15 output classes 122 | model.to(device) 123 | 124 | print('----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc' 125 | for name, param in model.named_parameters(): # all requires_grad by default, are True initially 126 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters 127 | if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name): 128 | param.requires_grad = True 129 | else: 130 | param.requires_grad = False 131 | 132 | # since we are not resuming the training of the model 133 | epochs_till_now = 0 134 | 135 | # making empty lists to collect all the losses 136 | losses_dict = {'epoch_train_loss': [], 'epoch_val_loss': [], 'total_train_loss_list': [], 'total_val_loss_list': []} 137 | 138 | else: 139 | if args.ckpt == None: 140 | q('ERROR: Please select a valid checkpoint to resume from') 141 | 142 | print('\nckpt loaded: {}'.format(args.ckpt)) 143 | ckpt = torch.load(os.path.join(config.models_dir, args.ckpt)) 144 | 145 | # since we are resuming the training of the model 146 | epochs_till_now = ckpt['epochs'] 147 | model = ckpt['model'] 148 | model.to(device) 149 | 150 | # loading previous loss lists to collect future losses 151 | losses_dict = ckpt['losses_dict'] 152 | 153 | # printing some hyperparameters 154 | print('\n> loss_fn: {}'.format(loss_fn)) 155 | print('> epochs_till_now: {}'.format(epochs_till_now)) 156 | print('> batch_size: {}'.format(batch_size)) 157 | print('> stage: {}'.format(stage)) 158 | print('> lr: {}'.format(lr)) 159 | 160 | else: # testing 161 | if args.ckpt == None: 162 | q('ERROR: Please select a checkpoint to load the testing model from') 163 | 164 | print('\ncheckpoint loaded: {}'.format(args.ckpt)) 165 | ckpt = torch.load(os.path.join(config.models_dir, args.ckpt)) 166 | 167 | # since we are resuming the training of the model 168 | epochs_till_now = ckpt['epochs'] 169 | model = ckpt['model'] 170 | 171 | # loading previous loss lists to collect future losses 172 | losses_dict = ckpt['losses_dict'] 173 | 174 | # make changes(freezing/unfreezing the model's layers) in the following, for training the model for different stages 175 | if (not args.test) and (args.resume): 176 | 177 | if stage == 1: 178 | 179 | print('\n----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc' 180 | for name, param in model.named_parameters(): # all requires_grad by default, are True initially 181 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters 182 | if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name): 183 | param.requires_grad = True 184 | else: 185 | param.requires_grad = False 186 | 187 | elif stage == 2: 188 | 189 | print('\n----- STAGE 2 -----') # only training 'layer3', 'layer4' and 'fc' 190 | for name, param in model.named_parameters(): 191 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters 192 | if ('layer3' in name) or ('layer4' in name) or ('fc' in name): 193 | param.requires_grad = True 194 | else: 195 | param.requires_grad = False 196 | 197 | elif stage == 3: 198 | 199 | print('\n----- STAGE 3 -----') # only training 'layer4' and 'fc' 200 | for name, param in model.named_parameters(): 201 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters 202 | if ('layer4' in name) or ('fc' in name): 203 | param.requires_grad = True 204 | else: 205 | param.requires_grad = False 206 | 207 | elif stage == 4: 208 | 209 | print('\n----- STAGE 4 -----') # only training 'fc' 210 | for name, param in model.named_parameters(): 211 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters 212 | if ('fc' in name): 213 | param.requires_grad = True 214 | else: 215 | param.requires_grad = False 216 | 217 | 218 | if not args.test: 219 | # checking the layers which are going to be trained (irrespective of args.resume) 220 | trainable_layers = [] 221 | for name, param in model.named_parameters(): 222 | if param.requires_grad == True: 223 | layer_name = str.split(name, '.')[0] 224 | if layer_name not in trainable_layers: 225 | trainable_layers.append(layer_name) 226 | print('\nfollowing are the trainable layers...') 227 | print(trainable_layers) 228 | 229 | print('\nwe have {} Million trainable parameters here in the {} model'.format(count_parameters(model), model.__class__.__name__)) 230 | 231 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr) 232 | 233 | # make changes in the parameters of the following 'fit' function 234 | fit(device, XRayTrain_dataset, train_loader, val_loader, 235 | test_loader, model, loss_fn, 236 | optimizer, losses_dict, 237 | epochs_till_now = epochs_till_now, epochs = 3, 238 | log_interval = 25, save_interval = 1, 239 | lr = lr, bs = batch_size, stage = stage, 240 | test_only = args.test) 241 | 242 | script_time = time.time() - script_start_time 243 | m, s = divmod(script_time, 60) 244 | h, m = divmod(m, 60) 245 | print('{} h {}m laga poore script me !'.format(int(h), int(m))) 246 | 247 | # ''' 248 | # This is how the model is trained... 249 | # ##### STAGE 1 ##### FocalLoss lr = 1e-5 250 | # training layers = layer2, layer3, layer4, fc 251 | # epochs = 2 252 | # ##### STAGE 2 ##### FocalLoss lr = 3e-4 253 | # training layers = layer3, layer4, fc 254 | # epochs = 5 255 | # ##### STAGE 3 ##### FocalLoss lr = 7e-4 256 | # training layers = layer4, fc 257 | # epochs = 4 258 | # ##### STAGE 4 ##### FocalLoss lr = 1e-3 259 | # training layers = fc 260 | # epochs = 3 261 | # ''' 262 | -------------------------------------------------------------------------------- /sample_xrays/Atelectasis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Atelectasis.png -------------------------------------------------------------------------------- /sample_xrays/Cardiomegaly_Edema_Effusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Cardiomegaly_Edema_Effusion.png -------------------------------------------------------------------------------- /sample_xrays/Effusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Effusion.png -------------------------------------------------------------------------------- /sample_xrays/Fibrosis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Fibrosis.png -------------------------------------------------------------------------------- /sample_xrays/No Finding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/No Finding.png -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import sys, os, time, random, pdb 6 | import numpy as np 7 | import pandas as pd 8 | import torch.nn.functional as F 9 | import torch 10 | import pickle 11 | import tqdm, pdb 12 | from sklearn.metrics import roc_auc_score 13 | 14 | import config 15 | 16 | def get_roc_auc_score(y_true, y_probs): 17 | ''' 18 | Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs. 19 | ''' 20 | 21 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'rb') as handle: 22 | all_classes = pickle.load(handle) 23 | 24 | NoFindingIndex = all_classes.index('No Finding') 25 | 26 | if True: 27 | print('\nNoFindingIndex: ', NoFindingIndex) 28 | print('y_true.shape, y_probs.shape ', y_true.shape, y_probs.shape) 29 | GT_and_probs = {'y_true': y_true, 'y_probs': y_probs} 30 | with open('GT_and_probs', 'wb') as handle: 31 | pickle.dump(GT_and_probs, handle, protocol = pickle.HIGHEST_PROTOCOL) 32 | 33 | class_roc_auc_list = [] 34 | useful_classes_roc_auc_list = [] 35 | 36 | for i in range(y_true.shape[1]): 37 | class_roc_auc = roc_auc_score(y_true[:, i], y_probs[:, i]) 38 | class_roc_auc_list.append(class_roc_auc) 39 | if i != NoFindingIndex: 40 | useful_classes_roc_auc_list.append(class_roc_auc) 41 | if True: 42 | print('\nclass_roc_auc_list: ', class_roc_auc_list) 43 | print('\nuseful_classes_roc_auc_list', useful_classes_roc_auc_list) 44 | 45 | return np.mean(np.array(useful_classes_roc_auc_list)) 46 | 47 | def make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name): 48 | ''' 49 | This function makes the following 4 different plots- 50 | 1. mean train loss VS number of epochs 51 | 2. mean val loss VS number of epochs 52 | 3. batch train loss for all the training batches VS number of batches 53 | 4. batch val loss for all the validation batches VS number of batches 54 | ''' 55 | fig = plt.figure(figsize=(16,16)) 56 | fig.suptitle('loss trends', fontsize=20) 57 | ax1 = fig.add_subplot(221) 58 | ax2 = fig.add_subplot(222) 59 | ax3 = fig.add_subplot(223) 60 | ax4 = fig.add_subplot(224) 61 | 62 | ax1.title.set_text('epoch train loss VS #epochs') 63 | ax1.set_xlabel('#epochs') 64 | ax1.set_ylabel('epoch train loss') 65 | ax1.plot(epoch_train_loss) 66 | 67 | ax2.title.set_text('epoch val loss VS #epochs') 68 | ax2.set_xlabel('#epochs') 69 | ax2.set_ylabel('epoch val loss') 70 | ax2.plot(epoch_val_loss) 71 | 72 | ax3.title.set_text('batch train loss VS #batches') 73 | ax3.set_xlabel('#batches') 74 | ax3.set_ylabel('batch train loss') 75 | ax3.plot(total_train_loss_list) 76 | 77 | ax4.title.set_text('batch val loss VS #batches') 78 | ax4.set_xlabel('#batches') 79 | ax4.set_ylabel('batch val loss') 80 | ax4.plot(total_val_loss_list) 81 | 82 | plt.savefig(os.path.join(config.models_dir,'losses_{}.png'.format(save_name))) 83 | 84 | def get_resampled_train_val_dataloaders(XRayTrain_dataset, transform, bs): 85 | ''' 86 | Resamples the XRaysTrainDataset class object and returns a training and a validation dataloaders, by splitting the sampled dataset in 80-20 ratio. 87 | ''' 88 | XRayTrain_dataset.resample() 89 | 90 | train_percentage = 0.8 91 | train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)]) 92 | 93 | print('\n-----Resampled Dataset Information-----') 94 | print('num images in train_dataset : {}'.format(len(train_dataset))) 95 | print('num images in val_dataset : {}'.format(len(val_dataset))) 96 | print('---------------------------------------') 97 | 98 | # make dataloaders 99 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = bs, shuffle = True) 100 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = bs, shuffle = not True) 101 | 102 | print('\n-----Resampled Batchloaders Information -----') 103 | print('num batches in train_loader: {}'.format(len(train_loader))) 104 | print('num batches in val_loader : {}'.format(len(val_loader))) 105 | print('---------------------------------------------\n') 106 | 107 | return train_loader, val_loader 108 | 109 | def train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval): 110 | ''' 111 | Takes in the data from the 'train_loader', calculates the loss over it using the 'loss_fn' 112 | and optimizes the 'model' using the 'optimizer' 113 | 114 | Also prints the loss and the ROC AUC score for the batches, after every 'log_interval' batches. 115 | ''' 116 | model.train() 117 | 118 | running_train_loss = 0 119 | train_loss_list = [] 120 | 121 | start_time = time.time() 122 | for batch_idx, (img, target) in enumerate(train_loader): 123 | # print(type(img), img.shape) # , np.unique(img)) 124 | 125 | img = img.to(device) 126 | target = target.to(device) 127 | 128 | optimizer.zero_grad() 129 | out = model(img) 130 | loss = loss_fn(out, target) 131 | running_train_loss += loss.item()*img.shape[0] 132 | train_loss_list.append(loss.item()) 133 | 134 | loss.backward() 135 | optimizer.step() 136 | 137 | if (batch_idx+1)%log_interval == 0: 138 | # batch metric evaluation 139 | # # out_detached = out.detach() 140 | # # batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy()) 141 | # 'out' is a torch.Tensor and 'roc_auc_score' function first tries to convert it into a numpy array, but since 'out' has requires_grad = True, it throws an error 142 | # RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead. 143 | # so we have to 'detach' the 'out' tensor and then convert it into a numpy array to avoid the error ! 144 | 145 | batch_time = time.time() - start_time 146 | m, s = divmod(batch_time, 60) 147 | print('Train Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(train_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2))) 148 | 149 | start_time = time.time() 150 | 151 | return train_loss_list, running_train_loss/float(len(train_loader.dataset)) 152 | 153 | def val_epoch(device, val_loader, model, loss_fn, epochs_till_now = None, final_epoch = None, log_interval = 1, test_only = False): 154 | ''' 155 | It essentially takes in the val_loader/test_loader, the model and the loss function and evaluates 156 | the loss and the ROC AUC score for all the data in the dataloader. 157 | 158 | It also prints the loss and the ROC AUC score for every 'log_interval'th batch, only when 'test_only' is False 159 | ''' 160 | model.eval() 161 | 162 | running_val_loss = 0 163 | val_loss_list = [] 164 | val_loader_examples_num = len(val_loader.dataset) 165 | 166 | probs = np.zeros((val_loader_examples_num, 15), dtype = np.float32) 167 | gt = np.zeros((val_loader_examples_num, 15), dtype = np.float32) 168 | k=0 169 | 170 | with torch.no_grad(): 171 | batch_start_time = time.time() 172 | for batch_idx, (img, target) in enumerate(val_loader): 173 | if test_only: 174 | per = ((batch_idx+1)/len(val_loader))*100 175 | a_, b_ = divmod(per, 1) 176 | print(f'{str(batch_idx+1).zfill(len(str(len(val_loader))))}/{str(len(val_loader)).zfill(len(str(len(val_loader))))} ({str(int(a_)).zfill(2)}.{str(int(100*b_)).zfill(2)} %)', end = '\r') 177 | # print(type(img), img.shape) # , np.unique(img)) 178 | 179 | img = img.to(device) 180 | target = target.to(device) 181 | 182 | out = model(img) 183 | loss = loss_fn(out, target) 184 | running_val_loss += loss.item()*img.shape[0] 185 | val_loss_list.append(loss.item()) 186 | 187 | # storing model predictions for metric evaluat`ion 188 | probs[k: k + out.shape[0], :] = out.cpu() 189 | gt[ k: k + out.shape[0], :] = target.cpu() 190 | k += out.shape[0] 191 | 192 | if ((batch_idx+1)%log_interval == 0) and (not test_only): # only when ((batch_idx + 1) is divisible by log_interval) and (when test_only = False) 193 | # batch metric evaluation 194 | # batch_roc_auc_score = get_roc_auc_score(target, out) 195 | 196 | batch_time = time.time() - batch_start_time 197 | m, s = divmod(batch_time, 60) 198 | print('Val Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(val_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2))) 199 | 200 | batch_start_time = time.time() 201 | 202 | # metric scenes 203 | roc_auc = get_roc_auc_score(gt, probs) 204 | 205 | return val_loss_list, running_val_loss/float(len(val_loader.dataset)), roc_auc 206 | 207 | def fit(device, XRayTrain_dataset, train_loader, val_loader, test_loader, model, 208 | loss_fn, optimizer, losses_dict, 209 | epochs_till_now, epochs, 210 | log_interval, save_interval, 211 | lr, bs, stage, test_only = False): 212 | ''' 213 | Trains or Tests the 'model' on the given 'train_loader', 'val_loader', 'test_loader' for 'epochs' number of epochs. 214 | If training ('test_only' = False), it saves the optimized 'model' and the loss plots ,after every 'save_interval'th epoch. 215 | ''' 216 | epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list = losses_dict['epoch_train_loss'], losses_dict['epoch_val_loss'], losses_dict['total_train_loss_list'], losses_dict['total_val_loss_list'] 217 | 218 | final_epoch = epochs_till_now + epochs 219 | 220 | if test_only: 221 | print('\n======= Testing... =======\n') 222 | test_start_time = time.time() 223 | test_loss, mean_running_test_loss, test_roc_auc = val_epoch(device, test_loader, model, loss_fn, log_interval, test_only = test_only) 224 | total_test_time = time.time() - test_start_time 225 | m, s = divmod(total_test_time, 60) 226 | print('test_roc_auc: {} in {} mins {} secs'.format(test_roc_auc, int(m), int(s))) 227 | sys.exit() 228 | 229 | starting_epoch = epochs_till_now 230 | print('\n======= Training after epoch #{}... =======\n'.format(epochs_till_now)) 231 | 232 | # epoch_train_loss = [] 233 | # epoch_val_loss = [] 234 | 235 | # total_train_loss_list = [] 236 | # total_val_loss_list = [] 237 | 238 | for epoch in range(epochs): 239 | 240 | if starting_epoch != epochs_till_now: 241 | # resample the train_loader and val_loader 242 | train_loader, val_loader = get_resampled_train_val_dataloaders(XRayTrain_dataset, config.transform, bs = bs) 243 | 244 | epochs_till_now += 1 245 | print('============ EPOCH {}/{} ============'.format(epochs_till_now, final_epoch)) 246 | epoch_start_time = time.time() 247 | 248 | print('TRAINING') 249 | train_loss, mean_running_train_loss = train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval) 250 | print('VALIDATION') 251 | val_loss, mean_running_val_loss, roc_auc = val_epoch(device, val_loader, model, loss_fn , epochs_till_now, final_epoch, log_interval) 252 | 253 | epoch_train_loss.append(mean_running_train_loss) 254 | epoch_val_loss.append(mean_running_val_loss) 255 | 256 | total_train_loss_list.extend(train_loss) 257 | total_val_loss_list.extend(val_loss) 258 | 259 | save_name = 'stage{}_{}_{}'.format(stage, str.split(str(lr), '.')[-1], str(epochs_till_now).zfill(2)) 260 | 261 | # the follwoing piece of codw needs to be worked on !!! LATEST DEVELOPMENT TILL HERE 262 | if ((epoch+1)%save_interval == 0) or test_only: 263 | save_path = os.path.join(config.models_dir, '{}.pth'.format(save_name)) 264 | 265 | torch.save({ 266 | 'epochs': epochs_till_now, 267 | 'model': model, # it saves the whole model 268 | 'losses_dict': {'epoch_train_loss': epoch_train_loss, 'epoch_val_loss': epoch_val_loss, 'total_train_loss_list': total_train_loss_list, 'total_val_loss_list': total_val_loss_list} 269 | }, save_path) 270 | 271 | print('\ncheckpoint {} saved'.format(save_path)) 272 | 273 | make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name) 274 | print('loss plots saved !!!') 275 | 276 | print('\nTRAIN LOSS : {}'.format(mean_running_train_loss)) 277 | print('VAL LOSS : {}'.format(mean_running_val_loss)) 278 | print('VAL ROC_AUC: {}'.format(roc_auc)) 279 | 280 | total_epoch_time = time.time() - epoch_start_time 281 | m, s = divmod(total_epoch_time, 60) 282 | h, m = divmod(m, 60) 283 | print('\nEpoch {}/{} took {} h {} m'.format(epochs_till_now, final_epoch, int(h), int(m))) 284 | 285 | 286 | 287 | ''' 288 | def pred_n_write(test_loader, model, save_name): 289 | res = np.zeros((3000, 15), dtype = np.float32) 290 | k=0 291 | for batch_idx, img in tqdm.tqdm(enumerate(test_loader)): 292 | model.eval() 293 | with torch.no_grad(): 294 | pred = torch.sigmoid(model(img)) 295 | # print(k) 296 | res[k: k + pred.shape[0], :] = pred 297 | k += pred.shape[0] 298 | 299 | # write csv 300 | print('populating the csv') 301 | submit = pd.DataFrame() 302 | submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list] 303 | with open('disease_classes.pickle', 'rb') as handle: 304 | disease_classes = pickle.load(handle) 305 | 306 | for idx, col in enumerate(disease_classes): 307 | if col == 'Hernia': 308 | submit['Hern'] = res[:, idx] 309 | elif col == 'Pleural_Thickening': 310 | submit['Pleural_thickening'] = res[:, idx] 311 | elif col == 'No Finding': 312 | submit['No_findings'] = res[:, idx] 313 | else: 314 | submit[col] = res[:, idx] 315 | rand_num = str(random.randint(1000, 9999)) 316 | csv_name = '{}___{}.csv'.format(save_name, rand_num) 317 | submit.to_csv('res/' + csv_name, index = False) 318 | print('{} saved !'.format(csv_name)) 319 | ''' 320 | --------------------------------------------------------------------------------