├── Figures ├── Ablation.PNG ├── Cell_segmentation.png ├── Description ├── Method (Main).png ├── Method (Submodule).png ├── Skin + Cell.png └── Skin lesion_segmentation.png ├── Prepare_ISIC2017.py ├── Prepare_ISIC2018.py ├── Prepare_ph2.py ├── README.md ├── config_skin.yml ├── evaluate_skin.ipynb ├── evaluate_skin.py ├── loader.py ├── model ├── TransMUNet.py ├── __init__.py └── transformer.py ├── requirements.txt ├── results └── readme.txt ├── train_skin.ipynb ├── train_skin.py └── weights └── readme.txt /Figures/Ablation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Ablation.PNG -------------------------------------------------------------------------------- /Figures/Cell_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Cell_segmentation.png -------------------------------------------------------------------------------- /Figures/Description: -------------------------------------------------------------------------------- 1 | Figures used in this repository. 2 | -------------------------------------------------------------------------------- /Figures/Method (Main).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Method (Main).png -------------------------------------------------------------------------------- /Figures/Method (Submodule).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Method (Submodule).png -------------------------------------------------------------------------------- /Figures/Skin + Cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Skin + Cell.png -------------------------------------------------------------------------------- /Figures/Skin lesion_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezazad68/TMUnet/13ec19fb78dfd6889754d49666140541ee7ee8b6/Figures/Skin lesion_segmentation.png -------------------------------------------------------------------------------- /Prepare_ISIC2017.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | from __future__ import division 7 | import numpy as np 8 | import scipy.io as sio 9 | import scipy.misc as sc 10 | import glob 11 | 12 | # Parameters 13 | height = 256 14 | width = 256 15 | channels = 3 16 | 17 | ############################################################# Prepare ISIC 2017 data set ################################################# 18 | Dataset_add = '../dataset_isic17/' 19 | Tr_add = 'ISIC-2017_Training_Data' 20 | 21 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 22 | # It contains 2594 training samples 23 | Data_train_2017 = np.zeros([2000, height, width, channels]) 24 | Label_train_2017 = np.zeros([2000, height, width]) 25 | 26 | print('Reading ISIC 2017') 27 | for idx in range(len(Tr_list)): 28 | print(idx+1) 29 | img = sc.imread(Tr_list[idx]) 30 | 31 | 32 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 33 | Data_train_2017[idx, :,:,:] = img 34 | 35 | 36 | b = Tr_list[idx] 37 | a = b[0:len(Dataset_add)] 38 | b = b[len(b)-16: len(b)-4] 39 | add = (a+ 'ISIC-2017_Training_Part1_GroundTruth/' + b +'_segmentation.png') 40 | img2 = sc.imread(add) 41 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 42 | Label_train_2017[idx, :,:] = img2 43 | 44 | print('Reading ISIC 2017 finished') 45 | 46 | ################################################################ Make the train and test sets ######################################## 47 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 48 | 49 | Train_img = Data_train_2017[0:1399,:,:,:] 50 | Validation_img = Data_train_2017[1399:1399+200,:,:,:] 51 | Test_img = Data_train_2017[1399+200:1999,:,:,:] 52 | 53 | Train_mask = Label_train_2017[0:1399,:,:] 54 | Validation_mask = Label_train_2017[1399:1399+200,:,:] 55 | Test_mask = Label_train_2017[1399+200:1999,:,:] 56 | 57 | 58 | np.save('data_train', Train_img) 59 | np.save('data_test' , Test_img) 60 | np.save('data_val' , Validation_img) 61 | 62 | np.save('mask_train', Train_mask) 63 | np.save('mask_test' , Test_mask) 64 | np.save('mask_val' , Validation_mask) 65 | 66 | 67 | -------------------------------------------------------------------------------- /Prepare_ISIC2018.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | import h5py 7 | import numpy as np 8 | import scipy.io as sio 9 | import scipy.misc as sc 10 | import glob 11 | 12 | # Parameters 13 | height = 256 14 | width = 256 15 | channels = 3 16 | 17 | ############################################################# Prepare ISIC 2018 data set ################################################# 18 | Dataset_add = '/ISIC2018/' 19 | Tr_add = 'ISIC2018_Task1-2_Training_Input' 20 | 21 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 22 | # It contains 2594 training samples 23 | Data_train_2018 = np.zeros([2594, height, width, channels]) 24 | Label_train_2018 = np.zeros([2594, height, width]) 25 | 26 | print('Reading ISIC 2018') 27 | for idx in range(len(Tr_list)): 28 | print(idx+1) 29 | img = sc.imread(Tr_list[idx]) 30 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 31 | Data_train_2018[idx, :,:,:] = img 32 | 33 | b = Tr_list[idx] 34 | a = b[0:len(Dataset_add)] 35 | b = b[len(b)-16: len(b)-4] 36 | add = (a+ 'ISIC2018_Task1_Training_GroundTruth/' + b +'_segmentation.png') 37 | img2 = sc.imread(add) 38 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 39 | Label_train_2018[idx, :,:] = img2 40 | 41 | print('Reading ISIC 2018 finished') 42 | 43 | ################################################################ Make the train and test sets ######################################## 44 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 45 | 46 | Train_img = Data_train_2018[0:1815,:,:,:] 47 | Validation_img = Data_train_2018[1815:1815+259,:,:,:] 48 | Test_img = Data_train_2018[1815+259:2594,:,:,:] 49 | 50 | Train_mask = Label_train_2018[0:1815,:,:] 51 | Validation_mask = Label_train_2018[1815:1815+259,:,:] 52 | Test_mask = Label_train_2018[1815+259:2594,:,:] 53 | 54 | 55 | np.save('data_train', Train_img) 56 | np.save('data_test' , Test_img) 57 | np.save('data_val' , Validation_img) 58 | 59 | np.save('mask_train', Train_mask) 60 | np.save('mask_test' , Test_mask) 61 | np.save('mask_val' , Validation_mask) 62 | 63 | 64 | -------------------------------------------------------------------------------- /Prepare_ph2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | from __future__ import division 7 | import numpy as np 8 | import scipy.io as sio 9 | import scipy.misc as sc 10 | import glob 11 | import random 12 | #from sklearn.model_selection import train_test_split 13 | 14 | # Parameters 15 | height = 256 16 | width = 256 17 | channels = 3 18 | 19 | ############################################################# Prepare ph2 data set ################################################# 20 | Dataset_add = 'data/' 21 | Tr_add = 'lesions/' 22 | 23 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.bmp') 24 | # It contains 2594 training samples 25 | Data_train = np.zeros([200, height, width, channels]) 26 | Label_train = np.zeros([200, height, width]) 27 | 28 | print('Reading Ph2') 29 | 30 | random.shuffle(Tr_list) 31 | 32 | for idx in range(len(Tr_list)): 33 | print(idx+1) 34 | print(Tr_list[idx]) 35 | img = sc.imread(Tr_list[idx]) 36 | 37 | 38 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 39 | Data_train[idx, :,:,:] = img 40 | 41 | 42 | b = Tr_list[idx] 43 | 44 | #print(b) 45 | 46 | a = b[0:len(Dataset_add)] 47 | b = b[len(b)-10: len(b)-4] 48 | 49 | # print(a) 50 | # print(b) 51 | 52 | 53 | add = (a+ 'masks/' + b +'_lesion.bmp') 54 | img2 = sc.imread(add) 55 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 56 | Label_train[idx, :,:] = img2 57 | 58 | print('Reading Ph2 finished') 59 | 60 | ################################################################ Make the train and test sets ######################################## 61 | # We consider 80 samples for training, 20 samples for validation and 100 samples for testing 62 | 63 | 64 | Train_img = Data_train[0:80,:,:,:] 65 | Validation_img = Data_train[80:100,:,:,:] 66 | Test_img = Data_train[100:200,:,:,:] 67 | 68 | Train_mask = Label_train[0:80,:,:] 69 | Validation_mask = Label_train[80:100,:,:] 70 | Test_mask = Label_train[100:200,:,:] 71 | 72 | 73 | np.save('data_train', Train_img) 74 | np.save('data_test' , Test_img) 75 | np.save('data_val' , Validation_img) 76 | 77 | np.save('mask_train', Train_mask) 78 | np.save('mask_test' , Test_mask) 79 | np.save('mask_val' , Validation_mask) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Contextual Attention Network: Transformer Meets U-Net](https://arxiv.org/abs/2203.01932) 2 | 3 | Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. 4 | If this code helps with your research please consider citing the following paper: 5 |
6 | > [R. Azad](https://scholar.google.com/citations?hl=en&user=Qb5ildMAAAAJ&view_op=list_works&sortby=pubdate), [Moein Heidari](https://scholar.google.com/citations?user=mir8D5UAAAAJ&hl=en&oi=sra), [Yuli Wu](https://scholar.google.com/citations?user=qlun0AgAAAAJ) and [Dorit Merhof 7 | ](https://scholar.google.com/citations?user=JH5HObAAAAAJ&sortby=pubdate), "Contextual Attention Network: Transformer Meets U-Net", download [link](https://arxiv.org/abs/2203.01932). 8 | 9 | ```python 10 | @article{reza2022contextual, 11 | title={Contextual Attention Network: Transformer Meets U-Net}, 12 | author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof}, 13 | journal={arXiv preprint arXiv:2203.01932}, 14 | year={2022} 15 | } 16 | 17 | ``` 18 | 19 | #### Please consider starring us, if you found it useful. Thanks 20 | 21 | ## Updates 22 | - February 27, 2022: First release (Complete implemenation for [SKin Lesion Segmentation on ISIC 2017](https://challenge.isic-archive.com/landing/2017/), [SKin Lesion Segmentation on ISIC 2018](https://challenge2018.isic-archive.com/), [SKin Lesion Segmentation on PH2](https://www.fc.up.pt/addi/ph2%20database.html) and [Multiple Myeloma Cell Segmentation (SegPC 2021)](https://www.kaggle.com/sbilab/segpc2021dataset) dataset added.) 23 | 24 | This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code: 25 | 26 | - Python 3 27 | - Pytorch 28 | 29 | 30 | ## Run Demo 31 | For training deep model and evaluating on each data set follow the bellow steps:
32 | 1- Download the ISIC 2018 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `dataset_isic18`.
33 | 2- Run `Prepare_ISIC2018.py` for data preperation and dividing data to train,validation and test sets.
34 | 3- Run `train_skin.py` for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
35 | 4- For performance calculation and producing segmentation result, run `evaluate_skin.py`. It will represent performance measures and will saves related results in `results` folder.
36 | 37 | **Notice:** 38 | For training and evaluating on ISIC 2017 and ph2 follow the bellow steps : 39 | 40 | **ISIC 2017**- Download the ISIC 2017 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `dataset_isic18\7`.
then Run ` Prepare_ISIC2017.py` for data preperation and dividing data to train,validation and test sets.
41 | **ph2**- Download the ph2 dataset from [this](https://www.dropbox.com/s/k88qukc20ljnbuo/PH2Dataset.rar) link and extract it then Run ` Prepare_ph2.py` for data preperation and dividing data to train,validation and test sets.
42 | Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset. 43 | 44 | 45 | 46 | ## Quick Overview 47 | ![Diagram of the proposed method](https://github.com/rezazad68/TMUnet/blob/main/Figures/Method%20(Main).png) 48 | 49 | ### Perceptual visualization of the proposed Contextual Attention module. 50 | ![Diagram of the proposed method](https://github.com/rezazad68/TMUnet/blob/main/Figures/Method%20(Submodule).png) 51 | 52 | 53 | ## Results 54 | For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated. 55 |
56 | #### Task 1: SKin Lesion Segmentation 57 | 58 | 59 | #### Performance Comparision on SKin Lesion Segmentation 60 | In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset. 61 | 62 | Methods (On ISIC 2017) |Dice-Score | Sensivity| Specificaty| Accuracy 63 | ------------ | -------------|----|-----------------|---- 64 | Ronneberger and et. all [U-net](https://arxiv.org/abs/1505.04597) |0.8159 |0.8172 |0.9680 |0.9164 65 | Oktay et. all [Attention U-net](https://arxiv.org/abs/1804.03999) |0.8082 |0.7998 |0.9776 |0.9145 66 | Lei et. all [DAGAN](https://www.sciencedirect.com/science/article/abs/pii/S1361841520300803) |0.8425 |0.8363 |0.9716 |0.9304 67 | Chen et. all [TransU-net](https://arxiv.org/abs/2102.04306) |0.8123 |0.8263 |0.9577 |0.9207 68 | Asadi et. all [MCGU-Net](https://arxiv.org/abs/2003.05056) |0.8927 | 0.8502 |**0.9855** |0.9570 69 | Valanarasu et. all [MedT](https://arxiv.org/abs/2102.10662) |0.8037 |0.8064 |0.9546 |0.9090 70 | Wu et. all [FAT-Net](https://www.sciencedirect.com/science/article/abs/pii/S1361841521003728) |0.8500 |0.8392 |0.9725 |0.9326 71 | Azad et. all [Proposed TMUnet](https://arxiv.org/abs/2203.01932) |**0.9164** | **0.9128** |0.9789 |**0.9660** 72 | ### For more results on ISIC 2018 and PH2 dataset, please refer to [the paper](https://arxiv.org/abs/2203.01932) 73 | 74 | 75 | #### SKin Lesion Segmentation segmentation result on test data 76 | 77 | ![SKin Lesion Segmentation result](https://github.com/rezazad68/TMUnet/blob/main/Figures/Skin%20lesion_segmentation.png) 78 | (a) Input images. (b) Ground truth. (c) [U-net](https://arxiv.org/abs/2102.10662). (d) [Gated Axial-Attention](https://arxiv.org/abs/2102.10662). (e) Proposed method without a contextual attention module and (f) Proposed method. 79 | 80 | 81 | ## Multiple Myeloma Cell Segmentation 82 | 83 | #### Performance Evalution on the Multiple Myeloma Cell Segmentation task 84 | 85 | Methods | mIOU 86 | ------------ | ------------- 87 | [Frequency recalibration U-Net](https://openaccess.thecvf.com/content/ICCV2021W/CVAMD/papers/Azad_Deep_Frequency_Re-Calibration_U-Net_for_Medical_Image_Segmentation_ICCVW_2021_paper.pdf) |0.9392 88 | [XLAB Insights](https://arxiv.org/abs/2105.06238) |0.9360 89 | [DSC-IITISM](https://arxiv.org/abs/2105.06238) |0.9356 90 | [Multi-scale attention deeplabv3+](https://arxiv.org/abs/2105.06238) |0.9065 91 | [U-Net](https://arxiv.org/abs/1505.04597) |0.7665 92 | [Baseline](https://arxiv.org/abs/2203.01932) |0.9172 93 | [Proposed](https://arxiv.org/abs/2203.01932) |**0.9395** 94 | 95 | 96 | 97 | #### Multiple Myeloma Cell Segmentation results 98 | 99 | ![Multiple Myeloma Cell Segmentation result](https://github.com/rezazad68/TMUnet/blob/main/Figures/Cell_segmentation.png) 100 | 101 | ### Model weights 102 | You can download the learned weights for each dataset in the following table. 103 | 104 | Dataset |Learned weights 105 | ------------ | ------------- 106 | [ISIC 2018]() |[TMUnet](https://drive.google.com/file/d/1EU4stQUtUt6bcSoWswBYpfTZd53sVAJy/view?usp=sharing) 107 | [ISIC 2017]() |[TMUnet](https://drive.google.com/file/d/1gEb8juWB2JjxAws91D3S0wxnrVwuMRZo/view?usp=sharing) 108 | [Ph2]() | [TMUnet](https://drive.google.com/file/d/1soZ6UYhZk7r5-klflJHZxtbdH6pKi7t6/view?usp=sharing) 109 | 110 | 111 | 112 | ### Query 113 | All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information. 114 | 115 | ```python 116 | rezazad68@gmail.com 117 | moeinheidari7829@gmail.com 118 | 119 | ``` 120 | 121 | -------------------------------------------------------------------------------- /config_skin.yml: -------------------------------------------------------------------------------- 1 | ## Config file 2 | lr: 1e-4 # Initial learning rate 3 | epochs: 50 # Number of epochs to train the model 4 | number_classes: 1 # Number of classes in the target dataset 5 | batch_size_tr: 4 # Batch size for train 6 | batch_size_va: 1 # Batch size for validationn 7 | saved_model: './weights/weights_isic18.model' # leave '' will use a default value 8 | path_to_data: './processed_data/isic18/' #path to dataset 9 | patience: 10 # number of epochs without improvement to do before finishing training early. 10 | save_result: './results/' # path to save results 11 | progress_p: 0.1 # value between 0-1 shows the number of time we need to report training progress in each epoch 12 | pretrained: 0 # load the previously trained weight or no value should either 1 or 0 13 | -------------------------------------------------------------------------------- /evaluate_skin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | from __future__ import division 8 | import os 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | import torch 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from loader import * 14 | import glob 15 | import numpy as np 16 | import copy 17 | import yaml 18 | from sklearn.metrics import f1_score 19 | from tqdm import tqdm 20 | from model.TransMUNet import TransMUNet 21 | from sklearn.metrics import confusion_matrix 22 | from sklearn.metrics import f1_score 23 | from matplotlib import pyplot as plt 24 | get_ipython().run_line_magic('matplotlib', 'inline') 25 | from scipy.ndimage.morphology import binary_fill_holes, binary_opening 26 | 27 | 28 | # In[2]: 29 | 30 | 31 | ## Hyper parameters 32 | config = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader) 33 | number_classes = int(config['number_classes']) 34 | input_channels = 3 35 | best_val_loss = np.inf 36 | patience = 0 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | 39 | data_path = config['path_to_data'] 40 | 41 | test_dataset = isic_loader(path_Data = data_path, train = False, Test = True) 42 | test_loader = DataLoader(test_dataset, batch_size = 1, shuffle= True) 43 | 44 | 45 | # In[3]: 46 | 47 | 48 | Net = TransMUNet(n_classes = number_classes) 49 | Net = Net.to(device) 50 | Net.load_state_dict(torch.load(config['saved_model'], map_location='cpu')['model_weights']) 51 | 52 | 53 | # ## Quntitative performance 54 | 55 | # In[4]: 56 | 57 | 58 | predictions = [] 59 | gt = [] 60 | 61 | with torch.no_grad(): 62 | print('val_mode') 63 | val_loss = 0 64 | Net.eval() 65 | for itter, batch in tqdm(enumerate(test_loader)): 66 | img = batch['image'].to(device, dtype=torch.float) 67 | msk = batch['mask'] 68 | msk_pred = Net(img) 69 | 70 | gt.append(msk.numpy()[0, 0]) 71 | msk_pred = msk_pred.cpu().detach().numpy()[0, 0] 72 | msk_pred = np.where(msk_pred>=0.43, 1, 0) 73 | msk_pred = binary_opening(msk_pred, structure=np.ones((6,6))).astype(msk_pred.dtype) 74 | msk_pred = binary_fill_holes(msk_pred, structure=np.ones((6,6))).astype(msk_pred.dtype) 75 | predictions.append(msk_pred) 76 | 77 | 78 | 79 | predictions = np.array(predictions) 80 | gt = np.array(gt) 81 | 82 | y_scores = predictions.reshape(-1) 83 | y_true = gt.reshape(-1) 84 | 85 | y_scores2 = np.where(y_scores>0.47, 1, 0) 86 | y_true2 = np.where(y_true>0.5, 1, 0) 87 | 88 | #F1 score 89 | F1_score = f1_score(y_true2, y_scores2, labels=None, average='binary', sample_weight=None) 90 | print ("\nF1 score (F-measure) or DSC: " +str(F1_score)) 91 | confusion = confusion_matrix(np.int32(y_true), y_scores2) 92 | print (confusion) 93 | accuracy = 0 94 | if float(np.sum(confusion))!=0: 95 | accuracy = float(confusion[0,0]+confusion[1,1])/float(np.sum(confusion)) 96 | print ("Accuracy: " +str(accuracy)) 97 | specificity = 0 98 | if float(confusion[0,0]+confusion[0,1])!=0: 99 | specificity = float(confusion[0,0])/float(confusion[0,0]+confusion[0,1]) 100 | print ("Specificity: " +str(specificity)) 101 | sensitivity = 0 102 | if float(confusion[1,1]+confusion[1,0])!=0: 103 | sensitivity = float(confusion[1,1])/float(confusion[1,1]+confusion[1,0]) 104 | print ("Sensitivity: " +str(sensitivity)) 105 | 106 | 107 | 108 | # ## Visualization section 109 | 110 | # In[5]: 111 | 112 | 113 | def save_sample(img, msk, msk_pred, th=0.3, name=''): 114 | img2 = img.detach().cpu().numpy()[0] 115 | img2 = np.einsum('kij->ijk', img2) 116 | msk2 = msk.detach().cpu().numpy()[0,0] 117 | mskp = msk_pred.detach().cpu().numpy()[0,0] 118 | msk2 = np.where(msk2>0.5, 1., 0) 119 | mskp = np.where(mskp>=th, 1., 0) 120 | 121 | plt.figure(figsize=(7,15)) 122 | 123 | plt.subplot(3,1,1) 124 | plt.imshow(img2/255.) 125 | plt.axis('off') 126 | 127 | plt.subplot(3,1,2) 128 | plt.imshow(msk2*255, cmap= 'gray') 129 | plt.axis('off') 130 | 131 | plt.subplot(3,1,3) 132 | plt.imshow(mskp*255, cmap = 'gray') 133 | plt.axis('off') 134 | 135 | plt.savefig('./results/'+name+'.png') 136 | 137 | 138 | # In[6]: 139 | 140 | 141 | predictions = [] 142 | gt = [] 143 | 144 | N = 5 ## Number of samples to visualize 145 | with torch.no_grad(): 146 | print('val_mode') 147 | val_loss = 0 148 | Net.eval() 149 | for itter, batch in tqdm(enumerate(test_loader)): 150 | img = batch['image'].to(device, dtype=torch.float) 151 | msk = batch['mask'] 152 | msk_pred = Net(img) 153 | 154 | gt.append(msk.numpy()) 155 | predictions.append(msk_pred.cpu().detach().numpy()) 156 | save_sample(img, msk, msk_pred, th=0.5, name=str(itter+1)) 157 | if itter+1==N: 158 | break 159 | 160 | 161 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import random 5 | from einops.layers.torch import Rearrange 6 | from scipy.ndimage.morphology import binary_dilation 7 | 8 | # ===== normalize over the dataset 9 | def dataset_normalized(imgs): 10 | imgs_normalized = np.empty(imgs.shape) 11 | imgs_std = np.std(imgs) 12 | imgs_mean = np.mean(imgs) 13 | imgs_normalized = (imgs-imgs_mean)/imgs_std 14 | for i in range(imgs.shape[0]): 15 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 16 | return imgs_normalized 17 | 18 | 19 | class weak_annotation(torch.nn.Module): 20 | def __init__(self, patch_size = 16, img_size = 256): 21 | super().__init__() 22 | self.arranger = Rearrange('c (ph h) (pw w) -> c (ph pw) h w', c=1, h=patch_size, ph=img_size//patch_size, w=patch_size, pw=img_size//patch_size) 23 | def forward(self, x): 24 | x = self.arranger(x) 25 | x = torch.sum(x, dim = [-2, -1]) 26 | x = x/x.max() 27 | return x 28 | 29 | def Bextraction(img): 30 | img = img[0].numpy() 31 | img2 = binary_dilation(img, structure=np.ones((7,7))).astype(img.dtype) 32 | img3 = img2 - img 33 | img3 = np.expand_dims(img3, axis = 0) 34 | return torch.tensor(img3.copy()) 35 | 36 | ## Temporary 37 | class isic_loader(Dataset): 38 | """ dataset class for Brats datasets 39 | """ 40 | def __init__(self, path_Data, train = True, Test = False): 41 | super(isic_loader, self) 42 | self.train = train 43 | if train: 44 | self.data = np.load(path_Data+'data_train.npy') 45 | self.mask = np.load(path_Data+'mask_train.npy') 46 | else: 47 | if Test: 48 | self.data = np.load(path_Data+'data_test.npy') 49 | self.mask = np.load(path_Data+'mask_test.npy') 50 | else: 51 | self.data = np.load(path_Data+'data_val.npy') 52 | self.mask = np.load(path_Data+'mask_val.npy') 53 | 54 | 55 | self.data = dataset_normalized(self.data) 56 | self.mask = np.expand_dims(self.mask, axis=3) 57 | self.mask = self.mask /255. 58 | self.weak_annotation = weak_annotation(patch_size = 16, img_size = 256) 59 | 60 | def __getitem__(self, indx): 61 | img = self.data[indx] 62 | seg = self.mask[indx] 63 | if self.train: 64 | img, seg = self.apply_augmentation(img, seg) 65 | 66 | seg = torch.tensor(seg.copy()) 67 | img = torch.tensor(img.copy()) 68 | img = img.permute( 2, 0, 1) 69 | seg = seg.permute( 2, 0, 1) 70 | 71 | weak_ann = self.weak_annotation(seg) 72 | boundary = Bextraction(seg) 73 | 74 | return {'image': img, 75 | 'weak_ann': weak_ann, 76 | 'boundary': boundary, 77 | 'mask' : seg} 78 | 79 | def apply_augmentation(self, img, seg): 80 | if random.random() < 0.5: 81 | img = np.flip(img, axis=1) 82 | seg = np.flip(seg, axis=1) 83 | return img, seg 84 | 85 | def __len__(self): 86 | return len(self.data) 87 | -------------------------------------------------------------------------------- /model/TransMUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | resnet = torchvision.models.resnet.resnet50(pretrained=True) 5 | from .transformer import ViT 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | class ConvBlock(nn.Module): 11 | """ 12 | Helper module that consists of a Conv -> BN -> ReLU 13 | """ 14 | 15 | def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) 18 | self.bn = nn.BatchNorm2d(out_channels) 19 | self.relu = nn.ReLU() 20 | self.with_nonlinearity = with_nonlinearity 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | if self.with_nonlinearity: 26 | x = self.relu(x) 27 | return x 28 | 29 | 30 | class Bridge(nn.Module): 31 | """ 32 | This is the middle layer of the UNet which just consists of some 33 | """ 34 | 35 | def __init__(self, in_channels, out_channels): 36 | super().__init__() 37 | self.bridge = nn.Sequential( 38 | ConvBlock(in_channels, out_channels), 39 | ConvBlock(out_channels, out_channels) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.bridge(x) 44 | 45 | 46 | class UpBlockForUNetWithResNet50(nn.Module): 47 | """ 48 | Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock 49 | """ 50 | 51 | def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, 52 | upsampling_method="conv_transpose"): 53 | super().__init__() 54 | 55 | if up_conv_in_channels == None: 56 | up_conv_in_channels = in_channels 57 | if up_conv_out_channels == None: 58 | up_conv_out_channels = out_channels 59 | 60 | if upsampling_method == "conv_transpose": 61 | self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) 62 | elif upsampling_method == "bilinear": 63 | self.upsample = nn.Sequential( 64 | nn.Upsample(mode='bilinear', scale_factor=2), 65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 66 | ) 67 | self.conv_block_1 = ConvBlock(in_channels, out_channels) 68 | self.conv_block_2 = ConvBlock(out_channels, out_channels) 69 | 70 | def forward(self, up_x, down_x): 71 | """ 72 | 73 | :param up_x: this is the output from the previous up block 74 | :param down_x: this is the output from the down block 75 | :return: upsampled feature map 76 | """ 77 | x = self.upsample(up_x) 78 | x = torch.cat([x, down_x], 1) 79 | x = self.conv_block_1(x) 80 | x = self.conv_block_2(x) 81 | return x 82 | 83 | 84 | class SE_Block(nn.Module): 85 | def __init__(self, c, r=16): 86 | super().__init__() 87 | self.squeeze = nn.AdaptiveAvgPool2d(1) 88 | self.excitation = nn.Sequential( 89 | nn.Linear(c, c // r, bias=False), 90 | nn.ReLU(inplace=True), 91 | nn.Linear(c // r, c, bias=False), 92 | nn.Sigmoid() 93 | ) 94 | 95 | def forward(self, x): 96 | bs, c, _, _ = x.shape 97 | y = self.squeeze(x).view(bs, c) 98 | y = self.excitation(y).view(bs, c, 1, 1) 99 | x = x * y.expand_as(x) 100 | return y 101 | 102 | 103 | class TransMUNet(nn.Module): 104 | DEPTH = 6 105 | 106 | def __init__(self, n_classes=2, 107 | patch_size: int = 16, 108 | emb_size: int = 512, 109 | img_size: int = 256, 110 | n_channels = 3, 111 | depth: int = 4, 112 | n_regions: int = (256//16)**2, 113 | output_ch: int = 1, 114 | bilinear=True): 115 | super().__init__() 116 | self.n_classes = n_classes 117 | self.transformer = ViT(in_channels= n_channels, 118 | patch_size=patch_size, 119 | emb_size=emb_size, 120 | img_size=img_size, 121 | depth=depth, 122 | n_regions=n_regions) 123 | resnet = torchvision.models.resnet.resnet50(pretrained=True) 124 | down_blocks = [] 125 | up_blocks = [] 126 | self.input_block = nn.Sequential(*list(resnet.children()))[:3] 127 | self.input_pool = list(resnet.children())[3] 128 | for bottleneck in list(resnet.children()): 129 | if isinstance(bottleneck, nn.Sequential): 130 | down_blocks.append(bottleneck) 131 | self.down_blocks = nn.ModuleList(down_blocks) 132 | self.bridge = Bridge(2048, 2048) 133 | up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) 134 | up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) 135 | up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) 136 | up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, 137 | up_conv_in_channels=256, up_conv_out_channels=128)) 138 | up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, 139 | up_conv_in_channels=128, up_conv_out_channels=64)) 140 | 141 | self.up_blocks = nn.ModuleList(up_blocks) 142 | 143 | self.out = nn.Conv2d(128, n_classes, kernel_size=1, stride=1) 144 | 145 | self.boundary = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, stride=1), 146 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 147 | nn.Conv2d(32, 1, kernel_size=1, stride=1, bias=False), 148 | nn.Sigmoid()) 149 | 150 | self.se = SE_Block(c =64) 151 | 152 | def forward(self, x, with_additional=False): 153 | [global_contexual, regional_distribution, region_coeff] = self.transformer(x) 154 | 155 | pre_pools = dict() 156 | pre_pools[f"layer_0"] = x 157 | x = self.input_block(x) 158 | pre_pools[f"layer_1"] = x 159 | x = self.input_pool(x) 160 | 161 | for i, block in enumerate(self.down_blocks, 2): 162 | x = block(x) 163 | if i == (TransMUNet.DEPTH - 1): 164 | continue 165 | pre_pools[f"layer_{i}"] = x 166 | 167 | x = self.bridge(x) 168 | 169 | for i, block in enumerate(self.up_blocks, 1): 170 | key = f"layer_{TransMUNet.DEPTH - 1 - i}" 171 | x = block(x, pre_pools[key]) 172 | 173 | B_out = self.boundary(x) 174 | B = B_out.repeat_interleave(int(x.shape[1]), dim=1) 175 | x = self.se(x) 176 | x = x+B 177 | att = regional_distribution.repeat_interleave(int(x.shape[1]), dim=1) 178 | x = x*att 179 | x = torch.cat((x, global_contexual), dim=1) 180 | x = self.out(x) 181 | del pre_pools 182 | x = torch.sigmoid(x) 183 | if with_additional: 184 | return x, B_out, region_coeff 185 | else: 186 | return x 187 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .TransMUNet import * 2 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import Tensor 5 | from einops import rearrange, reduce, repeat 6 | from einops.layers.torch import Rearrange, Reduce 7 | 8 | 9 | class DoubleConv(nn.Module): 10 | def __init__(self, in_channels, out_channels, mid_channels=None): 11 | super().__init__() 12 | if not mid_channels: 13 | mid_channels = out_channels 14 | self.double_conv = nn.Sequential( 15 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 16 | nn.BatchNorm2d(mid_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(2) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | class Encoder_patch(nn.Module): 28 | def __init__(self, n_channels, emb_size= 512, bilinear=True): 29 | super(Encoder_patch, self).__init__() 30 | self.n_channels = n_channels 31 | self.emb_size = emb_size 32 | self.bilinear = bilinear 33 | 34 | self.conv1 = DoubleConv(n_channels, 128) 35 | self.conv2 = DoubleConv(128, 256) 36 | self.conv3 = DoubleConv(256, emb_size) 37 | 38 | def forward(self, x): 39 | x = self.conv1(x) 40 | x = self.conv2(x) 41 | x = self.conv3(x) 42 | x = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(x, 1), start_dim = 1) 43 | return x 44 | 45 | class PatchEmbedding(nn.Module): 46 | def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224): 47 | self.patch_size = patch_size 48 | super().__init__() 49 | # self.encoder = Encoder_patch(n_channels = in_channels, emb_size= emb_size) 50 | self.projection = nn.Sequential( 51 | Rearrange('b c (ph h) (pw w) -> b c (ph pw) h w', c=in_channels, h=patch_size, ph=img_size//patch_size, w=patch_size, pw=img_size//patch_size), 52 | Rearrange('b c p h w -> (b p) c h w'), 53 | Encoder_patch(n_channels = in_channels, emb_size= emb_size), 54 | Rearrange('(b p) d-> b p d', p = (img_size//patch_size)**2), 55 | ) 56 | self.cls_token = nn.Parameter(torch.randn(1,1, emb_size)) 57 | self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size)) 58 | 59 | 60 | def forward(self, x: Tensor) -> Tensor: 61 | b, _, _, _ = x.shape 62 | x = self.projection(x) 63 | cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) 64 | # prepend the cls token to the input 65 | x = torch.cat([cls_tokens, x], dim=1) 66 | # add position embedding 67 | x += self.positions 68 | return x 69 | 70 | class MultiHeadAttention(nn.Module): 71 | def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0): 72 | super().__init__() 73 | self.emb_size = emb_size 74 | self.num_heads = num_heads 75 | # fuse the queries, keys and values in one matrix 76 | self.qkv = nn.Linear(emb_size, emb_size * 3) 77 | self.att_drop = nn.Dropout(dropout) 78 | self.projection = nn.Linear(emb_size, emb_size) 79 | 80 | def forward(self, x : Tensor, mask: Tensor = None) -> Tensor: 81 | # split keys, queries and values in num_heads 82 | qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3) 83 | queries, keys, values = qkv[0], qkv[1], qkv[2] 84 | # sum up over the last axis 85 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len 86 | if mask is not None: 87 | fill_value = torch.finfo(torch.float32).min 88 | energy.mask_fill(~mask, fill_value) 89 | 90 | scaling = self.emb_size ** (1/2) 91 | att = F.softmax(energy, dim=-1) / scaling 92 | att = self.att_drop(att) 93 | # sum up over the third axis 94 | out = torch.einsum('bhal, bhlv -> bhav ', att, values) 95 | out = rearrange(out, "b h n d -> b n (h d)") 96 | out = self.projection(out) 97 | return out 98 | 99 | class ResidualAdd(nn.Module): 100 | def __init__(self, fn): 101 | super().__init__() 102 | self.fn = fn 103 | 104 | def forward(self, x, **kwargs): 105 | res = x 106 | x = self.fn(x, **kwargs) 107 | x += res 108 | return x 109 | 110 | class FeedForwardBlock(nn.Sequential): 111 | def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.): 112 | super().__init__( 113 | nn.Linear(emb_size, expansion * emb_size), 114 | nn.GELU(), 115 | nn.Dropout(drop_p), 116 | nn.Linear(expansion * emb_size, emb_size), 117 | ) 118 | 119 | class TransformerEncoderBlock(nn.Sequential): 120 | def __init__(self, 121 | emb_size: int = 768, 122 | drop_p: float = 0., 123 | forward_expansion: int = 4, 124 | forward_drop_p: float = 0., 125 | ** kwargs): 126 | super().__init__( 127 | ResidualAdd(nn.Sequential( 128 | nn.LayerNorm(emb_size), 129 | MultiHeadAttention(emb_size, **kwargs), 130 | nn.Dropout(drop_p) 131 | )), 132 | ResidualAdd(nn.Sequential( 133 | nn.LayerNorm(emb_size), 134 | FeedForwardBlock( 135 | emb_size, expansion=forward_expansion, drop_p=forward_drop_p), 136 | nn.Dropout(drop_p) 137 | ) 138 | )) 139 | 140 | class TransformerEncoder(nn.Sequential): 141 | def __init__(self, depth: int = 12, **kwargs): 142 | super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)]) 143 | 144 | 145 | class dependencymap(nn.Sequential): 146 | def __init__(self, emb_size: int = 768, n_regions: int = 256, patch_size: int = 16, img_size: int = 256, output_ch: int=64, cuda=True): 147 | super().__init__() 148 | self.patch_size = patch_size 149 | self.img_size = img_size 150 | self.emb_size = emb_size 151 | self.output_ch = output_ch 152 | self.cuda = cuda 153 | self.outconv = nn.Sequential( 154 | nn.Conv2d(emb_size, output_ch, kernel_size=1, padding=0), 155 | nn.BatchNorm2d(output_ch), 156 | nn.Sigmoid() 157 | ) 158 | self.out2 = nn.Sigmoid() 159 | 160 | self.gpool = nn.AdaptiveAvgPool1d(1) 161 | def forward(self, x): 162 | x_gpool = self.gpool(x) 163 | coeff = torch.zeros((x.size()[0], self.emb_size, self.img_size, self.img_size)) 164 | coeff2 = torch.zeros((x.size()[0], 1, self.img_size, self.img_size)) 165 | if self.cuda: 166 | coeff = coeff.cuda() 167 | coeff2 = coeff2.cuda() 168 | for i in range(0, self.img_size//self.patch_size): 169 | for j in range(0, self.img_size//self.patch_size): 170 | value = x[:,(i*self.patch_size)+j] 171 | value = value.view(value.size()[0], value.size()[1], 1, 1) 172 | coeff[:,:,self.patch_size*i:self.patch_size*(i+1),self.patch_size*j:self.patch_size*(j+1)] = value.repeat(1, 1, self.patch_size, self.patch_size) 173 | 174 | value = x_gpool[:,(i*self.patch_size)+j] 175 | value = value.view(value.size()[0], value.size()[1], 1, 1) 176 | coeff2[:,:,self.patch_size*i:self.patch_size*(i+1),self.patch_size*j:self.patch_size*(j+1)] = value.repeat(1, 1, self.patch_size, self.patch_size) 177 | 178 | global_contexual = self.outconv(coeff) 179 | regional_distribution = self.out2(coeff2) 180 | return [global_contexual, regional_distribution, self.out2(x_gpool)] 181 | 182 | class ViT(nn.Sequential): 183 | def __init__(self, 184 | in_channels: int = 3, 185 | patch_size: int = 16, 186 | emb_size: int = 1024, 187 | img_size: int = 256, 188 | depth: int = 2, 189 | n_regions: int = (256//16)**2, 190 | output_ch: int = 64, 191 | cuda = True, 192 | **kwargs): 193 | super().__init__( 194 | PatchEmbedding(in_channels, patch_size, emb_size, img_size), 195 | TransformerEncoder(depth, emb_size=emb_size, **kwargs), 196 | dependencymap(emb_size, n_regions, patch_size, img_size, output_ch, cuda) 197 | ) 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | h5py==3.1.0 3 | matplotlib==3.2.2 4 | nibabel==3.0.2 5 | numpy==1.21.5 6 | opencv_python==4.1.2.30 7 | pandas==1.3.5 8 | PyYAML==6.0 9 | scikit_learn==1.0.2 10 | scipy==1.4.1 11 | torch==1.10.0+cu111 12 | torchvision==0.11.1+cu111 13 | tqdm==4.63.0 14 | -------------------------------------------------------------------------------- /results/readme.txt: -------------------------------------------------------------------------------- 1 | visualization results will be save here -------------------------------------------------------------------------------- /train_skin.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "palestinian-shadow", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from __future__ import division\n", 11 | "import os\n", 12 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 13 | "import torch\n", 14 | "import torch.optim as optim\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "from loader import *\n", 17 | "from model.TransMUNet import TransMUNet\n", 18 | "import pandas as pd\n", 19 | "import glob\n", 20 | "import nibabel as nib\n", 21 | "import numpy as np\n", 22 | "import copy\n", 23 | "import yaml" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "expensive-courage", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "## Loader\n", 34 | "## Hyper parameters\n", 35 | "config = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)\n", 36 | "number_classes = int(config['number_classes'])\n", 37 | "input_channels = 3\n", 38 | "best_val_loss = np.inf\n", 39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 40 | "\n", 41 | "data_path = config['path_to_data'] \n", 42 | "\n", 43 | "train_dataset = isic_loader(path_Data = data_path, train = True)\n", 44 | "train_loader = DataLoader(train_dataset, batch_size = int(config['batch_size_tr']), shuffle= True)\n", 45 | "val_dataset = isic_loader(path_Data = data_path, train = False)\n", 46 | "val_loader = DataLoader(val_dataset, batch_size = int(config['batch_size_va']), shuffle= False)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "id": "southern-harvard", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "Net = TransMUNet(n_classes = number_classes)\n", 57 | "\n", 58 | "Net = Net.to(device)\n", 59 | "if int(config['pretrained']):\n", 60 | " Net.load_state_dict(torch.load(config['saved_model'], map_location='cpu')['model_weights'])\n", 61 | " best_val_loss = torch.load(config['saved_model'], map_location='cpu')['val_loss']\n", 62 | "optimizer = optim.Adam(Net.parameters(), lr= float(config['lr']))\n", 63 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience'])\n", 64 | "criteria = torch.nn.BCELoss()\n", 65 | "criteria_boundary = torch.nn.BCELoss()\n", 66 | "criteria_region = torch.nn.MSELoss()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "dramatic-separation", 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | " Epoch>> 1 and itteration 1 Loss>> 0.4878859519958496\n", 80 | " Epoch>> 1 and itteration 46 Loss>> 0.436914870272512\n", 81 | " Epoch>> 1 and itteration 91 Loss>> 0.406768771645787\n", 82 | " Epoch>> 1 and itteration 136 Loss>> 0.387475683408625\n", 83 | " Epoch>> 1 and itteration 181 Loss>> 0.37029958166470184\n", 84 | " Epoch>> 1 and itteration 226 Loss>> 0.3621396487001824\n", 85 | " Epoch>> 1 and itteration 271 Loss>> 0.35190100401529967\n", 86 | " Epoch>> 1 and itteration 316 Loss>> 0.34527019302867634\n", 87 | " Epoch>> 1 and itteration 361 Loss>> 0.340015287214369\n", 88 | " Epoch>> 1 and itteration 406 Loss>> 0.3345703736975275\n", 89 | " Epoch>> 1 and itteration 451 Loss>> 0.3311662516017711\n", 90 | "val_mode\n", 91 | " validation on epoch>> 1 dice loss>> 0.2752935446539901\n", 92 | "New best loss, saving...\n", 93 | " Epoch>> 2 and itteration 1 Loss>> 0.28962594270706177\n", 94 | " Epoch>> 2 and itteration 46 Loss>> 0.29288280528524646\n", 95 | " Epoch>> 2 and itteration 91 Loss>> 0.2938939706309811\n", 96 | " Epoch>> 2 and itteration 136 Loss>> 0.2820981578791843\n", 97 | " Epoch>> 2 and itteration 181 Loss>> 0.2701566784092076\n", 98 | " Epoch>> 2 and itteration 226 Loss>> 0.2711107705151085\n", 99 | " Epoch>> 2 and itteration 271 Loss>> 0.26777574854363373\n", 100 | " Epoch>> 2 and itteration 316 Loss>> 0.26263950001212616\n", 101 | " Epoch>> 2 and itteration 361 Loss>> 0.26310517149783896\n", 102 | " Epoch>> 2 and itteration 406 Loss>> 0.2610443519988083\n", 103 | " Epoch>> 2 and itteration 451 Loss>> 0.256510941555389\n", 104 | "val_mode\n", 105 | " validation on epoch>> 2 dice loss>> 0.22702077907082197\n", 106 | "New best loss, saving...\n", 107 | " Epoch>> 3 and itteration 1 Loss>> 0.13973449170589447\n", 108 | " Epoch>> 3 and itteration 46 Loss>> 0.2345198366952979\n", 109 | " Epoch>> 3 and itteration 91 Loss>> 0.23095878385580504\n", 110 | " Epoch>> 3 and itteration 136 Loss>> 0.22887960920000777\n", 111 | " Epoch>> 3 and itteration 181 Loss>> 0.22682313011825414\n", 112 | " Epoch>> 3 and itteration 226 Loss>> 0.23694280053661987\n", 113 | " Epoch>> 3 and itteration 271 Loss>> 0.23583715939653754\n", 114 | " Epoch>> 3 and itteration 316 Loss>> 0.2386164115388182\n", 115 | " Epoch>> 3 and itteration 361 Loss>> 0.23308741177465778\n", 116 | " Epoch>> 3 and itteration 406 Loss>> 0.2304859147980589\n", 117 | " Epoch>> 3 and itteration 451 Loss>> 0.22806577361343705\n", 118 | "val_mode\n", 119 | " validation on epoch>> 3 dice loss>> 0.20915593974953914\n", 120 | "New best loss, saving...\n", 121 | " Epoch>> 4 and itteration 1 Loss>> 0.12882468104362488\n", 122 | " Epoch>> 4 and itteration 46 Loss>> 0.20352134416284767\n", 123 | " Epoch>> 4 and itteration 91 Loss>> 0.20201156525821476\n", 124 | " Epoch>> 4 and itteration 136 Loss>> 0.20217801005963018\n", 125 | " Epoch>> 4 and itteration 181 Loss>> 0.20059176369760576\n", 126 | " Epoch>> 4 and itteration 226 Loss>> 0.20083754753644487\n", 127 | " Epoch>> 4 and itteration 271 Loss>> 0.20434071197830883\n", 128 | " Epoch>> 4 and itteration 316 Loss>> 0.20028786205604107\n", 129 | " Epoch>> 4 and itteration 361 Loss>> 0.19854767649457725\n", 130 | " Epoch>> 4 and itteration 406 Loss>> 0.19949371474129812\n", 131 | " Epoch>> 4 and itteration 451 Loss>> 0.20029027064531182\n", 132 | "val_mode\n", 133 | " validation on epoch>> 4 dice loss>> 0.2112294882751685\n", 134 | " Epoch>> 5 and itteration 1 Loss>> 0.14160121977329254\n", 135 | " Epoch>> 5 and itteration 46 Loss>> 0.20069558011448901\n", 136 | " Epoch>> 5 and itteration 91 Loss>> 0.19101668493105814\n", 137 | " Epoch>> 5 and itteration 136 Loss>> 0.18340545553056634\n", 138 | " Epoch>> 5 and itteration 181 Loss>> 0.19069095532209174\n", 139 | " Epoch>> 5 and itteration 226 Loss>> 0.19197250065286603\n", 140 | " Epoch>> 5 and itteration 271 Loss>> 0.1918776873292958\n", 141 | " Epoch>> 5 and itteration 316 Loss>> 0.1895883841605126\n", 142 | " Epoch>> 5 and itteration 361 Loss>> 0.18681040210159203\n", 143 | " Epoch>> 5 and itteration 406 Loss>> 0.1870680547920354\n", 144 | " Epoch>> 5 and itteration 451 Loss>> 0.18691231079085704\n", 145 | "val_mode\n", 146 | " validation on epoch>> 5 dice loss>> 0.18418753754103045\n", 147 | "New best loss, saving...\n", 148 | " Epoch>> 6 and itteration 1 Loss>> 0.0972343161702156\n", 149 | " Epoch>> 6 and itteration 46 Loss>> 0.19054791347488112\n", 150 | " Epoch>> 6 and itteration 91 Loss>> 0.1728426708461164\n", 151 | " Epoch>> 6 and itteration 136 Loss>> 0.18079832261976073\n", 152 | " Epoch>> 6 and itteration 181 Loss>> 0.18011006914614314\n", 153 | " Epoch>> 6 and itteration 226 Loss>> 0.18350923110821607\n", 154 | " Epoch>> 6 and itteration 271 Loss>> 0.1807868150492436\n", 155 | " Epoch>> 6 and itteration 316 Loss>> 0.1776462991403628\n", 156 | " Epoch>> 6 and itteration 361 Loss>> 0.177176372124878\n", 157 | " Epoch>> 6 and itteration 406 Loss>> 0.17609688527698586\n", 158 | " Epoch>> 6 and itteration 451 Loss>> 0.17529611433226358\n", 159 | "val_mode\n", 160 | " validation on epoch>> 6 dice loss>> 0.17005475121460367\n", 161 | "New best loss, saving...\n", 162 | " Epoch>> 7 and itteration 1 Loss>> 0.09515374898910522\n", 163 | " Epoch>> 7 and itteration 46 Loss>> 0.1728055507916471\n", 164 | " Epoch>> 7 and itteration 91 Loss>> 0.16099595466812888\n", 165 | " Epoch>> 7 and itteration 136 Loss>> 0.1645061452051296\n", 166 | " Epoch>> 7 and itteration 181 Loss>> 0.16411139643159361\n", 167 | " Epoch>> 7 and itteration 226 Loss>> 0.1669956852508857\n", 168 | " Epoch>> 7 and itteration 271 Loss>> 0.16581149426773464\n", 169 | " Epoch>> 7 and itteration 316 Loss>> 0.164764616638422\n", 170 | " Epoch>> 7 and itteration 361 Loss>> 0.16483611638278511\n", 171 | " Epoch>> 7 and itteration 406 Loss>> 0.16375028892544102\n", 172 | " Epoch>> 7 and itteration 451 Loss>> 0.16221840408715335\n", 173 | "val_mode\n", 174 | " validation on epoch>> 7 dice loss>> 0.15854016915107555\n", 175 | "New best loss, saving...\n", 176 | " Epoch>> 8 and itteration 1 Loss>> 0.14681853353977203\n", 177 | " Epoch>> 8 and itteration 46 Loss>> 0.15058728777196095\n", 178 | " Epoch>> 8 and itteration 91 Loss>> 0.15467100729654124\n", 179 | " Epoch>> 8 and itteration 136 Loss>> 0.15571273216868148\n", 180 | " Epoch>> 8 and itteration 181 Loss>> 0.15117962238538332\n", 181 | " Epoch>> 8 and itteration 226 Loss>> 0.1503101315646045\n", 182 | " Epoch>> 8 and itteration 271 Loss>> 0.14998266829445794\n", 183 | " Epoch>> 8 and itteration 316 Loss>> 0.15000073913531967\n", 184 | " Epoch>> 8 and itteration 361 Loss>> 0.1535256127736575\n", 185 | " Epoch>> 8 and itteration 406 Loss>> 0.1531859316629142\n", 186 | " Epoch>> 8 and itteration 451 Loss>> 0.15230796429970841\n", 187 | "val_mode\n", 188 | " validation on epoch>> 8 dice loss>> 0.15324555431640055\n", 189 | "New best loss, saving...\n", 190 | " Epoch>> 9 and itteration 1 Loss>> 0.12041860818862915\n", 191 | " Epoch>> 9 and itteration 46 Loss>> 0.13735647208016852\n", 192 | " Epoch>> 9 and itteration 91 Loss>> 0.14442674799279853\n", 193 | " Epoch>> 9 and itteration 136 Loss>> 0.14633055480525775\n", 194 | " Epoch>> 9 and itteration 181 Loss>> 0.14017007574027415\n", 195 | " Epoch>> 9 and itteration 226 Loss>> 0.14329959132370695\n", 196 | " Epoch>> 9 and itteration 271 Loss>> 0.14341566736407826\n", 197 | " Epoch>> 9 and itteration 316 Loss>> 0.1468054106859844\n", 198 | " Epoch>> 9 and itteration 361 Loss>> 0.14839129310895863\n", 199 | " Epoch>> 9 and itteration 406 Loss>> 0.14842939013477616\n", 200 | " Epoch>> 9 and itteration 451 Loss>> 0.14841331706739583\n", 201 | "val_mode\n", 202 | " validation on epoch>> 9 dice loss>> 0.1478783944262098\n", 203 | "New best loss, saving...\n", 204 | " Epoch>> 10 and itteration 1 Loss>> 0.06400887668132782\n", 205 | " Epoch>> 10 and itteration 46 Loss>> 0.1386732875328997\n", 206 | " Epoch>> 10 and itteration 91 Loss>> 0.1310957330432567\n", 207 | " Epoch>> 10 and itteration 136 Loss>> 0.13258406530846567\n", 208 | " Epoch>> 10 and itteration 181 Loss>> 0.1302474557488992\n", 209 | " Epoch>> 10 and itteration 226 Loss>> 0.13106200245695304\n", 210 | " Epoch>> 10 and itteration 271 Loss>> 0.13732820812516547\n", 211 | " Epoch>> 10 and itteration 316 Loss>> 0.1382933629417344\n", 212 | " Epoch>> 10 and itteration 361 Loss>> 0.13946590896623617\n", 213 | " Epoch>> 10 and itteration 406 Loss>> 0.13961466785355153\n", 214 | " Epoch>> 10 and itteration 451 Loss>> 0.13947361076618245\n", 215 | "val_mode\n", 216 | " validation on epoch>> 10 dice loss>> 0.15285300796412823\n", 217 | " Epoch>> 11 and itteration 1 Loss>> 0.12822243571281433\n", 218 | " Epoch>> 11 and itteration 46 Loss>> 0.13029481632554013\n", 219 | " Epoch>> 11 and itteration 91 Loss>> 0.13056580517645722\n", 220 | " Epoch>> 11 and itteration 136 Loss>> 0.13051072973757982\n", 221 | " Epoch>> 11 and itteration 181 Loss>> 0.13123801955695968\n", 222 | " Epoch>> 11 and itteration 226 Loss>> 0.13130392167156776\n", 223 | " Epoch>> 11 and itteration 271 Loss>> 0.13287127794649767\n", 224 | " Epoch>> 11 and itteration 316 Loss>> 0.13399140213769448\n", 225 | " Epoch>> 11 and itteration 361 Loss>> 0.1365643096391184\n", 226 | " Epoch>> 11 and itteration 406 Loss>> 0.13595533071847385\n", 227 | " Epoch>> 11 and itteration 451 Loss>> 0.13607203920770378\n", 228 | "val_mode\n", 229 | " validation on epoch>> 11 dice loss>> 0.1465001785264685\n", 230 | "New best loss, saving...\n", 231 | " Epoch>> 12 and itteration 1 Loss>> 0.058516666293144226\n", 232 | " Epoch>> 12 and itteration 46 Loss>> 0.13626824999633042\n", 233 | " Epoch>> 12 and itteration 91 Loss>> 0.13372563963735495\n", 234 | " Epoch>> 12 and itteration 136 Loss>> 0.13410957825972752\n", 235 | " Epoch>> 12 and itteration 181 Loss>> 0.1295464351485118\n", 236 | " Epoch>> 12 and itteration 226 Loss>> 0.1312692841036921\n", 237 | " Epoch>> 12 and itteration 271 Loss>> 0.12950078386992106\n", 238 | " Epoch>> 12 and itteration 316 Loss>> 0.12675089216967927\n" 239 | ] 240 | }, 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | " Epoch>> 12 and itteration 361 Loss>> 0.12610800418786064\n", 246 | " Epoch>> 12 and itteration 406 Loss>> 0.12551013803283864\n", 247 | " Epoch>> 12 and itteration 451 Loss>> 0.12573659851188934\n", 248 | "val_mode\n", 249 | " validation on epoch>> 12 dice loss>> 0.14761795142744613\n", 250 | " Epoch>> 13 and itteration 1 Loss>> 0.06830105185508728\n", 251 | " Epoch>> 13 and itteration 46 Loss>> 0.1209399577068246\n", 252 | " Epoch>> 13 and itteration 91 Loss>> 0.11847573682502076\n", 253 | " Epoch>> 13 and itteration 136 Loss>> 0.1216593860539005\n", 254 | " Epoch>> 13 and itteration 181 Loss>> 0.12231733466791843\n", 255 | " Epoch>> 13 and itteration 226 Loss>> 0.12253736996347397\n", 256 | " Epoch>> 13 and itteration 271 Loss>> 0.12341048696329233\n", 257 | " Epoch>> 13 and itteration 316 Loss>> 0.12317021874875962\n", 258 | " Epoch>> 13 and itteration 361 Loss>> 0.12268629867242974\n", 259 | " Epoch>> 13 and itteration 406 Loss>> 0.1247509588473683\n", 260 | " Epoch>> 13 and itteration 451 Loss>> 0.125241905706883\n", 261 | "val_mode\n", 262 | " validation on epoch>> 13 dice loss>> 0.13717417760921147\n", 263 | "New best loss, saving...\n", 264 | " Epoch>> 14 and itteration 1 Loss>> 0.13640645146369934\n", 265 | " Epoch>> 14 and itteration 46 Loss>> 0.10579226886772591\n", 266 | " Epoch>> 14 and itteration 91 Loss>> 0.11497306872855176\n", 267 | " Epoch>> 14 and itteration 136 Loss>> 0.11856242041925297\n", 268 | " Epoch>> 14 and itteration 181 Loss>> 0.12347500472849245\n", 269 | " Epoch>> 14 and itteration 226 Loss>> 0.12016201444563612\n", 270 | " Epoch>> 14 and itteration 271 Loss>> 0.12008402851234942\n", 271 | " Epoch>> 14 and itteration 316 Loss>> 0.12076520530766324\n", 272 | " Epoch>> 14 and itteration 361 Loss>> 0.1228387365298258\n", 273 | " Epoch>> 14 and itteration 406 Loss>> 0.12062797438481758\n", 274 | " Epoch>> 14 and itteration 451 Loss>> 0.11947423479872903\n", 275 | "val_mode\n", 276 | " validation on epoch>> 14 dice loss>> 0.15339755314729503\n", 277 | " Epoch>> 15 and itteration 1 Loss>> 0.07482786476612091\n", 278 | " Epoch>> 15 and itteration 46 Loss>> 0.10358085896333923\n", 279 | " Epoch>> 15 and itteration 91 Loss>> 0.10589579144840712\n", 280 | " Epoch>> 15 and itteration 136 Loss>> 0.11440740583245368\n", 281 | " Epoch>> 15 and itteration 181 Loss>> 0.110930813857205\n", 282 | " Epoch>> 15 and itteration 226 Loss>> 0.10861900152094596\n", 283 | " Epoch>> 15 and itteration 271 Loss>> 0.1117622519248746\n", 284 | " Epoch>> 15 and itteration 316 Loss>> 0.11226264021913462\n", 285 | " Epoch>> 15 and itteration 361 Loss>> 0.11139493824661273\n", 286 | " Epoch>> 15 and itteration 406 Loss>> 0.11167366991709606\n", 287 | " Epoch>> 15 and itteration 451 Loss>> 0.11191047541384158\n", 288 | "val_mode\n", 289 | " validation on epoch>> 15 dice loss>> 0.1504231851087689\n", 290 | " Epoch>> 16 and itteration 1 Loss>> 0.2040601670742035\n", 291 | " Epoch>> 16 and itteration 46 Loss>> 0.10230100883737854\n", 292 | " Epoch>> 16 and itteration 91 Loss>> 0.09843452164268755\n", 293 | " Epoch>> 16 and itteration 136 Loss>> 0.09791147848591208\n", 294 | " Epoch>> 16 and itteration 181 Loss>> 0.10447321965961166\n", 295 | " Epoch>> 16 and itteration 226 Loss>> 0.10858650762686688\n", 296 | " Epoch>> 16 and itteration 271 Loss>> 0.10935131602016762\n", 297 | " Epoch>> 16 and itteration 316 Loss>> 0.10863794585619169\n", 298 | " Epoch>> 16 and itteration 361 Loss>> 0.11001722159643253\n", 299 | " Epoch>> 16 and itteration 406 Loss>> 0.11084756249644487\n", 300 | " Epoch>> 16 and itteration 451 Loss>> 0.11118052687553767\n", 301 | "val_mode\n", 302 | " validation on epoch>> 16 dice loss>> 0.1405838085482307\n", 303 | " Epoch>> 17 and itteration 1 Loss>> 0.07292813807725906\n", 304 | " Epoch>> 17 and itteration 46 Loss>> 0.09945778046613155\n", 305 | " Epoch>> 17 and itteration 91 Loss>> 0.1020904275905955\n", 306 | " Epoch>> 17 and itteration 136 Loss>> 0.1059761545894777\n", 307 | " Epoch>> 17 and itteration 181 Loss>> 0.10427628669537892\n", 308 | " Epoch>> 17 and itteration 226 Loss>> 0.1033167624598847\n", 309 | " Epoch>> 17 and itteration 271 Loss>> 0.10576806073094207\n", 310 | " Epoch>> 17 and itteration 316 Loss>> 0.10602264713401659\n", 311 | " Epoch>> 17 and itteration 361 Loss>> 0.10741545290910637\n", 312 | " Epoch>> 17 and itteration 406 Loss>> 0.10750829835830651\n", 313 | " Epoch>> 17 and itteration 451 Loss>> 0.10642406746679028\n", 314 | "val_mode\n", 315 | " validation on epoch>> 17 dice loss>> 0.1332140850643969\n", 316 | "New best loss, saving...\n", 317 | " Epoch>> 18 and itteration 1 Loss>> 0.08044999092817307\n", 318 | " Epoch>> 18 and itteration 46 Loss>> 0.11552384933051855\n", 319 | " Epoch>> 18 and itteration 91 Loss>> 0.11141193330615431\n", 320 | " Epoch>> 18 and itteration 136 Loss>> 0.10895795377847903\n", 321 | " Epoch>> 18 and itteration 181 Loss>> 0.10559947049831817\n", 322 | " Epoch>> 18 and itteration 226 Loss>> 0.10488712589825148\n", 323 | " Epoch>> 18 and itteration 271 Loss>> 0.10594684736480132\n", 324 | " Epoch>> 18 and itteration 316 Loss>> 0.10629495941809838\n", 325 | " Epoch>> 18 and itteration 361 Loss>> 0.10656408075026529\n", 326 | " Epoch>> 18 and itteration 406 Loss>> 0.10800657418597802\n", 327 | " Epoch>> 18 and itteration 451 Loss>> 0.10803435198516645\n", 328 | "val_mode\n", 329 | " validation on epoch>> 18 dice loss>> 0.13519006386456803\n", 330 | " Epoch>> 19 and itteration 1 Loss>> 0.05732952430844307\n", 331 | " Epoch>> 19 and itteration 46 Loss>> 0.10167763301211855\n", 332 | " Epoch>> 19 and itteration 91 Loss>> 0.10383454272216493\n", 333 | " Epoch>> 19 and itteration 136 Loss>> 0.10097106069545536\n", 334 | " Epoch>> 19 and itteration 181 Loss>> 0.09860320737101755\n", 335 | " Epoch>> 19 and itteration 226 Loss>> 0.09653991193766087\n", 336 | " Epoch>> 19 and itteration 271 Loss>> 0.09469419628436715\n", 337 | " Epoch>> 19 and itteration 316 Loss>> 0.09815880662136817\n", 338 | " Epoch>> 19 and itteration 361 Loss>> 0.09865882883359191\n", 339 | " Epoch>> 19 and itteration 406 Loss>> 0.09913877637258597\n", 340 | " Epoch>> 19 and itteration 451 Loss>> 0.09957678326944554\n", 341 | "val_mode\n", 342 | " validation on epoch>> 19 dice loss>> 0.13232341616569107\n", 343 | "New best loss, saving...\n", 344 | " Epoch>> 20 and itteration 1 Loss>> 0.0865911990404129\n", 345 | " Epoch>> 20 and itteration 46 Loss>> 0.0969339231758014\n", 346 | " Epoch>> 20 and itteration 91 Loss>> 0.10038319578046327\n", 347 | " Epoch>> 20 and itteration 136 Loss>> 0.10063144419451847\n", 348 | " Epoch>> 20 and itteration 181 Loss>> 0.09972624748286621\n", 349 | " Epoch>> 20 and itteration 226 Loss>> 0.10034953394031103\n", 350 | " Epoch>> 20 and itteration 271 Loss>> 0.10118887075858803\n", 351 | " Epoch>> 20 and itteration 316 Loss>> 0.10326909971765325\n", 352 | " Epoch>> 20 and itteration 361 Loss>> 0.10289405798152543\n", 353 | " Epoch>> 20 and itteration 406 Loss>> 0.10356041107308395\n", 354 | " Epoch>> 20 and itteration 451 Loss>> 0.10402174591357322\n", 355 | "val_mode\n", 356 | " validation on epoch>> 20 dice loss>> 0.14737040128084222\n", 357 | " Epoch>> 21 and itteration 1 Loss>> 0.04281836748123169\n", 358 | " Epoch>> 21 and itteration 46 Loss>> 0.0923526119440794\n", 359 | " Epoch>> 21 and itteration 91 Loss>> 0.09573533328679892\n", 360 | " Epoch>> 21 and itteration 136 Loss>> 0.09949091756168534\n", 361 | " Epoch>> 21 and itteration 181 Loss>> 0.09941958175015055\n", 362 | " Epoch>> 21 and itteration 226 Loss>> 0.09971112489238777\n", 363 | " Epoch>> 21 and itteration 271 Loss>> 0.09894957272438985\n", 364 | " Epoch>> 21 and itteration 316 Loss>> 0.09760519322243673\n", 365 | " Epoch>> 21 and itteration 361 Loss>> 0.09618166372840424\n", 366 | " Epoch>> 21 and itteration 406 Loss>> 0.09669299443880913\n", 367 | " Epoch>> 21 and itteration 451 Loss>> 0.09753532782031797\n", 368 | "val_mode\n", 369 | " validation on epoch>> 21 dice loss>> 0.13773053639269933\n", 370 | " Epoch>> 22 and itteration 1 Loss>> 0.04927198961377144\n", 371 | " Epoch>> 22 and itteration 46 Loss>> 0.0968274253865947\n", 372 | " Epoch>> 22 and itteration 91 Loss>> 0.10006347052998595\n", 373 | " Epoch>> 22 and itteration 136 Loss>> 0.09693950915928273\n", 374 | " Epoch>> 22 and itteration 181 Loss>> 0.09374010291218099\n", 375 | " Epoch>> 22 and itteration 226 Loss>> 0.09120593114500553\n", 376 | " Epoch>> 22 and itteration 271 Loss>> 0.091096666974885\n", 377 | " Epoch>> 22 and itteration 316 Loss>> 0.09088313245933645\n", 378 | " Epoch>> 22 and itteration 361 Loss>> 0.08953358514693635\n", 379 | " Epoch>> 22 and itteration 406 Loss>> 0.09022152446797622\n", 380 | " Epoch>> 22 and itteration 451 Loss>> 0.09058853434584622\n", 381 | "val_mode\n", 382 | " validation on epoch>> 22 dice loss>> 0.14696211825416183\n", 383 | " Epoch>> 23 and itteration 1 Loss>> 0.0562395378947258\n", 384 | " Epoch>> 23 and itteration 46 Loss>> 0.09645650398148142\n", 385 | " Epoch>> 23 and itteration 91 Loss>> 0.09781748101442725\n", 386 | " Epoch>> 23 and itteration 136 Loss>> 0.09414313183001735\n", 387 | " Epoch>> 23 and itteration 181 Loss>> 0.0928338432880067\n", 388 | " Epoch>> 23 and itteration 226 Loss>> 0.09392687361852257\n", 389 | " Epoch>> 23 and itteration 271 Loss>> 0.09331841945538222\n", 390 | " Epoch>> 23 and itteration 316 Loss>> 0.09350466943947197\n", 391 | " Epoch>> 23 and itteration 361 Loss>> 0.09331819876368026\n", 392 | " Epoch>> 23 and itteration 406 Loss>> 0.0917831511131208\n", 393 | " Epoch>> 23 and itteration 451 Loss>> 0.0919910596142587\n", 394 | "val_mode\n", 395 | " validation on epoch>> 23 dice loss>> 0.13442262967969346\n", 396 | " Epoch>> 24 and itteration 1 Loss>> 0.0942787453532219\n", 397 | " Epoch>> 24 and itteration 46 Loss>> 0.08618788664107738\n", 398 | " Epoch>> 24 and itteration 91 Loss>> 0.08921602781821084\n", 399 | " Epoch>> 24 and itteration 136 Loss>> 0.08748522945953642\n" 400 | ] 401 | }, 402 | { 403 | "name": "stdout", 404 | "output_type": "stream", 405 | "text": [ 406 | " Epoch>> 24 and itteration 181 Loss>> 0.09348085230391329\n", 407 | " Epoch>> 24 and itteration 226 Loss>> 0.09155483698818535\n", 408 | " Epoch>> 24 and itteration 271 Loss>> 0.08984833234075691\n", 409 | " Epoch>> 24 and itteration 316 Loss>> 0.09034567181303908\n", 410 | " Epoch>> 24 and itteration 361 Loss>> 0.09153363255169913\n", 411 | " Epoch>> 24 and itteration 406 Loss>> 0.09166926263918725\n", 412 | " Epoch>> 24 and itteration 451 Loss>> 0.09071590649810704\n", 413 | "val_mode\n", 414 | " validation on epoch>> 24 dice loss>> 0.13262791471541446\n", 415 | " Epoch>> 25 and itteration 1 Loss>> 0.06768117845058441\n", 416 | " Epoch>> 25 and itteration 46 Loss>> 0.08301182301796001\n", 417 | " Epoch>> 25 and itteration 91 Loss>> 0.08083239103575329\n", 418 | " Epoch>> 25 and itteration 136 Loss>> 0.08231433286495946\n", 419 | " Epoch>> 25 and itteration 181 Loss>> 0.08278180564141405\n", 420 | " Epoch>> 25 and itteration 226 Loss>> 0.08338782239078948\n", 421 | " Epoch>> 25 and itteration 271 Loss>> 0.08366489682580272\n", 422 | " Epoch>> 25 and itteration 316 Loss>> 0.08633823188233979\n", 423 | " Epoch>> 25 and itteration 361 Loss>> 0.08759486152070711\n", 424 | " Epoch>> 25 and itteration 406 Loss>> 0.08750929120983103\n", 425 | " Epoch>> 25 and itteration 451 Loss>> 0.08905773478541565\n", 426 | "val_mode\n", 427 | " validation on epoch>> 25 dice loss>> 0.16805427329089767\n", 428 | " Epoch>> 26 and itteration 1 Loss>> 0.08180712163448334\n", 429 | " Epoch>> 26 and itteration 46 Loss>> 0.08897217160657696\n", 430 | " Epoch>> 26 and itteration 91 Loss>> 0.08448535797523928\n", 431 | " Epoch>> 26 and itteration 136 Loss>> 0.08396991964100915\n", 432 | " Epoch>> 26 and itteration 181 Loss>> 0.08841165690296922\n", 433 | " Epoch>> 26 and itteration 226 Loss>> 0.0892280321238579\n", 434 | " Epoch>> 26 and itteration 271 Loss>> 0.08973164710721407\n", 435 | " Epoch>> 26 and itteration 316 Loss>> 0.08999759369070016\n", 436 | " Epoch>> 26 and itteration 361 Loss>> 0.08908081136210473\n", 437 | " Epoch>> 26 and itteration 406 Loss>> 0.08970647834681819\n", 438 | " Epoch>> 26 and itteration 451 Loss>> 0.08900697940667295\n", 439 | "val_mode\n", 440 | " validation on epoch>> 26 dice loss>> 0.14426588203980945\n", 441 | " Epoch>> 27 and itteration 1 Loss>> 0.039223358035087585\n", 442 | " Epoch>> 27 and itteration 46 Loss>> 0.09218907048520834\n", 443 | " Epoch>> 27 and itteration 91 Loss>> 0.08865025754158314\n", 444 | " Epoch>> 27 and itteration 136 Loss>> 0.08524092196432106\n", 445 | " Epoch>> 27 and itteration 181 Loss>> 0.08426037130270216\n", 446 | " Epoch>> 27 and itteration 226 Loss>> 0.08363635900669393\n", 447 | " Epoch>> 27 and itteration 271 Loss>> 0.08351786265568979\n", 448 | " Epoch>> 27 and itteration 316 Loss>> 0.08404931607597237\n", 449 | " Epoch>> 27 and itteration 361 Loss>> 0.08397765027828659\n", 450 | " Epoch>> 27 and itteration 406 Loss>> 0.08332927789759313\n", 451 | " Epoch>> 27 and itteration 451 Loss>> 0.08273112727225884\n", 452 | "val_mode\n", 453 | " validation on epoch>> 27 dice loss>> 0.1383740948585067\n", 454 | " Epoch>> 28 and itteration 1 Loss>> 0.08311550319194794\n", 455 | " Epoch>> 28 and itteration 46 Loss>> 0.1012978515709224\n", 456 | " Epoch>> 28 and itteration 91 Loss>> 0.09370828489517118\n", 457 | " Epoch>> 28 and itteration 136 Loss>> 0.08794473996385932\n", 458 | " Epoch>> 28 and itteration 181 Loss>> 0.08637622566842243\n", 459 | " Epoch>> 28 and itteration 226 Loss>> 0.08559680756478183\n", 460 | " Epoch>> 28 and itteration 271 Loss>> 0.08388490381056092\n", 461 | " Epoch>> 28 and itteration 316 Loss>> 0.08458453545323279\n", 462 | " Epoch>> 28 and itteration 361 Loss>> 0.08399707519570546\n", 463 | " Epoch>> 28 and itteration 406 Loss>> 0.08479345812843057\n", 464 | " Epoch>> 28 and itteration 451 Loss>> 0.08378367037605287\n", 465 | "val_mode\n", 466 | " validation on epoch>> 28 dice loss>> 0.14225212614407987\n", 467 | " Epoch>> 29 and itteration 1 Loss>> 0.1698126345872879\n", 468 | " Epoch>> 29 and itteration 46 Loss>> 0.08877985527657944\n", 469 | " Epoch>> 29 and itteration 91 Loss>> 0.08444935865290873\n", 470 | " Epoch>> 29 and itteration 136 Loss>> 0.0834582414587631\n", 471 | " Epoch>> 29 and itteration 181 Loss>> 0.08369870767902933\n", 472 | " Epoch>> 29 and itteration 226 Loss>> 0.08288452218789442\n", 473 | " Epoch>> 29 and itteration 271 Loss>> 0.08094785421591844\n", 474 | " Epoch>> 29 and itteration 316 Loss>> 0.08026386901170393\n", 475 | " Epoch>> 29 and itteration 361 Loss>> 0.08078846582043864\n", 476 | " Epoch>> 29 and itteration 406 Loss>> 0.0804786472604266\n", 477 | " Epoch>> 29 and itteration 451 Loss>> 0.07990459828098571\n", 478 | "val_mode\n", 479 | " validation on epoch>> 29 dice loss>> 0.14034965090115617\n", 480 | " Epoch>> 30 and itteration 1 Loss>> 0.048821527510881424\n", 481 | " Epoch>> 30 and itteration 46 Loss>> 0.08206759781941124\n", 482 | " Epoch>> 30 and itteration 91 Loss>> 0.07843461362542686\n", 483 | " Epoch>> 30 and itteration 136 Loss>> 0.08118060201077777\n", 484 | " Epoch>> 30 and itteration 181 Loss>> 0.07950749065088962\n", 485 | " Epoch>> 30 and itteration 226 Loss>> 0.07934703621848495\n", 486 | " Epoch>> 30 and itteration 271 Loss>> 0.07919699177695816\n", 487 | " Epoch>> 30 and itteration 316 Loss>> 0.08032545142017211\n", 488 | " Epoch>> 30 and itteration 361 Loss>> 0.08033098560084596\n", 489 | " Epoch>> 30 and itteration 406 Loss>> 0.08142869326795263\n", 490 | " Epoch>> 30 and itteration 451 Loss>> 0.08023580928186487\n", 491 | "val_mode\n", 492 | " validation on epoch>> 30 dice loss>> 0.13936084788292646\n", 493 | " Epoch>> 31 and itteration 1 Loss>> 0.03420330211520195\n", 494 | " Epoch>> 31 and itteration 46 Loss>> 0.07171482673805693\n", 495 | " Epoch>> 31 and itteration 91 Loss>> 0.07024175559098904\n", 496 | " Epoch>> 31 and itteration 136 Loss>> 0.07108417134184171\n", 497 | " Epoch>> 31 and itteration 181 Loss>> 0.07112861688130469\n", 498 | " Epoch>> 31 and itteration 226 Loss>> 0.07157239438224156\n", 499 | " Epoch>> 31 and itteration 271 Loss>> 0.07227904739964933\n", 500 | " Epoch>> 31 and itteration 316 Loss>> 0.07273332110924434\n", 501 | " Epoch>> 31 and itteration 361 Loss>> 0.07223353709334152\n", 502 | " Epoch>> 31 and itteration 406 Loss>> 0.07321686669266576\n", 503 | " Epoch>> 31 and itteration 451 Loss>> 0.07320670566279978\n", 504 | "val_mode\n", 505 | " validation on epoch>> 31 dice loss>> 0.1328000151372468\n", 506 | " Epoch>> 32 and itteration 1 Loss>> 0.04429105669260025\n", 507 | " Epoch>> 32 and itteration 46 Loss>> 0.0674717932453622\n", 508 | " Epoch>> 32 and itteration 91 Loss>> 0.06654717654481039\n", 509 | " Epoch>> 32 and itteration 136 Loss>> 0.06968597196699942\n", 510 | " Epoch>> 32 and itteration 181 Loss>> 0.06938673424128011\n", 511 | " Epoch>> 32 and itteration 226 Loss>> 0.06820537329577239\n", 512 | " Epoch>> 32 and itteration 271 Loss>> 0.0676589975635284\n", 513 | " Epoch>> 32 and itteration 316 Loss>> 0.0683497280542609\n", 514 | " Epoch>> 32 and itteration 361 Loss>> 0.06838685330880646\n", 515 | " Epoch>> 32 and itteration 406 Loss>> 0.06989806027523256\n", 516 | " Epoch>> 32 and itteration 451 Loss>> 0.06962462215211465\n", 517 | "val_mode\n", 518 | " validation on epoch>> 32 dice loss>> 0.14101827520927465\n", 519 | " Epoch>> 33 and itteration 1 Loss>> 0.06949695944786072\n", 520 | " Epoch>> 33 and itteration 46 Loss>> 0.07160836798341377\n", 521 | " Epoch>> 33 and itteration 91 Loss>> 0.06877203277506672\n", 522 | " Epoch>> 33 and itteration 136 Loss>> 0.06727086432168589\n", 523 | " Epoch>> 33 and itteration 181 Loss>> 0.06699409693973499\n", 524 | " Epoch>> 33 and itteration 226 Loss>> 0.06862597190568932\n", 525 | " Epoch>> 33 and itteration 271 Loss>> 0.06815938901059962\n", 526 | " Epoch>> 33 and itteration 316 Loss>> 0.0675833185462729\n", 527 | " Epoch>> 33 and itteration 361 Loss>> 0.06728096429640401\n", 528 | " Epoch>> 33 and itteration 406 Loss>> 0.06670073129401859\n", 529 | " Epoch>> 33 and itteration 451 Loss>> 0.06722739359111982\n", 530 | "val_mode\n", 531 | " validation on epoch>> 33 dice loss>> 0.142667758204834\n", 532 | " Epoch>> 34 and itteration 1 Loss>> 0.038447119295597076\n", 533 | " Epoch>> 34 and itteration 46 Loss>> 0.06872635107973347\n", 534 | " Epoch>> 34 and itteration 91 Loss>> 0.06787633554047935\n", 535 | " Epoch>> 34 and itteration 136 Loss>> 0.0674415948554216\n", 536 | " Epoch>> 34 and itteration 181 Loss>> 0.06517846672409806\n", 537 | " Epoch>> 34 and itteration 226 Loss>> 0.06653912910219051\n", 538 | " Epoch>> 34 and itteration 271 Loss>> 0.06556408721060111\n", 539 | " Epoch>> 34 and itteration 316 Loss>> 0.06585314618187803\n", 540 | " Epoch>> 34 and itteration 361 Loss>> 0.06563657556771406\n", 541 | " Epoch>> 34 and itteration 406 Loss>> 0.06570544071367075\n", 542 | " Epoch>> 34 and itteration 451 Loss>> 0.06553075277967225\n", 543 | "val_mode\n", 544 | " validation on epoch>> 34 dice loss>> 0.1377450026406342\n", 545 | " Epoch>> 35 and itteration 1 Loss>> 0.04737517982721329\n", 546 | " Epoch>> 35 and itteration 46 Loss>> 0.0704373478403558\n", 547 | " Epoch>> 35 and itteration 91 Loss>> 0.06592574779067066\n", 548 | " Epoch>> 35 and itteration 136 Loss>> 0.06582330850719967\n", 549 | " Epoch>> 35 and itteration 181 Loss>> 0.06573568559277453\n", 550 | " Epoch>> 35 and itteration 226 Loss>> 0.06650599000761204\n", 551 | " Epoch>> 35 and itteration 271 Loss>> 0.06603594232914192\n", 552 | " Epoch>> 35 and itteration 316 Loss>> 0.0656188878619784\n", 553 | " Epoch>> 35 and itteration 361 Loss>> 0.06554140978159997\n", 554 | " Epoch>> 35 and itteration 406 Loss>> 0.06470170174840048\n", 555 | " Epoch>> 35 and itteration 451 Loss>> 0.06511842903831845\n", 556 | "val_mode\n", 557 | " validation on epoch>> 35 dice loss>> 0.14777558465923352\n", 558 | " Epoch>> 36 and itteration 1 Loss>> 0.04132430627942085\n" 559 | ] 560 | }, 561 | { 562 | "name": "stdout", 563 | "output_type": "stream", 564 | "text": [ 565 | " Epoch>> 36 and itteration 46 Loss>> 0.058433276479658874\n", 566 | " Epoch>> 36 and itteration 91 Loss>> 0.05948395443732267\n", 567 | " Epoch>> 36 and itteration 136 Loss>> 0.059972569675130004\n", 568 | " Epoch>> 36 and itteration 181 Loss>> 0.060879789303103206\n", 569 | " Epoch>> 36 and itteration 226 Loss>> 0.06133900438559003\n", 570 | " Epoch>> 36 and itteration 271 Loss>> 0.06139697876116226\n", 571 | " Epoch>> 36 and itteration 316 Loss>> 0.06330230644067066\n", 572 | " Epoch>> 36 and itteration 361 Loss>> 0.06435676964923451\n", 573 | " Epoch>> 36 and itteration 406 Loss>> 0.06426818171984015\n", 574 | " Epoch>> 36 and itteration 451 Loss>> 0.06382929415188053\n", 575 | "val_mode\n", 576 | " validation on epoch>> 36 dice loss>> 0.1419757434857369\n", 577 | " Epoch>> 37 and itteration 1 Loss>> 0.06270346790552139\n", 578 | " Epoch>> 37 and itteration 46 Loss>> 0.05726564959015535\n", 579 | " Epoch>> 37 and itteration 91 Loss>> 0.06073792284907221\n", 580 | " Epoch>> 37 and itteration 136 Loss>> 0.062088602948386\n", 581 | " Epoch>> 37 and itteration 181 Loss>> 0.06590296876570467\n", 582 | " Epoch>> 37 and itteration 226 Loss>> 0.06604115374023671\n", 583 | " Epoch>> 37 and itteration 271 Loss>> 0.06668964986486628\n", 584 | " Epoch>> 37 and itteration 316 Loss>> 0.06628770464913378\n", 585 | " Epoch>> 37 and itteration 361 Loss>> 0.06571868572827852\n", 586 | " Epoch>> 37 and itteration 406 Loss>> 0.06575624925988208\n", 587 | " Epoch>> 37 and itteration 451 Loss>> 0.06574743659444104\n", 588 | "val_mode\n", 589 | " validation on epoch>> 37 dice loss>> 0.13125084782321308\n", 590 | "New best loss, saving...\n", 591 | " Epoch>> 38 and itteration 1 Loss>> 0.10123566538095474\n", 592 | " Epoch>> 38 and itteration 46 Loss>> 0.06181766036088052\n", 593 | " Epoch>> 38 and itteration 91 Loss>> 0.062176418419067674\n", 594 | " Epoch>> 38 and itteration 136 Loss>> 0.06123718883677879\n", 595 | " Epoch>> 38 and itteration 181 Loss>> 0.06050789042360546\n", 596 | " Epoch>> 38 and itteration 226 Loss>> 0.06238969752515575\n", 597 | " Epoch>> 38 and itteration 271 Loss>> 0.06263398154332848\n", 598 | " Epoch>> 38 and itteration 316 Loss>> 0.06234056523211206\n", 599 | " Epoch>> 38 and itteration 361 Loss>> 0.06325677641953788\n", 600 | " Epoch>> 38 and itteration 406 Loss>> 0.06316920334078702\n", 601 | " Epoch>> 38 and itteration 451 Loss>> 0.06299982049364077\n", 602 | "val_mode\n", 603 | " validation on epoch>> 38 dice loss>> 0.12651464964072562\n", 604 | "New best loss, saving...\n", 605 | " Epoch>> 39 and itteration 1 Loss>> 0.09684403985738754\n", 606 | " Epoch>> 39 and itteration 46 Loss>> 0.06142529031342786\n", 607 | " Epoch>> 39 and itteration 91 Loss>> 0.06302005197893787\n", 608 | " Epoch>> 39 and itteration 136 Loss>> 0.061821489485309404\n", 609 | " Epoch>> 39 and itteration 181 Loss>> 0.06229638904127297\n", 610 | " Epoch>> 39 and itteration 226 Loss>> 0.0626291784548522\n", 611 | " Epoch>> 39 and itteration 271 Loss>> 0.062372092844925245\n", 612 | " Epoch>> 39 and itteration 316 Loss>> 0.062032526417906526\n", 613 | " Epoch>> 39 and itteration 361 Loss>> 0.061909043000510526\n", 614 | " Epoch>> 39 and itteration 406 Loss>> 0.060673202888058324\n", 615 | " Epoch>> 39 and itteration 451 Loss>> 0.061088423622553206\n", 616 | "val_mode\n", 617 | " validation on epoch>> 39 dice loss>> 0.1418084614985698\n", 618 | " Epoch>> 40 and itteration 1 Loss>> 0.06737198680639267\n", 619 | " Epoch>> 40 and itteration 46 Loss>> 0.05826217575889567\n", 620 | " Epoch>> 40 and itteration 91 Loss>> 0.06189956708432554\n", 621 | " Epoch>> 40 and itteration 136 Loss>> 0.06234927545301616\n", 622 | " Epoch>> 40 and itteration 181 Loss>> 0.06159386576992043\n", 623 | " Epoch>> 40 and itteration 226 Loss>> 0.06031691571450339\n", 624 | " Epoch>> 40 and itteration 271 Loss>> 0.060595802213598005\n", 625 | " Epoch>> 40 and itteration 316 Loss>> 0.06051208599340878\n", 626 | " Epoch>> 40 and itteration 361 Loss>> 0.06145136181646932\n", 627 | " Epoch>> 40 and itteration 406 Loss>> 0.06185557037188208\n", 628 | " Epoch>> 40 and itteration 451 Loss>> 0.061458513739533274\n", 629 | "val_mode\n", 630 | " validation on epoch>> 40 dice loss>> 0.15452070934086692\n", 631 | " Epoch>> 41 and itteration 1 Loss>> 0.03912198916077614\n", 632 | " Epoch>> 41 and itteration 46 Loss>> 0.05562007605381634\n", 633 | " Epoch>> 41 and itteration 91 Loss>> 0.05661109134882361\n", 634 | " Epoch>> 41 and itteration 136 Loss>> 0.057752931901418114\n", 635 | " Epoch>> 41 and itteration 181 Loss>> 0.05950716551578506\n", 636 | " Epoch>> 41 and itteration 226 Loss>> 0.06012223679077836\n", 637 | " Epoch>> 41 and itteration 271 Loss>> 0.059869088741136216\n", 638 | " Epoch>> 41 and itteration 316 Loss>> 0.059420389235255465\n", 639 | " Epoch>> 41 and itteration 361 Loss>> 0.05903769401095581\n", 640 | " Epoch>> 41 and itteration 406 Loss>> 0.059313742031017545\n", 641 | " Epoch>> 41 and itteration 451 Loss>> 0.05913250953239903\n", 642 | "val_mode\n", 643 | " validation on epoch>> 41 dice loss>> 0.14478289365452\n", 644 | " Epoch>> 42 and itteration 1 Loss>> 0.04954291880130768\n", 645 | " Epoch>> 42 and itteration 46 Loss>> 0.057358553066201835\n", 646 | " Epoch>> 42 and itteration 91 Loss>> 0.05634032947185275\n", 647 | " Epoch>> 42 and itteration 136 Loss>> 0.056479945587103855\n", 648 | " Epoch>> 42 and itteration 181 Loss>> 0.05754764411001574\n", 649 | " Epoch>> 42 and itteration 226 Loss>> 0.05610715812154576\n", 650 | " Epoch>> 42 and itteration 271 Loss>> 0.055131187222595586\n", 651 | " Epoch>> 42 and itteration 316 Loss>> 0.05504096161479814\n", 652 | " Epoch>> 42 and itteration 361 Loss>> 0.05585714593694811\n", 653 | " Epoch>> 42 and itteration 406 Loss>> 0.05582418507469699\n", 654 | " Epoch>> 42 and itteration 451 Loss>> 0.056325403325044925\n", 655 | "val_mode\n", 656 | " validation on epoch>> 42 dice loss>> 0.1351057186194167\n", 657 | " Epoch>> 43 and itteration 1 Loss>> 0.057237982749938965\n", 658 | " Epoch>> 43 and itteration 46 Loss>> 0.0656402307845976\n", 659 | " Epoch>> 43 and itteration 91 Loss>> 0.05910175396027146\n", 660 | " Epoch>> 43 and itteration 136 Loss>> 0.05813514297444593\n", 661 | " Epoch>> 43 and itteration 181 Loss>> 0.057465322912562615\n", 662 | " Epoch>> 43 and itteration 226 Loss>> 0.05784972177703032\n", 663 | " Epoch>> 43 and itteration 271 Loss>> 0.05769379498828821\n", 664 | " Epoch>> 43 and itteration 316 Loss>> 0.05719733570667007\n", 665 | " Epoch>> 43 and itteration 361 Loss>> 0.05716557138782624\n", 666 | " Epoch>> 43 and itteration 406 Loss>> 0.05725260248196683\n", 667 | " Epoch>> 43 and itteration 451 Loss>> 0.057393912638965046\n", 668 | "val_mode\n", 669 | " validation on epoch>> 43 dice loss>> 0.14248319152152908\n", 670 | " Epoch>> 44 and itteration 1 Loss>> 0.04875374585390091\n", 671 | " Epoch>> 44 and itteration 46 Loss>> 0.06167426907821842\n", 672 | " Epoch>> 44 and itteration 91 Loss>> 0.06074637145950244\n", 673 | " Epoch>> 44 and itteration 136 Loss>> 0.05821837202700622\n", 674 | " Epoch>> 44 and itteration 181 Loss>> 0.05800442931689939\n", 675 | " Epoch>> 44 and itteration 226 Loss>> 0.05671893154987983\n", 676 | " Epoch>> 44 and itteration 271 Loss>> 0.055853585853512876\n", 677 | " Epoch>> 44 and itteration 316 Loss>> 0.05678048001342936\n", 678 | " Epoch>> 44 and itteration 361 Loss>> 0.05700655480191483\n", 679 | " Epoch>> 44 and itteration 406 Loss>> 0.05674053115650938\n", 680 | " Epoch>> 44 and itteration 451 Loss>> 0.056278321024716034\n", 681 | "val_mode\n", 682 | " validation on epoch>> 44 dice loss>> 0.15530629281645\n", 683 | " Epoch>> 45 and itteration 1 Loss>> 0.09440137445926666\n", 684 | " Epoch>> 45 and itteration 46 Loss>> 0.05830118983336117\n", 685 | " Epoch>> 45 and itteration 91 Loss>> 0.05584482134289139\n", 686 | " Epoch>> 45 and itteration 136 Loss>> 0.05437445165315533\n", 687 | " Epoch>> 45 and itteration 181 Loss>> 0.05400324788829569\n", 688 | " Epoch>> 45 and itteration 226 Loss>> 0.05412796204003085\n", 689 | " Epoch>> 45 and itteration 271 Loss>> 0.05342981169804436\n", 690 | " Epoch>> 45 and itteration 316 Loss>> 0.05341213810127936\n", 691 | " Epoch>> 45 and itteration 361 Loss>> 0.05428867298500855\n", 692 | " Epoch>> 45 and itteration 406 Loss>> 0.05467352060958963\n", 693 | " Epoch>> 45 and itteration 451 Loss>> 0.05482614739176439\n", 694 | "val_mode\n", 695 | " validation on epoch>> 45 dice loss>> 0.1467072245589382\n", 696 | " Epoch>> 46 and itteration 1 Loss>> 0.05303681641817093\n", 697 | " Epoch>> 46 and itteration 46 Loss>> 0.054856372713718723\n", 698 | " Epoch>> 46 and itteration 91 Loss>> 0.05333662286892042\n", 699 | " Epoch>> 46 and itteration 136 Loss>> 0.05294356738929363\n", 700 | " Epoch>> 46 and itteration 181 Loss>> 0.053488003501658284\n", 701 | " Epoch>> 46 and itteration 226 Loss>> 0.053084083140489804\n", 702 | " Epoch>> 46 and itteration 271 Loss>> 0.05295949561926931\n", 703 | " Epoch>> 46 and itteration 316 Loss>> 0.05325435338353243\n", 704 | " Epoch>> 46 and itteration 361 Loss>> 0.05346437517832191\n", 705 | " Epoch>> 46 and itteration 406 Loss>> 0.05294467103231717\n", 706 | " Epoch>> 46 and itteration 451 Loss>> 0.05362181119339958\n", 707 | "val_mode\n", 708 | " validation on epoch>> 46 dice loss>> 0.15701383395789328\n", 709 | " Epoch>> 47 and itteration 1 Loss>> 0.038202621042728424\n", 710 | " Epoch>> 47 and itteration 46 Loss>> 0.05354265961796045\n", 711 | " Epoch>> 47 and itteration 91 Loss>> 0.05236436496232892\n", 712 | " Epoch>> 47 and itteration 136 Loss>> 0.05271057191523997\n", 713 | " Epoch>> 47 and itteration 181 Loss>> 0.05182616357068989\n", 714 | " Epoch>> 47 and itteration 226 Loss>> 0.05239239897974561\n", 715 | " Epoch>> 47 and itteration 271 Loss>> 0.05223012228953442\n", 716 | " Epoch>> 47 and itteration 316 Loss>> 0.053495453287481884\n", 717 | " Epoch>> 47 and itteration 361 Loss>> 0.05405835971917307\n" 718 | ] 719 | }, 720 | { 721 | "name": "stdout", 722 | "output_type": "stream", 723 | "text": [ 724 | " Epoch>> 47 and itteration 406 Loss>> 0.0537252500561511\n" 725 | ] 726 | } 727 | ], 728 | "source": [ 729 | "for ep in range(int(config['epochs'])):\n", 730 | " Net.train()\n", 731 | " epoch_loss = 0\n", 732 | " for itter, batch in enumerate(train_loader):\n", 733 | " img = batch['image'].to(device, dtype=torch.float)\n", 734 | " msk = batch['mask'].to(device)\n", 735 | " weak_ann = batch['weak_ann'].to(device)\n", 736 | " boundary = batch['boundary'].to(device)\n", 737 | " mask_type = torch.float32 if Net.n_classes == 1 else torch.long\n", 738 | " msk = msk.to(device=device, dtype=mask_type)\n", 739 | " weak_ann = weak_ann.to(device=device, dtype=mask_type)\n", 740 | " boundary = boundary.to(device=device, dtype=mask_type)\n", 741 | " msk_pred, B, R = Net(img, with_additional=True)\n", 742 | " loss = criteria(msk_pred, msk) \n", 743 | " loss_regions = criteria_region(weak_ann[:,0], R[:,:-1,0])\n", 744 | " loss_boundary = criteria_boundary(B, boundary) \n", 745 | " tloss = (0.7*(loss)) + (0.1* loss_regions) + (0.2*loss_boundary)\n", 746 | " optimizer.zero_grad()\n", 747 | " tloss.backward()\n", 748 | " epoch_loss += tloss.item()\n", 749 | " optimizer.step() \n", 750 | " if itter%int(float(config['progress_p']) * len(train_loader))==0:\n", 751 | " print(f' Epoch>> {ep+1} and itteration {itter+1} Loss>> {((epoch_loss/(itter+1)))}')\n", 752 | " ## Validation phase\n", 753 | " with torch.no_grad():\n", 754 | " print('val_mode')\n", 755 | " val_loss = 0\n", 756 | " Net.eval()\n", 757 | " for itter, batch in enumerate(val_loader):\n", 758 | " img = batch['image'].to(device, dtype=torch.float)\n", 759 | " msk = batch['mask'].to(device)\n", 760 | " mask_type = torch.float32 if Net.n_classes == 1 else torch.long\n", 761 | " msk = msk.to(device=device, dtype=mask_type)\n", 762 | " msk_pred = Net(img)\n", 763 | " loss = criteria(msk_pred, msk) \n", 764 | " val_loss += loss.item()\n", 765 | " print(f' validation on epoch>> {ep+1} dice loss>> {(abs(val_loss/(itter+1)))}') \n", 766 | " mean_val_loss = (val_loss/(itter+1))\n", 767 | " # Check the performance and save the model\n", 768 | " if (mean_val_loss) < best_val_loss:\n", 769 | " print('New best loss, saving...')\n", 770 | " best_val_loss = copy.deepcopy(mean_val_loss)\n", 771 | " state = copy.deepcopy({'model_weights': Net.state_dict(), 'val_loss': best_val_loss})\n", 772 | " torch.save(state, config['saved_model'])\n", 773 | "\n", 774 | " scheduler.step(mean_val_loss)\n", 775 | " \n", 776 | "print('Trainng phase finished') " 777 | ] 778 | } 779 | ], 780 | "metadata": { 781 | "kernelspec": { 782 | "display_name": "pytorch_cuda11", 783 | "language": "python", 784 | "name": "pytorch_cuda11" 785 | }, 786 | "language_info": { 787 | "codemirror_mode": { 788 | "name": "ipython", 789 | "version": 3 790 | }, 791 | "file_extension": ".py", 792 | "mimetype": "text/x-python", 793 | "name": "python", 794 | "nbconvert_exporter": "python", 795 | "pygments_lexer": "ipython3", 796 | "version": "3.9.6" 797 | } 798 | }, 799 | "nbformat": 4, 800 | "nbformat_minor": 5 801 | } 802 | -------------------------------------------------------------------------------- /train_skin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | from __future__ import division 8 | import os 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | import torch 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from loader import * 14 | from model.TransMUNet import TransMUNet 15 | import pandas as pd 16 | import glob 17 | import nibabel as nib 18 | import numpy as np 19 | import copy 20 | import yaml 21 | 22 | 23 | # In[2]: 24 | 25 | 26 | ## Loader 27 | ## Hyper parameters 28 | config = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader) 29 | number_classes = int(config['number_classes']) 30 | input_channels = 3 31 | best_val_loss = np.inf 32 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 33 | 34 | data_path = config['path_to_data'] 35 | 36 | train_dataset = isic_loader(path_Data = data_path, train = True) 37 | train_loader = DataLoader(train_dataset, batch_size = int(config['batch_size_tr']), shuffle= True) 38 | val_dataset = isic_loader(path_Data = data_path, train = False) 39 | val_loader = DataLoader(val_dataset, batch_size = int(config['batch_size_va']), shuffle= False) 40 | 41 | 42 | # In[3]: 43 | 44 | 45 | Net = TransMUNet(n_classes = number_classes) 46 | 47 | Net = Net.to(device) 48 | if int(config['pretrained']): 49 | Net.load_state_dict(torch.load(config['saved_model'], map_location='cpu')['model_weights']) 50 | best_val_loss = torch.load(config['saved_model'], map_location='cpu')['val_loss'] 51 | optimizer = optim.Adam(Net.parameters(), lr= float(config['lr'])) 52 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience']) 53 | criteria = torch.nn.BCELoss() 54 | criteria_boundary = torch.nn.BCELoss() 55 | criteria_region = torch.nn.MSELoss() 56 | 57 | 58 | # In[ ]: 59 | 60 | 61 | for ep in range(int(config['epochs'])): 62 | Net.train() 63 | epoch_loss = 0 64 | for itter, batch in enumerate(train_loader): 65 | img = batch['image'].to(device, dtype=torch.float) 66 | msk = batch['mask'].to(device) 67 | weak_ann = batch['weak_ann'].to(device) 68 | boundary = batch['boundary'].to(device) 69 | mask_type = torch.float32 if Net.n_classes == 1 else torch.long 70 | msk = msk.to(device=device, dtype=mask_type) 71 | weak_ann = weak_ann.to(device=device, dtype=mask_type) 72 | boundary = boundary.to(device=device, dtype=mask_type) 73 | msk_pred, B, R = Net(img, with_additional=True) 74 | loss = criteria(msk_pred, msk) 75 | loss_regions = criteria_region(weak_ann[:,0], R[:,:-1,0]) 76 | loss_boundary = criteria_boundary(B, boundary) 77 | tloss = (0.7*(loss)) + (0.1* loss_regions) + (0.2*loss_boundary) 78 | optimizer.zero_grad() 79 | tloss.backward() 80 | epoch_loss += tloss.item() 81 | optimizer.step() 82 | if itter%int(float(config['progress_p']) * len(train_loader))==0: 83 | print(f' Epoch>> {ep+1} and itteration {itter+1} Loss>> {((epoch_loss/(itter+1)))}') 84 | ## Validation phase 85 | with torch.no_grad(): 86 | print('val_mode') 87 | val_loss = 0 88 | Net.eval() 89 | for itter, batch in enumerate(val_loader): 90 | img = batch['image'].to(device, dtype=torch.float) 91 | msk = batch['mask'].to(device) 92 | mask_type = torch.float32 if Net.n_classes == 1 else torch.long 93 | msk = msk.to(device=device, dtype=mask_type) 94 | msk_pred = Net(img) 95 | loss = criteria(msk_pred, msk) 96 | val_loss += loss.item() 97 | print(f' validation on epoch>> {ep+1} dice loss>> {(abs(val_loss/(itter+1)))}') 98 | mean_val_loss = (val_loss/(itter+1)) 99 | # Check the performance and save the model 100 | if (mean_val_loss) < best_val_loss: 101 | print('New best loss, saving...') 102 | best_val_loss = copy.deepcopy(mean_val_loss) 103 | state = copy.deepcopy({'model_weights': Net.state_dict(), 'val_loss': best_val_loss}) 104 | torch.save(state, config['saved_model']) 105 | 106 | scheduler.step(mean_val_loss) 107 | 108 | print('Trainng phase finished') 109 | 110 | -------------------------------------------------------------------------------- /weights/readme.txt: -------------------------------------------------------------------------------- 1 | Weights will be save here --------------------------------------------------------------------------------