├── LICENSE ├── README.md ├── examples_csv ├── test_multiclass_patches.csv └── train_WSI_multilabel.csv ├── test ├── testing_WSI.py └── testing_patches.py └── train ├── train.py └── train_MoCo_HE_adversarial_loss.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ilmaro8 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiple_Instance_Learning 2 | Implementation of Multiple Instance Learning instance-based CNNs for histopathology images classification. 3 | 4 | ## Reference 5 | If you find this repository useful in your research, please cite: N.Marini et al. (2022). "Unleashing the potential of digital pathology data by training computer-aided diagnosis models without human annotations" 6 | 7 | Paper link: https://www.nature.com/articles/s41746-022-00635-4 8 | 9 | ## Requirements 10 | Python==3.6.9, albumentations==0.1.8, numpy==1.17.3, opencv==4.2.0, pandas==0.25.2, pillow==6.1.0, torchvision==0.8.1, pytorch==1.7.0 11 | 12 | ## Best models 13 | - The best models for the Multiple Instance Learning CNN is available [here](https://drive.google.com/drive/folders/1-b3YJyJyydxMQPihVGGhQY15lrSH8dJo?usp=sharing). 14 | 15 | ## Datasets 16 | ### Private datasets 17 | Two private datasets are used for training the CNNs: 18 | - AOEC 19 | - Radboudumc 20 | ### Publicly available datasets: 21 | Six publicly available datasets are used for testing the CNNs: 22 | - [GlaS](https://warwick.ac.uk/fac/cross_fac/tia/data/glascontest/) 23 | - [CRC](https://warwick.ac.uk/fac/cross_fac/tia/data/crc_grading/) 24 | - [UNITPATHO](https://ieee-dataport.org/open-access/unitopatho) 25 | - [TCGA-COAD](https://portal.gdc.cancer.gov/projects/TCGA-COAD) 26 | - [Xu](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-017-1685-x) 27 | - [AIDA](https://datahub.aida.scilifelab.se/10.23698/aida/drco) 28 | - [IMP-CRC](https://www.nature.com/articles/s41598-021-93746-z#data-availability). 29 | 30 | ## Pre-Processing 31 | The WSIs are split in 224x224 pixels patches, from magnification 10x. 32 | The methods used to extract the patches come from [Multi_Scale_Tools library](https://github.com/sara-nl/multi-scale-tools) 33 | 34 | The method is in the /preprocessing folder of the Multi_Scale_Tools library: 35 | - python Patch_Extractor_Dense_Grid.py -m 10 -w 1.25 -p 10 -r True -s 224 -x 0.7 -y 0 -i /PATH/CSV/IMAGES/TO/EXTRACT.csv -t /PATH/TISSUE/MASKS/TO/USE/ -o /FOLDER/WHERE/TO/STORE/THE/PATCHES/ 36 | 37 | More info: https://www.frontiersin.org/articles/10.3389/fcomp.2021.684521/full 38 | 39 | ## CSV Input Files: 40 | CSV files are used as input for the scripts. The csvs have the following structures 41 | - For each partition (train, validation, test), the csv file has id_img, cancer, high-grade dysplasia, low-grade dysplasia, hyperplastic polyp, normal glands as column. 42 | - For the patches used as test set, the csv file has path_path, label as column. 43 | 44 | ## Training 45 | Script to train the CNN at WSI-level, using an instance-based MIL CNN: 46 | - python train.py -c resnet34 -b 512 -p att -e 10 -t multilabel -f True -i /PATH/WHERE/TO/FIND/THE/CSVS/INCLUDING/THE/PARTITIONS -o /PATH/WHERE/TO/SAVE/THE/MODEL/WEIGHTS -w /PATH/WHERE/TO/FIND/THE/PATCHES 47 | 48 | Script to pre-train the CNN using MoCo: python train_MoCo_HE_adversarial_loss.py -c resnet34 -b 512 -p att -e 10 -t multilabel -f True -l 0.001 -i /PATH/WHERE/TO/FIND/THE/CSVS/INCLUDING/THE/PARTITIONS -o /PATH/WHERE/TO/SAVE/THE/MODEL/WEIGHTS -w /PATH/WHERE/TO/FIND/THE/PATCHES 49 | 50 | ## Testing 51 | ### WSI-level 52 | Script to test the CNN at WSI-level. 53 | - python testing_WSI.py -c resnet34 -b 512 -p att -t multilabel -f True -m /PATH/TO/MODEL/WEIGHTS.pt -i /PATH/TO/INPUT/CSV.csv -w /PATH/WHERE/TO/FIND/THE/PATCHES 54 | 55 | ### patch-level 56 | - python testing_patches.py -c resnet34 -b 512 -p att -t multilabel -f True -m /PATH/TO/MODEL/WEIGHTS.pt -i /PATH/TO/INPUT/CSV.csv 57 | 58 | ## Acknoledgements 59 | This project has received funding from the EuropeanUnion’s Horizon 2020 research and innovation programme under grant agree-ment No. 825292 [ExaMode](http://www.examode.eu). Infrastructure fromthe SURFsara HPC center was used to train the CNN models in parallel. Otálora thanks Minciencias through the call 756 for PhD studies. 60 | -------------------------------------------------------------------------------- /examples_csv/test_multiclass_patches.csv: -------------------------------------------------------------------------------- 1 | PATH_IMG,class 2 | path0,0 3 | path1,1 4 | path2,2 5 | path3,3 6 | path4,4 7 | 8 | -------------------------------------------------------------------------------- /examples_csv/train_WSI_multilabel.csv: -------------------------------------------------------------------------------- 1 | ID_IMG,class0,class1,class2,class3,class4 2 | id0,0,0,1,0,0 3 | id1,1,1,0,0,0 4 | id2,0,1,1,0,1 5 | id3,0,0,0,0,1 6 | -------------------------------------------------------------------------------- /test/testing_WSI.py: -------------------------------------------------------------------------------- 1 | import sys, getopt 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | import albumentations as A 8 | import time 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | from matplotlib.pyplot import imshow 12 | import torch.utils.data 13 | from sklearn import metrics 14 | import os 15 | import argparse 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | args = sys.argv[1:] 21 | 22 | print("CUDA current device " + str(torch.cuda.current_device())) 23 | print("CUDA devices available " + str(torch.cuda.device_count())) 24 | 25 | #parser parameters 26 | parser = argparse.ArgumentParser(description='Configurations to train models.') 27 | parser.add_argument('-c', '--CNN', help='cnn architecture to use',type=str, default='resnet34') 28 | parser.add_argument('-b', '--BATCH_SIZE', help='batch_size',type=int, default=512) 29 | parser.add_argument('-p', '--pool', help='pooling algorithm',type=str, default='att') 30 | parser.add_argument('-t', '--TASK', help='task (binary/multilabel)',type=str, default='resnet34') 31 | parser.add_argument('-f', '--features', help='features_to_use: embedding (True) or features from CNN (False)',type=bool, default=True) 32 | parser.add_argument('-m', '--model', help='path of the model to load',type=str, default='./model/') 33 | parser.add_argument('-i', '--input', help='path of input csv',type=str, default='./model/') 34 | parser.add_argument('-w', '--wsi_folder', help='path where WSIs are stored',type=str, default='./images/') 35 | 36 | args = parser.parse_args() 37 | 38 | CNN_TO_USE = args.CNN 39 | BATCH_SIZE = args.BATCH_SIZE 40 | BATCH_SIZE_str = str(BATCH_SIZE) 41 | pool_algorithm = args.pool 42 | TASK = args.TASK 43 | EMBEDDING_bool = args.features 44 | INPUT_DATA = args.input 45 | MODEL_PATH = args.model 46 | WSI_FOLDER = args.wsi_folder 47 | 48 | print("PARAMETERS") 49 | print("TASK: " + str(TASK)) 50 | print("CNN used: " + str(CNN_TO_USE)) 51 | print("POOLING ALGORITHM: " + str(pool_algorithm)) 52 | print("BATCH_SIZE: " + str(BATCH_SIZE_str)) 53 | 54 | #create folder (used for saving weights) 55 | def create_dir(models_path): 56 | if not os.path.isdir(models_path): 57 | try: 58 | os.mkdir(models_path) 59 | except OSError: 60 | print ("Creation of the directory %s failed" % models_path) 61 | else: 62 | print ("Successfully created the directory %s " % models_path) 63 | 64 | def generate_list_instances(filename): 65 | 66 | instance_dir = WSI_FOLDER 67 | fname = os.path.split(filename)[-1] 68 | 69 | instance_csv = instance_dir+fname+'/'+fname+'_paths_densely.csv' 70 | 71 | return instance_csv 72 | 73 | checkpoint_path = MODEL_PATH+'checkpoints_MIL/' 74 | create_dir(checkpoint_path) 75 | 76 | #path model file 77 | model_weights_filename = MODEL_PATH 78 | 79 | 80 | print("CSV LOADING ") 81 | csv_filename_testing = INPUT_DATA 82 | #read data 83 | test_dataset = pd.read_csv(csv_filename_testing, sep=',', header=None).values 84 | 85 | #MODEL DEFINITION 86 | pre_trained_network = torch.hub.load('pytorch/vision:v0.4.2', CNN_TO_USE, pretrained=True) 87 | if (('resnet' in CNN_TO_USE) or ('resnext' in CNN_TO_USE)): 88 | fc_input_features = pre_trained_network.fc.in_features 89 | elif (('densenet' in CNN_TO_USE)): 90 | fc_input_features = pre_trained_network.classifier.in_features 91 | elif ('mobilenet' in CNN_TO_USE): 92 | fc_input_features = pre_trained_network.classifier[1].in_features 93 | 94 | 95 | class MIL_model(torch.nn.Module): 96 | def __init__(self): 97 | """ 98 | In the constructor we instantiate two nn.Linear modules and assign them as 99 | member variables. 100 | """ 101 | super(MIL_model, self).__init__() 102 | self.conv_layers = torch.nn.Sequential(*list(pre_trained_network.children())[:-1]) 103 | 104 | if (torch.cuda.device_count()>1): 105 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 106 | 107 | self.fc_feat_in = fc_input_features 108 | self.N_CLASSES = N_CLASSES 109 | 110 | if (EMBEDDING_bool==True): 111 | 112 | if ('resnet34' in CNN_TO_USE): 113 | self.E = 128 114 | self.L = self.E 115 | self.D = 64 116 | self.K = self.N_CLASSES 117 | 118 | elif ('resnet50' in CNN_TO_USE): 119 | self.E = 256 120 | self.L = self.E 121 | self.D = 128 122 | self.K = self.N_CLASSES 123 | 124 | 125 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 126 | self.embedding_fc = torch.nn.Linear(in_features=self.E, out_features=self.N_CLASSES) 127 | 128 | else: 129 | self.fc = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.N_CLASSES) 130 | 131 | if ('resnet34' in CNN_TO_USE): 132 | self.L = fc_input_features 133 | self.D = 128 134 | self.K = self.N_CLASSES 135 | 136 | elif ('resnet50' in CNN_TO_USE): 137 | self.L = self.E 138 | self.D = 256 139 | self.K = self.N_CLASSES 140 | 141 | if (pool_algorithm=='att'): 142 | 143 | self.attention = torch.nn.Sequential( 144 | torch.nn.Linear(self.L, self.D), 145 | torch.nn.Tanh(), 146 | torch.nn.Linear(self.D, self.K) 147 | ) 148 | 149 | self.tanh = torch.nn.Tanh() 150 | self.relu = torch.nn.ReLU() 151 | 152 | def forward(self, x, conv_layers_out): 153 | """ 154 | In the forward function we accept a Tensor of input data and we must return 155 | a Tensor of output data. We can use Modules defined in the constructor as 156 | well as arbitrary operators on Tensors. 157 | """ 158 | #if used attention pooling 159 | A = None 160 | #m = torch.nn.Softmax(dim=1) 161 | m_binary = torch.nn.Sigmoid() 162 | m_multiclass = torch.nn.Softmax() 163 | dropout = torch.nn.Dropout(p=0.2) 164 | 165 | if x is not None: 166 | #print(x.shape) 167 | conv_layers_out=self.conv_layers(x) 168 | #print(x.shape) 169 | 170 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 171 | 172 | #print(conv_layers_out.shape) 173 | 174 | if ('mobilenet' in CNN_TO_USE): 175 | dropout = torch.nn.Dropout(p=0.2) 176 | conv_layers_out = dropout(conv_layers_out) 177 | #print(conv_layers_out.shape) 178 | 179 | if (EMBEDDING_bool==True): 180 | #conv_layers_out = self.tanh(conv_layers_out) 181 | embedding_layer = self.embedding(conv_layers_out) 182 | #embedding_layer = self.tanh(embedding_layer) 183 | embedding_layer = self.relu(embedding_layer) 184 | features_to_return = embedding_layer 185 | 186 | #embedding_layer = self.tanh(embedding_layer) 187 | 188 | embedding_layer = dropout(embedding_layer) 189 | logits = self.embedding_fc(embedding_layer) 190 | 191 | else: 192 | logits = self.fc(conv_layers_out) 193 | features_to_return = conv_layers_out 194 | #print(output_fcn.shape) 195 | 196 | 197 | #print(output_fcn.size()) 198 | 199 | if (EMBEDDING_bool==True): 200 | A = self.attention(features_to_return) 201 | else: 202 | A = self.attention(conv_layers_out) # NxK 203 | 204 | #print(A.size()) 205 | #print(A) 206 | A = F.softmax(A, dim=0) # softmax over N 207 | #print(A.size()) 208 | #print(A) 209 | #A = A.view(-1, A.size()[0]) 210 | #print(A) 211 | 212 | output_pool = (logits * A).sum(dim = 0) #/ (A).sum(dim = 0) 213 | #print(output_pool.size()) 214 | #print(output_pool) 215 | #output_pool = torch.clamp(output_pool, 1e-7, 1 - 1e-7) 216 | 217 | output_fcn = m_multiclass(logits) 218 | output_pool = F.sigmoid(output_pool) 219 | return output_pool, output_fcn, A, features_to_return 220 | 221 | def accuracy_micro(y_true, y_pred): 222 | 223 | y_true_flatten = y_true.flatten() 224 | y_pred_flatten = y_pred.flatten() 225 | 226 | return metrics.accuracy_score(y_true_flatten, y_pred_flatten) 227 | 228 | 229 | def accuracy_macro(y_true, y_pred): 230 | 231 | n_classes = len(y_true[0]) 232 | 233 | acc_tot = 0.0 234 | 235 | for i in range(n_classes): 236 | 237 | acc = metrics.accuracy_score(y_true[:,i], y_pred[:,i]) 238 | #print(acc) 239 | acc_tot = acc_tot + acc 240 | 241 | acc_tot = acc_tot/n_classes 242 | 243 | return acc_tot 244 | 245 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 246 | 247 | from torchvision import transforms 248 | preprocess = transforms.Compose([ 249 | transforms.ToTensor(), 250 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 251 | ]) 252 | 253 | class Dataset_instance(data.Dataset): 254 | 255 | def __init__(self, list_IDs, partition): 256 | self.list_IDs = list_IDs 257 | self.set = partition 258 | 259 | def __len__(self): 260 | return len(self.list_IDs) 261 | 262 | def __getitem__(self, index): 263 | # Select sample 264 | ID = self.list_IDs[index][0] 265 | # Load data and get label 266 | img = Image.open(ID) 267 | X = np.asarray(img) 268 | img.close() 269 | #data transformation 270 | input_tensor = preprocess(X).type(torch.FloatTensor) 271 | 272 | #return input_tensor 273 | return input_tensor 274 | 275 | class Dataset_bag(data.Dataset): 276 | 277 | def __init__(self, list_IDs, labels): 278 | 279 | self.labels = labels 280 | self.list_IDs = list_IDs 281 | 282 | def __len__(self): 283 | 284 | return len(self.list_IDs) 285 | 286 | def __getitem__(self, index): 287 | # Select sample 288 | ID = self.list_IDs[index] 289 | 290 | # Load data and get label 291 | instances_filename = generate_list_instances(ID) 292 | y = self.labels[index] 293 | if (TASK=='binary' and N_CLASSES==1): 294 | y = np.asarray(y) 295 | else: 296 | y = torch.tensor(y.tolist() , dtype=torch.float32) 297 | 298 | 299 | return instances_filename, y 300 | 301 | batch_size_bag = 1 302 | 303 | params_test_bag = {'batch_size': batch_size_bag, 304 | 'shuffle': True} 305 | 306 | if (TASK=='binary' and N_CLASSES==1): 307 | testing_set_bag = Dataset_bag(test_dataset[:,0], test_dataset[:,1]) 308 | testing_generator_bag = data.DataLoader(testing_set_bag, **params_test_bag) 309 | else: 310 | testing_set_bag = Dataset_bag(test_dataset[:,0], test_dataset[:,1:]) 311 | testing_generator_bag = data.DataLoader(testing_set_bag, **params_test_bag) 312 | 313 | print("testing") 314 | print("testing at WSI level") 315 | y_pred = [] 316 | y_true = [] 317 | 318 | model = torch.load(model_weights_filename) 319 | model.to(device) 320 | model.eval() 321 | 322 | kappa_score_general_filename = checkpoint_path+'kappa_score_general_'+TASK+'.csv' 323 | acc_balanced_filename = checkpoint_path+'acc_balanced_general_'+TASK+'.csv' 324 | acc_filename = checkpoint_path+'acc_general_'+TASK+'.csv' 325 | acc_macro_filename = checkpoint_path+'acc_macro_general_'+TASK+'.csv' 326 | acc_micro_filename = checkpoint_path+'acc_micro_general_'+TASK+'.csv' 327 | confusion_matrix_filename = checkpoint_path+'conf_matr_general_'+TASK+'.csv' 328 | roc_auc_filename = checkpoint_path+'roc_auc_general_'+TASK+'.csv' 329 | f1_score_macro_filename = checkpoint_path+'f1_macro_'+TASK+'.csv' 330 | f1_score_micro_filename = checkpoint_path+'f1_micro_'+TASK+'.csv' 331 | hamming_loss_filename = checkpoint_path+'hamming_loss_general_'+TASK+'.csv' 332 | recall_score_macro_filename = checkpoint_path+'recall_score_macro_general_'+TASK+'.csv' 333 | recall_score_micro_filename = checkpoint_path+'recall_score_micro_general_'+TASK+'.csv' 334 | jaccard_score_macro_filename = checkpoint_path+'jaccard_score_macro_'+TASK+'.csv' 335 | jaccard_score_micro_filename = checkpoint_path+'jaccard_score_micro_'+TASK+'.csv' 336 | roc_auc_score_macro_filename = checkpoint_path+'roc_auc_score_macro_general_'+TASK+'.csv' 337 | roc_auc_score_micro_filename = checkpoint_path+'roc_auc_score_micro_general_'+TASK+'.csv' 338 | precision_score_macro_filename = checkpoint_path+'precision_score_macro_general_'+TASK+'.csv' 339 | precision_score_micro_filename = checkpoint_path+'precision_score_micro_general_'+TASK+'.csv' 340 | auc_score_filename = checkpoint_path+'auc_score_general_'+TASK+'.csv' 341 | 342 | def save_metric(filename,value): 343 | array = [value] 344 | File = {'val':array} 345 | df = pd.DataFrame(File,columns=['val']) 346 | np.savetxt(filename, df.values, fmt='%s',delimiter=',') 347 | 348 | filenames_wsis = [] 349 | pred_cancers = [] 350 | pred_hgd = [] 351 | pred_lgd = [] 352 | pred_hyper = [] 353 | pred_normal = [] 354 | 355 | with torch.no_grad(): 356 | j = 0 357 | for inputs_bag,labels in testing_generator_bag: 358 | #inputs: bags, labels: labels of the bags 359 | labels_np = labels.cpu().data.numpy() 360 | len_bag = len(labels_np) 361 | 362 | #list of bags 363 | print("inputs_bag " + str(inputs_bag)) 364 | 365 | filename_wsi = inputs_bag[0].split('/')[-2] 366 | 367 | inputs_bag = list(inputs_bag) 368 | 369 | try: 370 | 371 | for b in range(len_bag): 372 | labs = [] 373 | labs.append(labels_np[b]) 374 | labs = np.array(labs).flatten() 375 | 376 | labels = torch.tensor(labs).float().to(device) 377 | 378 | #read csv with instances 379 | csv_instances = pd.read_csv(inputs_bag[b], sep=',', header=None).values 380 | #number of instances 381 | n_elems = len(csv_instances) 382 | 383 | #params generator instances 384 | batch_size_instance = int(BATCH_SIZE_str) 385 | 386 | num_workers = 4 387 | params_instance = {'batch_size': batch_size_instance, 388 | 'shuffle': True, 389 | 'num_workers': num_workers} 390 | 391 | #generator for instances 392 | instances = Dataset_instance(csv_instances,'valid') 393 | validation_generator_instance = data.DataLoader(instances, **params_instance) 394 | 395 | features = [] 396 | with torch.no_grad(): 397 | for instances in validation_generator_instance: 398 | instances = instances.to(device) 399 | 400 | # forward + backward + optimize 401 | feats = model.conv_layers(instances) 402 | feats = feats.view(-1, fc_input_features) 403 | feats_np = feats.cpu().data.numpy() 404 | 405 | features = np.append(features,feats_np) 406 | 407 | #del instances 408 | 409 | features_np = np.reshape(features,(n_elems,fc_input_features)) 410 | 411 | del features, feats 412 | 413 | inputs = torch.tensor(features_np).float().to(device) 414 | 415 | predictions, _, _, _ = model(None, inputs) 416 | 417 | outputs_np = predictions.cpu().data.numpy() 418 | labels_np = labels.cpu().data.numpy() 419 | 420 | filenames_wsis = np.append(filenames_wsis,filename_wsi) 421 | pred_cancers = np.append(pred_cancers,outputs_np[0]) 422 | pred_hgd = np.append(pred_hgd,outputs_np[1]) 423 | pred_lgd = np.append(pred_lgd,outputs_np[2]) 424 | pred_hyper = np.append(pred_hyper,outputs_np[3]) 425 | pred_normal = np.append(pred_normal,outputs_np[4]) 426 | 427 | #print(outputs_np,labels_np) 428 | print("["+str(j)+"/"+str(len(test_dataset))+"]") 429 | print("output: "+str(outputs_np)) 430 | print("ground truth:" + str(labels_np)) 431 | outputs_np = np.where(outputs_np > 0.5, 1, 0) 432 | 433 | torch.cuda.empty_cache() 434 | 435 | y_pred = np.append(y_pred,outputs_np) 436 | y_true = np.append(y_true,labels_np) 437 | 438 | j = j + 1 439 | 440 | except: 441 | 442 | pass 443 | 444 | 445 | filename_training_predictions = checkpoint_path+'WSI_predictions_AOEC.csv' 446 | 447 | File = {'filenames':filenames_wsis, 'pred_cancers':pred_cancers, 'pred_hgd':pred_hgd,'pred_lgd':pred_lgd, 'pred_hyper':pred_hyper, 'pred_normal':pred_normal} 448 | 449 | df = pd.DataFrame(File,columns=['filenames','pred_cancers','pred_hgd','pred_lgd','pred_hyper','pred_normal']) 450 | np.savetxt(filename_training_predictions, df.values, fmt='%s',delimiter=',') 451 | 452 | y_pred = np.reshape(y_pred,(j,N_CLASSES)) 453 | y_true = np.reshape(y_true,(j,N_CLASSES)) 454 | 455 | try: 456 | accuracy_score = metrics.accuracy_score(y_true=y_true, y_pred=y_pred) 457 | print("accuracy_score : " + str(accuracy_score)) 458 | save_metric(acc_filename,accuracy_score) 459 | except: 460 | pass 461 | 462 | try: 463 | accuracy_macro_score = accuracy_macro(y_true=y_true, y_pred=y_pred) 464 | print("accuracy_macro_score : " + str(accuracy_macro_score)) 465 | save_metric(acc_macro_filename,accuracy_macro_score) 466 | except: 467 | pass 468 | 469 | try: 470 | accuracy_micro_score = accuracy_micro(y_true=y_true, y_pred=y_pred) 471 | print("accuracy_micro_score : " + str(accuracy_micro_score)) 472 | save_metric(acc_micro_filename,accuracy_micro_score) 473 | except: 474 | pass 475 | 476 | try: 477 | hamming_loss = metrics.hamming_loss(y_true=y_true, y_pred=y_pred, sample_weight=None) 478 | print("hamming_loss : " + str(hamming_loss)) 479 | save_metric(hamming_loss_filename,hamming_loss) 480 | except: 481 | pass 482 | 483 | try: 484 | zero_one_loss = metrics.zero_one_loss(y_true=y_true, y_pred=y_pred) 485 | print("zero_one_loss : " + str(zero_one_loss)) 486 | except: 487 | pass 488 | 489 | try: 490 | multilabel_confusion_matrix = metrics.multilabel_confusion_matrix(y_true=y_true, y_pred=y_pred) 491 | print("multilabel_confusion_matrix: ") 492 | print(multilabel_confusion_matrix) 493 | save_metric(confusion_matrix_filename,multilabel_confusion_matrix) 494 | except: 495 | pass 496 | 497 | try: 498 | target_names = ['cancer', 'hgd', 'lgd', 'hyper'] 499 | classification_report = metrics.classification_report(y_true, y_pred, target_names=target_names) 500 | print("classification_report: ") 501 | print(classification_report) 502 | except: 503 | pass 504 | 505 | try: 506 | jaccard_score_macro = metrics.jaccard_score(y_true=y_true, y_pred=y_pred, average='macro') 507 | jaccard_score_micro = metrics.jaccard_score(y_true=y_true, y_pred=y_pred, average='micro') 508 | print("jaccard_score_macro : " + str(jaccard_score_macro)) 509 | print("jaccard_score_micro : " + str(jaccard_score_micro)) 510 | save_metric(jaccard_score_macro_filename,jaccard_score_macro) 511 | save_metric(jaccard_score_micro_filename,jaccard_score_micro) 512 | except: 513 | pass 514 | 515 | try: 516 | f1_score_macro = metrics.f1_score(y_true=y_true, y_pred=y_pred, average='macro') 517 | f1_score_micro = metrics.f1_score(y_true=y_true, y_pred=y_pred, average='micro') 518 | print("f1_score_macro : " + str(f1_score_macro)) 519 | print("f1_score_micro : " + str(f1_score_micro)) 520 | save_metric(f1_score_macro_filename,f1_score_macro) 521 | save_metric(f1_score_micro_filename,f1_score_micro) 522 | except: 523 | pass 524 | 525 | try: 526 | recall_score_macro = metrics.recall_score(y_true=y_true, y_pred=y_pred, average='macro') 527 | recall_score_micro = metrics.recall_score(y_true=y_true, y_pred=y_pred, average='micro') 528 | print("recall_score_macro : " + str(recall_score_macro)) 529 | print("recall_score_micro : " + str(recall_score_micro)) 530 | save_metric(recall_score_macro_filename,recall_score_macro) 531 | save_metric(recall_score_micro_filename,recall_score_micro) 532 | except: 533 | pass 534 | 535 | try: 536 | precision_score_macro = metrics.precision_score(y_true=y_true, y_pred=y_pred, average='macro') 537 | precision_score_micro = metrics.precision_score(y_true=y_true, y_pred=y_pred, average='micro') 538 | print("precision_score_macro : " + str(precision_score_macro)) 539 | print("precision_score_micro : " + str(precision_score_micro)) 540 | save_metric(precision_score_macro_filename,precision_score_macro) 541 | save_metric(precision_score_micro_filename,precision_score_micro) 542 | except: 543 | pass 544 | 545 | try: 546 | roc_auc_score_macro = metrics.roc_auc_score(y_true=y_true, y_score=y_pred, average='macro') 547 | roc_auc_score_micro = metrics.roc_auc_score(y_true=y_true, y_score=y_pred, average='micro') 548 | print("roc_auc_score_macro : " + str(roc_auc_score_macro)) 549 | print("roc_auc_score_micro : " + str(roc_auc_score_micro)) 550 | save_metric(roc_auc_score_macro_filename,roc_auc_score_macro) 551 | save_metric(roc_auc_score_micro_filename,roc_auc_score_macro) 552 | except: 553 | pass 554 | -------------------------------------------------------------------------------- /test/testing_patches.py: -------------------------------------------------------------------------------- 1 | import sys, getopt 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | import albumentations as A 8 | import time 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | from matplotlib.pyplot import imshow 12 | import torch.utils.data 13 | from sklearn import metrics 14 | import os 15 | import argparse 16 | 17 | args = sys.argv[1:] 18 | 19 | import warnings 20 | warnings.filterwarnings("ignore") 21 | 22 | print("CUDA current device " + str(torch.cuda.current_device())) 23 | print("CUDA devices available " + str(torch.cuda.device_count())) 24 | 25 | #parser parameters 26 | parser = argparse.ArgumentParser(description='Configurations to train models.') 27 | parser.add_argument('-c', '--CNN', help='cnn architecture to use',type=str, default='resnet34') 28 | parser.add_argument('-b', '--BATCH_SIZE', help='batch_size',type=int, default=512) 29 | parser.add_argument('-p', '--pool', help='pooling algorithm',type=str, default='att') 30 | parser.add_argument('-t', '--TASK', help='task (binary/multilabel)',type=str, default='resnet34') 31 | parser.add_argument('-f', '--features', help='features_to_use: embedding (True) or features from CNN (False)',type=bool, default=True) 32 | parser.add_argument('-m', '--model', help='path of the model to load',type=str, default='./model/') 33 | parser.add_argument('-i', '--input', help='path of input csv',type=str, default='./model/') 34 | 35 | args = parser.parse_args() 36 | 37 | CNN_TO_USE = args.CNN 38 | BATCH_SIZE = args.BATCH_SIZE 39 | BATCH_SIZE_str = str(BATCH_SIZE) 40 | pool_algorithm = args.pool 41 | TASK = args.TASK 42 | EMBEDDING_bool = args.features 43 | INPUT_DATA = args.input 44 | MODEL_PATH = args.model 45 | 46 | 47 | print("PARAMETERS") 48 | print("TASK: " + str(TASK)) 49 | print("CNN used: " + str(CNN_TO_USE)) 50 | print("POOLING ALGORITHM: " + str(pool_algorithm)) 51 | print("BATCH_SIZE: " + str(BATCH_SIZE_str)) 52 | 53 | #create folder (used for saving weights) 54 | def create_dir(models_path): 55 | if not os.path.isdir(models_path): 56 | try: 57 | os.mkdir(models_path) 58 | except OSError: 59 | print ("Creation of the directory %s failed" % models_path) 60 | else: 61 | print ("Successfully created the directory %s " % models_path) 62 | 63 | #DIRECTORIES CREATION 64 | 65 | checkpoint_path = MODEL_PATH+'checkpoints_MIL/' 66 | create_dir(checkpoint_path) 67 | 68 | #path model file 69 | model_weights_filename = MODEL_PATH 70 | 71 | 72 | print("CSV LOADING ") 73 | csv_filename_testing = INPUT_DATA 74 | 75 | 76 | #read data 77 | test_dataset = pd.read_csv(csv_filename_testing, sep=',', header=None).values 78 | 79 | #MODEL DEFINITION 80 | pre_trained_network = torch.hub.load('pytorch/vision:v0.4.2', CNN_TO_USE, pretrained=True) 81 | 82 | if (('resnet' in CNN_TO_USE) or ('resnext' in CNN_TO_USE)): 83 | fc_input_features = pre_trained_network.fc.in_features 84 | elif (('densenet' in CNN_TO_USE)): 85 | fc_input_features = pre_trained_network.classifier.in_features 86 | elif ('mobilenet' in CNN_TO_USE): 87 | fc_input_features = pre_trained_network.classifier[1].in_features 88 | 89 | class MIL_model(torch.nn.Module): 90 | def __init__(self): 91 | """ 92 | In the constructor we instantiate two nn.Linear modules and assign them as 93 | member variables. 94 | """ 95 | super(MIL_model, self).__init__() 96 | self.conv_layers = torch.nn.Sequential(*list(pre_trained_network.children())[:-1]) 97 | #self.conv_layers = siamese_model.conv_layers 98 | """ 99 | if (torch.cuda.device_count()>1): 100 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 101 | """ 102 | self.fc_feat_in = fc_input_features 103 | self.N_CLASSES = N_CLASSES 104 | 105 | if (EMBEDDING_bool==True): 106 | 107 | if ('resnet18' in CNN_TO_USE): 108 | self.E = 128 109 | self.L = self.E 110 | self.D = 64 111 | self.K = self.N_CLASSES 112 | 113 | elif ('resnet34' in CNN_TO_USE): 114 | self.E = 128 115 | self.L = self.E 116 | self.D = 64 117 | self.K = self.N_CLASSES 118 | #self.K = 1 119 | elif ('resnet50' in CNN_TO_USE): 120 | self.E = 256 121 | self.L = self.E 122 | self.D = 128 123 | self.K = self.N_CLASSES 124 | 125 | #self.embedding = siamese_model.embedding 126 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 127 | self.embedding_fc = torch.nn.Linear(in_features=self.E, out_features=self.N_CLASSES) 128 | 129 | else: 130 | self.fc = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.N_CLASSES) 131 | 132 | if ('resnet18' in CNN_TO_USE): 133 | self.L = fc_input_features 134 | self.D = 128 135 | self.K = self.N_CLASSES 136 | 137 | elif ('resnet34' in CNN_TO_USE): 138 | self.L = fc_input_features 139 | self.D = 128 140 | self.K = self.N_CLASSES 141 | 142 | elif ('resnet50' in CNN_TO_USE): 143 | self.L = self.E 144 | self.D = 256 145 | self.K = self.N_CLASSES 146 | 147 | if (pool_algorithm=='att'): 148 | 149 | self.attention = torch.nn.Sequential( 150 | torch.nn.Linear(self.L, self.D), 151 | torch.nn.Tanh(), 152 | torch.nn.Linear(self.D, self.K) 153 | ) 154 | 155 | self.tanh = torch.nn.Tanh() 156 | self.relu = torch.nn.ReLU() 157 | 158 | def forward(self, x, conv_layers_out): 159 | """ 160 | In the forward function we accept a Tensor of input data and we must return 161 | a Tensor of output data. We can use Modules defined in the constructor as 162 | well as arbitrary operators on Tensors. 163 | """ 164 | #if used attention pooling 165 | A = None 166 | #m = torch.nn.Softmax(dim=1) 167 | m_binary = torch.nn.Sigmoid() 168 | m_multiclass = torch.nn.Softmax() 169 | dropout = torch.nn.Dropout(p=0.2) 170 | 171 | 172 | if x is not None: 173 | #print(x.shape) 174 | conv_layers_out=self.conv_layers(x) 175 | #print(x.shape) 176 | 177 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 178 | 179 | #print(conv_layers_out.shape) 180 | 181 | if ('mobilenet' in CNN_TO_USE): 182 | dropout = torch.nn.Dropout(p=0.2) 183 | conv_layers_out = dropout(conv_layers_out) 184 | #print(conv_layers_out.shape) 185 | 186 | if (EMBEDDING_bool==True): 187 | #conv_layers_out = self.tanh(conv_layers_out) 188 | embedding_layer = self.embedding(conv_layers_out) 189 | #embedding_layer = self.tanh(embedding_layer) 190 | embedding_layer = self.relu(embedding_layer) 191 | features_to_return = embedding_layer 192 | 193 | #embedding_layer = self.tanh(embedding_layer) 194 | 195 | embedding_layer = dropout(embedding_layer) 196 | logits = self.embedding_fc(embedding_layer) 197 | 198 | else: 199 | logits = self.fc(conv_layers_out) 200 | features_to_return = conv_layers_out 201 | #print(output_fcn.shape) 202 | 203 | 204 | #print(output_fcn.size()) 205 | 206 | if (EMBEDDING_bool==True): 207 | A = self.attention(features_to_return) 208 | else: 209 | A = self.attention(conv_layers_out) # NxK 210 | 211 | #print(A.size()) 212 | #print(A) 213 | A = F.softmax(A, dim=0) # softmax over N 214 | #print(A.size()) 215 | #print(A) 216 | #A = A.view(-1, A.size()[0]) 217 | #print(A) 218 | 219 | output_pool = (logits * A).sum(dim = 0) #/ (A).sum(dim = 0) 220 | #print(output_pool.size()) 221 | #print(output_pool) 222 | #output_pool = torch.clamp(output_pool, 1e-7, 1 - 1e-7) 223 | 224 | output_fcn = m_multiclass(logits) 225 | return output_pool, output_fcn, A, features_to_return 226 | 227 | 228 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 229 | 230 | model = torch.load(model_weights_filename) 231 | model.to(device) 232 | model.eval() 233 | 234 | from torchvision import transforms 235 | preprocess = transforms.Compose([ 236 | transforms.ToTensor(), 237 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 238 | ]) 239 | 240 | class Dataset_test_strong(data.Dataset): 241 | 242 | def __init__(self, list_IDs, labels): 243 | 244 | self.labels = labels 245 | self.list_IDs = list_IDs 246 | 247 | def __len__(self): 248 | 249 | return len(self.list_IDs) 250 | 251 | def __getitem__(self, index): 252 | 253 | # Select sample 254 | ID = self.list_IDs[index] 255 | # Load data and get label 256 | X = Image.open(ID) 257 | X = np.asarray(X) 258 | y = self.labels[index] 259 | #data augmentation 260 | #geometrical 261 | 262 | #data transformation 263 | input_tensor = preprocess(X).type(torch.FloatTensor) 264 | 265 | return input_tensor, np.asarray([y]), ID 266 | 267 | #read data 268 | test_dataset = pd.read_csv(csv_filename_testing, sep=',', header=None).values 269 | 270 | params_test = {'batch_size': int(BATCH_SIZE_str), 271 | #'shuffle': True, 272 | #'sampler': ImbalancedDatasetSampler(train_dataset), 273 | 'num_workers': 2} 274 | 275 | testing_set_strong = Dataset_test_strong(test_dataset[:,0], test_dataset[:,1]) 276 | testing_generator_strong = data.DataLoader(testing_set_strong, **params_test) 277 | 278 | y_pred = [] 279 | y_true = [] 280 | 281 | filenames = [] 282 | outputs_store = [] 283 | cumulative_labels = [] 284 | 285 | with torch.no_grad(): 286 | for inputs, labels, filename in testing_generator_strong: 287 | inputs, labels = inputs.to(device), labels.type(torch.FloatTensor).to(device) 288 | 289 | _, outputs, _, _ = model(inputs, None) 290 | 291 | #accumulate values 292 | outputs_np = outputs.cpu().data.numpy() 293 | labels_np = labels.cpu().data.numpy() 294 | 295 | filenames = np.append(filenames,filename) 296 | outputs_store = np.append(outputs_store,outputs_np) 297 | cumulative_labels = np.append(cumulative_labels,labels_np) 298 | 299 | outputs_np = np.argmax(outputs_np, axis=1) 300 | 301 | y_pred = np.append(y_pred,outputs_np) 302 | y_true = np.append(y_true,labels_np) 303 | 304 | #k-score 305 | k_score = metrics.cohen_kappa_score(y_true,y_pred, weights='quadratic') 306 | print("k_score " + str(k_score)) 307 | #f1_scrre 308 | f1_score = metrics.f1_score(y_true, y_pred, average='macro') 309 | print("f1_score " + str(f1_score)) 310 | #confusion matrix 311 | confusion_matrix = metrics.confusion_matrix(y_true=y_true, y_pred=y_pred) 312 | print("confusion_matrix ") 313 | print(str(confusion_matrix)) 314 | acc_balanced = metrics.balanced_accuracy_score(y_true, y_pred, sample_weight=None, adjusted=False) 315 | print("acc_balanced " + str(acc_balanced)) 316 | try: 317 | roc_auc_score = metrics.roc_auc_score(y_true, y_pred) 318 | print("roc_auc " + str(roc_auc_score)) 319 | except: 320 | pass 321 | 322 | 323 | kappa_score_general_filename = checkpoint_path+'kappa_score_general_multiclass_strong.csv' 324 | acc_balanced_filename = checkpoint_path+'acc_balanced_general_multiclass_strong.csv' 325 | confusion_matrix_filename = checkpoint_path+'conf_matr_general_multiclass_strong.csv' 326 | roc_auc_score_filename = checkpoint_path+'roc_auc_score_general_multiclass_strong.csv' 327 | f1_score_filename = checkpoint_path+'f1_score_general_multiclass_strong.csv' 328 | 329 | kappas = [k_score] 330 | 331 | File = {'val':kappas} 332 | df = pd.DataFrame(File,columns=['val']) 333 | np.savetxt(kappa_score_general_filename, df.values, fmt='%s',delimiter=',') 334 | 335 | f1_scores = [f1_score] 336 | 337 | File = {'val':f1_scores} 338 | df = pd.DataFrame(File,columns=['val']) 339 | np.savetxt(f1_score_filename, df.values, fmt='%s',delimiter=',') 340 | 341 | acc_balancs = [acc_balanced] 342 | 343 | File = {'val':acc_balancs} 344 | df = pd.DataFrame(File,columns=['val']) 345 | np.savetxt(acc_balanced_filename, df.values, fmt='%s',delimiter=',') 346 | 347 | conf_matr = [confusion_matrix] 348 | File = {'val':conf_matr} 349 | df = pd.DataFrame(File,columns=['val']) 350 | np.savetxt(confusion_matrix_filename, df.values, fmt='%s',delimiter=',') 351 | 352 | try: 353 | roc_auc = [roc_auc_score] 354 | File = {'val':roc_auc} 355 | df = pd.DataFrame(File,columns=['val']) 356 | np.savetxt(roc_auc_score_filename, df.values, fmt='%s',delimiter=',') 357 | except: 358 | pass 359 | 360 | predictions_radboudc_filename = checkpoint_path+'predictions_raw_aoec.csv' 361 | 362 | outputs_store = np.reshape(outputs_store,(len(test_dataset),5)) 363 | 364 | #print(outputs_store) 365 | 366 | pred_cancer = [row[0] for row in outputs_store] 367 | pred_hgd = [row[1] for row in outputs_store] 368 | pred_lgd = [row[2] for row in outputs_store] 369 | pred_hyper = [row[3] for row in outputs_store] 370 | pred_normal = [row[4] for row in outputs_store] 371 | 372 | File = {'filenames':filenames,'labels':cumulative_labels,'cancer':pred_cancer,'HGD':pred_hgd,'LGD':pred_lgd,'Hyper':pred_hyper,'Normal':pred_normal} 373 | df = pd.DataFrame(File,columns=['filenames','labels','cancer','HGD','LGD','Hyper','Normal']) 374 | np.savetxt(predictions_radboudc_filename, df.values, fmt='%s',delimiter=',') -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import sys, getopt 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | import albumentations as A 8 | import time 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | from matplotlib.pyplot import imshow 12 | import torch.utils.data 13 | from sklearn import metrics 14 | import os 15 | from sklearn.decomposition import PCA 16 | from sklearn.metrics import silhouette_score 17 | from scipy.spatial import KDTree, cKDTree 18 | from sklearn.cluster import MiniBatchKMeans, KMeans, MeanShift, AffinityPropagation, AgglomerativeClustering 19 | from sklearn import metrics 20 | from scipy.stats import entropy 21 | #from topk import SmoothTop1SVM 22 | import argparse 23 | import warnings 24 | warnings.filterwarnings("ignore") 25 | 26 | argv = sys.argv[1:] 27 | 28 | print("CUDA current device " + str(torch.cuda.current_device())) 29 | print("CUDA devices available " + str(torch.cuda.device_count())) 30 | 31 | if torch.cuda.is_available(): 32 | device = torch.device("cuda") 33 | print("working on gpu") 34 | else: 35 | device = torch.device("cpu") 36 | print("working on cpu") 37 | print(torch.backends.cudnn.version()) 38 | 39 | #parser parameters 40 | parser = argparse.ArgumentParser(description='Configurations to train models.') 41 | parser.add_argument('-c', '--CNN', help='cnn architecture to use',type=str, default='resnet34') 42 | parser.add_argument('-b', '--BATCH_SIZE', help='batch_size',type=int, default=512) 43 | parser.add_argument('-p', '--pool', help='pooling algorithm',type=str, default='att') 44 | parser.add_argument('-e', '--EPOCHS', help='epochs to train',type=int, default=10) 45 | parser.add_argument('-t', '--TASK', help='task (binary/multilabel)',type=str, default='multilabel') 46 | parser.add_argument('-f', '--features', help='features_to_use: embedding (True) or features from CNN (False)',type=bool, default=True) 47 | parser.add_argument('-i', '--input_folder', help='path of the folder where train.csv and valid.csv are stored',type=str, default='./partition/') 48 | parser.add_argument('-o', '--output_folder', help='path where to store the model weights',type=str, default='./models/') 49 | parser.add_argument('-w', '--wsi_folder', help='path where WSIs are stored',type=str, default='./images/') 50 | 51 | args = parser.parse_args() 52 | 53 | CNN_TO_USE = args.CNN 54 | BATCH_SIZE = args.BATCH_SIZE 55 | BATCH_SIZE_str = str(BATCH_SIZE) 56 | pool_algorithm = args.pool 57 | EPOCHS = args.EPOCHS 58 | EPOCHS_str = EPOCHS 59 | TASK = args.TASK 60 | EMBEDDING_bool = args.features 61 | INPUT_FOLDER = args.input_folder 62 | OUTPUT_FOLDER = args.output_folder 63 | WSI_FOLDER = args.wsi_folder 64 | 65 | seed = 0 66 | 67 | torch.manual_seed(seed) 68 | if torch.cuda.is_available(): 69 | torch.cuda.manual_seed_all(seed) 70 | np.random.seed(seed) 71 | 72 | print("PARAMETERS") 73 | print("TASK: " + str(TASK)) 74 | print("N_EPOCHS: " + str(EPOCHS_str)) 75 | print("CNN used: " + str(CNN_TO_USE)) 76 | print("POOLING ALGORITHM: " + str(pool_algorithm)) 77 | print("BATCH_SIZE: " + str(BATCH_SIZE_str)) 78 | 79 | #create folder (used for saving weights) 80 | def create_dir(models_path): 81 | if not os.path.isdir(models_path): 82 | try: 83 | os.mkdir(models_path) 84 | except OSError: 85 | print ("Creation of the directory %s failed" % models_path) 86 | else: 87 | print ("Successfully created the directory %s " % models_path) 88 | 89 | def select_parameters_colour(): 90 | hue_min = -15 91 | hue_max = 8 92 | 93 | sat_min = -20 94 | sat_max = 10 95 | 96 | val_min = -8 97 | val_max = 8 98 | 99 | 100 | p1 = np.random.uniform(hue_min,hue_max,1) 101 | p2 = np.random.uniform(sat_min,sat_max,1) 102 | p3 = np.random.uniform(val_min,val_max,1) 103 | 104 | return p1[0],p2[0],p3[0] 105 | 106 | def select_rgb_shift(): 107 | r_min = -10 108 | r_max = 10 109 | 110 | g_min = -10 111 | g_max = 10 112 | 113 | b_min = -10 114 | b_max = 10 115 | 116 | 117 | p1 = np.random.uniform(r_min,r_max,1) 118 | p2 = np.random.uniform(g_min,g_max,1) 119 | p3 = np.random.uniform(b_min,b_max,1) 120 | 121 | return p1[0],p2[0],p3[0] 122 | 123 | def select_elastic_distorsion(): 124 | sigma_min = 0 125 | sigma_max = 20 126 | 127 | alpha_affine_min = -20 128 | alpha_affine_max = 20 129 | 130 | p1 = np.random.uniform(sigma_min,sigma_max,1) 131 | p2 = np.random.uniform(alpha_affine_min,alpha_affine_max,1) 132 | 133 | return p1[0],p2[0] 134 | 135 | def select_grid_distorsion(): 136 | dist_min = 0 137 | dist_max = 0.2 138 | 139 | p1 = np.random.uniform(dist_min,dist_max,1) 140 | 141 | return p1[0] 142 | 143 | def generate_transformer(label, prob = 0.5): 144 | list_operations = [] 145 | probas = np.random.rand(7) 146 | 147 | if (probas[0]>prob): 148 | #print("VerticalFlip") 149 | list_operations.append(A.VerticalFlip(always_apply=True)) 150 | if (probas[1]>prob): 151 | #print("HorizontalFlip") 152 | list_operations.append(A.HorizontalFlip(always_apply=True)) 153 | #""" 154 | if (probas[2]>prob): 155 | #print("RandomRotate90") 156 | #list_operations.append(A.RandomRotate90(always_apply=True)) 157 | 158 | p_rot = np.random.rand(1)[0] 159 | if (p_rot<=0.33): 160 | lim_rot = 90 161 | elif (p_rot>0.33 and p_rot<=0.66): 162 | lim_rot = 180 163 | else: 164 | lim_rot = 270 165 | list_operations.append(A.SafeRotate(always_apply=True, limit=(lim_rot,lim_rot+1e-4), interpolation=1, border_mode=4)) 166 | #""" 167 | """ 168 | if (probas[2]>prob): 169 | #print("RandomRotate90") 170 | list_operations.append(A.RandomRotate90(always_apply=True)) 171 | """ 172 | if (probas[3]>prob): 173 | #print("HueSaturationValue") 174 | p1, p2, p3 = select_parameters_colour() 175 | list_operations.append(A.HueSaturationValue(always_apply=True,hue_shift_limit=(p1,p1+1e-4),sat_shift_limit=(p2,p2+1e-4),val_shift_limit=(p3,p3+1e-4))) 176 | 177 | #print(p1,p2,p3) 178 | """ 179 | if (probas[4]>prob): 180 | p1, p2, p3 = select_rgb_shift() 181 | list_operations.append(A.RGBShift(r_shift_limit=(p1,p1+1e-4), g_shift_limit=(p2,p2+1e-4), b_shift_limit=(p3,p3+1e-4), always_apply=True)) 182 | #print(p1,p2,p3) 183 | """ 184 | 185 | if (np.array_equal(label,np.array([1,0,0,0,0])) or np.array_equal(label,np.array([0,1,0,0,0])) or np.array_equal(label,np.array([0,0,0,1,0])) or np.array_equal(label,np.array([1,1,0,0,0]))): 186 | 187 | if (probas[5]>prob): 188 | p1, p2 = select_elastic_distorsion() 189 | list_operations.append(A.ElasticTransform(alpha=1,border_mode=4, sigma=p1, alpha_affine=p2,always_apply=True)) 190 | #print(p1,p2) 191 | if (probas[6]>prob): 192 | p1 = select_grid_distorsion() 193 | list_operations.append(A.GridDistortion(num_steps=3, distort_limit=p1, interpolation=1, border_mode=4, always_apply=True)) 194 | #print(p1) 195 | 196 | pipeline_transform = A.Compose(list_operations) 197 | return pipeline_transform 198 | 199 | def generate_list_instances(filename): 200 | 201 | instance_dir = WSI_FOLDER 202 | fname = os.path.split(filename)[-1] 203 | 204 | instance_csv = instance_dir+fname+'/'+fname+'_paths_densely.csv' 205 | 206 | return instance_csv 207 | 208 | 209 | #DIRECTORIES CREATION 210 | print("CREATING/CHECKING DIRECTORIES") 211 | 212 | create_dir(OUTPUT_FOLDER) 213 | 214 | models_path = OUTPUT_FOLDER 215 | checkpoint_path = models_path+'checkpoints_MIL/' 216 | create_dir(checkpoint_path) 217 | 218 | #path model file 219 | model_weights_filename = models_path+'MIL_colon_'+TASK+'.pt' 220 | model_weights_filename_temporary = models_path+'MIL_colon_'+TASK+'_temporary.pt' 221 | 222 | #CSV LOADING 223 | print("CSV LOADING ") 224 | csv_folder = INPUT_FOLDER 225 | 226 | if (TASK=='binary'): 227 | 228 | N_CLASSES = 1 229 | #N_CLASSES = 2 230 | 231 | if (N_CLASSES==1): 232 | csv_filename_training = csv_folder+'train_binary.csv' 233 | csv_filename_validation = csv_folder+'valid_binary.csv' 234 | 235 | 236 | elif (TASK=='multilabel'): 237 | 238 | N_CLASSES = 5 239 | 240 | csv_filename_training = csv_folder+'train_multilabel.csv' 241 | csv_filename_validation = csv_folder+'valid_multilabel.csv' 242 | 243 | 244 | #read data 245 | train_dataset = pd.read_csv(csv_filename_training, sep=',', header=None).values#[:10] 246 | valid_dataset = pd.read_csv(csv_filename_validation, sep=',', header=None).values#[:10] 247 | 248 | class Balanced_Multimodal(torch.utils.data.sampler.Sampler): 249 | 250 | def __init__(self, dataset, indices=None, num_samples=None, alpha = 0.5): 251 | 252 | self.indices = list(range(len(dataset))) if indices is None else indices 253 | 254 | self.num_samples = len(self.indices) if num_samples is None else num_samples 255 | 256 | class_sample_count = [0,0,0,0,0] 257 | 258 | 259 | class_sample_count = np.sum(train_dataset[:,1:],axis=0) 260 | 261 | min_class = np.argmin(class_sample_count) 262 | class_sample_count = np.array(class_sample_count) 263 | weights = [] 264 | for c in class_sample_count: 265 | weights.append((c/class_sample_count[min_class])) 266 | 267 | ratio = np.array(weights).astype(np.float) 268 | 269 | label_to_count = {} 270 | for idx in self.indices: 271 | label = self._get_label(dataset, idx) 272 | for l in label: 273 | if l in label_to_count: 274 | label_to_count[l] += 1 275 | else: 276 | label_to_count[l] = 1 277 | 278 | weights = [] 279 | 280 | for idx in self.indices: 281 | c = 0 282 | for j, l in enumerate(self._get_label(dataset, idx)): 283 | c = c+(1/label_to_count[l])#*ratio[l] 284 | 285 | weights.append(c/(j+1)) 286 | #weights.append(c) 287 | 288 | self.weights_original = torch.DoubleTensor(weights) 289 | 290 | self.weights_uniform = np.repeat(1/self.num_samples, self.num_samples) 291 | 292 | #print(self.weights_a, self.weights_b) 293 | 294 | beta = 1 - alpha 295 | self.weights = (alpha * self.weights_original) + (beta * self.weights_uniform) 296 | 297 | 298 | def _get_label(self, dataset, idx): 299 | labels = np.where(dataset[idx,1:]==1)[0] 300 | #print(labels) 301 | #labels = dataset[idx,2] 302 | return labels 303 | 304 | def __iter__(self): 305 | return (self.indices[i] for i in torch.multinomial( 306 | self.weights, self.num_samples, replacement=True)) 307 | 308 | def __len__(self): 309 | return self.num_samples 310 | 311 | 312 | #MODEL DEFINITION 313 | pre_trained_network = torch.hub.load('pytorch/vision:v0.4.2', CNN_TO_USE, pretrained=True) 314 | 315 | if (('resnet' in CNN_TO_USE) or ('resnext' in CNN_TO_USE)): 316 | fc_input_features = pre_trained_network.fc.in_features 317 | elif (('densenet' in CNN_TO_USE)): 318 | fc_input_features = pre_trained_network.classifier.in_features 319 | elif ('mobilenet' in CNN_TO_USE): 320 | fc_input_features = pre_trained_network.classifier[1].in_features 321 | 322 | class MIL_model(torch.nn.Module): 323 | def __init__(self): 324 | """ 325 | In the constructor we instantiate two nn.Linear modules and assign them as 326 | member variables. 327 | """ 328 | super(MIL_model, self).__init__() 329 | self.conv_layers = torch.nn.Sequential(*list(pre_trained_network.children())[:-1]) 330 | #self.conv_layers = siamese_model.conv_layers 331 | """ 332 | if (torch.cuda.device_count()>1): 333 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 334 | """ 335 | self.fc_feat_in = fc_input_features 336 | self.N_CLASSES = N_CLASSES 337 | 338 | if (EMBEDDING_bool==True): 339 | 340 | if ('resnet18' in CNN_TO_USE): 341 | self.E = 128 342 | self.L = self.E 343 | self.D = 64 344 | self.K = self.N_CLASSES 345 | 346 | elif ('resnet34' in CNN_TO_USE): 347 | self.E = 128 348 | self.L = self.E 349 | self.D = 64 350 | self.K = self.N_CLASSES 351 | #self.K = 1 352 | elif ('resnet50' in CNN_TO_USE): 353 | self.E = 256 354 | self.L = self.E 355 | self.D = 128 356 | self.K = self.N_CLASSES 357 | 358 | #self.embedding = siamese_model.embedding 359 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 360 | self.embedding_fc = torch.nn.Linear(in_features=self.E, out_features=self.N_CLASSES) 361 | 362 | else: 363 | self.fc = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.N_CLASSES) 364 | 365 | if ('resnet18' in CNN_TO_USE): 366 | self.L = fc_input_features 367 | self.D = 128 368 | self.K = self.N_CLASSES 369 | 370 | elif ('resnet34' in CNN_TO_USE): 371 | self.L = fc_input_features 372 | self.D = 128 373 | self.K = self.N_CLASSES 374 | 375 | elif ('resnet50' in CNN_TO_USE): 376 | self.L = self.E 377 | self.D = 256 378 | self.K = self.N_CLASSES 379 | 380 | if (pool_algorithm=='att'): 381 | 382 | self.attention = torch.nn.Sequential( 383 | torch.nn.Linear(self.L, self.D), 384 | torch.nn.Tanh(), 385 | torch.nn.Linear(self.D, self.K) 386 | ) 387 | 388 | self.tanh = torch.nn.Tanh() 389 | self.relu = torch.nn.ReLU() 390 | 391 | def forward(self, x, conv_layers_out, labels_wsi_np): 392 | """ 393 | In the forward function we accept a Tensor of input data and we must return 394 | a Tensor of output data. We can use Modules defined in the constructor as 395 | well as arbitrary operators on Tensors. 396 | """ 397 | #if used attention pooling 398 | A = None 399 | #m = torch.nn.Softmax(dim=1) 400 | m_binary = torch.nn.Sigmoid() 401 | m_multiclass = torch.nn.Softmax() 402 | dropout = torch.nn.Dropout(p=0.2) 403 | 404 | self.labels = labels_wsi_np 405 | 406 | 407 | if x is not None: 408 | #print(x.shape) 409 | conv_layers_out=self.conv_layers(x) 410 | #print(x.shape) 411 | 412 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 413 | 414 | #print(conv_layers_out.shape) 415 | 416 | if ('mobilenet' in CNN_TO_USE): 417 | dropout = torch.nn.Dropout(p=0.2) 418 | conv_layers_out = dropout(conv_layers_out) 419 | #print(conv_layers_out.shape) 420 | 421 | if (EMBEDDING_bool==True): 422 | #conv_layers_out = self.tanh(conv_layers_out) 423 | embedding_layer = self.embedding(conv_layers_out) 424 | #embedding_layer = self.tanh(embedding_layer) 425 | embedding_layer = self.relu(embedding_layer) 426 | features_to_return = embedding_layer 427 | 428 | #embedding_layer = self.tanh(embedding_layer) 429 | 430 | embedding_layer = dropout(embedding_layer) 431 | logits = self.embedding_fc(embedding_layer) 432 | 433 | else: 434 | logits = self.fc(conv_layers_out) 435 | features_to_return = conv_layers_out 436 | #print(output_fcn.shape) 437 | 438 | 439 | #print(output_fcn.size()) 440 | 441 | if (EMBEDDING_bool==True): 442 | A = self.attention(features_to_return) 443 | else: 444 | A = self.attention(conv_layers_out) # NxK 445 | 446 | #print(A.size()) 447 | #print(A) 448 | A = F.softmax(A, dim=0) # softmax over N 449 | #print(A.size()) 450 | #print(A) 451 | #A = A.view(-1, A.size()[0]) 452 | #print(A) 453 | 454 | output_pool = (logits * A).sum(dim = 0) #/ (A).sum(dim = 0) 455 | #print(output_pool.size()) 456 | #print(output_pool) 457 | #output_pool = torch.clamp(output_pool, 1e-7, 1 - 1e-7) 458 | 459 | output_fcn = m_multiclass(logits) 460 | return output_pool, output_fcn, A, features_to_return 461 | 462 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 463 | model = MIL_model() 464 | model.to(device) 465 | 466 | from torchvision import transforms 467 | prob = 0.5 468 | pipeline_transform = A.Compose([ 469 | A.VerticalFlip(p=prob), 470 | A.HorizontalFlip(p=prob), 471 | A.RandomRotate90(p=prob), 472 | #A.ElasticTransform(alpha=0.1,p=prob), 473 | #A.HueSaturationValue(hue_shift_limit=(-15,8),sat_shift_limit=(-30,20),val_shift_limit=(-15,15),p=prob), 474 | ]) 475 | 476 | preprocess = transforms.Compose([ 477 | transforms.ToTensor(), 478 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 479 | ]) 480 | 481 | class Dataset_instance(data.Dataset): 482 | 483 | def __init__(self, list_IDs, partition, pipeline_transform): 484 | self.list_IDs = list_IDs 485 | self.set = partition 486 | self.pipeline_transform = pipeline_transform 487 | 488 | def __len__(self): 489 | return len(self.list_IDs) 490 | 491 | def __getitem__(self, index): 492 | # Select sample 493 | ID = self.list_IDs[index][0] 494 | # Load data and get label 495 | img = Image.open(ID) 496 | X = np.asarray(img) 497 | 498 | if (self.set == 'train'): 499 | #data augmentation 500 | X = self.pipeline_transform(image=X)['image'] 501 | #X = pipeline_transform(image=X)['image'] 502 | 503 | #data transformation 504 | input_tensor = preprocess(X).type(torch.FloatTensor) 505 | img.close() 506 | #return input_tensor 507 | return input_tensor 508 | 509 | class Dataset_bag(data.Dataset): 510 | 511 | def __init__(self, list_IDs, labels): 512 | 513 | self.labels = labels 514 | self.list_IDs = list_IDs 515 | 516 | def __len__(self): 517 | 518 | return len(self.list_IDs) 519 | 520 | def __getitem__(self, index): 521 | # Select sample 522 | ID = self.list_IDs[index] 523 | 524 | # Load data and get label 525 | instances_filename = generate_list_instances(ID) 526 | y = self.labels[index] 527 | if (TASK=='binary' and N_CLASSES==1): 528 | y = np.asarray(y) 529 | else: 530 | y = torch.tensor(y.tolist() , dtype=torch.float32) 531 | 532 | 533 | return instances_filename, y 534 | 535 | batch_size_bag = 1 536 | 537 | sampler = Balanced_Multimodal 538 | params_train_bag = {'batch_size': batch_size_bag_train, 539 | 'sampler': sampler(train_dataset,alpha=0.25)} 540 | #'shuffle': True} 541 | 542 | params_valid_bag = {'batch_size': batch_size_bag, 543 | 'shuffle': True} 544 | 545 | 546 | num_epochs = EPOCHS 547 | 548 | if (TASK=='binary' and N_CLASSES==1): 549 | training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1]) 550 | training_generator_bag = data.DataLoader(training_set_bag, **params_train_bag) 551 | 552 | validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1]) 553 | validation_generator_bag = data.DataLoader(validation_set_bag, **params_valid_bag) 554 | 555 | 556 | else: 557 | training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1:]) 558 | training_generator_bag = data.DataLoader(training_set_bag, **params_train_bag) 559 | 560 | validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1:]) 561 | validation_generator_bag = data.DataLoader(validation_set_bag, **params_valid_bag) 562 | 563 | 564 | # Find total parameters and trainable parameters 565 | total_params = sum(p.numel() for p in model.parameters()) 566 | print(f'{total_params:,} total parameters.') 567 | total_trainable_params = sum( 568 | p.numel() for p in model.parameters() if p.requires_grad) 569 | print(f'{total_trainable_params:,} training parameters.') 570 | 571 | criterion_wsi = torch.nn.BCELoss() 572 | 573 | 574 | import torch.optim as optim 575 | optimizer_str = 'adam' 576 | #optimizer_str = 'sgd' 577 | 578 | lr_str = '0.0001' 579 | 580 | wt_decay_str = '0.0001' 581 | 582 | lr = float(lr_str) 583 | wt_decay = float(wt_decay_str) 584 | 585 | if (optimizer_str == 'adam'): 586 | optimizer = optim.Adam(model.parameters(),lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=wt_decay, amsgrad=True) 587 | elif (optimizer_str == 'sgd'): 588 | optimizer = optim.SGD(model.parameters(),lr=lr, momentum=0.9, weight_decay=wt_decay, nesterov=True) 589 | 590 | 591 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 592 | 593 | def accuracy_micro(y_true, y_pred): 594 | 595 | y_true_flatten = y_true.flatten() 596 | y_pred_flatten = y_pred.flatten() 597 | 598 | return metrics.accuracy_score(y_true_flatten, y_pred_flatten) 599 | 600 | 601 | def accuracy_macro(y_true, y_pred): 602 | 603 | n_classes = len(y_true[0]) 604 | 605 | acc_tot = 0.0 606 | 607 | for i in range(n_classes): 608 | 609 | acc = metrics.accuracy_score(y_true[i,:], y_pred[i,:]) 610 | #print(acc) 611 | acc_tot = acc_tot + acc 612 | 613 | acc_tot = acc_tot/n_classes 614 | 615 | return acc_tot 616 | 617 | 618 | def evaluate_validation_set(model, epoch, generator): 619 | #accumulator for validation set 620 | y_pred_val = [] 621 | y_true_val = [] 622 | 623 | valid_loss = 0.0 624 | 625 | mode = 'valid' 626 | wsi_store_loss = 0.0 627 | patches_store_loss = 0.0 628 | i_p = 0 629 | bool_patches = False 630 | 631 | filenames_wsis = [] 632 | pred_cancers = [] 633 | pred_hgd = [] 634 | pred_lgd = [] 635 | pred_hyper = [] 636 | pred_normal = [] 637 | 638 | model.eval() 639 | 640 | iterations = len(valid_dataset) 641 | 642 | with torch.no_grad(): 643 | j = 0 644 | for inputs_bag,labels in generator: 645 | print('[%d], %d / %d ' % (epoch, j, iterations)) 646 | #inputs: bags, labels: labels of the bags 647 | labels_np = labels.cpu().data.numpy() 648 | len_bag = len(labels_np) 649 | 650 | #list of bags 651 | filename_wsi = os.path.split(inputs_bag[0])[1] 652 | print("inputs_bag " + str(filename_wsi)) 653 | inputs_bag = list(inputs_bag) 654 | 655 | for b in range(len_bag): 656 | labs = [] 657 | labs.append(labels_np[b]) 658 | labs = np.array(labs).flatten() 659 | 660 | labels = torch.tensor(labs).float().to(device) 661 | labels_wsi_np = labels.cpu().data.numpy() 662 | 663 | #read csv with instances 664 | csv_instances = pd.read_csv(inputs_bag[b], sep=',', header=None).values 665 | #number of instances 666 | n_elems = len(csv_instances) 667 | print("num_instances " + str(n_elems)) 668 | #params generator instances 669 | batch_size_instance = BATCH_SIZE 670 | 671 | num_workers = 4 672 | params_instance = {'batch_size': batch_size_instance, 673 | 'shuffle': True, 674 | 'num_workers': num_workers} 675 | 676 | #generator for instances 677 | instances = Dataset_instance(csv_instances,'valid',pipeline_transform) 678 | validation_generator_instance = data.DataLoader(instances, **params_instance) 679 | 680 | features = [] 681 | with torch.no_grad(): 682 | for instances in validation_generator_instance: 683 | instances = instances.to(device) 684 | 685 | # forward + backward + optimize 686 | feats = model.conv_layers(instances) 687 | feats = feats.view(-1, fc_input_features) 688 | feats_np = feats.cpu().data.numpy() 689 | 690 | features = np.append(features,feats_np) 691 | 692 | #del instances 693 | 694 | features_np = np.reshape(features,(n_elems,fc_input_features)) 695 | 696 | del features, feats 697 | 698 | inputs = torch.tensor(features_np, requires_grad=True).float().to(device) 699 | 700 | predictions, probs, attn_layer, embeddings = model(None, inputs, labels_wsi_np) 701 | 702 | loss = criterion_wsi(predictions, labels) 703 | 704 | #loss.backward() 705 | 706 | sigm_prediction = F.sigmoid(predictions) 707 | outputs_wsi_np = sigm_prediction.cpu().data.numpy() 708 | 709 | del probs, attn_layer, embeddings 710 | 711 | 712 | #optimizer.step() 713 | #model.zero_grad() 714 | 715 | wsi_store_loss = wsi_store_loss + ((1 / (j+1)) * (loss.item() - wsi_store_loss)) 716 | 717 | valid_loss = wsi_store_loss 718 | 719 | print('output wsi: '+str(outputs_wsi_np)+', label: '+ str(labels_wsi_np) +', loss_WSI: '+str(wsi_store_loss)) 720 | 721 | print(outputs_wsi_np,labels_np) 722 | 723 | torch.cuda.empty_cache() 724 | output_norm = np.where(outputs_wsi_np > 0.5, 1, 0) 725 | 726 | y_pred_val = np.append(y_pred_val,output_norm) 727 | y_true_val = np.append(y_true_val,labels_np) 728 | 729 | micro_accuracy_valid = accuracy_micro(y_true_val, y_pred_val) 730 | print("micro_accuracy " + str(micro_accuracy_valid)) 731 | 732 | if (N_CLASSES==5): 733 | 734 | filenames_wsis = np.append(filenames_wsis,filename_wsi) 735 | pred_cancers = np.append(pred_cancers,outputs_wsi_np[0]) 736 | pred_hgd = np.append(pred_hgd,outputs_wsi_np[1]) 737 | pred_lgd = np.append(pred_lgd,outputs_wsi_np[2]) 738 | pred_hyper = np.append(pred_hyper,outputs_wsi_np[3]) 739 | pred_normal = np.append(pred_normal,outputs_wsi_np[4]) 740 | 741 | else: 742 | 743 | filenames_wsis = np.append(filenames_wsis,filename_wsi) 744 | pred_cancers = np.append(pred_cancers,outputs_wsi_np) 745 | 746 | bool_patches = False 747 | 748 | j = j+1 749 | 750 | if (N_CLASSES==5): 751 | 752 | #save_training predictions 753 | filename_validation_predictions = checkpoint_path+'validation_predictions_'+str(epoch)+'.csv' 754 | 755 | File = {'filenames':filenames_wsis, 'pred_cancers':pred_cancers, 'pred_hgd':pred_hgd,'pred_lgd':pred_lgd, 'pred_hyper':pred_hyper,'pred_normal':pred_normal} 756 | 757 | df = pd.DataFrame(File,columns=['filenames','pred_cancers','pred_hgd','pred_lgd','pred_hyper']) 758 | np.savetxt(filename_validation_predictions, df.values, fmt='%s',delimiter=',') 759 | 760 | else: 761 | 762 | #save_training predictions 763 | filename_validation_predictions = checkpoint_path+'validation_predictions_'+str(epoch)+'.csv' 764 | 765 | File = {'filenames':filenames_wsis, 'pred_cancers':pred_cancers} 766 | 767 | df = pd.DataFrame(File,columns=['filenames','pred_cancers']) 768 | np.savetxt(filename_validation_predictions, df.values, fmt='%s',delimiter=',') 769 | 770 | return valid_loss, wsi_store_loss, patches_store_loss 771 | 772 | #number of epochs without improvement 773 | epoch = 0 774 | if (TASK=='binary'): 775 | iterations = len(train_dataset) 776 | elif (TASK=='multilabel'): 777 | iterations = len(train_dataset)#+100 778 | 779 | tot_batches_training = iterations#int(len(train_dataset)/batch_size_bag) 780 | best_loss = 100000.0 781 | 782 | #number of epochs without improvement 783 | EARLY_STOP_NUM = 12 784 | early_stop_cont = 0 785 | epoch = 0 786 | 787 | validation_checkpoints = checkpoint_path+'validation_losses/' 788 | create_dir(validation_checkpoints) 789 | 790 | NUM_WSI_TO_CLUSTER = 1 791 | THRESHOLD = 0.7 792 | #ALPHA = 1 793 | 794 | def entropy_uncertaincy(self,prob): 795 | i = np.argmax(prob) 796 | v = entropy(prob, base=2) 797 | return v 798 | 799 | while (epoch 0.5, 1, 0) 922 | y_pred = np.append(y_pred,output_norm) 923 | y_true = np.append(y_true,labels_wsi_np) 924 | 925 | if (N_CLASSES==5): 926 | 927 | filenames_wsis = np.append(filenames_wsis,filename_wsi) 928 | pred_cancers = np.append(pred_cancers,outputs_wsi_np[0]) 929 | pred_hgd = np.append(pred_hgd,outputs_wsi_np[1]) 930 | pred_lgd = np.append(pred_lgd,outputs_wsi_np[2]) 931 | pred_hyper = np.append(pred_hyper,outputs_wsi_np[3]) 932 | pred_normal = np.append(pred_normal,outputs_wsi_np[4]) 933 | 934 | else: 935 | 936 | filenames_wsis = np.append(filenames_wsis,filename_wsi) 937 | pred_cancers = np.append(pred_cancers,outputs_wsi_np) 938 | 939 | micro_accuracy_train = accuracy_micro(y_true, y_pred) 940 | print("micro_accuracy " + str(micro_accuracy_train)) 941 | 942 | #del predictions, labels, inputs 943 | torch.cuda.empty_cache() 944 | 945 | torch.save(model, model_weights_filename_temporary) 946 | 947 | bool_patches = False 948 | 949 | print() 950 | #i = i+1 951 | #scheduler.step() 952 | 953 | if (N_CLASSES==5): 954 | 955 | #save_training predictions 956 | filename_training_predictions = checkpoint_path+'training_predictions_'+str(epoch)+'.csv' 957 | 958 | File = {'filenames':filenames_wsis, 'pred_cancers':pred_cancers, 'pred_hgd':pred_hgd,'pred_lgd':pred_lgd, 'pred_hyper':pred_hyper, 'pred_normal':pred_normal} 959 | 960 | df = pd.DataFrame(File,columns=['filenames','pred_cancers','pred_hgd','pred_lgd','pred_hyper','pred_normal']) 961 | np.savetxt(filename_training_predictions, df.values, fmt='%s',delimiter=',') 962 | 963 | else: 964 | 965 | filename_training_predictions = checkpoint_path+'training_predictions_'+str(epoch)+'.csv' 966 | 967 | File = {'filenames':filenames_wsis, 'pred_cancers':pred_cancers} 968 | 969 | df = pd.DataFrame(File,columns=['filenames','pred_cancers']) 970 | np.savetxt(filename_training_predictions, df.values, fmt='%s',delimiter=',') 971 | 972 | model.eval() 973 | 974 | print("epoch "+str(epoch)+ " train loss: " + str(train_loss) + " train micro accuracy " + str(micro_accuracy_train)) 975 | 976 | print("evaluating validation") 977 | valid_loss, valid_wsi_store_loss, valid_patches_store_loss = evaluate_validation_set(model, epoch, validation_generator_bag) 978 | 979 | #save validation 980 | filename_val = validation_checkpoints+'validation_value_'+str(epoch)+'.csv' 981 | array_val = [valid_loss] 982 | array_val_WSI = [valid_wsi_store_loss] 983 | array_val_patches = [valid_patches_store_loss] 984 | File = {'val':array_val, 'val_WSI': array_val_WSI, 'val_patches': array_val_patches} 985 | df = pd.DataFrame(File,columns=['val', 'val_WSI', 'val_patches']) 986 | np.savetxt(filename_val, df.values, fmt='%s',delimiter=',') 987 | 988 | #save_hyperparameters 989 | filename_hyperparameters = checkpoint_path+'hyperparameters.csv' 990 | array_n_classes = [str(N_CLASSES)] 991 | array_lr = [lr_str] 992 | array_opt = [optimizer_str] 993 | array_wt_decay = [wt_decay_str] 994 | array_embedding = [EMBEDDING_bool] 995 | array_data = [DATA_TO_OPTIMIZE] 996 | array_valid = [VALIDATION_DATA] 997 | array_alpha = [ALPHA] 998 | File = {'n_classes':array_n_classes,'opt':array_opt, 'lr':array_lr,'wt_decay':array_wt_decay,'embedding':array_embedding,'data':array_data,'valid_data':array_valid,'alpha':array_alpha} 999 | 1000 | df = pd.DataFrame(File,columns=['n_classes','opt','lr','wt_decay', 'embedding','data','valid_data','alpha']) 1001 | np.savetxt(filename_hyperparameters, df.values, fmt='%s',delimiter=',') 1002 | 1003 | 1004 | 1005 | if (best_loss>valid_loss): 1006 | early_stop_cont = 0 1007 | print ("=> Saving a new best model") 1008 | print("previous loss : " + str(best_loss) + ", new loss function: " + str(valid_loss)) 1009 | best_loss = valid_loss 1010 | torch.save(model, model_weights_filename) 1011 | else: 1012 | early_stop_cont = early_stop_cont+1 1013 | 1014 | epoch = epoch+1 1015 | if (early_stop_cont == EARLY_STOP_NUM): 1016 | print("EARLY STOPPING") 1017 | 1018 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /train/train_MoCo_HE_adversarial_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | from PIL import Image 5 | import albumentations as A 6 | import torch 7 | from torch.utils import data 8 | import torch.utils.data 9 | import argparse 10 | import warnings 11 | import sys 12 | from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler, Sampler 13 | 14 | from lars import LARS 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | argv = sys.argv[1:] 19 | 20 | print("CUDA current device " + str(torch.cuda.current_device())) 21 | print("CUDA devices available " + str(torch.cuda.device_count())) 22 | 23 | #parser parameters 24 | parser = argparse.ArgumentParser(description='Configurations to train models.') 25 | parser.add_argument('-n', '--N_EXP', help='number of experiment',type=int, default=0) 26 | parser.add_argument('-c', '--CNN', help='cnn_to_use',type=str, default='resnet34') 27 | parser.add_argument('-b', '--BATCH_SIZE', help='batch_size',type=int, default=256) 28 | parser.add_argument('-e', '--EPOCHS', help='epochs to train',type=int, default=10) 29 | parser.add_argument('-m', '--MAG', help='magnification to select',type=str, default='10') 30 | parser.add_argument('-f', '--features', help='features_to_use: embedding (True) or features from CNN (False)',type=str, default='True') 31 | parser.add_argument('-l', '--lr', help='learning rate',type=float, default=1e-4) 32 | parser.add_argument('-i', '--input_folder', help='path of the folder where train.csv and valid.csv are stored',type=str, default='./partition/') 33 | parser.add_argument('-o', '--output_folder', help='path where to store the model weights',type=str, default='./models/') 34 | parser.add_argument('-w', '--wsi_folder', help='path where WSIs are stored',type=str, default='./images/') 35 | 36 | args = parser.parse_args() 37 | 38 | N_EXP = args.N_EXP 39 | N_EXP_str = str(N_EXP) 40 | CNN_TO_USE = args.CNN 41 | BATCH_SIZE = args.BATCH_SIZE 42 | BATCH_SIZE_str = str(BATCH_SIZE) 43 | EPOCHS = args.EPOCHS 44 | EPOCHS_str = EPOCHS 45 | MAGNIFICATION = args.MAG 46 | EMBEDDING_bool = args.features 47 | lr = args.lr 48 | INPUT_FOLDER = args.input_folder 49 | OUTPUT_FOLDER = args.output_folder 50 | WSI_FOLDER = args.wsi_folder 51 | 52 | if (EMBEDDING_bool=='True'): 53 | EMBEDDING_bool = True 54 | else: 55 | EMBEDDING_bool = False 56 | 57 | num_keys = 4096 58 | num_keys = 8192 59 | num_keys = 16384 60 | num_keys = 32768 61 | #num_keys = 65536 62 | 63 | 64 | #print(EMBEDDING_bool) 65 | 66 | seed = N_EXP 67 | torch.manual_seed(seed) 68 | if torch.cuda.is_available(): 69 | torch.cuda.manual_seed_all(seed) 70 | np.random.seed(seed) 71 | 72 | print("PARAMETERS") 73 | print("N_EPOCHS: " + str(EPOCHS_str)) 74 | print("CNN used: " + str(CNN_TO_USE)) 75 | print("BATCH_SIZE: " + str(BATCH_SIZE_str)) 76 | print("MAGNIFICATION: " + str(MAGNIFICATION)) 77 | 78 | #create folder (used for saving weights) 79 | def create_dir(models_path): 80 | if not os.path.isdir(models_path): 81 | try: 82 | os.mkdir(models_path) 83 | except OSError: 84 | print ("Creation of the directory %s failed" % models_path) 85 | else: 86 | print ("Successfully created the directory %s " % models_path) 87 | #DIRECTORIES CREATION 88 | print("CREATING/CHECKING DIRECTORIES") 89 | 90 | create_dir(OUTPUT_FOLDER) 91 | 92 | models_path = OUTPUT_FOLDER 93 | checkpoint_path = models_path+'checkpoints_MIL/' 94 | create_dir(checkpoint_path) 95 | 96 | #path model file 97 | model_weights_filename = models_path+'MIL_colon_'+TASK+'.pt' 98 | model_weights_filename_temporary = models_path+'MIL_colon_'+TASK+'_temporary.pt' 99 | 100 | 101 | #CSV LOADING 102 | print("CSV LOADING ") 103 | 104 | k = 10 105 | N_CLASSES = 5 106 | csv_folder = INPUT_FOLDER 107 | 108 | if (TASK=='binary'): 109 | 110 | N_CLASSES = 1 111 | #N_CLASSES = 2 112 | 113 | if (N_CLASSES==1): 114 | csv_filename_training = csv_folder+'train_binary.csv' 115 | csv_filename_validation = csv_folder+'valid_binary.csv' 116 | 117 | 118 | elif (TASK=='multilabel'): 119 | 120 | N_CLASSES = 5 121 | 122 | csv_filename_training = csv_folder+'train_multilabel.csv' 123 | csv_filename_validation = csv_folder+'valid_multilabel.csv' 124 | 125 | #read data 126 | train_dataset = pd.read_csv(csv_filename_training, sep=',', header=None).values#[:10] 127 | valid_dataset = pd.read_csv(csv_filename_validation, sep=',', header=None).values#[:10] 128 | 129 | print(len(train_dataset)) 130 | 131 | n_centers = 1 132 | 133 | #reverse autograd 134 | from torch.autograd import Function 135 | class ReverseLayerF(Function): 136 | 137 | @staticmethod 138 | def forward(ctx, x, alpha): 139 | ctx.alpha = alpha 140 | 141 | return x.view_as(x) 142 | 143 | @staticmethod 144 | def backward(ctx, grad_output): 145 | output = grad_output.neg() * ctx.alpha 146 | 147 | return output, None 148 | 149 | class domain_predictor(torch.nn.Module): 150 | def __init__(self, n_centers): 151 | super(domain_predictor, self).__init__() 152 | # domain predictor 153 | self.fc_feat_in = fc_input_features 154 | self.n_centers = n_centers 155 | 156 | if ('resnet18' in CNN_TO_USE): 157 | self.E = 128 158 | 159 | elif ('resnet34' in CNN_TO_USE): 160 | self.E = 128 161 | 162 | elif ('resnet50' in CNN_TO_USE): 163 | self.E = 256 164 | 165 | elif ('densenet121' in CNN_TO_USE): 166 | self.E = 128 167 | 168 | 169 | self.domain_embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 170 | self.domain_classifier = torch.nn.Linear(in_features=self.E, out_features=self.n_centers) 171 | 172 | #self.domain_predictor = domain_predictor(6) 173 | self.prelu = torch.nn.PReLU(num_parameters=1, init=0.25) 174 | 175 | def forward(self, x): 176 | 177 | dropout = torch.nn.Dropout(p=0.1) 178 | m_binary = torch.nn.Sigmoid() 179 | relu = torch.nn.ReLU() 180 | 181 | domain_emb = self.domain_embedding(x) 182 | 183 | domain_emb = self.prelu(domain_emb) 184 | domain_emb = dropout(domain_emb) 185 | 186 | domain_prob = self.domain_classifier(domain_emb) 187 | 188 | #domain_prob = m_binary(domain_prob) 189 | 190 | return domain_prob 191 | 192 | pre_trained_network = torch.hub.load('pytorch/vision:v0.10.0', CNN_TO_USE, pretrained=True) 193 | 194 | if (('resnet' in CNN_TO_USE) or ('resnext' in CNN_TO_USE)): 195 | fc_input_features = pre_trained_network.fc.in_features 196 | elif (('densenet' in CNN_TO_USE)): 197 | fc_input_features = pre_trained_network.classifier.in_features 198 | elif ('mobilenet' in CNN_TO_USE): 199 | fc_input_features = pre_trained_network.classifier[1].in_features 200 | 201 | 202 | class Encoder(torch.nn.Module): 203 | def __init__(self, dim): 204 | """ 205 | In the constructor we instantiate two nn.Linear modules and assign them as 206 | member variables. 207 | """ 208 | super(Encoder, self).__init__() 209 | 210 | pre_trained_network = torch.hub.load('pytorch/vision:v0.10.0', CNN_TO_USE, pretrained=True) 211 | 212 | if (('resnet' in CNN_TO_USE) or ('resnext' in CNN_TO_USE)): 213 | fc_input_features = pre_trained_network.fc.in_features 214 | elif (('densenet' in CNN_TO_USE)): 215 | fc_input_features = pre_trained_network.classifier.in_features 216 | elif ('mobilenet' in CNN_TO_USE): 217 | fc_input_features = pre_trained_network.classifier[1].in_features 218 | 219 | self.conv_layers = torch.nn.Sequential(*list(pre_trained_network.children())[:-1]) 220 | 221 | if (torch.cuda.device_count()>1): 222 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 223 | 224 | self.fc_feat_in = fc_input_features 225 | self.N_CLASSES = N_CLASSES 226 | 227 | self.dim = dim 228 | 229 | if (EMBEDDING_bool==True): 230 | 231 | if ('resnet34' in CNN_TO_USE): 232 | self.E = self.dim 233 | self.L = self.E 234 | self.D = 64 235 | self.K = self.N_CLASSES 236 | 237 | elif ('resnet50' in CNN_TO_USE): 238 | self.E = self.dim 239 | self.L = self.E 240 | self.D = 128 241 | self.K = self.N_CLASSES 242 | 243 | elif ('resnet152' in CNN_TO_USE): 244 | self.E = self.dim 245 | self.L = self.E 246 | self.D = 128 247 | self.K = self.N_CLASSES 248 | 249 | 250 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 251 | 252 | 253 | self.domain_predictor = domain_predictor(6) 254 | self.prelu = torch.nn.PReLU(num_parameters=1, init=0.25) 255 | 256 | def forward(self, x, mode, alpha): 257 | """ 258 | In the forward function we accept a Tensor of input data and we must return 259 | a Tensor of output data. We can use Modules defined in the constructor as 260 | well as arbitrary operators on Tensors. 261 | """ 262 | #if used attention pooling 263 | A = None 264 | #m = torch.nn.Softmax(dim=1) 265 | dropout = torch.nn.Dropout(p=0.2) 266 | relu = torch.nn.ReLU() 267 | tanh = torch.nn.Tanh() 268 | 269 | 270 | if x is not None: 271 | #print(x.shape) 272 | conv_layers_out=self.conv_layers(x) 273 | #print(x.shape) 274 | 275 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 276 | 277 | #print(conv_layers_out.shape) 278 | 279 | if ('mobilenet' in CNN_TO_USE): 280 | #dropout = torch.nn.Dropout(p=0.2) 281 | conv_layers_out = dropout(conv_layers_out) 282 | #print(conv_layers_out.shape) 283 | 284 | if (EMBEDDING_bool==True): 285 | #conv_layers_out = relu(conv_layers_out) 286 | #conv_layers_out = dropout(conv_layers_out) 287 | embedding_layer = self.embedding(conv_layers_out) 288 | embedding_layer = self.prelu(embedding_layer) 289 | 290 | features_to_return = embedding_layer 291 | 292 | else: 293 | features_to_return = conv_layers_out 294 | 295 | norm = torch.norm(features_to_return, p='fro', dim=1, keepdim=True) 296 | 297 | #normalized_array = features_to_return #/ norm 298 | #normalized_array = features_to_return 299 | normalized_array = torch.nn.functional.normalize(features_to_return, dim=1) 300 | 301 | if (mode=='train'): 302 | reverse_feature = ReverseLayerF.apply(conv_layers_out, alpha) 303 | 304 | output_domain = self.domain_predictor(reverse_feature) 305 | 306 | return normalized_array, output_domain 307 | 308 | return normalized_array 309 | 310 | backbone = 'resnet34' 311 | #moco_dim = 768 312 | moco_dim = 128 313 | moco_m = 0.999 314 | temperature = 0.07 315 | 316 | batch_size = BATCH_SIZE 317 | 318 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 319 | 320 | encoder = Encoder(dim=moco_dim).to(device) 321 | momentum_encoder = Encoder(dim=moco_dim).to(device) 322 | 323 | encoder.embedding.weight.data.normal_(mean=0.0, std=0.01) 324 | encoder.embedding.bias.data.zero_() 325 | 326 | #momentum_encoder.embedding.weight.data.normal_(mean=0.0, std=0.01) 327 | #momentum_encoder.embedding.bias.data.zero_() 328 | 329 | import torchvision 330 | 331 | momentum_encoder.load_state_dict(encoder.state_dict(), strict=False) 332 | 333 | for param in momentum_encoder.parameters(): 334 | param.requires_grad = False 335 | 336 | #del pre_trained_network 337 | 338 | #DATA AUGMENTATION 339 | from torchvision import transforms 340 | prob = 0.75 341 | pipeline_transform_paper = A.Compose([ 342 | #A.RandomScale(scale_limit=(-0.005,0.005), interpolation=2, p=prob), 343 | #A.RandomCrop(height=220, width=220, p=prob), 344 | #A.Resize(224,224,always_apply=True), 345 | #A.MotionBlur(blur_limit=3, p=prob), 346 | #A.MedianBlur(blur_limit=3, p=prob), 347 | #A.CropAndPad(percent=(-0.01, -0.05),pad_mode=1,always_apply=True), 348 | A.RandomResizedCrop(height=224, width=224, scale=(0.8, 1), always_apply=True), 349 | A.VerticalFlip(p=prob), 350 | A.HorizontalFlip(p=prob), 351 | A.RandomRotate90(p=prob), 352 | #A.HueSaturationValue(hue_shift_limit=(-15,8),sat_shift_limit=(-20,10),val_shift_limit=(-8,8),always_apply=True), 353 | A.ColorJitter (brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, always_apply=True), 354 | A.GaussianBlur (blur_limit=(1, 3), sigma_limit=0, always_apply=True), 355 | #A.HueSaturationValue(hue_shift_limit=(-25,10),sat_shift_limit=(-25,15),val_shift_limit=(-15,15),always_apply=True), 356 | #A.RGBShift (r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, always_apply=True, p=prob), 357 | #A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=prob), 358 | #A.RandomBrightness(limit=0.2, p=prob), 359 | #A.RandomContrast(limit=0.2, p=prob), 360 | #A.GaussNoise(p=prob), 361 | #A.ElasticTransform(alpha=2,border_mode=4, sigma=20, alpha_affine=20, p=prob, always_apply=True), 362 | #A.GridDistortion(num_steps=2, distort_limit=0.2, interpolation=1, border_mode=4, p=prob), 363 | #A.GlassBlur(sigma=0.3, max_delta=2, iterations=1, p=prob), 364 | #A.OpticalDistortion (distort_limit=0.2, shift_limit=0.2, interpolation=1, border_mode=4, value=None, p=prob), 365 | #A.GridDropout (ratio=0.3, unit_size_min=3, unit_size_max=40, holes_number_x=3, holes_number_y=3, shift_x=1, shift_y=10, random_offset=True, fill_value=0, p=prob), 366 | #A.Equalize(p=prob), 367 | #A.Posterize(p=prob, always_apply=True), 368 | #A.RandomGamma(p=prob, always_apply=True), 369 | #A.Superpixels(p_replace=0.05, n_segments=100, max_size=128, interpolation=1, p=prob), 370 | #A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3, p=prob), 371 | A.ToGray(p=0.2), 372 | #A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=0, p=prob), 373 | #A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=255, p=prob), 374 | ]) 375 | 376 | pipeline_transform = A.Compose([ 377 | #A.RandomScale(scale_limit=(-0.005,0.005), interpolation=2, p=prob), 378 | #A.RandomCrop(height=220, width=220, p=prob), 379 | #A.Resize(224,224,always_apply=True), 380 | #A.MotionBlur(blur_limit=3, p=prob), 381 | #A.MedianBlur(blur_limit=3, p=prob), 382 | #A.CropAndPad(percent=(-0.01, -0.05),pad_mode=1,always_apply=True), 383 | A.RandomResizedCrop(height=224, width=224, scale=(0.8, 1), p = prob), 384 | A.VerticalFlip(p=prob), 385 | A.HorizontalFlip(p=prob), 386 | A.RandomRotate90(p=prob), 387 | #A.HueSaturationValue(hue_shift_limit=(-15,8),sat_shift_limit=(-20,10),val_shift_limit=(-8,8),always_apply=True), 388 | A.ColorJitter (brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, always_apply=True), 389 | A.GaussianBlur (blur_limit=(1, 3), sigma_limit=0, always_apply=True), 390 | #A.HueSaturationValue(hue_shift_limit=(-25,10),sat_shift_limit=(-25,15),val_shift_limit=(-15,15),always_apply=True), 391 | #A.RGBShift (r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, always_apply=True, p=prob), 392 | #A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=prob), 393 | #A.RandomBrightness(limit=0.2, p=prob), 394 | #A.RandomContrast(limit=0.2, p=prob), 395 | #A.GaussNoise(p=prob), 396 | A.ElasticTransform(alpha=2,border_mode=4, sigma=10, alpha_affine=10, p=prob), 397 | A.GridDistortion(num_steps=1, distort_limit=0.1, interpolation=1, border_mode=4, p=prob), 398 | #A.GlassBlur(sigma=0.3, max_delta=2, iterations=1, p=prob), 399 | A.OpticalDistortion (distort_limit=0.2, shift_limit=0.2, interpolation=1, border_mode=4, value=None, p=prob), 400 | #A.GridDropout (ratio=0.3, unit_size_min=3, unit_size_max=40, holes_number_x=3, holes_number_y=3, shift_x=1, shift_y=10, random_offset=True, fill_value=0, p=prob), 401 | #A.Equalize(p=prob), 402 | #A.Posterize(p=prob, always_apply=True), 403 | #A.RandomGamma(p=prob, always_apply=True), 404 | #A.Superpixels(p_replace=0.05, n_segments=100, max_size=128, interpolation=1, p=prob), 405 | #A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3, p=prob), 406 | A.ToGray(p=0.2), 407 | #A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=0, p=prob), 408 | #A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=255, p=prob), 409 | ]) 410 | 411 | p_soft = 0.5 412 | pipeline_transform_soft = A.Compose([ 413 | #A.ElasticTransform(alpha=0.01,p=p_soft), 414 | #A.RGBShift (r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, always_apply=True, p=p_soft), 415 | A.HueSaturationValue(hue_shift_limit=(-15,8),sat_shift_limit=(-20,10),val_shift_limit=(-8,8),p=p_soft), 416 | A.VerticalFlip(p=p_soft), 417 | A.HorizontalFlip(p=p_soft), 418 | A.RandomRotate90(p=p_soft), 419 | #A.HueSaturationValue(hue_shift_limit=(-25,10),sat_shift_limit=(-25,15),val_shift_limit=(-15,15),p=p_soft), 420 | #A.CLAHE(clip_limit=1.0, tile_grid_size=(8, 8), p=p_soft), 421 | #A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=p_soft), 422 | #A.RandomBrightness(limit=0.1, p=p_soft), 423 | #A.RandomContrast(limit=0.1, p=p_soft), 424 | ]) 425 | 426 | #DATA NORMALIZATION 427 | preprocess = transforms.Compose([ 428 | transforms.ToTensor(), 429 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 430 | ]) 431 | 432 | 433 | 434 | def generate_list_instances(filename): 435 | 436 | instance_dir = WSI_FOLDER 437 | fname = os.path.split(filename)[-1] 438 | 439 | instance_csv = instance_dir+fname+'/'+fname+'_paths_densely.csv' 440 | 441 | return instance_csv 442 | 443 | 444 | class ImbalancedDatasetSampler_multilabel(torch.utils.data.sampler.Sampler): 445 | 446 | def __init__(self, dataset, indices=None, num_samples=None): 447 | 448 | self.indices = list(range(len(dataset))) if indices is None else indices 449 | 450 | self.num_samples = len(self.indices) if num_samples is None else num_samples 451 | 452 | 453 | label_to_count = {} 454 | for idx in self.indices: 455 | label = self._get_label(dataset, idx) 456 | for l in label: 457 | if l in label_to_count: 458 | label_to_count[l] += 1 459 | else: 460 | label_to_count[l] = 1 461 | 462 | weights = [] 463 | 464 | for idx in self.indices: 465 | c = 0 466 | for j, l in enumerate(self._get_label(dataset, idx)): 467 | c = c+(1/label_to_count[l]) 468 | 469 | weights.append(c/(j+1)) 470 | self.weights = torch.DoubleTensor(weights) 471 | 472 | def _get_label(self, dataset, idx): 473 | labels = np.where(dataset[idx,1:]==1)[0] 474 | #print(labels) 475 | #labels = dataset[idx,2] 476 | return labels 477 | 478 | def __iter__(self): 479 | return (self.indices[i] for i in torch.multinomial( 480 | self.weights, self.num_samples, replacement=True)) 481 | 482 | def __len__(self): 483 | return self.num_samples 484 | 485 | class Balanced_Multimodal(torch.utils.data.sampler.Sampler): 486 | 487 | def __init__(self, dataset, indices=None, num_samples=None, alpha = 0.5): 488 | 489 | self.indices = list(range(len(dataset))) if indices is None else indices 490 | 491 | self.num_samples = len(self.indices) if num_samples is None else num_samples 492 | 493 | class_sample_count = [0,0,0,0,0] 494 | 495 | 496 | class_sample_count = np.sum(train_dataset[:,1:],axis=0) 497 | 498 | min_class = np.argmin(class_sample_count) 499 | class_sample_count = np.array(class_sample_count) 500 | weights = [] 501 | for c in class_sample_count: 502 | weights.append((c/class_sample_count[min_class])) 503 | 504 | ratio = np.array(weights).astype(np.float) 505 | 506 | label_to_count = {} 507 | for idx in self.indices: 508 | label = self._get_label(dataset, idx) 509 | for l in label: 510 | if l in label_to_count: 511 | label_to_count[l] += 1 512 | else: 513 | label_to_count[l] = 1 514 | 515 | weights = [] 516 | 517 | for idx in self.indices: 518 | c = 0 519 | for j, l in enumerate(self._get_label(dataset, idx)): 520 | c = c+(1/label_to_count[l])#*ratio[l] 521 | 522 | weights.append(c/(j+1)) 523 | #weights.append(c) 524 | 525 | self.weights_original = torch.DoubleTensor(weights) 526 | 527 | self.weights_uniform = np.repeat(1/self.num_samples, self.num_samples) 528 | 529 | #print(self.weights_a, self.weights_b) 530 | 531 | beta = 1 - alpha 532 | self.weights = (alpha * self.weights_original) + (beta * self.weights_uniform) 533 | 534 | 535 | def _get_label(self, dataset, idx): 536 | labels = np.where(dataset[idx,1:]==1)[0] 537 | #print(labels) 538 | #labels = dataset[idx,2] 539 | return labels 540 | 541 | def __iter__(self): 542 | return (self.indices[i] for i in torch.multinomial( 543 | self.weights, self.num_samples, replacement=True)) 544 | 545 | def __len__(self): 546 | return self.num_samples 547 | 548 | def H_E_Staining(img, Io=240, alpha=1, beta=0.15): 549 | ''' 550 | Normalize staining appearence of H&E stained images 551 | 552 | Example use: 553 | see test.py 554 | 555 | Input: 556 | I: RGB input image 557 | Io: (optional) transmitted light intensity 558 | 559 | Output: 560 | Inorm: normalized image 561 | H: hematoxylin image 562 | E: eosin image 563 | 564 | Reference: 565 | A method for normalizing histology slides for quantitative analysis. M. 566 | Macenko et al., ISBI 2009 567 | ''' 568 | 569 | # define height and width of image 570 | h, w, c = img.shape 571 | 572 | # reshape image 573 | img = img.reshape((-1,3)) 574 | 575 | # calculate optical density 576 | OD = -np.log((img.astype(np.float)+1)/Io) 577 | 578 | # remove transparent pixels 579 | ODhat = OD[~np.any(OD vMax[0]: 601 | HE = np.array((vMin[:,0], vMax[:,0])).T 602 | else: 603 | HE = np.array((vMax[:,0], vMin[:,0])).T 604 | 605 | return HE 606 | 607 | class Dataset_instance(data.Dataset): 608 | 609 | def __init__(self, list_IDs, mode): 610 | self.list_IDs = list_IDs 611 | #self.list_IDs = list_IDs[:,0] 612 | #self.list_hes = list_IDs[:,1:] 613 | self.mode = mode 614 | 615 | def __len__(self): 616 | return len(self.list_IDs) 617 | 618 | def __getitem__(self, index): 619 | # Select sample 620 | ID = self.list_IDs[index] 621 | # Load data and get label 622 | img = Image.open(ID) 623 | X = np.asarray(img) 624 | img.close() 625 | 626 | #k = pipeline_transform_soft(image=k)['image'] 627 | #k = pipeline_transform(image=q)['image'] 628 | 629 | h_e_matrix = [0,0,0,0,0,0] 630 | h_e_matrix = np.array(h_e_matrix) 631 | 632 | if (self.mode == 'train'): 633 | 634 | #h_e_matrix = self.list_hes[index].tolist() 635 | 636 | b = False 637 | k = X 638 | 639 | while(b==False): 640 | 641 | k = pipeline_transform_soft(image=k)['image'] 642 | 643 | try: 644 | h_e_matrix = H_E_Staining(k) 645 | b = True 646 | except: 647 | k = pipeline_transform_soft(image=k)['image'] 648 | pass 649 | 650 | #idx_n = np.random.randint(0,self.__len__()) 651 | #self.__getitem__(idx_n) 652 | #pass 653 | 654 | q = pipeline_transform_paper(image=k)['image'] 655 | h_e_matrix = np.reshape(h_e_matrix, 6) 656 | h_e_matrix = np.asarray(h_e_matrix) 657 | 658 | 659 | #h_e_matrix = np.reshape(h_e_matrix, 6) 660 | #h_e_matrix = np.array(h_e_matrix) 661 | 662 | #print(h_e_matrix) 663 | #print(q.shape) 664 | else: 665 | k = X 666 | q = pipeline_transform(image=k)['image'] 667 | 668 | del X 669 | #data transformation 670 | q = preprocess(q).type(torch.FloatTensor) 671 | k = preprocess(k).type(torch.FloatTensor) 672 | h_e_matrix = torch.FloatTensor(h_e_matrix) 673 | #return input_tensor 674 | return k, q, h_e_matrix 675 | 676 | class Dataset_bag(data.Dataset): 677 | 678 | def __init__(self, list_IDs, labels): 679 | 680 | self.labels = labels 681 | self.list_IDs = list_IDs 682 | 683 | def __len__(self): 684 | 685 | return len(self.list_IDs) 686 | 687 | def __getitem__(self, index): 688 | # Select sample 689 | WSI = self.list_IDs[index] 690 | 691 | return WSI 692 | 693 | #parameters bag 694 | batch_size_bag = 16 695 | 696 | """ 697 | sampler = ImbalancedDatasetSampler_multilabel 698 | params_train_bag = {'batch_size': batch_size_bag, 699 | 'sampler': sampler(train_dataset)} 700 | #'shuffle': True} 701 | """ 702 | #""" 703 | sampler = Balanced_Multimodal 704 | params_train_bag = {'batch_size': batch_size_bag, 705 | #'sampler': sampler(train_dataset,alpha=0.25)} 706 | 'shuffle': True} 707 | #""" 708 | """ 709 | sampler = Balanced_Multimodal 710 | params_bag_train = {'batch_size': batch_size_bag, 711 | 'sampler': sampler(train_dataset,alpha=0.5)} 712 | #'shuffle': True} 713 | """ 714 | 715 | 716 | params_bag_test = {'batch_size': batch_size_bag, 717 | #'sampler': sampler(train_dataset) 718 | 'shuffle': True} 719 | 720 | params_bag_train_queue = {'batch_size': int(batch_size_bag*2), 721 | 'sampler': sampler(train_dataset,alpha=0.25)} 722 | #'shuffle': True} 723 | 724 | params_bag_test_queue = {'batch_size': int(batch_size_bag*2), 725 | #'sampler': sampler(train_dataset) 726 | 'shuffle': True} 727 | 728 | training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1:]) 729 | training_generator_bag = data.DataLoader(training_set_bag, **params_train_bag) 730 | 731 | #validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1:]) 732 | #validation_generator_bag = data.DataLoader(validation_set_bag, **params_bag_test) 733 | 734 | training_set_bag = Dataset_bag(train_dataset[:,0], train_dataset[:,1:]) 735 | training_generator_bag_queue = data.DataLoader(training_set_bag, **params_bag_train_queue) 736 | 737 | #validation_set_bag = Dataset_bag(valid_dataset[:,0], valid_dataset[:,1:]) 738 | #validation_generator_bag_queue = data.DataLoader(validation_set_bag, **params_bag_test_queue) 739 | 740 | #params patches generated 741 | 742 | # Find total parameters and trainable parameters 743 | total_params = sum(p.numel() for p in encoder.parameters()) 744 | print(f'{total_params:,} total parameters.') 745 | 746 | total_trainable_params = sum( 747 | p.numel() for p in encoder.parameters() if p.requires_grad) 748 | print(f'{total_trainable_params:,} training parameters.') 749 | 750 | torch.backends.cudnn.benchmark=True 751 | 752 | class RMSELoss(torch.nn.Module): 753 | def __init__(self, eps=1e-6): 754 | super().__init__() 755 | self.mse = torch.nn.MSELoss() 756 | self.eps = eps 757 | 758 | def forward(self,yhat,y): 759 | loss = torch.sqrt(self.mse(yhat,y) + self.eps) 760 | return loss 761 | 762 | def loss_function(q, k, queue): 763 | 764 | #N is the batch size 765 | N = q.shape[0] 766 | 767 | #C is the dimension of the representation 768 | C = q.shape[1] 769 | 770 | #BMM stands for batch matrix multiplication 771 | #If mat1 is B × n × M tensor, then mat2 is B × m × P tensor, 772 | #Then output a B × n × P tensor. 773 | pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),temperature)) 774 | 775 | #Matrix multiplication is performed between the query and the queue tensor 776 | neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),temperature)), dim=1) 777 | 778 | #Sum up 779 | denominator = neg + pos 780 | 781 | return torch.mean(-torch.log(torch.div(pos,denominator))) 782 | 783 | criterion = torch.nn.CrossEntropyLoss().to(device) 784 | #criterion_domain = RMSELoss().to(device) 785 | criterion_domain = torch.nn.L1Loss() 786 | 787 | lambda_val = 0.5 788 | 789 | import torch.optim as optim 790 | optimizer_str = 'adam' 791 | #optimizer_str = 'sgd' 792 | #optimizer_str = 'lars' 793 | 794 | #print(model.conv_layers.parameters()) 795 | 796 | # Optimizer 797 | SGD_momentum = 0.9 798 | weight_decay = 1e-4 799 | shuffle_bn = True 800 | 801 | if (optimizer_str=='sgd'): 802 | optimizer = optim.SGD(encoder.parameters(), 803 | lr=lr, 804 | momentum=SGD_momentum, 805 | weight_decay=weight_decay) 806 | 807 | elif(optimizer_str=='lars'): 808 | optimizer = LARS(params=encoder.parameters(), 809 | lr=lr, 810 | momentum=SGD_momentum, 811 | weight_decay=weight_decay, 812 | eta=0.001, 813 | max_epoch=EPOCHS) 814 | 815 | elif (optimizer_str=='adam'): 816 | optimizer = optim.Adam(encoder.parameters(),lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=True) 817 | #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 818 | 819 | def momentum_step(m=1): 820 | ''' 821 | Momentum step (Eq (2)). 822 | Args: 823 | - m (float): momentum value. 1) m = 0 -> copy parameter of encoder to key encoder 824 | 2) m = 0.999 -> momentum update of key encoder 825 | ''' 826 | params_q = encoder.state_dict() 827 | params_k = momentum_encoder.state_dict() 828 | 829 | dict_params_k = dict(params_k) 830 | 831 | for name in params_q: 832 | theta_k = dict_params_k[name] 833 | theta_q = params_q[name].data 834 | dict_params_k[name].data.copy_(m * theta_k + (1-m) * theta_q) 835 | 836 | momentum_encoder.load_state_dict(dict_params_k) 837 | 838 | def update_lr(epoch): 839 | ''' 840 | Learning rate scheduling. 841 | Args: 842 | - epoch (float): Set new learning rate by a given epoch. 843 | ''' 844 | 845 | if epoch < 10: 846 | lr = args.lr 847 | elif epoch >= 10 and epoch < 20: 848 | lr = args.lr * 0.1 849 | elif epoch >= 20 : 850 | lr = args.lr * 0.01 851 | 852 | for param_group in optimizer.param_groups: 853 | param_group['lr'] = lr 854 | 855 | def update_queue(queue, k): 856 | 857 | len_k = k.shape[0] 858 | len_queue = queue.shape[0] 859 | 860 | new_queue = torch.cat([k, queue], dim=0) 861 | 862 | new_queue = new_queue[:num_keys] 863 | 864 | return new_queue 865 | 866 | ''' ######################## < Step 4 > Start training ######################## ''' 867 | 868 | # Initialize momentum_encoder with parameters of encoder. 869 | momentum_step(m=0) 870 | 871 | def mapping_patches(patches, THRESHOLD): 872 | 873 | n_patches = len(patches) 874 | ratio = int(n_patches/THRESHOLD) 875 | #print(ratio) 876 | 877 | if (ratio>0): 878 | 879 | #idx_list = np.random.randint(0, n_patches, THRESHOLD) 880 | #idx = list(set(idx_list)) 881 | #new_patches = patches[idx] 882 | 883 | idx_list = np.random.choice(n_patches, THRESHOLD, replace = False) 884 | idx = list(set(idx_list)) 885 | idx.sort() 886 | new_patches = patches[idx] 887 | 888 | else: 889 | 890 | new_patches = patches 891 | 892 | return new_patches 893 | 894 | def validate(epoch, generator): 895 | #accumulator for validation set 896 | 897 | encoder.eval() 898 | momentum_encoder.eval() 899 | 900 | queue = [] 901 | dataloader_iterator = iter(validation_generator_bag_queue) 902 | 903 | wsis = next(dataloader_iterator) 904 | 905 | fnames_patches = [] 906 | 907 | new_patches = 0 908 | 909 | for wsi in wsis: 910 | 911 | fname = wsi 912 | 913 | print(fname) 914 | 915 | csv_fname = generate_list_instances(fname) 916 | csv_instances = pd.read_csv(csv_fname, sep=',', header=None).values 917 | 918 | l_csv = len(csv_instances) 919 | #csv_instances = mapping_patches(csv_instances,img_len) 920 | p_to_select = 512 921 | csv_instances = mapping_patches(csv_instances,p_to_select) 922 | 923 | new_patches = new_patches + len(csv_instances) 924 | 925 | fnames_patches = np.append(fnames_patches, csv_instances) 926 | 927 | #fnames_patches = np.reshape(fnames_patches, (new_patches, 7)) 928 | #print(fnames_patches.shape) 929 | 930 | num_workers = 2 931 | params_instance = {'batch_size': batch_size, 932 | 'shuffle': True, 933 | 'num_workers': num_workers} 934 | 935 | instances = Dataset_instance(fnames_patches, 'valid') 936 | generator_inst = data.DataLoader(instances, **params_instance) 937 | 938 | with torch.no_grad(): 939 | for i, (_, img, _) in enumerate(generator_inst): 940 | key_feature = momentum_encoder(img.to(device), 'valid', _) 941 | queue.append(key_feature) 942 | 943 | if i == (num_keys / batch_size) - 1: 944 | break 945 | queue = torch.cat(queue, dim=0) 946 | 947 | valid_loss = 0.0 948 | total_iters = 0 949 | dataloader_iterator = iter(generator) 950 | 951 | j = 0 952 | 953 | iterations = int(len(valid_dataset) / batch_size_bag) 954 | 955 | for i in range(iterations): 956 | print('[%d], %d / %d ' % (epoch, i, iterations)) 957 | 958 | try: 959 | wsis = next(dataloader_iterator) 960 | except StopIteration: 961 | dataloader_iterator = iter(training_generator_bag) 962 | wsis = next(dataloader_iterator) 963 | #inputs: bags, labels: labels of the bags 964 | 965 | fnames_patches = [] 966 | 967 | new_patches = 0 968 | 969 | for wsi in wsis: 970 | 971 | fname = wsi 972 | 973 | print(fname) 974 | 975 | csv_fname = generate_list_instances(fname) 976 | csv_instances = pd.read_csv(csv_fname, sep=',', header=None).values 977 | 978 | l_csv = len(csv_instances) 979 | #csv_instances = mapping_patches(csv_instances,img_len) 980 | p_to_select = max(int(l_csv/3), img_len) 981 | csv_instances = mapping_patches(csv_instances,p_to_select) 982 | 983 | new_patches = new_patches + len(csv_instances) 984 | 985 | fnames_patches = np.append(fnames_patches, csv_instances) 986 | 987 | #fnames_patches = np.reshape(fnames_patches, (new_patches, 7)) 988 | 989 | num_workers = 2 990 | params_instance = {'batch_size': batch_size, 991 | 'shuffle': True, 992 | 'num_workers': num_workers} 993 | 994 | instances = Dataset_instance(fnames_patches, mode) 995 | generator = data.DataLoader(instances, **params_instance) 996 | 997 | for a, (x_q, x_k, _) in enumerate(generator): 998 | # Preprocess 999 | 1000 | momentum_encoder.zero_grad() 1001 | encoder.zero_grad() 1002 | 1003 | # Shffled BN : shuffle x_k before distributing it among GPUs (Section. 3.3) 1004 | if shuffle_bn: 1005 | idx = torch.randperm(x_k.size(0)) 1006 | x_k = x_k[idx] 1007 | 1008 | # x_q, x_k : (N, 3, 64, 64) 1009 | x_q, x_k = x_q.to(device), x_k.to(device) 1010 | 1011 | q = encoder(x_q, 'valid', _) # q : (N, 128) 1012 | k = momentum_encoder(x_k, 'valid', _).detach() 1013 | 1014 | # Shuffled BN : unshuffle k (Section. 3.3) 1015 | if shuffle_bn: 1016 | k_temp = torch.zeros_like(k) 1017 | for a, j in enumerate(idx): 1018 | k_temp[j] = k[a] 1019 | k = k_temp 1020 | 1021 | """ 1022 | # positive logits: Nx1 1023 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 1024 | # negative logits: NxK 1025 | l_neg = torch.einsum('nc,ck->nk', [q, queue.t()]) 1026 | 1027 | # Positive sampling q & k 1028 | #l_pos = torch.sum(q * k, dim=1, keepdim=True) # (N, 1) 1029 | #print("l_pos",l_pos) 1030 | 1031 | # Negative sampling q & queue 1032 | #l_neg = torch.mm(q, queue.t()) # (N, 4096) 1033 | #print("l_neg",l_neg) 1034 | 1035 | # Logit and label 1036 | logits = torch.cat([l_pos, l_neg], dim=1) / temperature # (N, 4097) witi label [0, 0, ..., 0] 1037 | labels = torch.zeros(logits.size(0), dtype=torch.long).to(device) 1038 | 1039 | # Get loss and backprop 1040 | loss_moco = criterion(logits, labels) 1041 | """ 1042 | loss = loss_function(q, k, queue) 1043 | 1044 | # Encoder update 1045 | #optimizer.step() 1046 | 1047 | # Momentum encoder update 1048 | #momentum_step(m=moco_m) 1049 | 1050 | # Update dictionary 1051 | #queue = torch.cat([k, queue[:queue.size(0) - k.size(0)]], dim=0) 1052 | queue = update_queue(queue, k) 1053 | #print(queue.shape) 1054 | 1055 | # Print a training status, save a loss value, and plot a loss graph. 1056 | 1057 | valid_loss = valid_loss + ((1 / (total_iters+1)) * (loss.item() - valid_loss)) 1058 | total_iters = total_iters + 1 1059 | print('[Epoch : %d / Total iters : %d] : loss : %f ...' %(epoch, total_iters, valid_loss)) 1060 | 1061 | momentum_encoder.zero_grad() 1062 | encoder.zero_grad() 1063 | 1064 | torch.cuda.empty_cache() 1065 | return valid_loss 1066 | 1067 | 1068 | # Training 1069 | print('\nStart training!') 1070 | epoch = 0 1071 | 1072 | iterations_per_epoch = 8600 1073 | 1074 | losses_train = [] 1075 | 1076 | #number of epochs without improvement 1077 | EARLY_STOP_NUM = 10 1078 | early_stop_cont = 0 1079 | epoch = 0 1080 | num_epochs = EPOCHS 1081 | validation_checkpoints = checkpoint_path+'validation_losses/' 1082 | create_dir(validation_checkpoints) 1083 | #number of epochs without improvement 1084 | epoch = 0 1085 | iterations = int(len(train_dataset) / batch_size_bag)#+100 1086 | #iterations = 600 1087 | 1088 | tot_batches_training = iterations#int(len(train_dataset)/batch_size_bag) 1089 | best_loss = 100000.0 1090 | 1091 | tot_iterations = num_epochs * iterations_per_epoch 1092 | cont_iterations_tot = 0 1093 | 1094 | TEMPERATURE = 0.07 1095 | 1096 | p_to_select = 512 1097 | 1098 | while (epochn', [q, k]).unsqueeze(-1) 1244 | # negative logits: NxK 1245 | l_neg = torch.einsum('nc,ck->nk', [q, queue.t()]) 1246 | 1247 | # Positive sampling q & k 1248 | #l_pos = torch.sum(q * k, dim=1, keepdim=True) # (N, 1) 1249 | #print("l_pos",l_pos) 1250 | 1251 | # Negative sampling q & queue 1252 | #l_neg = torch.mm(q, queue.t()) # (N, 4096) 1253 | #print("l_neg",l_neg) 1254 | 1255 | # Logit and label 1256 | logits = torch.cat([l_pos, l_neg], dim=1) / temperature # (N, 4097) witi label [0, 0, ..., 0] 1257 | labels = torch.zeros(logits.size(0), dtype=torch.long).to(device) 1258 | 1259 | # Get loss and backprop 1260 | loss_moco = criterion(logits, labels) 1261 | """ 1262 | loss_moco = loss_function(q, k, queue) 1263 | loss_domains = lambda_val * criterion_domain(he_q, he_staining) 1264 | 1265 | loss = loss_moco + loss_domains 1266 | 1267 | loss.backward() 1268 | 1269 | # Encoder update 1270 | optimizer.step() 1271 | 1272 | momentum_encoder.zero_grad() 1273 | encoder.zero_grad() 1274 | 1275 | # Momentum encoder update 1276 | momentum_step(m=moco_m) 1277 | 1278 | # Update dictionary 1279 | #queue = torch.cat([k, queue[:queue.size(0) - k.size(0)]], dim=0) 1280 | queue = update_queue(queue, k) 1281 | #print(queue.shape) 1282 | 1283 | # Print a training status, save a loss value, and plot a loss graph. 1284 | 1285 | train_loss_moco = train_loss_moco + ((1 / (total_iters+1)) * (loss_moco.item() - train_loss_moco)) 1286 | train_loss_domain = train_loss_domain + ((1 / (total_iters+1)) * (loss_domains.item() - train_loss_domain)) 1287 | total_iters = total_iters + 1 1288 | cont_iterations_tot = cont_iterations_tot + 1 1289 | train_loss = train_loss_moco + train_loss_domain 1290 | 1291 | print('[Epoch : %d / Total iters : %d] : loss_moco :%f, loss_domain :%f ...' %(epoch, total_iters, train_loss_moco, train_loss_domain)) 1292 | 1293 | if (i%10==True): 1294 | print('a') 1295 | if (best_loss>train_loss_moco): 1296 | early_stop_cont = 0 1297 | print ("=> Saving a new best model") 1298 | print("previous loss : " + str(best_loss) + ", new loss function: " + str(train_loss_moco)) 1299 | best_loss = train_loss_moco 1300 | try: 1301 | torch.save(encoder.state_dict(), model_weights_filename,_use_new_zipfile_serialization=False) 1302 | except: 1303 | torch.save(encoder.state_dict(), model_weights_filename) 1304 | else: 1305 | 1306 | try: 1307 | torch.save(encoder.state_dict(), model_weights_temporary_filename,_use_new_zipfile_serialization=False) 1308 | except: 1309 | torch.save(encoder.state_dict(), model_weights_temporary_filename) 1310 | 1311 | torch.cuda.empty_cache() 1312 | 1313 | # Update learning rate 1314 | #update_lr(epoch) 1315 | 1316 | print("epoch "+str(epoch)+ " train loss: " + str(train_loss)) 1317 | 1318 | print("evaluating validation") 1319 | """ 1320 | valid_loss = validate(epoch, validation_generator_bag) 1321 | 1322 | #save validation 1323 | filename_val = validation_checkpoints+'validation_value_'+str(epoch)+'.csv' 1324 | array_val = [valid_loss] 1325 | File = {'val':array_val} 1326 | df = pd.DataFrame(File,columns=['val']) 1327 | np.savetxt(filename_val, df.values, fmt='%s',delimiter=',') 1328 | 1329 | #save_hyperparameters 1330 | filename_hyperparameters = checkpoint_path+'hyperparameters.csv' 1331 | array_lr = [str(lr)] 1332 | array_opt = [optimizer_str] 1333 | array_wt_decay = [str(weight_decay)] 1334 | array_embedding = [EMBEDDING_bool] 1335 | File = {'opt':array_opt, 'lr':array_lr,'wt_decay':array_wt_decay,'array_embedding':EMBEDDING_bool} 1336 | 1337 | df = pd.DataFrame(File,columns=['opt','lr','wt_decay','array_embedding']) 1338 | np.savetxt(filename_hyperparameters, df.values, fmt='%s',delimiter=',') 1339 | """ 1340 | 1341 | 1342 | 1343 | 1344 | 1345 | epoch = epoch+1 1346 | if (early_stop_cont == EARLY_STOP_NUM): 1347 | print("EARLY STOPPING") --------------------------------------------------------------------------------