├── scripts ├── extend_kd_tree_offline.py ├── Augmentation.py └── Training_new_augmentation.py └── README.md /scripts/extend_kd_tree_offline.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import albumentations as A 6 | from scipy.spatial import KDTree, cKDTree 7 | import pickle 8 | import matplotlib.pyplot as plt 9 | import time 10 | from tqdm import tqdm 11 | import argparse 12 | import sys, getopt 13 | 14 | argv = sys.argv[1:] 15 | 16 | 17 | #parser parameters 18 | parser = argparse.ArgumentParser(description='Configurations to train models.') 19 | parser.add_argument('-i', '--INPUT', help='input to extend',type=str, default='input.pickle') 20 | parser.add_argument('-o', '--OUTPUT', help='where to store file',type=str, default='output.pickle') 21 | parser.add_argument('-d', '--DATA_TO_ADD', help='csv_file',type=int, default=32) 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | INPUT_DATA = args.DATA_TO_ADD 27 | OUTPUT_FILE = args.OUTPUT 28 | INPUT_FILE = args.INPUT 29 | 30 | def H_E_Staining(img, Io=240, alpha=1, beta=0.15): 31 | 32 | # define height and width of image 33 | h, w, c = img.shape 34 | 35 | # reshape image 36 | img = img.reshape((-1,3)) 37 | 38 | # calculate optical density 39 | OD = -np.log((img.astype(np.float)+1)/Io) 40 | 41 | # remove transparent pixels 42 | ODhat = OD[~np.any(OD vMax[0]: 64 | HE = np.array((vMin[:,0], vMax[:,0])).T 65 | else: 66 | HE = np.array((vMax[:,0], vMin[:,0])).T 67 | 68 | return HE 69 | 70 | fname = INPUT_FILE 71 | 72 | with open(fname, 'rb') as f: 73 | kdtree = pickle.load(f) 74 | 75 | HEs_general = kdtree.data 76 | 77 | input_csv = INPUT_DATA 78 | input_data = pd.read_csv(input_csv, sep=',',header=None).values 79 | 80 | def extend_stains(kdtree, new_data, save_new_array = False, PERC=1.0): 81 | 82 | HEs_new_stains = [] 83 | 84 | threshold_value = int(len(new_data)*PERC) 85 | 86 | i = 0 87 | 88 | HEs_general = kdtree.data 89 | 90 | np.random.shuffle(new_data) 91 | 92 | for i in tqdm(range(threshold_value)): 93 | 94 | patch = new_data[i,0] 95 | 96 | img = Image.open(patch) 97 | img_np = np.asarray(img) 98 | 99 | HE = H_E_Staining(img_np) 100 | 101 | HE = np.reshape(HE, 6) 102 | HEs_new_stains.append(HE) 103 | 104 | img.close() 105 | 106 | #i = i + 1 107 | 108 | HEs_new_stains = np.array(HEs_new_stains) 109 | HEs_general = np.append(HEs_general,HEs_new_stains,axis=0) 110 | 111 | new_kdtree = cKDTree(HEs_general) 112 | 113 | if (save_new_array==True): 114 | 115 | fname = OUTPUT_FILE 116 | with open(fname, 'wb') as f: 117 | pickle.dump(kdtree, f) 118 | 119 | return new_kdtree 120 | 121 | start_time = time.time() 122 | new_kdtree = extend_stains(kdtree, input_data, save_new_array = True, PERC=1.0) 123 | elapsed_time = time.time() - start_time 124 | print("elapsed time " + str(elapsed_time)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data_Driven_Color_Augmentation 2 | 3 | Implementation of "Data-driven color augmentation for H&E stained images in computational pathology". 4 | 5 | ## Reference 6 | If you find this repository useful in your research, please cite: 7 | 8 | [1] Marini N., Otálora S., Wodzinski M., Tomassini S. Dragoni A.F., Marchand-Maillet S., Dominguez P., Duran-Lopez L., Vatrano S., Müller H. & Atzori M., Data-driven color augmentation for H&E stained images in computational pathology. 9 | 10 | Paper link: https://www.sciencedirect.com/science/article/pii/S2153353922007830 11 | 12 | ## Requirements 13 | Python==3.6.9, albumentations==0.1.8, numpy==1.17.3, opencv==4.2.0, pandas==0.25.2, pillow==6.1.0, torchvision==0.8.1, pytorch==1.7.0 14 | 15 | ## CSV Input Files: 16 | CSV files are used as input for the scripts. For each partition (train, validation, test), the csv file has path_to_image, class_label as columns. 17 | For prostate experiments, the class_label can be: 18 | 0: benign 19 | 1: Gleason pattern 3 20 | 2: Gleason pattern 4 21 | 3: Gleason pattern 5 22 | 23 | For colon experiments, the class_label can be: 24 | 0: cancer 25 | 1: dysplasia 26 | 2: normal glands 27 | 28 | ## Augmentation 29 | Methods to perform data drive color augmentation (Augmentation.py): 30 | - new_color_augmentation (HSC color augmentation): 31 | * patch_np: numpy array for the input patch (224x224) 32 | * kdtree: the database where acceptable color variations are stored 33 | * alpha: neighbors 34 | * beta: radius 35 | * shift_value: perturbation to apply to Hue, Saturation, Contrast 36 | - new_stain_augmentation (perturbation of H&E channels): 37 | * patch_np: numpy array for the input patch (224x224) 38 | * kdtree: the database where acceptable color variations are stored 39 | * alpha: neighbors 40 | * beta: radius 41 | * sigma1: range (-sigma1, sigma1) to generate random value to multiply to H&E components 42 | * sigma2: range (-sigma2, sigma2) to generate random value to add to H&E components 43 | 44 | ## Database 45 | Database including color variations: https://zenodo.org/record/7505727#.Y7ayO3bMJPY. 46 | 47 | Method to extend database with new histopathology patches (extend_kd_tree_offline) 48 | * -i: input pickle file (database to extend) 49 | * -o: output pickle file 50 | * -d: csv including patches to extend database 51 | 52 | ## Training 53 | Scripts to train the CNN at path-level, in a fully-supervised fashion. 54 | Some parameters must be manually changed, such as the number of classes (output of the network). 55 | 56 | - Training_new_augmentation.py -n -b -c -e -f -i -o -a -d. The script is used to train the CNN without any augmentation (no_augment), with colour augmentation (augment). 57 | * -n: number of the experiment for the training 58 | * -b: batch size (32) 59 | * -c: CNN backbone to use (densenet121) 60 | * -e: number of epochs (10) 61 | * -t: task of the network (no_augment, augment, normalizer) 62 | * -f: if True an embedding layer with 128 nodes is inserted before the output layer 63 | * -i: path of the folder where the input csvs for training (train.csv), validation (valid.csv) and testing (test.csv) are stored 64 | * -o: path of the folder where to store the CNN’s weights. 65 | * -a: new augmentation to use: color (HSC color augmentation), stain (H&E stain augmentation), he (H&E-adversarial CNN + HSC color augmentation) 66 | * -x: extend (False): add color variations training data to color variation dataset 67 | * -d: database: database including color variations 68 | 69 | 70 | ## Acknoledgements 71 | This project has received funding from the EuropeanUnion’s Horizon 2020 research and innovation programme under grant agree-ment No. 825292 [ExaMode](http://www.examode.eu). Infrastructure fromthe SURFsara HPC center was used to train the CNN models in parallel. 72 | -------------------------------------------------------------------------------- /scripts/Augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from PIL import Image 4 | import albumentations as A 5 | import warnings 6 | 7 | from scipy.spatial import KDTree, cKDTree 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | def H_E_Staining(img, Io=240, alpha=1, beta=0.15): 12 | 13 | # define height and width of image 14 | h, w, c = img.shape 15 | 16 | # reshape image 17 | img = img.reshape((-1,3)) 18 | 19 | # calculate optical density 20 | OD = -np.log((img.astype(np.float)+1)/Io) 21 | 22 | # remove transparent pixels 23 | ODhat = OD[~np.any(OD vMax[0]: 45 | HE = np.array((vMin[:,0], vMax[:,0])).T 46 | else: 47 | HE = np.array((vMax[:,0], vMin[:,0])).T 48 | 49 | return HE 50 | 51 | def unique_elements(array): 52 | 53 | _, counts = np.unique(array, return_counts=True) 54 | 55 | b = True 56 | 57 | for c in counts: 58 | 59 | if (c>1): 60 | 61 | b = False 62 | 63 | return b 64 | 65 | def normalizeStaining(img, HERef, Io=240, alpha=1, beta=0.15): 66 | 67 | maxCRef = np.array([1.9705, 1.0308]) 68 | 69 | # define height and width of image 70 | h, w, c = img.shape 71 | 72 | # reshape image 73 | img = img.reshape((-1,3)) 74 | 75 | # calculate optical density 76 | OD = -np.log((img.astype(np.float)+1)/Io) 77 | 78 | # remove transparent pixels 79 | ODhat = OD[~np.any(OD vMax[0]: 101 | HE = np.array((vMin[:,0], vMax[:,0])).T 102 | else: 103 | HE = np.array((vMax[:,0], vMin[:,0])).T 104 | 105 | # rows correspond to channels (RGB), columns to OD values 106 | Y = np.reshape(OD, (-1, 3)).T 107 | 108 | # determine concentrations of the individual stains 109 | C = np.linalg.lstsq(HE,Y, rcond=None)[0] 110 | 111 | # normalize stain concentrations 112 | maxC = np.array([np.percentile(C[0,:], 99), np.percentile(C[1,:],99)]) 113 | tmp = np.divide(maxC,maxCRef) 114 | C2 = np.divide(C,tmp[:, np.newaxis]) 115 | 116 | # recreate the image using reference mixing matrix 117 | Inorm = np.multiply(Io, np.exp(-HERef.dot(C2))) 118 | Inorm[Inorm>255] = 254 119 | Inorm = np.reshape(Inorm.T, (h, w, 3)).astype(np.uint8) 120 | 121 | return Inorm 122 | 123 | def new_color_augmentation(patch_np, kdtree, alpha, beta, shift_value=70, threshold=1000): 124 | 125 | b = False 126 | i = 0 127 | 128 | pipeline_transform_ = A.Compose([ 129 | A.HueSaturationValue(hue_shift_limit=(-shift_value,shift_value),sat_shift_limit=(-shift_value,shift_value),val_shift_limit=(-shift_value,shift_value),always_apply=True), 130 | ]) 131 | 132 | while (b==False and i1): 225 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 226 | 227 | self.fc_feat_in = fc_input_features 228 | self.N_CLASSES = 4 229 | 230 | if (EMBEDDING_bool==True): 231 | 232 | if ('resnet18' in CNN_TO_USE): 233 | self.E = 128 234 | self.L = self.E 235 | self.D = 64 236 | self.K = self.N_CLASSES 237 | 238 | elif ('resnet34' in CNN_TO_USE): 239 | self.E = 128 240 | self.L = self.E 241 | self.D = 64 242 | self.K = self.N_CLASSES 243 | #self.K = 1 244 | elif ('resnet50' in CNN_TO_USE): 245 | self.E = 256 246 | self.L = self.E 247 | self.D = 128 248 | self.K = self.N_CLASSES 249 | elif ('densenet121' in CNN_TO_USE): 250 | self.E = 128 251 | self.L = self.E 252 | self.D = 64 253 | self.K = self.N_CLASSES 254 | 255 | #self.embedding = siamese_model.embedding 256 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 257 | self.embedding_fc = torch.nn.Linear(in_features=self.E, out_features=self.N_CLASSES) 258 | 259 | else: 260 | self.fc = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.N_CLASSES) 261 | 262 | if ('resnet18' in CNN_TO_USE): 263 | self.L = fc_input_features 264 | self.D = 128 265 | self.K = self.N_CLASSES 266 | 267 | elif ('resnet34' in CNN_TO_USE): 268 | self.L = fc_input_features 269 | self.D = 128 270 | self.K = self.N_CLASSES 271 | 272 | elif ('resnet50' in CNN_TO_USE): 273 | self.L = self.E 274 | self.D = 256 275 | self.K = self.N_CLASSES 276 | elif ('densenet121' in CNN_TO_USE): 277 | self.E = 128 278 | self.L = self.E 279 | self.D = 64 280 | self.K = self.N_CLASSES 281 | 282 | self.domain_predictor = domain_predictor(6) 283 | 284 | def forward(self, x, mode, alpha): 285 | """ 286 | In the forward function we accept a Tensor of input data and we must return 287 | a Tensor of output data. We can use Modules defined in the constructor as 288 | well as arbitrary operators on Tensors. 289 | """ 290 | #if used attention pooling 291 | A = None 292 | #m = torch.nn.Softmax(dim=1) 293 | m_binary = torch.nn.Sigmoid() 294 | m_multiclass = torch.nn.Softmax() 295 | dropout = torch.nn.Dropout(p=0.2) 296 | 297 | if x is not None: 298 | #print(x.shape) 299 | conv_layers_out=self.conv_layers(x) 300 | #print(x.shape) 301 | if ('densenet' in CNN_TO_USE): 302 | n = torch.nn.AdaptiveAvgPool2d((1,1)) 303 | conv_layers_out = n(conv_layers_out) 304 | 305 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 306 | 307 | #print(conv_layers_out.shape) 308 | 309 | if ('mobilenet' in CNN_TO_USE): 310 | dropout = torch.nn.Dropout(p=0.2) 311 | conv_layers_out = dropout(conv_layers_out) 312 | #print(conv_layers_out.shape) 313 | 314 | if (EMBEDDING_bool==True): 315 | embedding_layer = self.embedding(conv_layers_out) 316 | features_to_return = embedding_layer 317 | 318 | embedding_layer = dropout(embedding_layer) 319 | logits = self.embedding_fc(embedding_layer) 320 | 321 | else: 322 | logits = self.fc(conv_layers_out) 323 | features_to_return = conv_layers_out 324 | 325 | output_fcn = m_multiclass(logits) 326 | 327 | if (mode=='train'): 328 | reverse_feature = ReverseLayerF.apply(conv_layers_out, alpha) 329 | 330 | output_domain = self.domain_predictor(reverse_feature) 331 | output_fcn = m_multiclass(logits) 332 | 333 | return logits, output_fcn, output_domain 334 | 335 | return logits, output_fcn 336 | 337 | 338 | class CNN_model(torch.nn.Module): 339 | def __init__(self): 340 | """ 341 | In the constructor we instantiate two nn.Linear modules and assign them as 342 | member variables. 343 | """ 344 | super(CNN_model, self).__init__() 345 | self.conv_layers = torch.nn.Sequential(*list(pre_trained_network.children())[:-1]) 346 | 347 | if (torch.cuda.device_count()>1): 348 | self.conv_layers = torch.nn.DataParallel(self.conv_layers) 349 | 350 | self.fc_feat_in = fc_input_features 351 | self.N_CLASSES = 4 352 | 353 | if (EMBEDDING_bool==True): 354 | if ('resnet18' in CNN_TO_USE): 355 | self.E = 128 356 | self.L = self.E 357 | self.D = 64 358 | self.K = self.N_CLASSES 359 | 360 | elif ('resnet34' in CNN_TO_USE): 361 | self.E = 128 362 | self.L = self.E 363 | self.D = 64 364 | self.K = self.N_CLASSES 365 | #self.K = 1 366 | elif ('resnet50' in CNN_TO_USE): 367 | self.E = 256 368 | self.L = self.E 369 | self.D = 128 370 | self.K = self.N_CLASSES 371 | elif ('densenet121' in CNN_TO_USE): 372 | self.E = 128 373 | self.L = self.E 374 | self.D = 64 375 | self.K = self.N_CLASSES 376 | 377 | #self.embedding = siamese_model.embedding 378 | self.embedding = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.E) 379 | self.embedding_fc = torch.nn.Linear(in_features=self.E, out_features=self.N_CLASSES) 380 | 381 | else: 382 | self.fc = torch.nn.Linear(in_features=self.fc_feat_in, out_features=self.N_CLASSES) 383 | 384 | if ('resnet18' in CNN_TO_USE): 385 | self.L = fc_input_features 386 | self.D = 128 387 | self.K = self.N_CLASSES 388 | 389 | elif ('resnet34' in CNN_TO_USE): 390 | self.L = fc_input_features 391 | self.D = 128 392 | self.K = self.N_CLASSES 393 | 394 | elif ('resnet50' in CNN_TO_USE): 395 | self.L = fc_input_features 396 | self.D = 256 397 | self.K = self.N_CLASSES 398 | 399 | elif ('densenet121' in CNN_TO_USE): 400 | self.L = fc_input_features 401 | self.D = 64 402 | self.K = self.N_CLASSES 403 | 404 | 405 | def forward(self, x, conv_layers_out): 406 | """ 407 | In the forward function we accept a Tensor of input data and we must return 408 | a Tensor of output data. We can use Modules defined in the constructor as 409 | well as arbitrary operators on Tensors. 410 | """ 411 | #if used attention pooling 412 | A = None 413 | #m = torch.nn.Softmax(dim=1) 414 | m_binary = torch.nn.Sigmoid() 415 | m_multiclass = torch.nn.Softmax() 416 | 417 | dropout = torch.nn.Dropout(p=0.2) 418 | 419 | if x is not None: 420 | #print(x.shape) 421 | conv_layers_out=self.conv_layers(x) 422 | #print(x.shape) 423 | 424 | if ('densenet' in CNN_TO_USE): 425 | n = torch.nn.AdaptiveAvgPool2d((1,1)) 426 | conv_layers_out = n(conv_layers_out) 427 | 428 | conv_layers_out = conv_layers_out.view(-1, self.fc_feat_in) 429 | 430 | #print(conv_layers_out.shape) 431 | 432 | if ('mobilenet' in CNN_TO_USE): 433 | dropout = torch.nn.Dropout(p=0.2) 434 | conv_layers_out = dropout(conv_layers_out) 435 | #print(conv_layers_out.shape) 436 | 437 | if (EMBEDDING_bool==True): 438 | embedding_layer = self.embedding(conv_layers_out) 439 | features_to_return = embedding_layer 440 | 441 | embedding_layer = dropout(embedding_layer) 442 | logits = self.embedding_fc(embedding_layer) 443 | 444 | else: 445 | logits = self.fc(conv_layers_out) 446 | features_to_return = conv_layers_out 447 | 448 | output_fcn = m_multiclass(logits) 449 | 450 | return logits, output_fcn 451 | 452 | if (TYPE_AUGMENTATION=='he'): 453 | model = CNN_model_multitask() 454 | else: 455 | model = CNN_model() 456 | 457 | #DATA AUGMENTATION 458 | from torchvision import transforms 459 | prob = 0.5 460 | 461 | pipeline_transform = A.Compose([ 462 | A.VerticalFlip(p=prob), 463 | A.HorizontalFlip(p=prob), 464 | A.RandomRotate90(p=prob), 465 | #A.ElasticTransform(alpha=0.1,p=prob), 466 | #A.HueSaturationValue(hue_shift_limit=(-9),sat_shift_limit=25,val_shift_limit=10,p=prob), 467 | ]) 468 | 469 | #DATA NORMALIZATION 470 | preprocess = transforms.Compose([ 471 | transforms.ToTensor(), 472 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 473 | ]) 474 | 475 | def extend_stains(kdtree, new_data, save_new_array = False, PERC=1.0): 476 | 477 | HEs_new_stains = [] 478 | 479 | threshold_value = int(len(new_data)*PERC) 480 | 481 | i = 0 482 | 483 | HEs_general = kdtree.data 484 | 485 | np.random.shuffle(new_data) 486 | 487 | for i in tqdm(range(threshold_value)): 488 | 489 | patch = new_data[i,0] 490 | 491 | img = Image.open(patch) 492 | img_np = np.asarray(img) 493 | 494 | HE = H_E_Staining(img_np) 495 | 496 | HE = np.reshape(HE, 6) 497 | HEs_new_stains.append(HE) 498 | 499 | img.close() 500 | 501 | #i = i + 1 502 | 503 | HEs_new_stains = np.array(HEs_new_stains) 504 | HEs_general = np.append(HEs_general,HEs_new_stains,axis=0) 505 | 506 | new_kdtree = cKDTree(HEs_general) 507 | 508 | if (save_new_array==True): 509 | print("EXTENDING STAINS") 510 | fname = FOLDER_KD + 'kdtree_extended.pickle' 511 | 512 | with open(fname, 'wb') as f: 513 | pickle.dump(kdtree, f) 514 | 515 | print("EXTENSION DONE") 516 | 517 | return new_kdtree 518 | 519 | sigma_perturb = 0.1 520 | nearest_neighbours = 5 521 | 522 | sigma1 = 0.7 523 | sigma2 = 0.7 524 | 525 | alpha = nearest_neighbours 526 | beta = sigma_perturb 527 | 528 | class Dataset_patches(data.Dataset): 529 | 530 | def __init__(self, list_IDs, labels, mode): 531 | 532 | self.labels = labels 533 | self.list_IDs = list_IDs 534 | self.mode = mode 535 | 536 | def __len__(self): 537 | 538 | return len(self.list_IDs) 539 | 540 | def __getitem__(self, index): 541 | 542 | # Select sample 543 | ID = self.list_IDs[index] 544 | # Load data and get label 545 | X = Image.open(ID) 546 | X = np.asarray(X) 547 | y = self.labels[index] 548 | #data augmentation 549 | 550 | if (self.mode == 'train'): 551 | X = pipeline_transform(image=X)['image'] 552 | 553 | rand_val = np.random.rand(1)[0] 554 | 555 | if (rand_val>prob): 556 | 557 | if (TYPE_AUGMENTATION=='color'): 558 | #print("color") 559 | X, _ = new_color_augmentation(X, kdtree, alpha, beta) 560 | 561 | elif (TYPE_AUGMENTATION=='stain'): 562 | #print("stain") 563 | X, _ = new_stain_augmentation(X,kdtree, alpha, beta, sigma1, sigma2) 564 | 565 | elif (TYPE_AUGMENTATION=='he'): 566 | #print("color") 567 | X, h_e_matrix = new_color_augmentation(X, kdtree, alpha, beta) 568 | 569 | if (TYPE_AUGMENTATION=='he'): 570 | 571 | h_e_matrix = H_E_Staining(X) 572 | 573 | h_e_matrix = np.reshape(h_e_matrix, 6) 574 | h_e_matrix = np.asarray(h_e_matrix) 575 | else: 576 | h_e_matrix = np.asarray([0]) 577 | 578 | new_image = np.asarray(X) 579 | #data transformation 580 | input_tensor = preprocess(new_image) 581 | 582 | return input_tensor, np.asarray(y), h_e_matrix 583 | 584 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 585 | #device = torch.device('cpu') 586 | 587 | # Parameters 588 | 589 | num_workers = 2 590 | params_train = {'batch_size': BATCH_SIZE, 591 | #'shuffle': True, 592 | 'sampler': ImbalancedDatasetSampler(train_dataset), 593 | 'num_workers': num_workers} 594 | 595 | params_valid = {'batch_size': BATCH_SIZE, 596 | 'shuffle': True, 597 | #'sampler': ImbalancedDatasetSampler(valid_dataset), 598 | 'num_workers': num_workers} 599 | 600 | params_test = {'batch_size': BATCH_SIZE, 601 | 'shuffle': True, 602 | #'sampler': ImbalancedDatasetSampler(test_dataset), 603 | 'num_workers': num_workers} 604 | 605 | max_epochs = int(EPOCHS_str) 606 | 607 | 608 | 609 | # In[28]: 610 | 611 | 612 | #CREATE GENERATORS 613 | #train 614 | training_set = Dataset_patches(train_dataset[:,0], train_dataset[:,1],'train') 615 | training_generator = data.DataLoader(training_set, **params_train) 616 | 617 | validation_set = Dataset_patches(valid_dataset[:,0], valid_dataset[:,1],'valid') 618 | validation_generator = data.DataLoader(validation_set, **params_valid) 619 | 620 | 621 | #semi-weakly supervision 622 | 623 | # Find total parameters and trainable parameters 624 | total_params = sum(p.numel() for p in model.parameters()) 625 | print(f'{total_params:,} total parameters.') 626 | total_trainable_params = sum( 627 | p.numel() for p in model.parameters() if p.requires_grad) 628 | print(f'{total_trainable_params:,} training parameters.') 629 | 630 | class_sample_count = np.unique(train_dataset[:,1], return_counts=True)[1] 631 | weight = class_sample_count / len(train_dataset[:,1]) 632 | #for avoiding propagation of fake benign class 633 | samples_weight = torch.from_numpy(weight).type(torch.FloatTensor) 634 | 635 | class RMSELoss(torch.nn.Module): 636 | def __init__(self, eps=1e-6): 637 | super().__init__() 638 | self.mse = torch.nn.MSELoss() 639 | self.eps = eps 640 | 641 | def forward(self,yhat,y): 642 | loss = torch.sqrt(self.mse(yhat,y) + self.eps) 643 | return loss 644 | 645 | import torch.optim as optim 646 | 647 | criterion_domain = RMSELoss() 648 | criterion = torch.nn.CrossEntropyLoss() 649 | 650 | num_epochs = EPOCHS 651 | epoch = 0 652 | early_stop_cont = 0 653 | EARLY_STOP_NUM = 5 654 | #weight_decay = 1e-4 655 | weight_decay = 0 656 | lr = 1e-3 657 | 658 | optimizer = optim.Adam(model.parameters(),lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=True) 659 | model.to(device) 660 | 661 | if (EXTEND==True): 662 | 663 | new_kdtree = extend_stains(kdtree, train_dataset, save_new_array = False, PERC=0.5) 664 | kdtree = new_kdtree 665 | 666 | else: 667 | 668 | #fname = FOLDER_KD + 'kdtree_TCGA_ExaMode_extended_prostate.pickle' 669 | fname = FOLDER_KD + 'kdtree_TCGA_ExaMode.pickle' 670 | 671 | with open(fname, 'rb') as f: 672 | kdtree = pickle.load(f) 673 | 674 | 675 | def evaluate_validation_set(generator): 676 | #accumulator for validation set 677 | y_pred = [] 678 | y_true = [] 679 | 680 | valid_loss = 0.0 681 | 682 | with torch.no_grad(): 683 | j = 0 684 | for inputs,labels, _ in generator: 685 | inputs, labels = inputs.to(device), labels.to(device) 686 | 687 | # forward + backward + optimize 688 | logits, outputs = model(inputs, None) 689 | 690 | loss = criterion(logits, labels) 691 | #outputs = F.softmax(outputs) 692 | 693 | valid_loss = valid_loss + ((1 / (j+1)) * (loss.item() - valid_loss)) 694 | 695 | outputs_np = outputs.cpu().data.numpy() 696 | labels_np = labels.cpu().data.numpy() 697 | outputs_np = np.argmax(outputs_np, axis=1) 698 | 699 | y_true = np.append(y_true, outputs_np) 700 | y_pred = np.append(y_pred, labels_np) 701 | 702 | j = j+1 703 | 704 | acc_valid = metrics.accuracy_score(y_true=y_true, y_pred=y_pred) 705 | kappa_valid = metrics.cohen_kappa_score(y1=y_true,y2=y_pred, weights='quadratic') 706 | print("loss: " + str(valid_loss) + ", accuracy: " + str(acc_valid) + ", kappa score: " + str(kappa_valid)) 707 | 708 | return valid_loss 709 | # In[35]: 710 | 711 | best_loss_valid = 100000.0 712 | 713 | losses_train = [] 714 | losses_valid = [] 715 | 716 | 717 | lambda_val = 0.5 718 | 719 | while (epochvalid_loss): 832 | print ("=> Saving a new best model") 833 | print("previous loss TMA: " + str(best_loss_valid) + ", new loss function TMA: " + str(valid_loss)) 834 | best_loss_valid = valid_loss 835 | torch.save(model, model_path) 836 | early_stop_cont = 0 837 | else: 838 | early_stop_cont = early_stop_cont+1 839 | 840 | epoch = epoch + 1 841 | 842 | print('Finished Training') 843 | --------------------------------------------------------------------------------