├── README.md
├── config.py
├── datasets.py
├── losses.py
├── main.py
├── sample_xrays
├── Atelectasis.png
├── Cardiomegaly_Edema_Effusion.png
├── Effusion.png
├── Fibrosis.png
└── No Finding.png
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | # NIH-Chest-X-rays-Multi-Label-Image-Classification-In-Pytorch
2 | Multi-Label Image Classification of the Chest X-Rays In Pytorch
3 |
4 | # Requirements
5 | * torch >= 0.4
6 | * torchvision >= 0.2.2
7 | * opencv-python
8 | * numpy >= 1.7.3
9 | * matplotlib
10 | * tqdm
11 |
12 | # Dataset
13 | [NIH Chest X-ray Dataset](https://www.kaggle.com/nih-chest-xrays/data#Data_Entry_2017.csv) is used for Multi-Label Disease Classification of of the Chest X-Rays.
14 | There are a total of 15 classes (14 diseases, and one for 'No findings')
15 | Images can be classified as "No findings" or one or more disease classes:
16 | * Atelectasis
17 | * Consolidation
18 | * Infiltration
19 | * Pneumothorax
20 | * Edema
21 | * Emphysema
22 | * Fibrosis
23 | * Effusion
24 | * Pneumonia
25 | * Pleural_thickening
26 | * Cardiomegaly
27 | * Nodule Mass
28 | * Hernia
29 |
30 | There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing.
31 |
32 | # Sample X-Ray Images
33 |
34 |
35 |

36 |
37 |
38 |

39 |
40 |
41 |

42 |
43 |
44 |
45 | # Model
46 | Pretrained Resnet50 model is used for Transfer Learning on this new image dataset.
47 |
48 | # Loss Function
49 | There is a choice of loss function
50 | * Focal Loss (default)
51 | * Binary Cross Entropy Loss or BCE Loss
52 |
53 | # Training
54 | * ### From Scratch
55 | Following are the layers which are set to trainable-
56 | * layer2
57 | * layer3
58 | * layer4
59 | * fc
60 |
61 | Terminal Code:
62 | ```
63 | python main.py
64 | ```
65 |
66 | * ### Resuming From a Saved Checkpoint
67 | A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the
68 | * epochs (number of epochs the model has been trained till that time)
69 | * model (architecture and the learnt weights of the model)
70 | * lr_scheduler_state_dict (state_dict of the lr_scheduler)
71 | * losses_dict (a dictionary containing the following loses)
72 |
73 | * mean train epoch losses for all the epochs
74 | * mean val epoch losses for all the epochs
75 | * batch train loss for all the training batches
76 | * batch train loss for all the val batches
77 |
78 | Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The 'stage' parameter can be passed from the terminal using the argument --stage STAGE
79 |
80 | Terminal Code:
81 | ```
82 | python main.py --resume --ckpt checkpoint_file.pth --stage 2
83 | ```
84 |
85 | Training the model will create a **models** directory and will save the checkpoints in there.
86 |
87 | # Testing
88 | A Saved Checkpoint needs to be loaded using the **--ckpt** argument and **--test** argument needs to be passed for activating the Test Mode
89 |
90 | Terminal Code:
91 | ```
92 | python main.py --test --ckpt checkpoint_file.pth
93 | ```
94 |
95 | # Result
96 | The model achieved the average **ROC AUC Score** of **0.73241** on all classes(excluding "No findings" class) after training in the following stages-
97 |
98 | #### STAGE 1
99 | * Loss Function: FocalLoss
100 | * lr: 1e-5
101 | * Training Layers: layer2, layer3, layer4, fc
102 | * Epochs: 2
103 |
104 | #### STAGE 2
105 | * Loss Function: FocalLoss
106 | * lr: 3e-4
107 | * Training Layers: layer3, layer4, fc
108 | * Epochs: 1
109 |
110 | #### STAGE 3
111 | * Loss Function: FocalLoss
112 | * lr: 1e-3
113 | * Training Layers: layer4, fc
114 | * Epochs: 3
115 |
116 | #### STAGE 4
117 | * Loss Function: FocalLoss
118 | * lr: 1e-3
119 | * Training Layers: fc
120 | * Epochs: 2
121 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | pkl_dir_path = 'pickles'
2 | train_val_df_pkl_path = 'train_val_df.pickle'
3 | test_df_pkl_path = 'test_df.pickle'
4 | disease_classes_pkl_path = 'disease_classes.pickle'
5 | models_dir = 'models'
6 |
7 | from torchvision import transforms
8 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
9 | std=[0.229, 0.224, 0.225])
10 |
11 | # transforms.RandomHorizontalFlip() not used because some disease might be more likely to the present in a specific lung (lelf/rigth)
12 | transform = transforms.Compose([transforms.ToPILImage(),
13 | transforms.Resize(224),
14 | transforms.ToTensor(),
15 | normalize])
16 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import glob, os, sys, pdb, time
2 | import pandas as pd
3 | import numpy as np
4 | import cv2
5 | import pickle
6 | from torch.utils.data import Dataset
7 | from tqdm import tqdm
8 | import torch
9 |
10 | import config
11 |
12 | def q(text = ''): # easy way to exiting the script. useful while debugging
13 | print('> ', text)
14 | sys.exit()
15 |
16 | class XRaysTrainDataset(Dataset):
17 | def __init__(self, data_dir, transform = None):
18 | self.data_dir = data_dir
19 |
20 | self.transform = transform
21 | # print('self.data_dir: ', self.data_dir)
22 |
23 | # full dataframe including train_val and test set
24 | self.df = self.get_df()
25 | print('self.df.shape: {}'.format(self.df.shape))
26 |
27 | self.make_pkl_dir(config.pkl_dir_path)
28 |
29 | # get train_val_df
30 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path)):
31 |
32 | self.train_val_df = self.get_train_val_df()
33 | print('\nself.train_val_df.shape: {}'.format(self.train_val_df.shape))
34 |
35 | # pickle dump the train_val_df
36 | with open(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path), 'wb') as handle:
37 | pickle.dump(self.train_val_df, handle, protocol = pickle.HIGHEST_PROTOCOL)
38 | print('{}: dumped'.format(config.train_val_df_pkl_path))
39 |
40 | else:
41 | # pickle load the train_val_df
42 | with open(os.path.join(config.pkl_dir_path, config.train_val_df_pkl_path), 'rb') as handle:
43 | self.train_val_df = pickle.load(handle)
44 | print('\n{}: loaded'.format(config.train_val_df_pkl_path))
45 | print('self.train_val_df.shape: {}'.format(self.train_val_df.shape))
46 |
47 | self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices()
48 |
49 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path)):
50 | # pickle dump the classes list
51 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'wb') as handle:
52 | pickle.dump(self.all_classes, handle, protocol = pickle.HIGHEST_PROTOCOL)
53 | print('\n{}: dumped'.format(config.disease_classes_pkl_path))
54 | else:
55 | print('\n{}: already exists'.format(config.disease_classes_pkl_path))
56 |
57 | self.new_df = self.train_val_df.iloc[self.the_chosen, :] # this is the sampled train_val data
58 | print('\nself.all_classes_dict: {}'.format(self.all_classes_dict))
59 |
60 | def resample(self):
61 | self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices()
62 | self.new_df = self.train_val_df.iloc[self.the_chosen, :]
63 | print('\nself.all_classes_dict: {}'.format(self.all_classes_dict))
64 |
65 | def make_pkl_dir(self, pkl_dir_path):
66 | if not os.path.exists(pkl_dir_path):
67 | os.mkdir(pkl_dir_path)
68 |
69 | def get_train_val_df(self):
70 |
71 | # get the list of train_val data
72 | train_val_list = self.get_train_val_list()
73 |
74 | train_val_df = pd.DataFrame()
75 | print('\nbuilding train_val_df...')
76 | for i in tqdm(range(self.df.shape[0])):
77 | filename = os.path.basename(self.df.iloc[i,0])
78 | # print('filename: ', filename)
79 | if filename in train_val_list:
80 | train_val_df = train_val_df.append(self.df.iloc[i:i+1, :])
81 |
82 | # print('train_val_df.shape: {}'.format(train_val_df.shape))
83 |
84 | return train_val_df
85 |
86 | def __getitem__(self, index):
87 | row = self.new_df.iloc[index, :]
88 |
89 | img = cv2.imread(row['image_links'])
90 | labels = str.split(row['Finding Labels'], '|')
91 |
92 | target = torch.zeros(len(self.all_classes))
93 | for lab in labels:
94 | lab_idx = self.all_classes.index(lab)
95 | target[lab_idx] = 1
96 |
97 | if self.transform is not None:
98 | img = self.transform(img)
99 |
100 | return img, target
101 |
102 | def choose_the_indices(self):
103 |
104 | max_examples_per_class = 10000 # its the maximum number of examples that would be sampled in the training set for any class
105 | the_chosen = []
106 | all_classes = {}
107 | length = len(self.train_val_df)
108 | # for i in tqdm(range(len(merged_df))):
109 | print('\nSampling the huuuge training dataset')
110 | for i in tqdm(list(np.random.choice(range(length),length, replace = False))):
111 |
112 | temp = str.split(self.train_val_df.iloc[i, :]['Finding Labels'], '|')
113 |
114 | # special case of ultra minority hernia. we will use all the images with 'Hernia' tagged in them.
115 | if 'Hernia' in temp:
116 | the_chosen.append(i)
117 | for t in temp:
118 | if t not in all_classes:
119 | all_classes[t] = 1
120 | else:
121 | all_classes[t] += 1
122 | continue
123 |
124 | # choose if multiple labels
125 | if len(temp) > 1:
126 | bool_lis = [False]*len(temp)
127 | # check if any label crosses the upper limit
128 | for idx, t in enumerate(temp):
129 | if t in all_classes:
130 | if all_classes[t]< max_examples_per_class: # 500
131 | bool_lis[idx] = True
132 | else:
133 | bool_lis[idx] = True
134 | # if all lables under upper limit, append
135 | if sum(bool_lis) == len(temp):
136 | the_chosen.append(i)
137 | # maintain count
138 | for t in temp:
139 | if t not in all_classes:
140 | all_classes[t] = 1
141 | else:
142 | all_classes[t] += 1
143 | else: # these are single label images
144 | for t in temp:
145 | if t not in all_classes:
146 | all_classes[t] = 1
147 | else:
148 | if all_classes[t] < max_examples_per_class: # 500
149 | all_classes[t] += 1
150 | the_chosen.append(i)
151 |
152 | # print('len(all_classes): ', len(all_classes))
153 | # print('all_classes: ', all_classes)
154 | # print('len(the_chosen): ', len(the_chosen))
155 |
156 | '''
157 | if len(the_chosen) != len(set(the_chosen)):
158 | print('\nGadbad !!!')
159 | print('and the difference is: ', len(the_chosen) - len(set(the_chosen)))
160 | else:
161 | print('\nGood')
162 | '''
163 |
164 | return the_chosen, sorted(list(all_classes)), all_classes
165 |
166 | def get_df(self):
167 | csv_path = os.path.join(self.data_dir, 'Data_Entry_2017.csv')
168 | print('\n{} found: {}'.format(csv_path, os.path.exists(csv_path)))
169 |
170 | all_xray_df = pd.read_csv(csv_path)
171 |
172 | df = pd.DataFrame()
173 | df['image_links'] = [x for x in glob.glob(os.path.join(self.data_dir, 'images*', '*', '*.png'))]
174 |
175 | df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)])
176 | merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index'])
177 | merged_df = merged_df[['image_links','Finding Labels']]
178 | return merged_df
179 |
180 | def get_train_val_list(self):
181 | f = open(os.path.join('data', 'NIH Chest X-rays', 'train_val_list.txt'), 'r')
182 | train_val_list = str.split(f.read(), '\n')
183 | return train_val_list
184 |
185 | def __len__(self):
186 | return len(self.new_df)
187 |
188 |
189 | # prepare the test dataset
190 | class XRaysTestDataset(Dataset):
191 | def __init__(self, data_dir, transform = None):
192 | self.data_dir = data_dir
193 | self.transform = transform
194 | # print('self.data_dir: ', self.data_dir)
195 |
196 | # full dataframe including train_val and test set
197 | self.df = self.get_df()
198 | print('\nself.df.shape: {}'.format(self.df.shape))
199 |
200 | self.make_pkl_dir(config.pkl_dir_path)
201 |
202 | # loading the classes list
203 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'rb') as handle:
204 | self.all_classes = pickle.load(handle)
205 |
206 | # get test_df
207 | if not os.path.exists(os.path.join(config.pkl_dir_path, config.test_df_pkl_path)):
208 |
209 | self.test_df = self.get_test_df()
210 | print('self.test_df.shape: ', self.test_df.shape)
211 |
212 | # pickle dump the test_df
213 | with open(os.path.join(config.pkl_dir_path, config.test_df_pkl_path), 'wb') as handle:
214 | pickle.dump(self.test_df, handle, protocol = pickle.HIGHEST_PROTOCOL)
215 | print('\n{}: dumped'.format(config.test_df_pkl_path))
216 | else:
217 | # pickle load the test_df
218 | with open(os.path.join(config.pkl_dir_path, config.test_df_pkl_path), 'rb') as handle:
219 | self.test_df = pickle.load(handle)
220 | print('\n{}: loaded'.format(config.test_df_pkl_path))
221 | print('self.test_df.shape: {}'.format(self.test_df.shape))
222 |
223 | def __getitem__(self, index):
224 | row = self.test_df.iloc[index, :]
225 |
226 | img = cv2.imread(row['image_links'])
227 | labels = str.split(row['Finding Labels'], '|')
228 |
229 | target = torch.zeros(len(self.all_classes))
230 | for lab in labels:
231 | lab_idx = self.all_classes.index(lab)
232 | target[lab_idx] = 1
233 |
234 | if self.transform is not None:
235 | img = self.transform(img)
236 |
237 | return img, target
238 |
239 | def make_pkl_dir(self, pkl_dir_path):
240 | if not os.path.exists(pkl_dir_path):
241 | os.mkdir(pkl_dir_path)
242 |
243 | def get_df(self):
244 | csv_path = os.path.join(self.data_dir, 'Data_Entry_2017.csv')
245 |
246 | all_xray_df = pd.read_csv(csv_path)
247 |
248 | df = pd.DataFrame()
249 | df['image_links'] = [x for x in glob.glob(os.path.join(self.data_dir, 'images*', '*', '*.png'))]
250 |
251 | df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)])
252 | merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index'])
253 | merged_df = merged_df[['image_links','Finding Labels']]
254 | return merged_df
255 |
256 | def get_test_df(self):
257 |
258 | # get the list of test data
259 | test_list = self.get_test_list()
260 |
261 | test_df = pd.DataFrame()
262 | print('\nbuilding test_df...')
263 | for i in tqdm(range(self.df.shape[0])):
264 | filename = os.path.basename(self.df.iloc[i,0])
265 | # print('filename: ', filename)
266 | if filename in test_list:
267 | test_df = test_df.append(self.df.iloc[i:i+1, :])
268 |
269 | print('test_df.shape: ', test_df.shape)
270 |
271 | return test_df
272 |
273 | def get_test_list(self):
274 | f = open( os.path.join('data', 'NIH Chest X-rays', 'test_list.txt'), 'r')
275 | test_list = str.split(f.read(), '\n')
276 | return test_list
277 |
278 | def __len__(self):
279 | return len(self.test_df)
280 |
281 |
282 |
283 |
284 |
285 |
286 | '''
287 | # prepare the test dataset
288 | import random
289 | class XRaysTestDataset2(Dataset):
290 | def __init__(self, test_data_dir, transform = None):
291 | self.test_data_dir = test_data_dir
292 | self.transform = transform
293 | self.data_list = self.get_data_list(self.test_data_dir)
294 |
295 | self.subset = self.data_list[:1000]
296 |
297 | def __getitem__(self, index):
298 | img_path = self.data_list[index]
299 | img = cv2.imread(img_path)
300 |
301 | if self.transform is not None:
302 | img = self.transform(img)
303 |
304 | return img_path
305 |
306 | def sample(self):
307 |
308 | random.shuffle(self.data_list)
309 |
310 | self.subset = self.data_list[:np.random.randint(500,700)]
311 |
312 | def __len__(self):
313 | return len(self.subset)
314 |
315 | def get_data_list(self, data_dir):
316 | data_list = []
317 | for path in glob.glob(data_dir + os.sep + '*'):
318 | data_list.append(path)
319 | return data_list
320 | '''
321 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch, sys, os, pdb
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FocalLoss(nn.Module):
6 |
7 | def __init__(self, device, gamma = 1.0):
8 | super(FocalLoss, self).__init__()
9 | self.device = device
10 | self.gamma = torch.tensor(gamma, dtype = torch.float32).to(device)
11 | self.eps = 1e-6
12 |
13 | # self.BCE_loss = nn.BCEWithLogitsLoss(reduction='none').to(device)
14 |
15 | def forward(self, input, target):
16 |
17 | BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none').to(self.device)
18 | # BCE_loss = self.BCE_loss(input, target)
19 | pt = torch.exp(-BCE_loss) # prevents nans when probability 0
20 | F_loss = (1-pt)**self.gamma * BCE_loss
21 |
22 | return F_loss.mean()
23 |
24 | # def forward(self, input, target):
25 |
26 | # # input are not the probabilities, they are just the cnn out vector
27 | # # input and target shape: (bs, n_classes)
28 | # # sigmoid
29 | # probs = torch.sigmoid(input)
30 | # log_probs = -torch.log(probs)
31 |
32 | # focal_loss = torch.sum( torch.pow(1-probs + self.eps, self.gamma).mul(log_probs).mul(target) , dim=1)
33 | # # bce_loss = torch.sum(log_probs.mul(target), dim = 1)
34 |
35 | # return focal_loss.mean() #, bce_loss
36 |
37 | if __name__ == '__main__':
38 | inp = torch.tensor([[1., 0.95],
39 | [.9, 0.3],
40 | [0.6, 0.4]], requires_grad = True)
41 | target = torch.tensor([[1., 1],
42 | [1, 0],
43 | [0, 0]])
44 |
45 | print('inp\n',inp, '\n')
46 | print('target\n',target, '\n')
47 |
48 | print('inp.requires_grad:', inp.requires_grad, inp.shape)
49 | print('target.requires_grad:', target.requires_grad, target.shape)
50 |
51 |
52 | loss = FocalLoss(gamma = 2)
53 |
54 | focal_loss, bce_loss = loss(inp ,target)
55 | print('\nbce_loss',bce_loss, '\n')
56 | print('\nfocal_loss',focal_loss, '\n')
57 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, pdb, sys, glob, time
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import cv2
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torchvision.models as models
11 |
12 | # import custom dataset classes
13 | from datasets import XRaysTrainDataset
14 | from datasets import XRaysTestDataset
15 |
16 | # import neccesary libraries for defining the optimizers
17 | import torch.optim as optim
18 |
19 | from trainer import fit
20 | import config
21 |
22 | def q(text = ''): # easy way to exiting the script. useful while debugging
23 | print('> ', text)
24 | sys.exit()
25 |
26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27 | print(f'\ndevice: {device}')
28 |
29 | parser = argparse.ArgumentParser(description='Following are the arguments that can be passed form the terminal itself ! Cool huh ? :D')
30 | parser.add_argument('--data_path', type = str, default = 'NIH Chest X-rays', help = 'This is the path of the training data')
31 | parser.add_argument('--bs', type = int, default = 128, help = 'batch size')
32 | parser.add_argument('--lr', type = float, default = 1e-5, help = 'Learning Rate for the optimizer')
33 | parser.add_argument('--stage', type = int, default = 1, help = 'Stage, it decides which layers of the Neural Net to train')
34 | parser.add_argument('--loss_func', type = str, default = 'FocalLoss', choices = {'BCE', 'FocalLoss'}, help = 'loss function')
35 | parser.add_argument('-r','--resume', action = 'store_true') # args.resume will return True if -r or --resume is used in the terminal
36 | parser.add_argument('--ckpt', type = str, help = 'Path of the ckeckpoint that you wnat to load')
37 | parser.add_argument('-t','--test', action = 'store_true') # args.test will return True if -t or --test is used in the terminal
38 | args = parser.parse_args()
39 |
40 | if args.resume and args.test: # what if --test is not defiend at all ? test case hai ye ek
41 | q('The flow of this code has been designed either to train the model or to test it.\nPlease choose either --resume or --test')
42 |
43 | stage = args.stage
44 | if not args.resume:
45 | print(f'\nOverwriting stage to 1, as the model training is being done from scratch')
46 | stage = 1
47 |
48 | if args.test:
49 | print('TESTING THE MODEL')
50 | else:
51 | if args.resume:
52 | print('RESUMING THE MODEL TRAINING')
53 | else:
54 | print('TRAINING THE MODEL FROM SCRATCH')
55 |
56 | script_start_time = time.time() # tells the total run time of this script
57 |
58 | # mention the path of the data
59 | data_dir = os.path.join('data',args.data_path) # Data_Entry_2017.csv should be present in the mentioned path
60 |
61 | # define a function to count the total number of trainable parameters
62 | def count_parameters(model):
63 | num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
64 | return num_parameters/1e6 # in terms of millions
65 |
66 | # make the datasets
67 | XRayTrain_dataset = XRaysTrainDataset(data_dir, transform = config.transform)
68 | train_percentage = 0.8
69 | train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)])
70 |
71 | XRayTest_dataset = XRaysTestDataset(data_dir, transform = config.transform)
72 |
73 | print('\n-----Initial Dataset Information-----')
74 | print('num images in train_dataset : {}'.format(len(train_dataset)))
75 | print('num images in val_dataset : {}'.format(len(val_dataset)))
76 | print('num images in XRayTest_dataset: {}'.format(len(XRayTest_dataset)))
77 | print('-------------------------------------')
78 |
79 | # make the dataloaders
80 | batch_size = args.bs # 128 by default
81 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
82 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True)
83 | test_loader = torch.utils.data.DataLoader(XRayTest_dataset, batch_size = batch_size, shuffle = not True)
84 |
85 | print('\n-----Initial Batchloaders Information -----')
86 | print('num batches in train_loader: {}'.format(len(train_loader)))
87 | print('num batches in val_loader : {}'.format(len(val_loader)))
88 | print('num batches in test_loader : {}'.format(len(test_loader)))
89 | print('-------------------------------------------')
90 |
91 | # sanity check
92 | if len(XRayTrain_dataset.all_classes) != 15: # 15 is the unique number of diseases in this dataset
93 | q('\nnumber of classes not equal to 15 !')
94 |
95 | a,b = train_dataset[0]
96 | print('\nwe are working with \nImages shape: {} and \nTarget shape: {}'.format( a.shape, b.shape))
97 |
98 | # make models directory, where the models and the loss plots will be saved
99 | if not os.path.exists(config.models_dir):
100 | os.mkdir(config.models_dir)
101 |
102 | # define the loss function
103 | if args.loss_func == 'FocalLoss': # by default
104 | from losses import FocalLoss
105 | loss_fn = FocalLoss(device = device, gamma = 2.).to(device)
106 | elif args.loss_func == 'BCE':
107 | loss_fn = nn.BCEWithLogitsLoss().to(device)
108 |
109 | # define the learning rate
110 | lr = args.lr
111 |
112 | if not args.test: # training
113 |
114 | # initialize the model if not args.resume
115 | if not args.resume:
116 | print('\ntraining from scratch')
117 | # import pretrained model
118 | model = models.resnet50(pretrained=True) # pretrained = False bydefault
119 | # change the last linear layer
120 | num_ftrs = model.fc.in_features
121 | model.fc = nn.Linear(num_ftrs, len(XRayTrain_dataset.all_classes)) # 15 output classes
122 | model.to(device)
123 |
124 | print('----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
125 | for name, param in model.named_parameters(): # all requires_grad by default, are True initially
126 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
127 | if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
128 | param.requires_grad = True
129 | else:
130 | param.requires_grad = False
131 |
132 | # since we are not resuming the training of the model
133 | epochs_till_now = 0
134 |
135 | # making empty lists to collect all the losses
136 | losses_dict = {'epoch_train_loss': [], 'epoch_val_loss': [], 'total_train_loss_list': [], 'total_val_loss_list': []}
137 |
138 | else:
139 | if args.ckpt == None:
140 | q('ERROR: Please select a valid checkpoint to resume from')
141 |
142 | print('\nckpt loaded: {}'.format(args.ckpt))
143 | ckpt = torch.load(os.path.join(config.models_dir, args.ckpt))
144 |
145 | # since we are resuming the training of the model
146 | epochs_till_now = ckpt['epochs']
147 | model = ckpt['model']
148 | model.to(device)
149 |
150 | # loading previous loss lists to collect future losses
151 | losses_dict = ckpt['losses_dict']
152 |
153 | # printing some hyperparameters
154 | print('\n> loss_fn: {}'.format(loss_fn))
155 | print('> epochs_till_now: {}'.format(epochs_till_now))
156 | print('> batch_size: {}'.format(batch_size))
157 | print('> stage: {}'.format(stage))
158 | print('> lr: {}'.format(lr))
159 |
160 | else: # testing
161 | if args.ckpt == None:
162 | q('ERROR: Please select a checkpoint to load the testing model from')
163 |
164 | print('\ncheckpoint loaded: {}'.format(args.ckpt))
165 | ckpt = torch.load(os.path.join(config.models_dir, args.ckpt))
166 |
167 | # since we are resuming the training of the model
168 | epochs_till_now = ckpt['epochs']
169 | model = ckpt['model']
170 |
171 | # loading previous loss lists to collect future losses
172 | losses_dict = ckpt['losses_dict']
173 |
174 | # make changes(freezing/unfreezing the model's layers) in the following, for training the model for different stages
175 | if (not args.test) and (args.resume):
176 |
177 | if stage == 1:
178 |
179 | print('\n----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
180 | for name, param in model.named_parameters(): # all requires_grad by default, are True initially
181 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
182 | if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
183 | param.requires_grad = True
184 | else:
185 | param.requires_grad = False
186 |
187 | elif stage == 2:
188 |
189 | print('\n----- STAGE 2 -----') # only training 'layer3', 'layer4' and 'fc'
190 | for name, param in model.named_parameters():
191 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
192 | if ('layer3' in name) or ('layer4' in name) or ('fc' in name):
193 | param.requires_grad = True
194 | else:
195 | param.requires_grad = False
196 |
197 | elif stage == 3:
198 |
199 | print('\n----- STAGE 3 -----') # only training 'layer4' and 'fc'
200 | for name, param in model.named_parameters():
201 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
202 | if ('layer4' in name) or ('fc' in name):
203 | param.requires_grad = True
204 | else:
205 | param.requires_grad = False
206 |
207 | elif stage == 4:
208 |
209 | print('\n----- STAGE 4 -----') # only training 'fc'
210 | for name, param in model.named_parameters():
211 | # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
212 | if ('fc' in name):
213 | param.requires_grad = True
214 | else:
215 | param.requires_grad = False
216 |
217 |
218 | if not args.test:
219 | # checking the layers which are going to be trained (irrespective of args.resume)
220 | trainable_layers = []
221 | for name, param in model.named_parameters():
222 | if param.requires_grad == True:
223 | layer_name = str.split(name, '.')[0]
224 | if layer_name not in trainable_layers:
225 | trainable_layers.append(layer_name)
226 | print('\nfollowing are the trainable layers...')
227 | print(trainable_layers)
228 |
229 | print('\nwe have {} Million trainable parameters here in the {} model'.format(count_parameters(model), model.__class__.__name__))
230 |
231 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)
232 |
233 | # make changes in the parameters of the following 'fit' function
234 | fit(device, XRayTrain_dataset, train_loader, val_loader,
235 | test_loader, model, loss_fn,
236 | optimizer, losses_dict,
237 | epochs_till_now = epochs_till_now, epochs = 3,
238 | log_interval = 25, save_interval = 1,
239 | lr = lr, bs = batch_size, stage = stage,
240 | test_only = args.test)
241 |
242 | script_time = time.time() - script_start_time
243 | m, s = divmod(script_time, 60)
244 | h, m = divmod(m, 60)
245 | print('{} h {}m laga poore script me !'.format(int(h), int(m)))
246 |
247 | # '''
248 | # This is how the model is trained...
249 | # ##### STAGE 1 ##### FocalLoss lr = 1e-5
250 | # training layers = layer2, layer3, layer4, fc
251 | # epochs = 2
252 | # ##### STAGE 2 ##### FocalLoss lr = 3e-4
253 | # training layers = layer3, layer4, fc
254 | # epochs = 5
255 | # ##### STAGE 3 ##### FocalLoss lr = 7e-4
256 | # training layers = layer4, fc
257 | # epochs = 4
258 | # ##### STAGE 4 ##### FocalLoss lr = 1e-3
259 | # training layers = fc
260 | # epochs = 3
261 | # '''
262 |
--------------------------------------------------------------------------------
/sample_xrays/Atelectasis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Atelectasis.png
--------------------------------------------------------------------------------
/sample_xrays/Cardiomegaly_Edema_Effusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Cardiomegaly_Edema_Effusion.png
--------------------------------------------------------------------------------
/sample_xrays/Effusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Effusion.png
--------------------------------------------------------------------------------
/sample_xrays/Fibrosis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/Fibrosis.png
--------------------------------------------------------------------------------
/sample_xrays/No Finding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/0489269a518c9bc6580dcd80eea0c79ee92a8269/sample_xrays/No Finding.png
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 |
5 | import sys, os, time, random, pdb
6 | import numpy as np
7 | import pandas as pd
8 | import torch.nn.functional as F
9 | import torch
10 | import pickle
11 | import tqdm, pdb
12 | from sklearn.metrics import roc_auc_score
13 |
14 | import config
15 |
16 | def get_roc_auc_score(y_true, y_probs):
17 | '''
18 | Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs.
19 | '''
20 |
21 | with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'rb') as handle:
22 | all_classes = pickle.load(handle)
23 |
24 | NoFindingIndex = all_classes.index('No Finding')
25 |
26 | if True:
27 | print('\nNoFindingIndex: ', NoFindingIndex)
28 | print('y_true.shape, y_probs.shape ', y_true.shape, y_probs.shape)
29 | GT_and_probs = {'y_true': y_true, 'y_probs': y_probs}
30 | with open('GT_and_probs', 'wb') as handle:
31 | pickle.dump(GT_and_probs, handle, protocol = pickle.HIGHEST_PROTOCOL)
32 |
33 | class_roc_auc_list = []
34 | useful_classes_roc_auc_list = []
35 |
36 | for i in range(y_true.shape[1]):
37 | class_roc_auc = roc_auc_score(y_true[:, i], y_probs[:, i])
38 | class_roc_auc_list.append(class_roc_auc)
39 | if i != NoFindingIndex:
40 | useful_classes_roc_auc_list.append(class_roc_auc)
41 | if True:
42 | print('\nclass_roc_auc_list: ', class_roc_auc_list)
43 | print('\nuseful_classes_roc_auc_list', useful_classes_roc_auc_list)
44 |
45 | return np.mean(np.array(useful_classes_roc_auc_list))
46 |
47 | def make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name):
48 | '''
49 | This function makes the following 4 different plots-
50 | 1. mean train loss VS number of epochs
51 | 2. mean val loss VS number of epochs
52 | 3. batch train loss for all the training batches VS number of batches
53 | 4. batch val loss for all the validation batches VS number of batches
54 | '''
55 | fig = plt.figure(figsize=(16,16))
56 | fig.suptitle('loss trends', fontsize=20)
57 | ax1 = fig.add_subplot(221)
58 | ax2 = fig.add_subplot(222)
59 | ax3 = fig.add_subplot(223)
60 | ax4 = fig.add_subplot(224)
61 |
62 | ax1.title.set_text('epoch train loss VS #epochs')
63 | ax1.set_xlabel('#epochs')
64 | ax1.set_ylabel('epoch train loss')
65 | ax1.plot(epoch_train_loss)
66 |
67 | ax2.title.set_text('epoch val loss VS #epochs')
68 | ax2.set_xlabel('#epochs')
69 | ax2.set_ylabel('epoch val loss')
70 | ax2.plot(epoch_val_loss)
71 |
72 | ax3.title.set_text('batch train loss VS #batches')
73 | ax3.set_xlabel('#batches')
74 | ax3.set_ylabel('batch train loss')
75 | ax3.plot(total_train_loss_list)
76 |
77 | ax4.title.set_text('batch val loss VS #batches')
78 | ax4.set_xlabel('#batches')
79 | ax4.set_ylabel('batch val loss')
80 | ax4.plot(total_val_loss_list)
81 |
82 | plt.savefig(os.path.join(config.models_dir,'losses_{}.png'.format(save_name)))
83 |
84 | def get_resampled_train_val_dataloaders(XRayTrain_dataset, transform, bs):
85 | '''
86 | Resamples the XRaysTrainDataset class object and returns a training and a validation dataloaders, by splitting the sampled dataset in 80-20 ratio.
87 | '''
88 | XRayTrain_dataset.resample()
89 |
90 | train_percentage = 0.8
91 | train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)])
92 |
93 | print('\n-----Resampled Dataset Information-----')
94 | print('num images in train_dataset : {}'.format(len(train_dataset)))
95 | print('num images in val_dataset : {}'.format(len(val_dataset)))
96 | print('---------------------------------------')
97 |
98 | # make dataloaders
99 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = bs, shuffle = True)
100 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = bs, shuffle = not True)
101 |
102 | print('\n-----Resampled Batchloaders Information -----')
103 | print('num batches in train_loader: {}'.format(len(train_loader)))
104 | print('num batches in val_loader : {}'.format(len(val_loader)))
105 | print('---------------------------------------------\n')
106 |
107 | return train_loader, val_loader
108 |
109 | def train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval):
110 | '''
111 | Takes in the data from the 'train_loader', calculates the loss over it using the 'loss_fn'
112 | and optimizes the 'model' using the 'optimizer'
113 |
114 | Also prints the loss and the ROC AUC score for the batches, after every 'log_interval' batches.
115 | '''
116 | model.train()
117 |
118 | running_train_loss = 0
119 | train_loss_list = []
120 |
121 | start_time = time.time()
122 | for batch_idx, (img, target) in enumerate(train_loader):
123 | # print(type(img), img.shape) # , np.unique(img))
124 |
125 | img = img.to(device)
126 | target = target.to(device)
127 |
128 | optimizer.zero_grad()
129 | out = model(img)
130 | loss = loss_fn(out, target)
131 | running_train_loss += loss.item()*img.shape[0]
132 | train_loss_list.append(loss.item())
133 |
134 | loss.backward()
135 | optimizer.step()
136 |
137 | if (batch_idx+1)%log_interval == 0:
138 | # batch metric evaluation
139 | # # out_detached = out.detach()
140 | # # batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy())
141 | # 'out' is a torch.Tensor and 'roc_auc_score' function first tries to convert it into a numpy array, but since 'out' has requires_grad = True, it throws an error
142 | # RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
143 | # so we have to 'detach' the 'out' tensor and then convert it into a numpy array to avoid the error !
144 |
145 | batch_time = time.time() - start_time
146 | m, s = divmod(batch_time, 60)
147 | print('Train Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(train_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
148 |
149 | start_time = time.time()
150 |
151 | return train_loss_list, running_train_loss/float(len(train_loader.dataset))
152 |
153 | def val_epoch(device, val_loader, model, loss_fn, epochs_till_now = None, final_epoch = None, log_interval = 1, test_only = False):
154 | '''
155 | It essentially takes in the val_loader/test_loader, the model and the loss function and evaluates
156 | the loss and the ROC AUC score for all the data in the dataloader.
157 |
158 | It also prints the loss and the ROC AUC score for every 'log_interval'th batch, only when 'test_only' is False
159 | '''
160 | model.eval()
161 |
162 | running_val_loss = 0
163 | val_loss_list = []
164 | val_loader_examples_num = len(val_loader.dataset)
165 |
166 | probs = np.zeros((val_loader_examples_num, 15), dtype = np.float32)
167 | gt = np.zeros((val_loader_examples_num, 15), dtype = np.float32)
168 | k=0
169 |
170 | with torch.no_grad():
171 | batch_start_time = time.time()
172 | for batch_idx, (img, target) in enumerate(val_loader):
173 | if test_only:
174 | per = ((batch_idx+1)/len(val_loader))*100
175 | a_, b_ = divmod(per, 1)
176 | print(f'{str(batch_idx+1).zfill(len(str(len(val_loader))))}/{str(len(val_loader)).zfill(len(str(len(val_loader))))} ({str(int(a_)).zfill(2)}.{str(int(100*b_)).zfill(2)} %)', end = '\r')
177 | # print(type(img), img.shape) # , np.unique(img))
178 |
179 | img = img.to(device)
180 | target = target.to(device)
181 |
182 | out = model(img)
183 | loss = loss_fn(out, target)
184 | running_val_loss += loss.item()*img.shape[0]
185 | val_loss_list.append(loss.item())
186 |
187 | # storing model predictions for metric evaluat`ion
188 | probs[k: k + out.shape[0], :] = out.cpu()
189 | gt[ k: k + out.shape[0], :] = target.cpu()
190 | k += out.shape[0]
191 |
192 | if ((batch_idx+1)%log_interval == 0) and (not test_only): # only when ((batch_idx + 1) is divisible by log_interval) and (when test_only = False)
193 | # batch metric evaluation
194 | # batch_roc_auc_score = get_roc_auc_score(target, out)
195 |
196 | batch_time = time.time() - batch_start_time
197 | m, s = divmod(batch_time, 60)
198 | print('Val Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(val_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
199 |
200 | batch_start_time = time.time()
201 |
202 | # metric scenes
203 | roc_auc = get_roc_auc_score(gt, probs)
204 |
205 | return val_loss_list, running_val_loss/float(len(val_loader.dataset)), roc_auc
206 |
207 | def fit(device, XRayTrain_dataset, train_loader, val_loader, test_loader, model,
208 | loss_fn, optimizer, losses_dict,
209 | epochs_till_now, epochs,
210 | log_interval, save_interval,
211 | lr, bs, stage, test_only = False):
212 | '''
213 | Trains or Tests the 'model' on the given 'train_loader', 'val_loader', 'test_loader' for 'epochs' number of epochs.
214 | If training ('test_only' = False), it saves the optimized 'model' and the loss plots ,after every 'save_interval'th epoch.
215 | '''
216 | epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list = losses_dict['epoch_train_loss'], losses_dict['epoch_val_loss'], losses_dict['total_train_loss_list'], losses_dict['total_val_loss_list']
217 |
218 | final_epoch = epochs_till_now + epochs
219 |
220 | if test_only:
221 | print('\n======= Testing... =======\n')
222 | test_start_time = time.time()
223 | test_loss, mean_running_test_loss, test_roc_auc = val_epoch(device, test_loader, model, loss_fn, log_interval, test_only = test_only)
224 | total_test_time = time.time() - test_start_time
225 | m, s = divmod(total_test_time, 60)
226 | print('test_roc_auc: {} in {} mins {} secs'.format(test_roc_auc, int(m), int(s)))
227 | sys.exit()
228 |
229 | starting_epoch = epochs_till_now
230 | print('\n======= Training after epoch #{}... =======\n'.format(epochs_till_now))
231 |
232 | # epoch_train_loss = []
233 | # epoch_val_loss = []
234 |
235 | # total_train_loss_list = []
236 | # total_val_loss_list = []
237 |
238 | for epoch in range(epochs):
239 |
240 | if starting_epoch != epochs_till_now:
241 | # resample the train_loader and val_loader
242 | train_loader, val_loader = get_resampled_train_val_dataloaders(XRayTrain_dataset, config.transform, bs = bs)
243 |
244 | epochs_till_now += 1
245 | print('============ EPOCH {}/{} ============'.format(epochs_till_now, final_epoch))
246 | epoch_start_time = time.time()
247 |
248 | print('TRAINING')
249 | train_loss, mean_running_train_loss = train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval)
250 | print('VALIDATION')
251 | val_loss, mean_running_val_loss, roc_auc = val_epoch(device, val_loader, model, loss_fn , epochs_till_now, final_epoch, log_interval)
252 |
253 | epoch_train_loss.append(mean_running_train_loss)
254 | epoch_val_loss.append(mean_running_val_loss)
255 |
256 | total_train_loss_list.extend(train_loss)
257 | total_val_loss_list.extend(val_loss)
258 |
259 | save_name = 'stage{}_{}_{}'.format(stage, str.split(str(lr), '.')[-1], str(epochs_till_now).zfill(2))
260 |
261 | # the follwoing piece of codw needs to be worked on !!! LATEST DEVELOPMENT TILL HERE
262 | if ((epoch+1)%save_interval == 0) or test_only:
263 | save_path = os.path.join(config.models_dir, '{}.pth'.format(save_name))
264 |
265 | torch.save({
266 | 'epochs': epochs_till_now,
267 | 'model': model, # it saves the whole model
268 | 'losses_dict': {'epoch_train_loss': epoch_train_loss, 'epoch_val_loss': epoch_val_loss, 'total_train_loss_list': total_train_loss_list, 'total_val_loss_list': total_val_loss_list}
269 | }, save_path)
270 |
271 | print('\ncheckpoint {} saved'.format(save_path))
272 |
273 | make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name)
274 | print('loss plots saved !!!')
275 |
276 | print('\nTRAIN LOSS : {}'.format(mean_running_train_loss))
277 | print('VAL LOSS : {}'.format(mean_running_val_loss))
278 | print('VAL ROC_AUC: {}'.format(roc_auc))
279 |
280 | total_epoch_time = time.time() - epoch_start_time
281 | m, s = divmod(total_epoch_time, 60)
282 | h, m = divmod(m, 60)
283 | print('\nEpoch {}/{} took {} h {} m'.format(epochs_till_now, final_epoch, int(h), int(m)))
284 |
285 |
286 |
287 | '''
288 | def pred_n_write(test_loader, model, save_name):
289 | res = np.zeros((3000, 15), dtype = np.float32)
290 | k=0
291 | for batch_idx, img in tqdm.tqdm(enumerate(test_loader)):
292 | model.eval()
293 | with torch.no_grad():
294 | pred = torch.sigmoid(model(img))
295 | # print(k)
296 | res[k: k + pred.shape[0], :] = pred
297 | k += pred.shape[0]
298 |
299 | # write csv
300 | print('populating the csv')
301 | submit = pd.DataFrame()
302 | submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list]
303 | with open('disease_classes.pickle', 'rb') as handle:
304 | disease_classes = pickle.load(handle)
305 |
306 | for idx, col in enumerate(disease_classes):
307 | if col == 'Hernia':
308 | submit['Hern'] = res[:, idx]
309 | elif col == 'Pleural_Thickening':
310 | submit['Pleural_thickening'] = res[:, idx]
311 | elif col == 'No Finding':
312 | submit['No_findings'] = res[:, idx]
313 | else:
314 | submit[col] = res[:, idx]
315 | rand_num = str(random.randint(1000, 9999))
316 | csv_name = '{}___{}.csv'.format(save_name, rand_num)
317 | submit.to_csv('res/' + csv_name, index = False)
318 | print('{} saved !'.format(csv_name))
319 | '''
320 |
--------------------------------------------------------------------------------