├── Figs ├── Network.png └── results.png ├── Preprocessing ├── octaaug.py └── pre_dis.m ├── README.md ├── basic_jointlearning.py ├── criterion.py ├── dataset.py ├── evaluate.py ├── inference ├── aug.py ├── boundry_loss.py └── hausdirff.py ├── losses.py ├── models.py ├── smp_model.py ├── train_cls.py ├── train_seg_clf.py ├── train_smp.py ├── train_smp_y.py └── utils.py /Figs/Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llmir/MultitaskOCTA/81d8811cd667402a2be649734358501222a20905/Figs/Network.png -------------------------------------------------------------------------------- /Figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llmir/MultitaskOCTA/81d8811cd667402a2be649734358501222a20905/Figs/results.png -------------------------------------------------------------------------------- /Preprocessing/octaaug.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | # from albumentations.pytorch import ToTensor 3 | import os, shutil 4 | import cv2 5 | import numpy as np 6 | import random 7 | 8 | 9 | 10 | def mkdir(path): 11 | # 去除首位空格 12 | path = path.strip() 13 | # 去除尾部 \ 符号 14 | path = path.rstrip("\\") 15 | # 判断路径是否存在 16 | # 存在 True 17 | # 不存在 False 18 | isExists = os.path.exists(path) 19 | # 判断结果 20 | if not isExists: 21 | # 如果不存在则创建目录 22 | # 创建目录操作函数 23 | os.makedirs(path) 24 | print(path + ' 创建成功') 25 | return True 26 | else: 27 | # 如果目录存在则不创建,并提示目录已存在 28 | print(path + ' 目录已存在') 29 | return False 30 | 31 | 32 | 33 | if __name__ == "__main__": 34 | 35 | basename = './dataset3/originaldata/' 36 | IMG_DIR = basename + "image" 37 | MASK_DIR = basename + "mask" 38 | 39 | # AUG_MASK_DIR = basename.replace('_ori/','_aug/') + "mask" # 存储增强后的XML文件夹路径 40 | AUG_MASK_DIR = basename + "aug_mask" 41 | try: 42 | shutil.rmtree(AUG_MASK_DIR) 43 | except FileNotFoundError as e: 44 | a = 1 45 | mkdir(AUG_MASK_DIR) 46 | 47 | # AUG_IMG_DIR = basename.replace('_ori/','_aug/') + "image" # 存储增强后的影像文件夹路径 48 | AUG_IMG_DIR = basename + "aug_img" 49 | try: 50 | shutil.rmtree(AUG_IMG_DIR) 51 | except FileNotFoundError as e: 52 | a = 1 53 | mkdir(AUG_IMG_DIR) 54 | 55 | AUGCROP_IMG_DIR = basename + "augcrre_img" 56 | try: 57 | shutil.rmtree(AUGCROP_IMG_DIR) 58 | except FileNotFoundError as e: 59 | a = 1 60 | mkdir(AUGCROP_IMG_DIR) 61 | 62 | AUGCROP_MA_DIR = basename + "augcrre_mask" 63 | try: 64 | shutil.rmtree(AUGCROP_MA_DIR) 65 | except FileNotFoundError as e: 66 | a = 1 67 | mkdir(AUGCROP_MA_DIR) 68 | 69 | AUGLOOP = 30 # 每张影像增强的数量 70 | 71 | aug = A.Compose([ 72 | A.RandomRotate90(), 73 | # albu.Cutout(), 74 | A.HorizontalFlip(), 75 | A.VerticalFlip(), 76 | A.OneOf([ 77 | 78 | A.augmentations.transforms.CLAHE(clip_limit=3), 79 | # A.augmentations.transforms.Downscale(scale_min=0.45, scale_max=0.95), 80 | 81 | A.augmentations.transforms.GaussNoise(var_limit=(20.0)), 82 | A.augmentations.transforms.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1), 83 | # A.imgaug.transforms.IAACropAndPad(percent=0.3, pad_mode="reflect"), 84 | A.imgaug.transforms.IAASharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5), 85 | # A.imgaug.transforms.IAACropAndPad(percent=0.3, pad_mode="reflect"), 86 | A.augmentations.transforms.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5), 87 | A.imgaug.transforms.IAAAdditiveGaussianNoise(loc=0, scale=(2.5500000000000003, 12.75), per_channel=False, always_apply=False, p=0.5) 88 | ], p= 0.6), 89 | A.OneOf([ 90 | A.augmentations.transforms.Blur(blur_limit=3), 91 | A.augmentations.transforms.GaussianBlur(blur_limit=3, sigma_limit=0, always_apply=False, p=0.5), 92 | A.augmentations.transforms.MedianBlur(blur_limit=3, always_apply=False, p=0.5), 93 | # A.augmentations.transforms.Blur(blur_limit=3, always_apply=False, p=0.5), 94 | # A.augmentations.transforms.MotionBlur(blur_limit=3), 95 | # A.augmentations.transforms.GlassBlur (sigma=0.7, max_delta=4, iterations=2, always_apply=False, mode='fast', p=0.5) 96 | # A.argumentations.transforms. 97 | ]), 98 | A.imgaug.transforms.IAAAffine(scale=1.0, translate_percent=0, translate_px=None, rotate=(-90, 90), shear=0.0, order=1, cval=0, mode='reflect'), 99 | # A.augmentations.geometric.rotate.Rotate(limit=90, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=0.5) 100 | # A.augmentations.transforms.PadIfNeeded (min_height=224, min_width=224, pad_height_divisor=None, pad_width_divisor=None, border_mode=4, value=None, mask_value=None, always_apply=False, p=1.0) 101 | ], p=1) 102 | 103 | random.seed(2021) #seed 104 | size = 192 105 | 106 | for root, sub_folders, files in os.walk(MASK_DIR): 107 | 108 | for name in files: 109 | print(name) 110 | 111 | '''bndbox = read_xml_annotation(MASK_DIR, name) 112 | shutil.copy(os.path.join(MASK_DIR, name), AUG_MASK_DIR) 113 | ''' 114 | # shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.png'), AUG_IMG_DIR) 115 | 116 | for epoch in range(AUGLOOP): 117 | # seq_det = seq.to_deterministic() # 保持坐标和图像同步改变,而不是随机 118 | # 读取图片 119 | img = cv2.imread(os.path.join(IMG_DIR, name[:-4] + '.png'), cv2.IMREAD_GRAYSCALE) 120 | # sp = img.size 121 | 122 | mask = cv2.imread(os.path.join(MASK_DIR, name[:-4] + '.png'), cv2.IMREAD_GRAYSCALE) 123 | augmented = aug(image=img, mask=mask) 124 | 125 | img_aug = augmented['image'] 126 | mask_aug = augmented['mask'] 127 | 128 | ret, mask_aug_thres = cv2.threshold(mask_aug,127,255,cv2.THRESH_BINARY) 129 | 130 | img_aug_path = os.path.join(AUG_IMG_DIR, name[:-4] + '_' + str(epoch+1) + '.png') 131 | mask_aug_path = os.path.join(AUG_MASK_DIR, name[:-4] + '_' + str(epoch+1) + '.png') 132 | cv2.imwrite(img_aug_path, img_aug) 133 | cv2.imwrite(mask_aug_path, mask_aug_thres) 134 | 135 | if ('_3_' in name) or ('area3' in name): 136 | img_re = cv2.resize(img_aug, (size, size), interpolation=cv2.INTER_CUBIC) 137 | mask_re = cv2.resize(mask_aug, (size, size), interpolation=cv2.INTER_CUBIC) 138 | # mask_re = cv2.threshold(mask_re,127,255,cv2.THRESH_BINARY) 139 | elif ('_6_' in name) or ('area6' in name): 140 | h, w = img_aug.shape 141 | # a = w/2-size/2 142 | # b = w/2+size/2 143 | img_re = img_aug[100:300,100:300] 144 | mask_re = mask_aug_thres[100:300,100:300] 145 | img_re = cv2.resize(img_re, (size, size), interpolation=cv2.INTER_CUBIC) 146 | mask_re = cv2.resize(mask_re, (size, size), interpolation=cv2.INTER_CUBIC) 147 | img_re_path = os.path.join(AUGCROP_IMG_DIR, name[:-4] + '_' + str(epoch+1) + '.png') 148 | mask_re_path = os.path.join(AUGCROP_MA_DIR, name[:-4] + '_' + str(epoch+1) + '.png') 149 | cv2.imwrite(img_re_path, img_re) 150 | cv2.imwrite(mask_re_path, mask_re) -------------------------------------------------------------------------------- /Preprocessing/pre_dis.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llmir/MultitaskOCTA/81d8811cd667402a2be649734358501222a20905/Preprocessing/pre_dis.m -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultitaskOCTA 2 | This repository is an official PyTorch implementation of paper: 3 | 4 | "BSDA-Net: A Boundary Shape and Distance Aware Joint Learning Framework for Segmenting and Classifying OCTA Images", MICCAI 2021. [**Student Travel Award**] paper link: [BSDA-Net](https://www.researchgate.net/publication/354793161_BSDA-Net_A_Boundary_Shape_and_Distance_Aware_Joint_Learning_Framework_for_Segmenting_and_Classifying_OCTA_Images) 5 | 6 | "Multi-task Learning Based Ocular Disease Discrimination and FAZ 7 | Segmentation Utilizing OCTA Images", EMBC 2021. paper link: [paper](https://www.researchgate.net/publication/356934116_Multi-task_Learning_Based_Ocular_Disease_Discrimination_and_FAZ_Segmentation_Utilizing_OCTA_Images) 8 | 9 | MICCAI 2021 10 | 11 | ![Network](https://github.com/llmir/MultitaskOCTA/blob/master/Figs/Network.png) 12 | 13 | EMBC 2021 [Coming soon] 14 | 15 | 16 | ## Dependencies 17 | 18 | ### Packages 19 | * Python 3.7 20 | * PyTorch >= 1.7.0 21 | * Numpy 22 | * Sklearn 23 | * Segmentation Models Pytorch 24 | * TensorboardX 25 | * OpenCV 26 | * numpy 27 | * Tqdm 28 | * surface-distance 29 | 30 | ### Datasets 31 | 32 | Now our processed datasets are avaliable here: [AliDrive](https://www.aliyundrive.com/s/eHpKveH3jfH) and [GoogleDrive](https://drive.google.com/drive/folders/1PIlDncAQUCG6-ffINujYOgNSJHdqLVcu?usp=sharing) 33 | 34 | ### Data Preprocessing 35 | Using the file *pre_dis.m* in Matlab formula for image preprocessing to generate *Boundary Heatmaps* and *Signed distanced maps (SDMs)* for training BSDA-Net. 36 | 37 | Run *octaaug.py* to start preprocessing automatically with preset directory value to make augment OCTA images, which is stored in local directory. 38 | 39 | ## Directory Structure 40 | ```bash 41 | ├── contour 42 | └── 1.png 43 | └── 2.png 44 | ... 45 | ├── dist_contour 46 | └── 1.mat 47 | └── 2.mat 48 | ... 49 | ├── dist_mask 50 | └── 1.mat 51 | └── 2.mat 52 | ... 53 | ├── dist_signed_01 54 | └── 1.mat 55 | └── 2.mat 56 | ... 57 | ├── dist_signed_11 (used in the MICCAI paper) 58 | └── 1.mat 59 | └── 2.mat 60 | ... 61 | ├──image 62 | └── 1.jpg 63 | └── 2.jpg 64 | ... 65 | ├── mask 66 | └── 1.jpg 67 | └── 2.jpg 68 | ... 69 | ``` 70 | 71 | ## Training Code 72 | To start training, you should set the parameters used for training: 73 | * train_path: Training image path. 74 | * val_path: Validation image path. 75 | * test_path: Testing image path. 76 | * save_path: Path for saving results. 77 | * train-type: Training type, including single classification & segmentation, cotraining or multitask. 78 | * model_type: Used for single segmentation, cotraining or multitask. The segmentation architecture used for training. 79 | * batch_size: Batch size for training stage. 80 | * val_batch_size: Batch size for validation. 81 | * num_epochs: Total number of epochs for training stage. 82 | * use_pretrained: Use pretrained weight on ImageNet or not. 83 | * loss_type: Loss used for training stage. 84 | * LR_seg: Learning rate setting for segmentation process. 85 | * LR_clf: Learning rate setting for classification process. 86 | * classnum: Used for single classification, cotraining or multitask. Class number for classification. 87 | 88 | For simply start training, you can use our preset shell file named *Demo.sh* with prepared dataset stored in local path. Or you can set the parameters listed above to define your own training architecture. The results will be stored at the local path with your dataset name as a folder named as *model_type+loss_type*. 89 | 90 | ## Testing code 91 | To start testing, you should set the parameters that is used for training to load the model file correctly: 92 | * train_path: Training image path. 93 | * val_path: Validation image path. 94 | * test_path: Testing image path. 95 | * save_path: Path for saving results. 96 | * train-type: Training type, including single classification & segmentation, cotraining or multitask. 97 | * model_type: Used for single segmentation, cotraining or multitask. The segmentation architecture used for training. 98 | * val_batch_size: Batch size for validation. 99 | * use_pretrained: Use pretrained weight on ImageNet or not. 100 | * loss_type: Loss used for training stage. 101 | 102 | ## Results 103 | From left to right, they are respectively representation of segmentation results of FAZ using different models. The bottom line represents corresponding boudnary heatmaps and signed distance maps for groundtruth. 104 | 105 | ![Results](https://github.com/llmir/MultitaskOCTA/blob/master/Figs/results.png) 106 | 107 | ## Citation 108 | L. Lin, Z. Wang, J. Wu, Y. Huang, J. Lyu, P. Cheng, J. Wu, X. Tang*, "BSDA-Net: a boundary shape and distance aware joint learning framework for segmenting and classifying OCTA images", In the 24th International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI), Strasbourg, France, September 2021. 109 | ## Contact 110 | Li Lin (linli@eee.hku.hk) 111 | 112 | Zhonghua Wang (Wzhjerry1112@gmail.com) 113 | 114 | ## Acknowledgements 115 | Thanks for [segmentation models pytorch](https://github.com/qubvel/segmentation_models.pytorch) for the implementation of the segmentation codes. 116 | -------------------------------------------------------------------------------- /basic_jointlearning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from smp_model_basic import MyUnetModel, my_get_encoder 4 | import torch 5 | import os 6 | import glob 7 | from torch.optim import Adam 8 | from tqdm import tqdm 9 | import logging 10 | from torch import nn 11 | import random 12 | from tensorboardX import SummaryWriter 13 | from utils import create_train_arg_parser, define_loss 14 | import segmentation_models_pytorch as smp 15 | import numpy as np 16 | from sklearn.metrics import cohen_kappa_score 17 | from utils import generate_dataset 18 | 19 | IN_MODELS = ['unet_smp', 'unet++', 'manet', 'linknet', 'fpn', 'pspnet', 'pan', 'deeplabv3', 'deeplabv3+'] 20 | 21 | 22 | def set_seed(seed): 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 29 | 30 | 31 | def segmentation_iteration(model, optimizer, model_type, criterion, data_loader, device, writer, training=False): 32 | running_loss = 0.0 33 | total_size = 0 34 | 35 | if training: 36 | model.train() 37 | torch.set_grad_enabled(True) 38 | else: 39 | model.eval() 40 | torch.set_grad_enabled(False) 41 | 42 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(tqdm(data_loader)): 43 | inputs = inputs.to(device) 44 | targets1, targets2 = targets1.to(device), targets2.to(device) 45 | targets3, targets4 = targets3.to(device), targets4.to(device) 46 | targets = [targets1, targets2, targets3, targets4] 47 | 48 | if training: 49 | optimizer.zero_grad() 50 | 51 | outputs = model(inputs) 52 | 53 | if model_type in IN_MODELS + ["unet"]: 54 | if not isinstance(outputs, list): 55 | outputs = [outputs] 56 | loss = criterion(outputs[0], targets[0]) 57 | dsc_loss = smp.utils.losses.DiceLoss() 58 | # 只写了一个例子 59 | # preds = torch.argmax(outputs[0].exp(), dim=1) 60 | preds = torch.argmax(torch.sigmoid(outputs[0]), dim=1) 61 | dsc = 1 - dsc_loss(preds, targets[0].squeeze(1)) 62 | 63 | elif model_type == "dcan": 64 | loss = criterion(outputs[0], outputs[1], targets[0], targets[1]) 65 | 66 | elif model_type == "dmtn": 67 | loss = criterion(outputs[0], outputs[1], targets[0], targets[2]) 68 | 69 | elif model_type in ["psinet", "convmcd"]: 70 | loss = criterion( 71 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2] 72 | ) 73 | else: 74 | raise ValueError('error') 75 | 76 | if training: 77 | # with amp.scale_loss(loss, optimizer) as scaled_loss: 78 | # scaled_loss.backward() 79 | loss.backward() 80 | optimizer.step() 81 | # scheduler.step() 82 | 83 | running_loss += loss.item() * inputs.size(0) 84 | total_size += inputs.size(0) 85 | 86 | epoch_loss = running_loss / total_size 87 | # print("total size:", total_size, training) 88 | 89 | return epoch_loss, dsc 90 | 91 | 92 | class AverageMeter(object): 93 | """Computes and stores the average and current value""" 94 | def __init__(self, name, fmt=':f'): 95 | self.name = name 96 | self.fmt = fmt 97 | self.reset() 98 | 99 | def reset(self): 100 | self.val = 0 101 | self.avg = 0 102 | self.sum = 0 103 | self.count = 0 104 | 105 | def update(self, val, n=1): 106 | self.val = val 107 | self.sum += val * n 108 | self.count += n 109 | self.avg = self.sum / self.count 110 | 111 | def __str__(self): 112 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 113 | return fmtstr.format(**self.__dict__) 114 | 115 | 116 | def seg_clf_iteration(epoch, model, optimizer, criterion, data_loader, device, writer, loss_weights, startpoint, training=False): 117 | seg_losses = AverageMeter("Loss", ".16f") 118 | seg_dices = AverageMeter("Dice", ".8f") 119 | seg_jaccards = AverageMeter("Jaccard", ".8f") 120 | clf_losses = AverageMeter("Loss", ".16f") 121 | clf_accs = AverageMeter("Acc", ".8f") 122 | clf_kappas = AverageMeter("Kappa", ".8f") 123 | 124 | if training: 125 | model.train() 126 | torch.set_grad_enabled(True) 127 | else: 128 | model.eval() 129 | torch.set_grad_enabled(False) 130 | 131 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(tqdm(data_loader)): 132 | inputs = inputs.to(device) 133 | targets1, targets2 = targets1.to(device), targets2.to(device) 134 | targets3, targets4 = targets3.to(device), targets4.to(device) 135 | targets = [targets1, targets2, targets3, targets4] 136 | 137 | if training: 138 | optimizer.zero_grad() 139 | 140 | seg_outputs = model.seg_forward(inputs) 141 | # seg_preds = torch.argmax(seg_outputs[0].exp(), dim=1) 142 | if not isinstance(seg_outputs, list): 143 | seg_outputs = [seg_outputs] 144 | 145 | seg_preds = torch.round(seg_outputs[0]) 146 | clf_outputs = model.clf_forward(inputs, seg_outputs[1]) 147 | 148 | # print(seg_criterion)j 149 | # print(seg_outputs[0].shape, targets[0].shape) 150 | seg_criterion, dice_criterion, jaccard_criterion, clf_criterion = criterion[0], criterion[1], criterion[2], criterion[3] 151 | seg_loss = seg_criterion(seg_outputs[0], targets[0].to(torch.float32)) 152 | # print(seg_preds.shape, targets[0].squeeze(1).shape) 153 | seg_dice = 1 - dice_criterion(seg_preds.squeeze(1), targets[0].squeeze(1)) 154 | seg_jaccard = 1 - jaccard_criterion(seg_preds.squeeze(1), targets[0].squeeze(1)) 155 | # seg_iou = smp.utils.metrics.IoU(threshold=0.5) 156 | 157 | # print(clf_outputs.shape, targets[3].shape) 158 | clf_labels = torch.argmax(targets[3], dim=2).squeeze(1) 159 | clf_preds = torch.argmax(clf_outputs, dim=1) 160 | clf_loss = clf_criterion(clf_outputs, clf_labels) 161 | kappa = cohen_kappa_score(clf_preds.detach().cpu().numpy(), clf_labels.detach().cpu().numpy()) 162 | # print(targets[3]) 163 | # print(clf_labels) 164 | # print(clf_preds) 165 | acc = np.mean(clf_labels.detach().cpu().numpy() == clf_preds.detach().cpu().numpy()) 166 | 167 | if training: 168 | if epoch <= startpoint: 169 | loss = seg_loss 170 | else: 171 | loss = (seg_loss + clf_loss) 172 | loss.backward() 173 | # with amp.scale_loss(loss, optimizer) as scaled_loss: 174 | # scaled_loss.backward() 175 | optimizer.step() 176 | # scheduler.step() 177 | 178 | seg_losses.update(seg_loss.item(), inputs.size(0)) 179 | seg_dices.update(seg_dice.item(), inputs.size(0)) 180 | seg_jaccards.update(seg_jaccard.item(), inputs.size(0)) 181 | clf_losses.update(clf_loss.item(), inputs.size(0)) 182 | clf_accs.update(acc, inputs.size(0)) 183 | clf_kappas.update(kappa, inputs.size(0)) 184 | 185 | seg_epoch_loss = seg_losses.avg 186 | seg_epoch_dice = seg_dices.avg 187 | seg_epoch_jaccard = seg_jaccards.avg 188 | clf_epoch_loss = clf_losses.avg 189 | clf_epoch_acc = clf_accs.avg 190 | clf_epoch_kappa = clf_kappas.avg 191 | # clf_epoch_loss = clf_running_loss / total_size 192 | # print("total size:", total_size, training, seg_epoch_loss) 193 | 194 | return seg_epoch_loss, seg_epoch_dice, seg_epoch_jaccard, clf_epoch_loss, clf_epoch_acc, clf_epoch_kappa 195 | 196 | 197 | class CotrainingModel(nn.Module): 198 | 199 | def __init__(self, encoder, pretrain, usenorm, attention_type, classnum): 200 | super().__init__() 201 | self.seg_model = MyUnetModel( 202 | encoder_name=encoder, encoder_depth=5, encoder_weights=pretrain, decoder_use_batchnorm=usenorm, 203 | decoder_channels=(256, 128, 64, 32, 16), decoder_attention_type=attention_type, in_channels=1, classes=1, 204 | activation='sigmoid', aux_params=None 205 | ) 206 | self.clf_model = my_get_encoder(encoder, in_channels=1, depth=5, weights=pretrain, num_classes=classnum) 207 | 208 | def seg_forward(self, x): 209 | return self.seg_model(x) 210 | 211 | def clf_forward(self, x, decoder_features): 212 | return self.clf_model(x, decoder_features) 213 | 214 | 215 | def main(): 216 | with torch.backends.cudnn.flags(enabled=True, benchmark=True, deterministic=False, allow_tf32=False): 217 | torch.set_num_threads(4) 218 | set_seed(2021) 219 | 220 | args = create_train_arg_parser().parse_args() 221 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 222 | 223 | log_path = os.path.join(args.save_path, "summary/") 224 | writer = SummaryWriter(log_dir=log_path) 225 | rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 226 | log_name = os.path.join(log_path, str(rq) + '.log') 227 | logging.basicConfig( 228 | filename=log_name, 229 | filemode="a", 230 | format="%(asctime)s %(levelname)s %(message)s", 231 | datefmt="%Y-%m-%d %H:%M", 232 | level=logging.INFO, 233 | ) 234 | logging.info(args) 235 | 236 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 237 | 238 | encoder = args.encoder 239 | usenorm = args.usenorm 240 | attention_type = args.attention 241 | if args.pretrain in ['imagenet', 'ssl', 'swsl']: 242 | pretrain = args.pretrain 243 | # preprocess_input = get_preprocessing_fn(encoder, pretrain) 244 | else: 245 | pretrain = None 246 | # preprocess_input = get_preprocessing_fn(encoder) 247 | model = CotrainingModel(encoder, pretrain, usenorm, attention_type, args.classnum).to(device) 248 | logging.info(model) 249 | # seg_criterion = smp.utils.losses.DiceLoss() 250 | # seg_dice_criterion = smp.utils.losses.DiceLoss() 251 | # clf_criterion = smp.utils.losses.CrossEntropyLoss() 252 | criterion = [ 253 | define_loss(args.loss_type), 254 | smp.utils.losses.DiceLoss(), 255 | smp.utils.losses.JaccardLoss(), 256 | smp.utils.losses.CrossEntropyLoss() 257 | ] 258 | 259 | optimizer = Adam([ 260 | {"params": model.seg_model.parameters(), "lr": args.LR_seg}, 261 | {"params": model.clf_model.parameters(), "lr": args.LR_clf} 262 | ]) 263 | 264 | # model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 265 | 266 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 267 | random.shuffle(train_file_names) 268 | val_file_names = glob.glob(os.path.join(args.val_path, "*.png")) 269 | 270 | # train_dataset = DatasetImageMaskContourDist(train_file_names, args.distance_type) 271 | # valid_dataset = DatasetImageMaskContourDist(val_file_names, args.distance_type) 272 | # trainLoader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8) 273 | # devLoader = DataLoader(valid_dataset, batch_size=args.val_batch_size, num_workers=4) 274 | trainLoader, devLoader = generate_dataset(train_file_names, val_file_names, args.batch_size, args.val_batch_size, args.distance_type, args.clahe) 275 | 276 | epoch_start = 0 277 | max_dice = 0.8 278 | max_acc = 0.6 279 | loss_weights = [1, 1, 1] 280 | startpoint = args.startpoint 281 | 282 | for epoch in range(epoch_start + 1, epoch_start + 1 + args.num_epochs): 283 | 284 | print('\nEpoch: {}'.format(epoch)) 285 | training_seg_loss, training_seg_dice, training_seg_jaccard, training_clf_loss, training_clf_acc, training_clf_kappa = seg_clf_iteration(epoch, model, optimizer, criterion, trainLoader, device, writer, loss_weights, startpoint, training=True) 286 | dev_seg_loss, dev_seg_dice, dev_seg_jaccard, dev_clf_loss, dev_clf_acc, dev_clf_kappa = seg_clf_iteration(epoch, model, optimizer, criterion, devLoader, device, writer, loss_weights, startpoint, training=False) 287 | 288 | epoch_info = "Epoch: {}".format(epoch) 289 | train_info = "TrainSeg Loss:{:.7f}, Dice: {:.7f}, Jaccard: {:.7f}, TrainClf Loss:{:.7f}, Acc: {:.7f}, Kappa:{:.7f}".format(training_seg_loss, training_seg_dice, training_seg_jaccard, training_clf_loss, training_clf_acc, training_clf_kappa) 290 | val_info = "ValSeg Loss:{:.7f}, Dice: {:.7f}, Jaccard: {:.7f}, ValClf Loss:{:.7f}, Acc: {:.7f}, Kappa:{:.7f}:".format(dev_seg_loss, dev_seg_dice, dev_seg_jaccard, dev_clf_loss, dev_clf_acc, dev_clf_kappa) 291 | print(train_info) 292 | print(val_info) 293 | logging.info(epoch_info) 294 | logging.info(train_info) 295 | logging.info(val_info) 296 | writer.add_scalar("trainseg_loss", training_seg_loss, epoch) 297 | writer.add_scalar("trainseg_dice", training_seg_dice, epoch) 298 | writer.add_scalar("trainseg_jaccard", training_seg_jaccard, epoch) 299 | writer.add_scalar("traincls_loss", training_clf_loss, epoch) 300 | writer.add_scalar("traincls_acc", training_clf_acc, epoch) 301 | writer.add_scalar("traincls_kappa", training_clf_kappa, epoch) 302 | 303 | writer.add_scalar("valseg_loss", dev_seg_loss, epoch) 304 | writer.add_scalar("valseg_dice", dev_seg_dice, epoch) 305 | writer.add_scalar("valseg_jaccard", dev_seg_jaccard, epoch) 306 | writer.add_scalar("valcls_loss", dev_clf_loss, epoch) 307 | writer.add_scalar("valcls_acc", dev_clf_acc, epoch) 308 | writer.add_scalar("valcls_kappa", dev_clf_kappa, epoch) 309 | 310 | best_name = os.path.join(args.save_path, "dice_" + str(round(dev_seg_dice, 5)) + "_jaccard_" + str(round(dev_seg_jaccard, 5)) + "_acc_" + str(round(dev_clf_acc, 4)) + "_kap_" + str(round(dev_clf_kappa, 4)) + ".pt") 311 | save_name = os.path.join(args.save_path, str(epoch) + "_dice_" + str(round(dev_seg_dice, 5)) + "_jaccard_" + str(round(dev_seg_jaccard, 5)) + "_acc_" + str(round(dev_clf_acc, 4)) + "_kap_" + str(round(dev_clf_kappa, 4)) + ".pt") 312 | 313 | if max_dice <= dev_seg_dice: 314 | max_dice = dev_seg_dice 315 | if torch.cuda.device_count() > 1: 316 | torch.save(model.module.state_dict(), best_name) 317 | else: 318 | torch.save(model.state_dict(), best_name) 319 | print('Best seg model saved!') 320 | logging.warning('Best seg model saved!') 321 | if max_acc <= dev_clf_acc: 322 | max_acc = dev_clf_acc 323 | if torch.cuda.device_count() > 1: 324 | torch.save(model.module.state_dict(), best_name) 325 | else: 326 | torch.save(model.state_dict(), best_name) 327 | print('Best clf model saved!') 328 | logging.warning('Best clf model saved!') 329 | 330 | if epoch % 50 == 0: 331 | if torch.cuda.device_count() > 1: 332 | torch.save(model.module.state_dict(), save_name) 333 | print('Epoch {} model saved!'.format(epoch)) 334 | else: 335 | torch.save(model.state_dict(), save_name) 336 | print('Epoch {} model saved!'.format(epoch)) 337 | 338 | 339 | if __name__ == "__main__": 340 | main() 341 | -------------------------------------------------------------------------------- /criterion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 14 10:40:54 2019 5 | 6 | @author: wujon 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import numpy as np 14 | import surface_distance 15 | # import nibabel as ni 16 | import scipy.io 17 | import scipy.spatial 18 | import xlwt 19 | import os 20 | import cv2 21 | from skimage import morphology 22 | from skimage.morphology import thin 23 | from sklearn.metrics import confusion_matrix, jaccard_score, f1_score 24 | 25 | 26 | os.chdir('./') 27 | 28 | 29 | predictName = 'cotrain_192_pad' 30 | predictPath = './smp/' + predictName + '/' 31 | labelPath = "./smp/mask_ori_f1/" 32 | name_experiment = 'exp_test' 33 | path_experiment = './' + name_experiment + '/' 34 | 35 | # labelPath = "./gt/" 36 | # outpredictPath = "./gt_poor_o_thin/" 37 | # outlabelPath = "./gt_o_thin/" 38 | 39 | 40 | def getDSC(testImage, resultImage): 41 | """Compute the Dice Similarity Coefficient.""" 42 | testArray = testImage.flatten() 43 | resultArray = resultImage.flatten() 44 | 45 | return 1.0 - scipy.spatial.distance.dice(testArray, resultArray) 46 | 47 | 48 | def getJaccard(testImage, resultImage): 49 | """Compute the Dice Similarity Coefficient.""" 50 | testArray = testImage.flatten() 51 | resultArray = resultImage.flatten() 52 | 53 | return 1.0 - scipy.spatial.distance.jaccard(testArray, resultArray) 54 | 55 | 56 | def getPrecisionAndRecall(testImage, resultImage): 57 | testArray = testImage.flatten() 58 | resultArray = resultImage.flatten() 59 | 60 | TP = np.sum(testArray*resultArray) 61 | FP = np.sum((1-testArray)*resultArray) 62 | FN = np.sum(testArray*(1-resultArray)) 63 | 64 | precision = TP/(TP+FP) 65 | recall = TP/(TP+FN) 66 | 67 | return precision, recall 68 | 69 | 70 | def intersection(testImage, resultImage): 71 | testSkel = morphology.skeletonize(testImage) 72 | testSkel = testSkel.astype(int) 73 | resultSkel = morphology.skeletonize(resultImage) 74 | resultSkel = resultSkel.astype(int) 75 | 76 | testArray = testImage.flatten() 77 | resultArray = resultImage.flatten() 78 | 79 | testSkel = testSkel.flatten() 80 | resultSkel = resultSkel.flatten() 81 | 82 | recall = np.sum(resultSkel * testArray) / (np.sum(testSkel)) 83 | precision = np.sum(resultArray * testSkel) / (np.sum(testSkel)) 84 | 85 | intersection = 2 * precision * recall / (precision + recall) 86 | return intersection 87 | 88 | 89 | if __name__ == "__main__": 90 | labelList = os.listdir(labelPath) 91 | # labelList.sort(key = lambda x: int(x[:-4])) 92 | img_nums = len(labelList) 93 | 94 | Q1 = [] 95 | Q2 = [] 96 | Q3 = [] 97 | Q4 = [] 98 | Q5 = [] 99 | Q6 = [] 100 | Q7 = [] 101 | Q8 = [] 102 | Q9 = [] 103 | Q10 = [] 104 | Q11 = [] 105 | Q12 = [] 106 | Q13 = [] 107 | Q14 = [] 108 | Q15 = [] 109 | Q16 = [] 110 | Q17 = [] 111 | 112 | book = xlwt.Workbook(encoding='utf-8', style_compression=0) 113 | sheet = book.add_sheet('mysheet', cell_overwrite_ok=True) 114 | row_num = 0 115 | sheet.write(row_num, 0, 'CaseName') 116 | sheet.write(row_num, 1, 'DSC') 117 | sheet.write(row_num, 2, 'Pre') 118 | sheet.write(row_num, 3, 'Recall') 119 | sheet.write(row_num, 4, 'HD') 120 | sheet.write(row_num, 5, 'ASSD') 121 | sheet.write(row_num, 6, 'surface_dice_0') 122 | sheet.write(row_num, 7, 'rel_overlap_gt') 123 | sheet.write(row_num, 8, 'rel_overlap_pred') 124 | sheet.write(row_num, 9, 'intersec') 125 | sheet.write(row_num, 10, 'HD_thin') 126 | sheet.write(row_num, 11, 'ASSD_thin') 127 | sheet.write(row_num, 12, 'surface_dice_1') 128 | sheet.write(row_num, 13, 'surface_dice_2') 129 | sheet.write(row_num, 14, 'Jaccard') 130 | sheet.write(row_num, 15, 'acc') 131 | sheet.write(row_num, 16, 'spe') 132 | sheet.write(row_num, 17, 'sen') 133 | 134 | for idx, filename in enumerate(labelList): 135 | label = cv2.imread(labelPath + filename, 0) 136 | # print (label.dtype) 137 | 138 | # label = cv2.imread(labelPath + filename) 139 | label[label < 50] = 0 140 | label[label >= 50] = 1 141 | 142 | thinned_label = thin(label) 143 | # cv2.imwrite(outlabelPath+filename,(thinned_label*255).astype(np.uint8)) 144 | # ret,label = cv2.threshold(label,127,255,cv2.THRESH_BINARY) 145 | predict = cv2.imread(predictPath + filename.replace('_manual.png', '_expert.png'), 0) 146 | # print(predictPath + filename) 147 | # print (predict.dtype) 148 | # ret,predict = cv2.threshold(predict,127,255,cv2.THRESH_BINARY) 149 | # predict = cv2.imread(predictPath + filename) 150 | # predict = predict / 255 151 | predict[predict < 127] = 0 152 | predict[predict >= 127] = 1 153 | 154 | # ============================================================================================================================================================================== 155 | y_scores = cv2.imread(predictPath + filename.replace('_manual.png', '_expert.png'), 0) # ##################################################################### 156 | y_scores = np.asarray(y_scores.flatten())/255. 157 | y_scores = y_scores[:, np.newaxis] 158 | # print(y_scores.shape) 159 | y_true = cv2.imread(labelPath + filename, 0) 160 | y_true = np.asarray(y_true.flatten())/255. 161 | 162 | # fpr, tpr, thresholds = roc_curve((y_true), y_scores) 163 | # AUC_ROC = roc_auc_score(y_true, y_scores) 164 | # # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration 165 | # print ("\nArea under the ROC curve: " +str(AUC_ROC)) 166 | # roc_curve =plt.figure() 167 | # plt.plot(fpr,tpr,'-',label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC) 168 | # plt.title('ROC curve') 169 | # plt.xlabel("FPR (False Positive Rate)") 170 | # plt.ylabel("TPR (True Positive Rate)") 171 | # plt.legend(loc="lower right") 172 | # plt.savefig(path_experiment+"ROC.png") 173 | # precision, recall, thresholds = precision_recall_curve(y_true, y_scores) 174 | # precision = np.fliplr([precision])[0] #so the array is increasing (you won't get negative AUC) 175 | # recall = np.fliplr([recall])[0] #so the array is increasing (you won't get negative AUC) 176 | # AUC_prec_rec = np.trapz(precision,recall) 177 | # print ("\nArea under Precision-Recall curve: " +str(AUC_prec_rec)) 178 | # prec_rec_curve = plt.figure() 179 | # plt.plot(recall,precision,'-',label='Area Under the Curve (AUC = %0.4f)' % AUC_prec_rec) 180 | # plt.title('Precision - Recall curve') 181 | # plt.xlabel("Recall") 182 | # plt.ylabel("Precision") 183 | # plt.legend(loc="lower right") 184 | # plt.savefig(path_experiment+"Precision_recall.png") 185 | 186 | # def best_f1_threshold(precision, recall, thresholds): 187 | # best_f1=-1 188 | # for index in range(len(precision)): 189 | # curr_f1=2.*precision[index]*recall[index]/(precision[index]+recall[index]) 190 | # if best_f1= threshold_confusion: 209 | y_pred[i] = 1 210 | else: 211 | y_pred[i] = 0 212 | # print(np.unique(y_pred)) 213 | # print(np.unique(y_true)) 214 | confusion = confusion_matrix(y_true, y_pred) 215 | # print (confusion) 216 | accuracy = 0 217 | if float(np.sum(confusion)) != 0: 218 | accuracy = float(confusion[0, 0]+confusion[1, 1])/float(np.sum(confusion)) 219 | # print ("Global Accuracy: " +str(accuracy)) 220 | specificity = 0 221 | if float(confusion[0, 0]+confusion[0, 1]) != 0: # 00 tn 11 tp 10 fn 01 fp 222 | specificity = float(confusion[0, 0])/float(confusion[0, 0]+confusion[0, 1]) 223 | # print ("Specificity: " +str(specificity)) 224 | sensitivity = 0 225 | if float(confusion[1, 1]+confusion[1, 0]) != 0: 226 | sensitivity = float(confusion[1, 1])/float(confusion[1, 1]+confusion[1, 0]) 227 | # print ("Sensitivity: " +str(sensitivity)) 228 | precision = 0 229 | if float(confusion[1, 1]+confusion[0, 1]) != 0: 230 | precision = float(confusion[1, 1])/float(confusion[1, 1]+confusion[0, 1]) 231 | # print ("Precision: " +str(precision)) 232 | 233 | if float(confusion[1, 1]+confusion[0, 1]) != 0: 234 | PPV = float(confusion[1, 1])/float(confusion[1, 1]+confusion[0, 1]) 235 | # print ("PPV: " +str(PPV)) 236 | 237 | # Jaccard similarity index 238 | jaccard_index = jaccard_score(y_true, y_pred) 239 | print("\nJaccard similarity score: " + str(jaccard_index)) 240 | 241 | # F1 score 242 | F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None) 243 | # print ("\nF1 score (F-measure): " +str(F1_score)) 244 | 245 | # Save the results 246 | # file_perf = open(path_experiment+'performances.txt', 'w') 247 | # # file_perf.write("Area under the ROC curve: "+str(AUC_ROC) 248 | # # + "\nArea under Precision-Recall curve: " +str(AUC_prec_rec) 249 | # # + "\nJaccard similarity score: " +str(jaccard_index) 250 | # # + "\nF1 score (F-measure): " +str(F1_score) 251 | # # +"\n\nConfusion matrix:" 252 | # # +str(confusion) 253 | # # +"\nACCURACY: " +str(accuracy) 254 | # # +"\nSENSITIVITY: " +str(sensitivity) 255 | # # +"\nSPECIFICITY: " +str(specificity) 256 | # # +"\nPRECISION: " +str(precision) 257 | # # +"\nRECALL: " +str(sensitivity) 258 | # # +"\nPPV: " +str(PPV) 259 | # # +"\nbest_th: " +str(best_threshold) 260 | # # +"\nbest_f1: " +str(best_f1) 261 | # # ) 262 | # file_perf.write( 263 | # "\nJaccard similarity score: " +str(jaccard_index) 264 | # + "\nF1 score (F-measure): " +str(F1_score) 265 | # +"\n\nConfusion matrix:" 266 | # +str(confusion) 267 | # +"\nACCURACY: " +str(accuracy) 268 | # +"\nSENSITIVITY: " +str(sensitivity) 269 | # +"\nSPECIFICITY: " +str(specificity) 270 | # +"\nPRECISION: " +str(precision) 271 | # +"\nRECALL: " +str(sensitivity) 272 | # +"\nPPV: " +str(PPV) 273 | # ) 274 | # file_perf.close() 275 | # #============================================================================================================================================================================== 276 | 277 | thinned_predict = thin(predict) 278 | # cv2.imwrite(outpredictPath+filename,(thinned_predict*255).astype(np.uint8)) 279 | 280 | # predict[predict>=1] = 1 281 | # dice = getDSC(predict, label) 282 | # print("filename:" , filename , "dice:" , dice) 283 | # dice_res = "the " + filename[:-4] + " image's DSC : " + str(round(dice,4)) + "\n" 284 | DSC = getDSC(label, predict) 285 | # surface_distances = surface_distance.compute_surface_distances(label, predict, spacing_mm=(1, 1, 1)) 286 | # HD = surface_distance.compute_robust_hausdorff(surface_distances, 95) 287 | 288 | # distances_gt_to_pred = surface_distances["distances_gt_to_pred"] 289 | # distances_pred_to_gt = surface_distances["distances_pred_to_gt"] 290 | # surfel_areas_gt = surface_distances["surfel_areas_gt"] 291 | # surfel_areas_pred = surface_distances["surfel_areas_pred"] 292 | 293 | # ASSD = (np.sum(distances_pred_to_gt * surfel_areas_pred) +np.sum(distances_gt_to_pred * surfel_areas_gt))/(np.sum(surfel_areas_gt)+np.sum(surfel_areas_pred)) 294 | Jaccard = getJaccard(label, predict) 295 | 296 | precision, recall = getPrecisionAndRecall(label, predict) 297 | intersec = intersection(label, predict) 298 | 299 | label = np.array(label, dtype=bool) 300 | predict = np.array(predict, dtype=bool) 301 | 302 | surface_distances = surface_distance.compute_surface_distances(label, predict, spacing_mm=(1, 1)) 303 | 304 | surface_distances_thin = surface_distance.compute_surface_distances(thinned_label, thinned_predict, spacing_mm=(1, 1)) 305 | 306 | HD = surface_distance.compute_robust_hausdorff(surface_distances, 95) 307 | 308 | HD_thin = surface_distance.compute_robust_hausdorff(surface_distances_thin, 95) 309 | 310 | surface_dice_2 = surface_distance.compute_surface_dice_at_tolerance(surface_distances, 2) 311 | rel_overlap_gt, rel_overlap_pred = surface_distance.compute_surface_overlap_at_tolerance(surface_distances, 2) 312 | surface_dice_1 = surface_distance.compute_surface_dice_at_tolerance(surface_distances, 1) 313 | surface_dice_0 = surface_distance.compute_surface_dice_at_tolerance(surface_distances, 0) 314 | surface_dice_3 = surface_distance.compute_surface_dice_at_tolerance(surface_distances, 3) 315 | 316 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"] 317 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"] 318 | surfel_areas_gt = surface_distances["surfel_areas_gt"] 319 | surfel_areas_pred = surface_distances["surfel_areas_pred"] 320 | 321 | ASSD = (np.sum(distances_pred_to_gt * surfel_areas_pred) + np.sum(distances_gt_to_pred * surfel_areas_gt))/(np.sum(surfel_areas_gt)+np.sum(surfel_areas_pred)) 322 | 323 | distances_gt_to_pred_t = surface_distances_thin["distances_gt_to_pred"] 324 | distances_pred_to_gt_t = surface_distances_thin["distances_pred_to_gt"] 325 | surfel_areas_gt_t = surface_distances_thin["surfel_areas_gt"] 326 | surfel_areas_pred_t = surface_distances_thin["surfel_areas_pred"] 327 | 328 | ASSD_thin = (np.sum(distances_pred_to_gt_t * surfel_areas_pred_t) + np.sum(distances_gt_to_pred_t * surfel_areas_gt_t))/(np.sum(surfel_areas_gt_t)+np.sum(surfel_areas_pred_t)) 329 | 330 | # print(surface_overlap) 331 | row_num += 1 332 | sheet.write(row_num, 0, filename) 333 | sheet.write(row_num, 1, DSC) 334 | sheet.write(row_num, 2, precision) 335 | sheet.write(row_num, 3, recall) 336 | sheet.write(row_num, 4, HD) 337 | sheet.write(row_num, 5, ASSD) 338 | sheet.write(row_num, 6, surface_dice_0) 339 | sheet.write(row_num, 7, rel_overlap_gt) 340 | sheet.write(row_num, 8, rel_overlap_pred) 341 | sheet.write(row_num, 9, intersec) 342 | sheet.write(row_num, 10, HD_thin) 343 | sheet.write(row_num, 11, ASSD_thin) 344 | sheet.write(row_num, 12, surface_dice_1) 345 | sheet.write(row_num, 13, surface_dice_2) 346 | # sheet.write(row_num, 14, surface_dice_3) 347 | sheet.write(row_num, 14, Jaccard) 348 | sheet.write(row_num, 15, accuracy) 349 | sheet.write(row_num, 16, specificity) 350 | sheet.write(row_num, 17, sensitivity) 351 | 352 | Q1.append(DSC) 353 | Q2.append(precision) 354 | Q3.append(recall) 355 | Q4.append(HD) 356 | Q5.append(ASSD) 357 | Q6.append(surface_dice_0) 358 | Q7.append(rel_overlap_gt) 359 | Q8.append(rel_overlap_pred) 360 | Q9.append(intersec) 361 | Q10.append(HD_thin) 362 | Q11.append(ASSD_thin) 363 | Q12.append(surface_dice_1) 364 | Q13.append(surface_dice_2) 365 | # Q14.append(surface_dice_3) 366 | Q14.append(Jaccard) 367 | Q15.append(accuracy) 368 | Q16.append(specificity) 369 | Q17.append(sensitivity) 370 | 371 | Q1 = np.array(Q1) 372 | Q2 = np.array(Q2) 373 | Q3 = np.array(Q3) 374 | Q4 = np.array(Q4) 375 | Q5 = np.array(Q5) 376 | Q6 = np.array(Q6) 377 | Q7 = np.array(Q7) 378 | Q8 = np.array(Q8) 379 | Q9 = np.array(Q9) 380 | Q10 = np.array(Q10) 381 | Q11 = np.array(Q11) 382 | Q12 = np.array(Q12) 383 | Q13 = np.array(Q13) 384 | Q14 = np.array(Q14) 385 | Q15 = np.array(Q15) 386 | Q16 = np.array(Q16) 387 | Q17 = np.array(Q17) 388 | 389 | row_num += 2 390 | sheet.write(row_num, 0, 'CaseName') 391 | sheet.write(row_num, 1, 'DSC') 392 | sheet.write(row_num, 2, 'Pre') 393 | sheet.write(row_num, 3, 'Recall') 394 | sheet.write(row_num, 4, 'HD') 395 | sheet.write(row_num, 5, 'ASSD') 396 | sheet.write(row_num, 6, 'surface_dice_0') 397 | sheet.write(row_num, 7, 'rel_overlap_gt') 398 | sheet.write(row_num, 8, 'rel_overlap_pred') 399 | sheet.write(row_num, 9, 'intersec') 400 | sheet.write(row_num, 10, 'HD_thin') 401 | sheet.write(row_num, 11, 'ASSD_thin') 402 | sheet.write(row_num, 12, 'surface_dice_1') 403 | sheet.write(row_num, 13, 'surface_dice_2') 404 | sheet.write(row_num, 14, 'Jaccard') 405 | sheet.write(row_num, 15, 'accuracy') 406 | sheet.write(row_num, 16, 'specificity') 407 | sheet.write(row_num, 17, 'sensitivity') 408 | 409 | row_num += 1 410 | sheet.write(row_num, 0, predictName) 411 | sheet.write(row_num, 1, Q1.mean()) 412 | sheet.write(row_num, 2, Q2.mean()) 413 | sheet.write(row_num, 3, Q3.mean()) 414 | sheet.write(row_num, 4, Q4.mean()) 415 | sheet.write(row_num, 5, Q5.mean()) 416 | sheet.write(row_num, 6, Q6.mean()) 417 | sheet.write(row_num, 7, Q7.mean()) 418 | sheet.write(row_num, 8, Q8.mean()) 419 | sheet.write(row_num, 9, Q9.mean()) 420 | sheet.write(row_num, 10, Q10.mean()) 421 | sheet.write(row_num, 11, Q11.mean()) 422 | sheet.write(row_num, 12, Q12.mean()) 423 | sheet.write(row_num, 13, Q13.mean()) 424 | sheet.write(row_num, 14, Q14.mean()) 425 | sheet.write(row_num, 15, Q15.mean()) 426 | sheet.write(row_num, 16, Q16.mean()) 427 | sheet.write(row_num, 17, Q17.mean()) 428 | 429 | book.save('./smp/' + predictName + '.xls') 430 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from scipy.integrate._ivp.radau import P 2 | import torch 3 | import os 4 | import random 5 | import numpy as np 6 | import cv2 7 | import kornia 8 | import kornia.augmentation as K 9 | 10 | from tqdm import tqdm 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | from scipy import io 15 | from scipy.ndimage import distance_transform_edt 16 | 17 | 18 | def mean_and_std(paths): 19 | print('Calculating mean and std of training set for data normalization.') 20 | m_list, s_list = [], [] 21 | for img_filename in tqdm(paths): 22 | img = cv2.imread(img_filename) 23 | m, s = cv2.meanStdDev(img) 24 | m_list.append(m.reshape((3,))) 25 | s_list.append(s.reshape((3,))) 26 | m_array = np.array(m_list) 27 | s_array = np.array(s_list) 28 | m = m_array.mean(axis=0, keepdims=True) 29 | s = s_array.mean(axis=0, keepdims=True) 30 | m = m[0][::-1][0]/255 31 | s = s[0][::-1][0]/255 32 | print(m) 33 | print(s) 34 | 35 | return m, s 36 | 37 | 38 | class DatasetImageMaskContourDist(Dataset): 39 | 40 | # dataset_type(cup,disc,polyp), 41 | # distance_type(dist_mask,dist_contour,dist_signed) 42 | 43 | def __init__(self, file_names, distance_type, mean, std, clahe): 44 | 45 | self.file_names = file_names 46 | self.distance_type = distance_type 47 | self.mean = mean 48 | self.std = std 49 | self.clahe = clahe 50 | 51 | def __len__(self): 52 | 53 | return len(self.file_names) 54 | 55 | def __getitem__(self, idx): 56 | 57 | img_file_name = self.file_names[idx] 58 | image = load_image(img_file_name, self.mean, self.std, self.clahe) 59 | mask = load_mask(img_file_name) 60 | contour = load_contourheat(img_file_name) 61 | dist = load_distance(img_file_name, self.distance_type) 62 | cls = load_class(img_file_name) 63 | 64 | return img_file_name, image, mask, contour, dist, cls 65 | # return image, mask 66 | 67 | 68 | def clahe_equalized(imgs): 69 | # print(imgs.shape) 70 | # assert (len(imgs.shape)==4) #4D arrays 71 | # assert (imgs.shape[1]==1) #check the channel is 1 72 | # create a CLAHE object (Arguments are optional). 73 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) 74 | imgs_equalized = np.empty(imgs.shape) 75 | for i in range(imgs.shape[0]): 76 | imgs_equalized[i, 0] = clahe.apply(np.array(imgs[i, 0], dtype=np.uint8)) 77 | return imgs_equalized 78 | 79 | 80 | def load_image(path, mean, std, clahe): 81 | 82 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 83 | # print(type(img)) 84 | # if clahe: 85 | # img = clahe_equalized(img) 86 | # img = np.array(img ,dtype=np.float32) 87 | 88 | data_transforms = transforms.Compose( 89 | [ 90 | # transforms.Resize(416), 91 | transforms.ToTensor(), 92 | # transforms.Normalize([0.445,], [0.222,]), 93 | transforms.Normalize([mean, ], [std, ]), 94 | ] 95 | ) 96 | img = data_transforms(img) 97 | 98 | return img 99 | 100 | 101 | def load_mask(path): 102 | 103 | mask = cv2.imread(path.replace("image", "mask").replace("png", "png"), 0) 104 | mask[mask == 255] = 1 105 | 106 | return torch.from_numpy(np.expand_dims(mask, 0)).long() 107 | 108 | 109 | def load_contour(path): 110 | 111 | contour = cv2.imread(path.replace("image", "contour").replace("png", "png"), 0) 112 | contour[contour == 255] = 1 113 | 114 | return torch.from_numpy(np.expand_dims(contour, 0)).float() 115 | 116 | 117 | def load_contourheat(path): 118 | 119 | path = path.replace("image", "contour").replace("png", "mat") 120 | contour = io.loadmat(path)["contour"] 121 | 122 | return torch.from_numpy(np.expand_dims(contour, 0)).float() 123 | 124 | 125 | def load_class(path): 126 | 127 | cls0 = [1, 0, 0] 128 | cls1 = [0, 1, 0] 129 | cls2 = [0, 0, 1] 130 | if 'N' in os.path.basename(path): 131 | cls = cls0 132 | if 'D' in os.path.basename(path): 133 | cls = cls1 134 | if 'M' in os.path.basename(path): 135 | cls = cls2 136 | 137 | return torch.from_numpy(np.expand_dims(cls, 0)).long() 138 | 139 | 140 | def load_distance(path, distance_type): 141 | 142 | if distance_type == "dist_mask": 143 | path = path.replace("image", "dis_mask").replace("png", "mat") 144 | # print (path) 145 | # print (io.loadmat(path)) 146 | dist = io.loadmat(path)["dis"] 147 | 148 | if distance_type == "dist_contour": 149 | path = path.replace("image", "dis_contour").replace("png", "mat") 150 | dist = io.loadmat(path)["c_dis"] 151 | 152 | if distance_type == "dist_signed01": 153 | path = path.replace("image", "dis_signed01").replace("png", "mat") 154 | dist = io.loadmat(path)["s_dis01"] 155 | 156 | if distance_type == "dist_signed11": 157 | path = path.replace("image", "dis_signed11").replace("png", "mat") 158 | dist = io.loadmat(path)["s_dis11"] 159 | 160 | if distance_type == "dist_fore": 161 | path = path.replace("image", "dis_fore").replace("png", "mat") 162 | dist = io.loadmat(path)["f_dis"] 163 | 164 | return torch.from_numpy(np.expand_dims(dist, 0)).float() 165 | 166 | 167 | class DatasetCornea(Dataset): 168 | 169 | def __init__(self, file_names, targetpaths): 170 | 171 | self.file_names = file_names 172 | self.targetpaths = targetpaths 173 | self.im_transform = transforms.Compose([ 174 | transforms.RandomHorizontalFlip(), 175 | transforms.RandomVerticalFlip(), 176 | transforms.Resize(512), 177 | transforms.ColorJitter(0.2, 0.2, 0.0, 0.0), 178 | transforms.ToTensor(), 179 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 180 | ]) 181 | self.label_transform = transforms.Compose([ 182 | transforms.RandomHorizontalFlip(), 183 | transforms.RandomVerticalFlip(), 184 | transforms.Resize(512, interpolation=Image.NEAREST), 185 | # transforms.ToTensor() 186 | ]) 187 | 188 | def __len__(self): 189 | 190 | return len(self.file_names) 191 | 192 | def __getitem__(self, idx): 193 | 194 | name = os.path.split(self.file_names[idx])[1] 195 | img_file_name = os.path.splitext(name)[0] 196 | image = cv2.imread(self.file_names[idx])[..., ::-1] 197 | image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_CUBIC) 198 | _target = cv2.imread(self.targetpaths[idx]) 199 | _target = cv2.resize(_target, (512, 512), interpolation=cv2.INTER_NEAREST) 200 | _target = (255 - _target)[..., 0] / 255. 201 | # _target[_target == 255] = 1 202 | 203 | im = Image.fromarray(np.uint8(image)) 204 | target = Image.fromarray(np.uint8(_target)).convert('L') 205 | 206 | seed = np.random.randint(2147483647) 207 | torch.manual_seed(seed) 208 | random.seed(seed) 209 | 210 | if self.im_transform is not None: 211 | im_t = self.im_transform(im) 212 | 213 | torch.manual_seed(seed) 214 | random.seed(seed) 215 | if self.label_transform is not None: 216 | target_t = self.label_transform(target) 217 | # target_t = torch.from_numpy(np.asfarray(target_t).copy()) 218 | target_t = torch.from_numpy(np.expand_dims(target_t, 0).copy()).float() 219 | 220 | # import imageio 221 | # im_np = (im_t.permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255 222 | # target_np = (target_t.permute(1, 2, 0).numpy()) * 255 223 | # imageio.imwrite('./debug/im.png', np.array(im_np).astype(np.uint8)) 224 | # imageio.imwrite('./debug/gt.png', np.array(target_np).astype(np.uint8)) 225 | 226 | return img_file_name, im_t, target_t 227 | 228 | 229 | class distancedStainingImage(Dataset): 230 | def __init__(self, 231 | x, 232 | y, 233 | masks, 234 | # names, 235 | # args, 236 | train=False): 237 | assert len(x) == len(y) 238 | assert len(x) == len(masks) 239 | # assert len(x) == len(names) 240 | self.dataset_size = len(y) 241 | self.x = x 242 | self.y = y 243 | self.masks = masks 244 | # self.names = names 245 | self.train = train 246 | 247 | # augmentation 248 | self.hflip = K.RandomHorizontalFlip() 249 | self.vflip = K.RandomVerticalFlip() 250 | self.jit = K.ColorJitter(0.2, 0.2, 0.05, 0.05) 251 | self.resize = kornia.geometry.resize 252 | self.normalize = K.Normalize(mean=torch.tensor([0.5, 0.5, 0.5]), 253 | std=torch.tensor([0.5, 0.5, 0.5])) 254 | # inductive bias 255 | # self.inductive_bias = args.inductive_bias 256 | 257 | def __len__(self): 258 | return self.dataset_size 259 | 260 | def _get_index(self, idx): 261 | if self.train: 262 | return idx % self.dataset_size 263 | else: 264 | return idx 265 | 266 | def __getitem__(self, idx): 267 | if torch.is_tensor(idx): 268 | idx = idx.tolist() 269 | idx = self._get_index(idx) 270 | 271 | # BGR -> RGB -> PIL 272 | image = cv2.imread(self.x[idx])[..., ::-1] 273 | label = cv2.imread(self.y[idx]) 274 | label = (255 - label)[..., 0] 275 | mask = cv2.imread(self.masks[idx]) 276 | mask = (255 - mask)[..., 0] 277 | name = os.path.split(self.x[idx])[1] 278 | 279 | image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_CUBIC) 280 | label = cv2.resize(label, (1024, 1024), interpolation=cv2.INTER_NEAREST) 281 | mask = cv2.resize(mask, (1024, 1024), interpolation=cv2.INTER_NEAREST) 282 | 283 | embed_map = distance_transform_edt(mask) 284 | embed_map = (embed_map / np.max(embed_map) - 0.5) / 0.5 285 | embed_map = torch.tensor(embed_map) 286 | 287 | image_t = torch.tensor(image / 255).permute(2, 0, 1).unsqueeze(0) 288 | label_t = torch.tensor(label // 255).unsqueeze(0).float() 289 | map_t = embed_map.clone().detach().reshape(1, -1, image_t.size(-2), image_t.size(-1)).float() 290 | 291 | if self.train: 292 | hflip_params = self.hflip.forward_parameters(image_t.shape) 293 | image_t = self.hflip(image_t, hflip_params) 294 | label_t = self.hflip(label_t, hflip_params) 295 | map_t = self.hflip(map_t, hflip_params) 296 | vflip_params = self.vflip.forward_parameters(image_t.shape) 297 | image_t = self.vflip(image_t, vflip_params) 298 | label_t = self.vflip(label_t, vflip_params) 299 | map_t = self.vflip(map_t, vflip_params) 300 | image_t = self.resize(image_t, size=512, interpolation='bilinear', align_corners=False) 301 | label_t = self.resize(label_t, size=512, interpolation='nearest') 302 | map_t = self.resize(map_t, size=512, interpolation='nearest') 303 | jit_params = self.jit.forward_parameters(image_t.shape) 304 | image_t = self.jit(image_t, jit_params) 305 | else: 306 | image_t = self.resize(image_t, size=512, interpolation='bilinear', align_corners=False) 307 | label_t = self.resize(label_t, size=512, interpolation='nearest') 308 | map_t = self.resize(map_t, size=512, interpolation='nearest') 309 | map_t = map_t.view(1, -1, 512, 512) 310 | 311 | image_t = self.normalize(image_t).squeeze(0).float() 312 | label_t = label_t.long().squeeze(0) 313 | map_t = map_t.squeeze(0).float() 314 | 315 | # io debug 316 | # import imageio 317 | # im_np = image_t.permute(1, 2, 0).numpy() 318 | # im_np = (im_np * 0.5 + 0.5) * 255 319 | # gt_np = label_t.numpy() * 255 320 | # map_np = (map_t.squeeze().numpy() * 0.5 + 0.5) * 255 321 | # imageio.imwrite('./debug/im.png', im_np.astype(np.uint8)) 322 | # imageio.imwrite('./debug/gt.png', gt_np.astype(np.uint8)) 323 | # imageio.imwrite('./debug/map.png', map_np.astype(np.uint8)) 324 | 325 | # if self.inductive_bias != '': 326 | image_t = torch.cat([image_t, map_t], dim=0) 327 | 328 | return name, image_t, label_t 329 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | from dataset import DatasetImageMaskContourDist, mean_and_std 5 | import glob 6 | from models import UNet, UNet_DCAN, UNet_DMTN, PsiNet, UNet_ConvMCD 7 | from tqdm import tqdm 8 | import numpy as np 9 | import cv2 10 | from utils import create_validation_arg_parser 11 | import scipy.io as scio 12 | from utils import AverageMeter 13 | import segmentation_models_pytorch as smp 14 | from sklearn.metrics import cohen_kappa_score, accuracy_score, confusion_matrix, recall_score, f1_score, classification_report, jaccard_score 15 | from train_seg_clf import CotrainingModelMulti 16 | import pandas as pd 17 | from scipy.special import softmax 18 | import surface_distance 19 | import scipy.spatial 20 | from numpy import mean 21 | 22 | 23 | def getDSC(testImage, resultImage): 24 | """Compute the Dice Similarity Coefficient.""" 25 | testArray = testImage.flatten() 26 | resultArray = resultImage.flatten() 27 | 28 | return 1.0 - scipy.spatial.distance.dice(testArray, resultArray) 29 | 30 | 31 | def getJaccard(testImage, resultImage): 32 | """Compute the Dice Similarity Coefficient.""" 33 | testArray = testImage.flatten() 34 | resultArray = resultImage.flatten() 35 | 36 | return 1.0 - scipy.spatial.distance.jaccard(testArray, resultArray) 37 | 38 | 39 | def getPrecisionAndRecall(testImage, resultImage): 40 | testArray = testImage.flatten() 41 | resultArray = resultImage.flatten() 42 | 43 | TP = np.sum(testArray*resultArray) 44 | FP = np.sum((1-testArray)*resultArray) 45 | FN = np.sum(testArray*(1-resultArray)) 46 | 47 | precision = TP/(TP+FP) 48 | recall = TP/(TP+FN) 49 | 50 | return precision, recall 51 | 52 | 53 | def build_model(model_type): 54 | 55 | if model_type == "unet": 56 | model = UNet(num_classes=2) 57 | if model_type == "dcan": 58 | model = UNet_DCAN(num_classes=2) 59 | if model_type == "dmtn": 60 | model = UNet_DMTN(num_classes=2) 61 | if model_type == "psinet": 62 | model = PsiNet(num_classes=2) 63 | if model_type == "convmcd": 64 | model = UNet_ConvMCD(num_classes=2) 65 | 66 | return model 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | args = create_validation_arg_parser().parse_args() 72 | 73 | test_path = os.path.join(args.test_path, "*.png") 74 | model_file = args.model_file 75 | save_path = args.save_path 76 | model_type = args.model_type 77 | distance_type = args.distance_type 78 | 79 | cuda_no = args.cuda_no 80 | CUDA_SELECT = "cuda:{}".format(cuda_no) 81 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 82 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 83 | train_mean, train_std = mean_and_std(train_file_names) 84 | test_file_names = glob.glob(test_path) 85 | test_dataset = DatasetImageMaskContourDist(test_file_names, distance_type, train_mean, train_std, args.clahe) 86 | testLoader = DataLoader(test_dataset, batch_size=4, num_workers=4, shuffle=True) 87 | 88 | if not os.path.exists(save_path): 89 | os.mkdir(save_path) 90 | 91 | clf_accs = AverageMeter("Acc", ".8f") 92 | clf_kappas = AverageMeter("Kappa", ".8f") 93 | 94 | encoder = args.encoder 95 | attention_type = args.attention 96 | if args.pretrain in ['imagenet', 'ssl', 'swsl', 'instagram']: 97 | pretrain = args.pretrain 98 | else: 99 | pretrain = None 100 | usenorm = args.usenorm 101 | print("clahe:", args.clahe) 102 | model = CotrainingModelMulti(encoder, pretrain, usenorm, attention_type, args.classnum).to(device) 103 | model.load_state_dict(torch.load(model_file)) 104 | model.eval() 105 | 106 | name = [] 107 | prob = [] 108 | label = [] 109 | pred = [] 110 | dice_1o = [] 111 | dice_2o = [] 112 | jaccard_1o = [] 113 | jaccard_2o = [] 114 | HD_o = [] 115 | ASSD_o = [] 116 | 117 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(tqdm(testLoader)): 118 | 119 | inputs = inputs.to(device) 120 | seg_labels = targets1.numpy() 121 | targets1, targets2 = targets1.to(device), targets2.to(device) 122 | targets3, targets4 = targets3.to(device), targets4.to(device) 123 | targets = [targets1, targets2, targets3, targets4] 124 | 125 | seg_outputs = model.seg_forward(inputs) 126 | if not isinstance(seg_outputs, list): 127 | seg_outputs = [seg_outputs] 128 | 129 | clf_outputs = model.clf_forward(inputs, seg_outputs[3], seg_outputs[4], seg_outputs[5]) 130 | outputs1 = seg_outputs[0].detach().cpu().numpy().squeeze() 131 | outputs2 = seg_outputs[1].detach().cpu().numpy().squeeze() 132 | outputs3 = seg_outputs[2].detach().cpu().numpy().squeeze() 133 | seg_preds = np.round(outputs1) 134 | 135 | clf_labels = torch.argmax(targets[3], dim=2).squeeze(1).detach().cpu().item() 136 | clf_preds = torch.argmax(clf_outputs, dim=1).detach().cpu().numpy().item() 137 | 138 | dsc_loss = smp.utils.losses.DiceLoss() 139 | jac_loss = smp.utils.losses.JaccardLoss() 140 | seg_prs = seg_preds 141 | dice_1 = f1_score(seg_labels.squeeze(), seg_prs, average='micro') 142 | dice_2 = getDSC(seg_labels, seg_prs) 143 | jaccard_1 = jaccard_score(seg_labels.squeeze(), seg_prs, average='micro') 144 | jaccard_2 = getJaccard(seg_labels, seg_prs) 145 | 146 | label_seg = np.array(seg_labels.squeeze(), dtype=bool) 147 | predict = np.array(seg_preds, dtype=bool) 148 | 149 | surface_distances = surface_distance.compute_surface_distances(label_seg, predict, spacing_mm=(1, 1)) 150 | 151 | HD = surface_distance.compute_robust_hausdorff(surface_distances, 95) 152 | 153 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"] 154 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"] 155 | surfel_areas_gt = surface_distances["surfel_areas_gt"] 156 | surfel_areas_pred = surface_distances["surfel_areas_pred"] 157 | 158 | ASSD = (np.sum(distances_pred_to_gt * surfel_areas_pred) + np.sum(distances_gt_to_pred * surfel_areas_gt))/(np.sum(surfel_areas_gt)+np.sum(surfel_areas_pred)) 159 | 160 | output_path_m = os.path.join( 161 | save_path, "m_" + os.path.basename(img_file_name[0]) 162 | ) 163 | output_path_d = os.path.join( 164 | save_path, "d_" + os.path.basename(img_file_name[0]) 165 | ) 166 | output_path_dmat = os.path.join( 167 | save_path, "d_" + os.path.basename(img_file_name[0]).replace('.png', '.mat') 168 | ) 169 | output_path_p = os.path.join( 170 | save_path, os.path.basename(img_file_name[0]) 171 | ) 172 | output_path_b = os.path.join( 173 | save_path, "b_" + os.path.basename(img_file_name[0]) 174 | ) 175 | output_path_bmat = os.path.join( 176 | save_path, "b_" + os.path.basename(img_file_name[0]).replace('.png', '.mat') 177 | ) 178 | 179 | cv2.imwrite(output_path_p, (outputs1*255.)) 180 | cv2.imwrite(output_path_m, (seg_preds*255.)) 181 | cv2.imwrite(output_path_b, (outputs2*255.)) 182 | cv2.imwrite(output_path_d, (outputs3*255.)) 183 | scio.savemat(output_path_bmat, {'boundary': outputs2}) 184 | scio.savemat(output_path_dmat, {'dist': outputs3}) 185 | name.append(os.path.basename(img_file_name[0])) 186 | prob.append(softmax(clf_outputs.detach().cpu().numpy().squeeze())) 187 | label.append(clf_labels) 188 | pred.append(clf_preds) 189 | dice_1o.append(dice_1) 190 | dice_2o.append(dice_2) 191 | jaccard_1o.append(jaccard_1) 192 | jaccard_2o.append(jaccard_2) 193 | HD_o.append(HD) 194 | ASSD_o.append(ASSD) 195 | 196 | kappa = cohen_kappa_score(label, pred) 197 | acc = accuracy_score(label, pred) 198 | recall = recall_score(label, pred, average='micro') 199 | f1 = f1_score(label, pred, average='weighted') 200 | c_matrix = confusion_matrix(label, pred) 201 | if args.classnum == 3: 202 | target_names = ['N', 'D', 'M'] 203 | clas_report = classification_report(label, pred, target_names=target_names, digits=5) 204 | elif args.classnum == 2: 205 | target_names = ['N', 'D'] 206 | clas_report = classification_report(label, pred, target_names=target_names, digits=5) 207 | 208 | name_flag = args.val_path[11:12].replace('/', '_') 209 | print(name_flag) 210 | dataframe = pd.DataFrame({'case': name, 'prob': prob, 'label': label, 'pred': pred, 'dice1': dice_1o, 'dice2': dice_2o, 'jaccard1': jaccard_1o, 'jaccard2': jaccard_2o, 'HD': HD_o, 'ASSD': ASSD_o}) 211 | dataframe.to_csv(save_path + "/" + name_flag + "_class&seg.csv", index=False, sep=',') 212 | resultframe = pd.DataFrame({'acc': acc, 'kappa': kappa, 'recall': recall, 'f1score': f1, 'seg_dice1': mean(dice_1o), 'seg_dice2': mean(dice_2o), 'jaccard1': mean(jaccard_1o), 'jaccard2': mean(jaccard_2o), 'HD': mean(HD_o), 'ASSD': mean(ASSD_o)}, index=[1]) 213 | resultframe.to_csv(save_path + "/" + name_flag + "_acc_kappa.csv", index=0) 214 | with open(save_path + "/" + name_flag + "_cmatrix.txt", "w") as f: 215 | f.write(str(c_matrix)) 216 | with open(save_path + "/" + name_flag + "_clas_report.txt", "w") as f: 217 | f.write(str(clas_report)) 218 | -------------------------------------------------------------------------------- /inference/aug.py: -------------------------------------------------------------------------------- 1 | import albumentations as albu 2 | from albumentations.pytorch import ToTensor 3 | 4 | 5 | def pre_transforms(image_size=416): 6 | return [albu.Resize(image_size, image_size, p=1)] 7 | 8 | 9 | def hard_transforms(): 10 | result = [ 11 | albu.RandomRotate90(), 12 | # albu.Cutout(), 13 | albu.HorizontalFlip(p=0.5), 14 | albu.VerticalFlip(p=0.5), 15 | albu.RandomBrightnessContrast( 16 | brightness_limit=0.2, contrast_limit=0.2, p=0.3 17 | ), 18 | albu.GridDistortion(p=0.3) 19 | # albu. 20 | ] 21 | 22 | return result 23 | 24 | 25 | def resize_transforms(image_size=224): 26 | # BORDER_CONSTANT = 0 27 | pre_size = int(image_size * 1.5) 28 | 29 | random_crop = albu.Compose([ 30 | albu.SmallestMaxSize(pre_size, p=1), 31 | albu.RandomCrop( 32 | image_size, image_size, p=1 33 | ) 34 | 35 | ]) 36 | 37 | rescale = albu.Compose([albu.Resize(image_size, image_size, p=1)]) 38 | 39 | random_crop_big = albu.Compose([ 40 | albu.LongestMaxSize(pre_size, p=1), 41 | albu.RandomCrop( 42 | image_size, image_size, p=1 43 | ) 44 | 45 | ]) 46 | 47 | # Converts the image to a square of size image_size x image_size 48 | result = [ 49 | albu.OneOf([ 50 | random_crop, 51 | rescale, 52 | random_crop_big 53 | ], p=1) 54 | ] 55 | 56 | return result 57 | 58 | 59 | def post_transforms(): 60 | # we use ImageNet image normalization 61 | # and convert it to torch.Tensor 62 | return [albu.Normalize(), ToTensor()] 63 | 64 | 65 | def compose(transforms_to_compose): 66 | # combine all augmentations into single pipeline 67 | result = albu.Compose([ 68 | item for sublist in transforms_to_compose for item in sublist 69 | ]) 70 | return result 71 | -------------------------------------------------------------------------------- /inference/boundry_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from nnunet.training.loss_functions.TopK_loss import TopKLoss 3 | from nnunet.utilities.nd_softmax import softmax_helper 4 | # from nnunet.training.loss_functions.ND_Crossentropy import CrossentropyND 5 | from nnunet.utilities.tensor_utilities import sum_tensor 6 | from torch import nn 7 | from scipy.ndimage import distance_transform_edt 8 | from skimage import segmentation as skimage_seg 9 | import numpy as np 10 | 11 | 12 | def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False): 13 | """ 14 | net_output must be (b, c, x, y(, z))) 15 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 16 | if mask is provided it must have shape (b, 1, x, y(, z))) 17 | :param net_output: 18 | :param gt: 19 | :param axes: 20 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 21 | :param square: if True then fp, tp and fn will be squared before summation 22 | :return: 23 | """ 24 | if axes is None: 25 | axes = tuple(range(2, len(net_output.size()))) 26 | 27 | shp_x = net_output.shape 28 | shp_y = gt.shape 29 | 30 | with torch.no_grad(): 31 | if len(shp_x) != len(shp_y): 32 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 33 | 34 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 35 | # if this is the case then gt is probably already a one hot encoding 36 | y_onehot = gt 37 | else: 38 | gt = gt.long() 39 | y_onehot = torch.zeros(shp_x) 40 | if net_output.device.type == "cuda": 41 | y_onehot = y_onehot.cuda(net_output.device.index) 42 | y_onehot.scatter_(1, gt, 1) 43 | 44 | tp = net_output * y_onehot 45 | fp = net_output * (1 - y_onehot) 46 | fn = (1 - net_output) * y_onehot 47 | 48 | if mask is not None: 49 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 50 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 51 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 52 | 53 | if square: 54 | tp = tp ** 2 55 | fp = fp ** 2 56 | fn = fn ** 2 57 | 58 | tp = sum_tensor(tp, axes, keepdim=False) 59 | fp = sum_tensor(fp, axes, keepdim=False) 60 | fn = sum_tensor(fn, axes, keepdim=False) 61 | 62 | return tp, fp, fn 63 | 64 | 65 | class SoftDiceLoss(nn.Module): 66 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., 67 | square=False): 68 | """ 69 | """ 70 | super(SoftDiceLoss, self).__init__() 71 | 72 | self.square = square 73 | self.do_bg = do_bg 74 | self.batch_dice = batch_dice 75 | self.apply_nonlin = apply_nonlin 76 | self.smooth = smooth 77 | 78 | def forward(self, x, y, loss_mask=None): 79 | shp_x = x.shape 80 | 81 | if self.batch_dice: 82 | axes = [0] + list(range(2, len(shp_x))) 83 | else: 84 | axes = list(range(2, len(shp_x))) 85 | 86 | if self.apply_nonlin is not None: 87 | x = self.apply_nonlin(x) 88 | 89 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 90 | 91 | dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) 92 | 93 | if not self.do_bg: 94 | if self.batch_dice: 95 | dc = dc[1:] 96 | else: 97 | dc = dc[:, 1:] 98 | dc = dc.mean() 99 | 100 | return 1-dc 101 | 102 | 103 | def compute_sdf(img_gt, out_shape): 104 | """ 105 | compute the signed distance map of binary mask 106 | input: segmentation, shape = (batch_size, x, y, z) 107 | output: the Signed Distance Map (SDM) 108 | sdf(x) = 0; x in segmentation boundary 109 | -inf|x-y|; x in segmentation 110 | +inf|x-y|; x out of segmentation 111 | """ 112 | 113 | img_gt = img_gt.astype(np.uint8) 114 | 115 | gt_sdf = np.zeros(out_shape) 116 | 117 | for b in range(out_shape[0]): # batch size 118 | for c in range(1, out_shape[1]): # channel 119 | posmask = img_gt[b][c].astype(np.bool) 120 | if posmask.any(): 121 | negmask = ~posmask 122 | posdis = distance_transform_edt(posmask) 123 | negdis = distance_transform_edt(negmask) 124 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 125 | sdf = negdis - posdis 126 | sdf[boundary == 1] = 0 127 | gt_sdf[b][c] = sdf 128 | 129 | return gt_sdf 130 | 131 | 132 | class BDLoss(nn.Module): 133 | def __init__(self): 134 | """ 135 | compute boudary loss 136 | only compute the loss of foreground 137 | ref: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L74 138 | """ 139 | super(BDLoss, self).__init__() 140 | # self.do_bg = do_bg 141 | 142 | def forward(self, net_output, gt): 143 | """ 144 | net_output: (batch_size, class, x,y,z) 145 | target: ground truth, shape: (batch_size, 1, x,y,z) 146 | bound: precomputed distance map, shape (batch_size, class, x,y,z) 147 | """ 148 | net_output = softmax_helper(net_output) 149 | with torch.no_grad(): 150 | if len(net_output.shape) != len(gt.shape): 151 | gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) 152 | 153 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 154 | # if this is the case then gt is probably already a one hot encoding 155 | y_onehot = gt 156 | else: 157 | gt = gt.long() 158 | y_onehot = torch.zeros(net_output.shape) 159 | if net_output.device.type == "cuda": 160 | y_onehot = y_onehot.cuda(net_output.device.index) 161 | y_onehot.scatter_(1, gt, 1) 162 | gt_sdf = compute_sdf(y_onehot.cpu().numpy(), net_output.shape) 163 | 164 | phi = torch.from_numpy(gt_sdf) 165 | if phi.device != net_output.device: 166 | phi = phi.to(net_output.device).type(torch.float32) 167 | # pred = net_output[:, 1:, ...].type(torch.float32) 168 | # phi = phi[:,1:, ...].type(torch.float32) 169 | 170 | multipled = torch.einsum("bcxyz,bcxyz->bcxyz", net_output[:, 1:, ...], phi[:, 1:, ...]) 171 | bd_loss = multipled.mean() 172 | 173 | return bd_loss 174 | 175 | 176 | class DC_and_BD_loss(nn.Module): 177 | def __init__(self, soft_dice_kwargs, bd_kwargs, aggregate="sum"): 178 | super(DC_and_BD_loss, self).__init__() 179 | self.aggregate = aggregate 180 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 181 | self.bd = BDLoss(**bd_kwargs) 182 | 183 | def forward(self, net_output, target): 184 | dc_loss = self.dc(net_output, target) 185 | bd_loss = self.bd(net_output, target) 186 | if self.aggregate == "sum": 187 | result = 0.01*dc_loss + (1-0.01)*bd_loss 188 | else: 189 | raise NotImplementedError("nah son") 190 | return result 191 | 192 | 193 | ##################################################### 194 | 195 | def compute_gt_dtm(img_gt, out_shape): 196 | """ 197 | compute the distance transform map of foreground in ground gruth. 198 | input: segmentation, shape = (batch_size, class, x, y, z) 199 | output: the foreground Distance Map (SDM) 200 | dtm(x) = 0; x in segmentation boundary 201 | inf|x-y|; x in segmentation 202 | """ 203 | 204 | fg_dtm = np.zeros(out_shape) 205 | 206 | for b in range(out_shape[0]): # batch size 207 | for c in range(1, out_shape[1]): # class; exclude the background class 208 | posmask = img_gt[b][c].astype(np.bool) 209 | if posmask.any(): 210 | posdis = distance_transform_edt(posmask) 211 | fg_dtm[b][c] = posdis 212 | 213 | return fg_dtm 214 | 215 | 216 | def compute_pred_dtm(img_gt, out_shape): 217 | """ 218 | compute the distance transform map of foreground in prediction. 219 | input: segmentation, shape = (batch_size, class, x, y, z) 220 | output: the foreground Distance Map (SDM) 221 | dtm(x) = 0; x in segmentation boundary 222 | inf|x-y|; x in segmentation 223 | """ 224 | 225 | fg_dtm = np.zeros(out_shape) 226 | 227 | for b in range(out_shape[0]): # batch size 228 | for c in range(1, out_shape[1]): # class; exclude the background class 229 | posmask = img_gt[b][c] > 0.5 230 | if posmask.any(): 231 | posdis = distance_transform_edt(posmask) 232 | fg_dtm[b][c] = posdis 233 | 234 | return fg_dtm 235 | 236 | 237 | class HDLoss(nn.Module): 238 | def __init__(self): 239 | """ 240 | compute haudorff loss for binary segmentation 241 | https://arxiv.org/pdf/1904.10030v1.pdf 242 | """ 243 | super(HDLoss, self).__init__() 244 | 245 | def forward(self, net_output, gt): 246 | """ 247 | net_output: (batch_size, c, x,y,z) 248 | target: ground truth, shape: (batch_size, c, x,y,z) 249 | """ 250 | net_output = softmax_helper(net_output) 251 | # one hot code for gt 252 | with torch.no_grad(): 253 | if len(net_output.shape) != len(gt.shape): 254 | gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) 255 | 256 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 257 | # if this is the case then gt is probably already a one hot encoding 258 | y_onehot = gt 259 | else: 260 | gt = gt.long() 261 | y_onehot = torch.zeros(net_output.shape) 262 | if net_output.device.type == "cuda": 263 | y_onehot = y_onehot.cuda(net_output.device.index) 264 | y_onehot.scatter_(1, gt, 1) 265 | # print('hd loss.py', net_output.shape, y_onehot.shape) 266 | 267 | with torch.no_grad(): 268 | pc_dist = compute_pred_dtm(net_output.cpu().numpy(), net_output.shape) 269 | gt_dist = compute_gt_dtm(y_onehot.cpu().numpy(), net_output.shape) 270 | dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8) 271 | # print('pc_dist.shape: ', pc_dist.shape, 'gt_dist.shape', gt_dist.shape) 272 | 273 | pred_error = (net_output - y_onehot)**2 274 | 275 | dist = torch.from_numpy(dist) 276 | if dist.device != pred_error.device: 277 | dist = dist.to(pred_error.device).type(torch.float32) 278 | 279 | multipled = torch.einsum("bcxyz,bcxyz->bcxyz", pred_error[:, 1:, ...], dist[:, 1:, ...]) 280 | hd_loss = multipled.mean() 281 | 282 | return hd_loss 283 | 284 | 285 | class DC_and_HD_loss(nn.Module): 286 | def __init__(self, soft_dice_kwargs, hd_kwargs, aggregate="sum"): 287 | super(DC_and_HD_loss, self).__init__() 288 | self.aggregate = aggregate 289 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 290 | self.hd = HDLoss(**hd_kwargs) 291 | 292 | def forward(self, net_output, target): 293 | dc_loss = self.dc(net_output, target) 294 | hd_loss = self.hd(net_output, target) 295 | if self.aggregate == "sum": 296 | with torch.no_grad(): 297 | alpha = hd_loss / (dc_loss + 1e-5) 298 | result = alpha * dc_loss + hd_loss 299 | else: 300 | raise NotImplementedError("nah son") 301 | return result 302 | -------------------------------------------------------------------------------- /inference/hausdirff.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from scipy.ndimage.morphology import distance_transform_edt as edt 8 | from scipy.ndimage import convolve 9 | 10 | """ 11 | Hausdorff loss implementation based on paper: 12 | https://arxiv.org/pdf/1904.10030.pdf 13 | copy pasted from - all credit goes to original authors: 14 | https://github.com/SilmarilBearer/HausdorffLoss 15 | """ 16 | 17 | 18 | class HausdorffDTLoss(nn.Module): 19 | """Binary Hausdorff loss based on distance transform""" 20 | 21 | def __init__(self, alpha=2.0, **kwargs): 22 | super(HausdorffDTLoss, self).__init__() 23 | self.alpha = alpha 24 | 25 | @torch.no_grad() 26 | def distance_field(self, img: np.ndarray) -> np.ndarray: 27 | field = np.zeros_like(img) 28 | 29 | for batch in range(len(img)): 30 | fg_mask = img[batch] > 0.5 31 | 32 | if fg_mask.any(): 33 | bg_mask = ~fg_mask 34 | 35 | fg_dist = edt(fg_mask) 36 | bg_dist = edt(bg_mask) 37 | 38 | field[batch] = fg_dist + bg_dist 39 | 40 | return field 41 | 42 | def forward( 43 | self, pred: torch.Tensor, target: torch.Tensor, debug=False 44 | ) -> torch.Tensor: 45 | """ 46 | Uses one binary channel: 1 - fg, 0 - bg 47 | pred: (b, 1, x, y, z) or (b, 1, x, y) 48 | target: (b, 1, x, y, z) or (b, 1, x, y) 49 | """ 50 | assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported" 51 | assert ( 52 | pred.dim() == target.dim() 53 | ), "Prediction and target need to be of same dimension" 54 | 55 | # pred = torch.sigmoid(pred) 56 | 57 | pred_dt = torch.from_numpy(self.distance_field(pred.cpu().numpy())).float() 58 | target_dt = torch.from_numpy(self.distance_field(target.cpu().numpy())).float() 59 | 60 | pred_error = (pred - target) ** 2 61 | distance = pred_dt ** self.alpha + target_dt ** self.alpha 62 | 63 | dt_field = pred_error * distance 64 | loss = dt_field.mean() 65 | 66 | if debug: 67 | return ( 68 | loss.cpu().numpy(), 69 | ( 70 | dt_field.cpu().numpy()[0, 0], 71 | pred_error.cpu().numpy()[0, 0], 72 | distance.cpu().numpy()[0, 0], 73 | pred_dt.cpu().numpy()[0, 0], 74 | target_dt.cpu().numpy()[0, 0], 75 | ), 76 | ) 77 | 78 | else: 79 | return loss 80 | 81 | 82 | class HausdorffERLoss(nn.Module): 83 | """Binary Hausdorff loss based on morphological erosion""" 84 | 85 | def __init__(self, alpha=2.0, erosions=10, **kwargs): 86 | super(HausdorffERLoss, self).__init__() 87 | self.alpha = alpha 88 | self.erosions = erosions 89 | self.prepare_kernels() 90 | 91 | def prepare_kernels(self): 92 | cross = np.array([cv.getStructuringElement(cv.MORPH_CROSS, (3, 3))]) 93 | bound = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]) 94 | 95 | self.kernel2D = cross * 0.2 96 | self.kernel3D = np.array([bound, cross, bound]) * (1 / 7) 97 | 98 | @torch.no_grad() 99 | def perform_erosion( 100 | self, pred: np.ndarray, target: np.ndarray, debug 101 | ) -> np.ndarray: 102 | bound = (pred - target) ** 2 103 | 104 | if bound.ndim == 5: 105 | kernel = self.kernel3D 106 | elif bound.ndim == 4: 107 | kernel = self.kernel2D 108 | else: 109 | raise ValueError(f"Dimension {bound.ndim} is nor supported.") 110 | 111 | eroted = np.zeros_like(bound) 112 | erosions = [] 113 | 114 | for batch in range(len(bound)): 115 | 116 | # debug 117 | erosions.append(np.copy(bound[batch][0])) 118 | 119 | for k in range(self.erosions): 120 | 121 | # compute convolution with kernel 122 | dilation = convolve(bound[batch], kernel, mode="constant", cval=0.0) 123 | 124 | # apply soft thresholding at 0.5 and normalize 125 | erosion = dilation - 0.5 126 | erosion[erosion < 0] = 0 127 | 128 | if erosion.ptp() != 0: 129 | erosion = (erosion - erosion.min()) / erosion.ptp() 130 | 131 | # save erosion and add to loss 132 | bound[batch] = erosion 133 | eroted[batch] += erosion * (k + 1) ** self.alpha 134 | 135 | if debug: 136 | erosions.append(np.copy(erosion[0])) 137 | 138 | # image visualization in debug mode 139 | if debug: 140 | return eroted, erosions 141 | else: 142 | return eroted 143 | 144 | def forward( 145 | self, pred: torch.Tensor, target: torch.Tensor, debug=False 146 | ) -> torch.Tensor: 147 | """ 148 | Uses one binary channel: 1 - fg, 0 - bg 149 | pred: (b, 1, x, y, z) or (b, 1, x, y) 150 | target: (b, 1, x, y, z) or (b, 1, x, y) 151 | """ 152 | assert pred.dim() == 4 or pred.dim() == 5, "Only 2D and 3D supported" 153 | assert ( 154 | pred.dim() == target.dim() 155 | ), "Prediction and target need to be of same dimension" 156 | 157 | # pred = torch.sigmoid(pred) 158 | 159 | if debug: 160 | eroted, erosions = self.perform_erosion( 161 | pred.cpu().numpy(), target.cpu().numpy(), debug 162 | ) 163 | return eroted.mean(), erosions 164 | 165 | else: 166 | eroted = torch.from_numpy( 167 | self.perform_erosion(pred.cpu().numpy(), target.cpu().numpy(), debug) 168 | ).float() 169 | 170 | loss = eroted.mean() 171 | 172 | return loss 173 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import segmentation_models_pytorch as smp 5 | 6 | 7 | class LossMulti: 8 | def __init__( 9 | self, jaccard_weight=0, class_weights=None, num_classes=1, device=None 10 | ): 11 | self.device = device 12 | if class_weights is not None: 13 | nll_weight = torch.from_numpy(class_weights.astype(np.float32)).to( 14 | self.device 15 | ) 16 | else: 17 | nll_weight = None 18 | 19 | self.nll_loss = nn.NLLLoss(weight=nll_weight) 20 | self.jaccard_weight = jaccard_weight 21 | self.num_classes = num_classes 22 | 23 | def __call__(self, outputs, targets): 24 | 25 | targets = targets.squeeze(1) 26 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 27 | if self.jaccard_weight: 28 | eps = 1e-7 29 | for cls in range(self.num_classes): 30 | jaccard_target = (targets == cls).float() 31 | jaccard_output = outputs[:, cls].exp() 32 | intersection = (jaccard_output * jaccard_target).sum() 33 | 34 | union = jaccard_output.sum() + jaccard_target.sum() 35 | loss -= ( 36 | torch.log((intersection + eps) / (union - intersection + eps)) 37 | * self.jaccard_weight 38 | ) 39 | 40 | return loss 41 | 42 | 43 | class LossUNet: 44 | def __init__(self, weights=[1, 1, 1]): 45 | 46 | self.criterion = LossMulti(num_classes=2) 47 | 48 | def __call__(self, outputs, targets): 49 | 50 | criterion = self.criterion(outputs, targets) 51 | 52 | return criterion 53 | 54 | 55 | class LossDCAN: 56 | def __init__(self, weights=[1, 1, 1]): 57 | 58 | self.criterion1 = LossMulti(num_classes=2) 59 | self.criterion2 = LossMulti(num_classes=2) 60 | self.weights = weights 61 | 62 | def __call__(self, outputs1, outputs2, targets1, targets2): 63 | 64 | criterion = self.weights[0] * self.criterion1( 65 | outputs1, targets1 66 | ) + self.weights[1] * self.criterion2(outputs2, targets2) 67 | 68 | return criterion 69 | 70 | 71 | class LossDMTN: 72 | def __init__(self, weights=[1, 1, 1]): 73 | self.criterion1 = LossMulti(num_classes=2) 74 | self.criterion2 = nn.MSELoss() 75 | self.weights = weights 76 | 77 | def __call__(self, outputs1, outputs2, targets1, targets2): 78 | 79 | criterion = self.weights[0] * self.criterion1( 80 | outputs1, targets1 81 | ) + self.weights[1] * self.criterion2(outputs2, targets2) 82 | 83 | return criterion 84 | 85 | 86 | class LossPsiNet: 87 | def __init__(self, weights=[1, 1, 1]): # weights=[1,1,1] 88 | 89 | self.criterion1 = LossMulti(num_classes=2) 90 | self.criterion2 = LossMulti(num_classes=2) 91 | self.criterion3 = nn.MSELoss() 92 | # self.criterion3 = nn.SmoothL1Loss() 93 | self.weights = weights 94 | 95 | def __call__(self, outputs1, outputs2, outputs3, targets1, targets2, targets3): 96 | # print(self.weights) 97 | 98 | criterion = ( 99 | self.weights[0] * self.criterion1(outputs1, targets1) 100 | + self.weights[1] * self.criterion2(outputs2, targets2) 101 | + self.weights[2] * self.criterion3(outputs3, targets3) 102 | ) 103 | 104 | return criterion 105 | 106 | 107 | class My_multiLoss: 108 | def __init__(self, weights=[1, 1, 1]): # weights=[1,1,1] 109 | 110 | self.criterion1 = smp.utils.losses.DiceLoss() 111 | # self.criterion2 = smp.utils.losses.CrossEntropyLoss() 112 | # self.criterion2 = smp.utils.losses.BCEWithLogitsLoss() 113 | self.criterion2 = nn.MSELoss() 114 | self.criterion3 = nn.MSELoss() 115 | # self.criterion3 = nn.SmoothL1Loss() 116 | self.weights = weights 117 | 118 | def __call__(self, outputs1, outputs2, outputs3, targets1, targets2, targets3): 119 | # print(self.weights) 120 | 121 | criterion = ( 122 | self.weights[0] * self.criterion1(outputs1, targets1) 123 | + self.weights[1] * self.criterion2(outputs2, targets2) 124 | + self.weights[2] * self.criterion3(outputs3, targets3) 125 | ) 126 | 127 | return criterion 128 | 129 | 130 | # Lovasz loss 131 | def lovasz_grad(gt_sorted): 132 | """ 133 | Computes gradient of the Lovasz extension w.r.t sorted errors 134 | See Alg. 1 in paper 135 | """ 136 | p = len(gt_sorted) 137 | gts = gt_sorted.sum() 138 | intersection = gts - gt_sorted.float().cumsum(0) 139 | union = gts + (1 - gt_sorted).float().cumsum(0) 140 | jaccard = 1. - intersection / union 141 | if p > 1: # cover 1-pixel case 142 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 143 | return jaccard 144 | 145 | 146 | class LovaszSoftmax(nn.Module): 147 | def __init__(self, reduction='mean'): 148 | super(LovaszSoftmax, self).__init__() 149 | self.reduction = reduction 150 | 151 | def prob_flatten(self, input, target): 152 | assert input.dim() in [4, 5] 153 | num_class = input.size(1) 154 | if input.dim() == 4: 155 | input = input.permute(0, 2, 3, 1).contiguous() 156 | input_flatten = input.view(-1, num_class) 157 | elif input.dim() == 5: 158 | input = input.permute(0, 2, 3, 4, 1).contiguous() 159 | input_flatten = input.view(-1, num_class) 160 | target_flatten = target.view(-1) 161 | return input_flatten, target_flatten 162 | 163 | def lovasz_softmax_flat(self, inputs, targets): 164 | num_classes = inputs.size(1) 165 | losses = [] 166 | for c in range(num_classes): 167 | target_c = (targets == c).float() 168 | if num_classes == 1: 169 | input_c = inputs[:, 0] 170 | else: 171 | input_c = inputs[:, c] 172 | loss_c = (torch.autograd.Variable(target_c) - input_c).abs() 173 | loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True) 174 | target_c_sorted = target_c[loss_index] 175 | losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted)))) 176 | losses = torch.stack(losses) 177 | 178 | if self.reduction == 'none': 179 | loss = losses 180 | elif self.reduction == 'sum': 181 | loss = losses.sum() 182 | else: 183 | loss = losses.mean() 184 | return loss 185 | 186 | def forward(self, inputs, targets): 187 | # print(inputs.shape, targets.shape) # (batch size, class_num, x,y,z), (batch size, 1, x,y,z) 188 | inputs, targets = self.prob_flatten(inputs, targets) 189 | # print(inputs.shape, targets.shape) 190 | losses = self.lovasz_softmax_flat(inputs, targets) 191 | return losses 192 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def conv3x3(in_, out): 7 | return nn.Conv2d(in_, out, 3, padding=1) 8 | 9 | 10 | class Conv3BN(nn.Module): 11 | def __init__(self, in_: int, out: int, bn=False): # bn=False 12 | super().__init__() 13 | self.conv = conv3x3(in_, out) 14 | self.bn = nn.BatchNorm2d(out) if bn else None 15 | self.activation = nn.ReLU(inplace=True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | if self.bn is not None: 20 | x = self.bn(x) 21 | x = self.activation(x) 22 | return x 23 | 24 | 25 | class UNetModule(nn.Module): 26 | def __init__(self, in_: int, out: int): 27 | super().__init__() 28 | self.l1 = Conv3BN(in_, out) 29 | self.l2 = Conv3BN(out, out) 30 | 31 | def forward(self, x): 32 | x = self.l1(x) 33 | x = self.l2(x) 34 | return x 35 | 36 | 37 | class PsiNet(nn.Module): 38 | """ 39 | Adapted from Vanilla UNet implementation - https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 40 | """ 41 | 42 | output_downscaled = 1 43 | module = UNetModule 44 | 45 | def __init__( 46 | self, 47 | input_channels: int = 1, 48 | filters_base: int = 32, 49 | down_filter_factors=(1, 2, 4, 8, 16), 50 | up_filter_factors=(1, 2, 4, 8, 16), 51 | bottom_s=4, 52 | num_classes=1, 53 | add_output=True, 54 | ): 55 | super().__init__() 56 | self.num_classes = num_classes 57 | assert len(down_filter_factors) == len(up_filter_factors) 58 | assert down_filter_factors[-1] == up_filter_factors[-1] 59 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 60 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 61 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 62 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 63 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 64 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 65 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 66 | self.up.append( 67 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 68 | ) 69 | pool = nn.MaxPool2d(2, 2) 70 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 71 | upsample1 = nn.Upsample(scale_factor=2) 72 | upsample_bottom1 = nn.Upsample(scale_factor=bottom_s) 73 | upsample2 = nn.Upsample(scale_factor=2) 74 | upsample_bottom2 = nn.Upsample(scale_factor=bottom_s) 75 | upsample3 = nn.Upsample(scale_factor=2) 76 | upsample_bottom3 = nn.Upsample(scale_factor=bottom_s) 77 | 78 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 79 | self.downsamplers[-1] = pool_bottom 80 | self.upsamplers1 = [upsample1] * len(self.up) 81 | self.upsamplers1[-1] = upsample_bottom1 82 | self.upsamplers2 = [upsample2] * len(self.up) 83 | self.upsamplers2[-1] = upsample_bottom2 84 | self.upsamplers3 = [upsample3] * len(self.up) 85 | self.upsamplers3[-1] = upsample_bottom3 86 | 87 | self.add_output = add_output 88 | if add_output: 89 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 90 | if add_output: 91 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 92 | if add_output: 93 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1) 94 | 95 | def forward(self, x): 96 | xs = [] 97 | for downsample, down in zip(self.downsamplers, self.down): 98 | x_in = x if downsample is None else downsample(xs[-1]) 99 | x_out = down(x_in) 100 | xs.append(x_out) 101 | 102 | x_out = xs[-1] 103 | x_out1 = x_out 104 | x_out2 = x_out 105 | x_out3 = x_out 106 | 107 | # Decoder mask segmentation 108 | for x_skip, upsample, up in reversed( 109 | list(zip(xs[:-1], self.upsamplers1, self.up)) 110 | ): 111 | x_out1 = upsample(x_out1) 112 | x_out1 = up(torch.cat([x_out1, x_skip], 1)) 113 | 114 | # Decoder contour estimation 115 | for x_skip, upsample, up in reversed( 116 | list(zip(xs[:-1], self.upsamplers2, self.up)) 117 | ): 118 | x_out2 = upsample(x_out2) 119 | x_out2 = up(torch.cat([x_out2, x_skip], 1)) 120 | 121 | # Regression 122 | for x_skip, upsample, up in reversed( 123 | list(zip(xs[:-1], self.upsamplers3, self.up)) 124 | ): 125 | x_out3 = upsample(x_out3) 126 | x_out3 = up(torch.cat([x_out3, x_skip], 1)) 127 | 128 | if self.add_output: 129 | x_out1 = self.conv_final1(x_out1) 130 | if self.num_classes > 1: 131 | x_out1 = F.log_softmax(x_out1, dim=1) 132 | 133 | if self.add_output: 134 | x_out2 = self.conv_final2(x_out2) 135 | if self.num_classes > 1: 136 | x_out2 = F.log_softmax(x_out2, dim=1) 137 | 138 | if self.add_output: 139 | x_out3 = self.conv_final3(x_out3) 140 | x_out3 = torch.sigmoid(x_out3) 141 | # x_out3 = torch.tanh(x_out3) 142 | 143 | return [x_out1, x_out2, x_out3] 144 | 145 | 146 | class UNet_DCAN(nn.Module): 147 | """ 148 | Adapted from Vanilla UNet implementation - https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 149 | """ 150 | 151 | output_downscaled = 1 152 | module = UNetModule 153 | 154 | def __init__( 155 | self, 156 | input_channels: int = 1, 157 | filters_base: int = 32, 158 | down_filter_factors=(1, 2, 4, 8, 16), 159 | up_filter_factors=(1, 2, 4, 8, 16), 160 | bottom_s=4, 161 | num_classes=1, 162 | add_output=True, 163 | ): 164 | super().__init__() 165 | self.num_classes = num_classes 166 | assert len(down_filter_factors) == len(up_filter_factors) 167 | assert down_filter_factors[-1] == up_filter_factors[-1] 168 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 169 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 170 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 171 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 172 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 173 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 174 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 175 | self.up.append( 176 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 177 | ) 178 | pool = nn.MaxPool2d(2, 2) 179 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 180 | upsample1 = nn.Upsample(scale_factor=2) 181 | upsample_bottom1 = nn.Upsample(scale_factor=bottom_s) 182 | upsample2 = nn.Upsample(scale_factor=2) 183 | upsample_bottom2 = nn.Upsample(scale_factor=bottom_s) 184 | 185 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 186 | self.downsamplers[-1] = pool_bottom 187 | self.upsamplers1 = [upsample1] * len(self.up) 188 | self.upsamplers1[-1] = upsample_bottom1 189 | self.upsamplers2 = [upsample2] * len(self.up) 190 | self.upsamplers2[-1] = upsample_bottom2 191 | 192 | self.add_output = add_output 193 | if add_output: 194 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 195 | if add_output: 196 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 197 | 198 | def forward(self, x): 199 | xs = [] 200 | for downsample, down in zip(self.downsamplers, self.down): 201 | x_in = x if downsample is None else downsample(xs[-1]) 202 | x_out = down(x_in) 203 | xs.append(x_out) 204 | 205 | x_out = xs[-1] 206 | x_out1 = x_out 207 | x_out2 = x_out 208 | 209 | # Decoder mask segmentation 210 | for x_skip, upsample, up in reversed( 211 | list(zip(xs[:-1], self.upsamplers1, self.up)) 212 | ): 213 | x_out1 = upsample(x_out1) 214 | x_out1 = up(torch.cat([x_out1, x_skip], 1)) 215 | 216 | # Decoder contour estimation 217 | for x_skip, upsample, up in reversed( 218 | list(zip(xs[:-1], self.upsamplers2, self.up)) 219 | ): 220 | x_out2 = upsample(x_out2) 221 | x_out2 = up(torch.cat([x_out2, x_skip], 1)) 222 | 223 | if self.add_output: 224 | x_out1 = self.conv_final1(x_out1) 225 | if self.num_classes > 1: 226 | x_out1 = F.log_softmax(x_out1, dim=1) 227 | 228 | if self.add_output: 229 | x_out2 = self.conv_final2(x_out2) 230 | if self.num_classes > 1: 231 | x_out2 = F.log_softmax(x_out2, dim=1) 232 | 233 | return [x_out1, x_out2] 234 | 235 | 236 | class UNet_DMTN(nn.Module): 237 | """ 238 | Adapted from Vanilla UNet implementation - https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 239 | """ 240 | 241 | output_downscaled = 1 242 | module = UNetModule 243 | 244 | def __init__( 245 | self, 246 | input_channels=1, 247 | filters_base: int = 32, 248 | down_filter_factors=(1, 2, 4, 8, 16), 249 | up_filter_factors=(1, 2, 4, 8, 16), 250 | bottom_s=4, 251 | num_classes=1, 252 | add_output=True, 253 | ): 254 | super().__init__() 255 | self.num_classes = num_classes 256 | assert len(down_filter_factors) == len(up_filter_factors) 257 | assert down_filter_factors[-1] == up_filter_factors[-1] 258 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 259 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 260 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 261 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 262 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 263 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 264 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 265 | self.up.append( 266 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 267 | ) 268 | pool = nn.MaxPool2d(2, 2) 269 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 270 | upsample1 = nn.Upsample(scale_factor=2) 271 | upsample_bottom1 = nn.Upsample(scale_factor=bottom_s) 272 | upsample2 = nn.Upsample(scale_factor=2) 273 | upsample_bottom2 = nn.Upsample(scale_factor=bottom_s) 274 | 275 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 276 | self.downsamplers[-1] = pool_bottom 277 | self.upsamplers1 = [upsample1] * len(self.up) 278 | self.upsamplers1[-1] = upsample_bottom1 279 | self.upsamplers2 = [upsample2] * len(self.up) 280 | self.upsamplers2[-1] = upsample_bottom2 281 | 282 | self.add_output = add_output 283 | if add_output: 284 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 285 | if add_output: 286 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], 1, 1) 287 | 288 | def forward(self, x): 289 | xs = [] 290 | for downsample, down in zip(self.downsamplers, self.down): 291 | x_in = x if downsample is None else downsample(xs[-1]) 292 | x_out = down(x_in) 293 | xs.append(x_out) 294 | 295 | x_out = xs[-1] 296 | x_out1 = x_out 297 | x_out2 = x_out 298 | 299 | # Decoder mask segmentation 300 | for x_skip, upsample, up in reversed( 301 | list(zip(xs[:-1], self.upsamplers1, self.up)) 302 | ): 303 | x_out1 = upsample(x_out1) 304 | x_out1 = up(torch.cat([x_out1, x_skip], 1)) 305 | 306 | # Regression 307 | for x_skip, upsample, up in reversed( 308 | list(zip(xs[:-1], self.upsamplers2, self.up)) 309 | ): 310 | x_out2 = upsample(x_out2) 311 | x_out2 = up(torch.cat([x_out2, x_skip], 1)) 312 | 313 | if self.add_output: 314 | x_out1 = self.conv_final1(x_out1) 315 | if self.num_classes > 1: 316 | x_out1 = F.log_softmax(x_out1, dim=1) 317 | 318 | if self.add_output: 319 | x_out2 = self.conv_final2(x_out2) 320 | x_out2 = torch.sigmoid(x_out2) 321 | # x_out2 = torch.tanh(x_out2) 322 | 323 | return [x_out1, x_out2] 324 | 325 | 326 | class UNet(nn.Module): 327 | """ 328 | Vanilla UNet. 329 | 330 | Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 331 | """ 332 | 333 | output_downscaled = 1 334 | module = UNetModule 335 | 336 | def __init__( 337 | self, 338 | input_channels=1, 339 | filters_base: int = 32, 340 | down_filter_factors=(1, 2, 4, 8, 16), 341 | up_filter_factors=(1, 2, 4, 8, 16), 342 | bottom_s=4, 343 | num_classes=1, 344 | padding=1, 345 | add_output=True, 346 | ): 347 | super().__init__() 348 | self.num_classes = num_classes 349 | assert len(down_filter_factors) == len(up_filter_factors) 350 | assert down_filter_factors[-1] == up_filter_factors[-1] 351 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 352 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 353 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 354 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 355 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 356 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 357 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 358 | self.up.append( 359 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 360 | ) 361 | pool = nn.MaxPool2d(2, 2) 362 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 363 | upsample = nn.Upsample(scale_factor=2) 364 | upsample_bottom = nn.Upsample(scale_factor=bottom_s) 365 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 366 | self.downsamplers[-1] = pool_bottom 367 | self.upsamplers = [upsample] * len(self.up) 368 | self.upsamplers[-1] = upsample_bottom 369 | self.add_output = add_output 370 | if add_output: 371 | self.conv_final = nn.Conv2d(up_filter_sizes[0], num_classes, padding) 372 | 373 | def forward(self, x): 374 | xs = [] 375 | for downsample, down in zip(self.downsamplers, self.down): 376 | x_in = x if downsample is None else downsample(xs[-1]) 377 | x_out = down(x_in) 378 | xs.append(x_out) 379 | # print(x_out.shape) 380 | 381 | x_out = xs[-1] 382 | for x_skip, upsample, up in reversed( 383 | list(zip(xs[:-1], self.upsamplers, self.up)) 384 | ): 385 | x_out = upsample(x_out) 386 | x_out = up(torch.cat([x_out, x_skip], 1)) 387 | # print(x_out.shape) 388 | 389 | if self.add_output: 390 | x_out = self.conv_final(x_out) 391 | # print(x_out.shape) 392 | if self.num_classes > 1: 393 | x_out = F.log_softmax(x_out, dim=1) 394 | 395 | return [x_out] 396 | 397 | 398 | class UNet_ConvMCD(nn.Module): 399 | """ 400 | Vanilla UNet. 401 | 402 | Implementation from https://github.com/lopuhin/mapillary-vistas-2017/blob/master/unet_models.py 403 | """ 404 | 405 | output_downscaled = 1 406 | module = UNetModule 407 | 408 | def __init__( 409 | self, 410 | input_channels: int = 1, 411 | filters_base: int = 32, 412 | down_filter_factors=(1, 2, 4, 8, 16), 413 | up_filter_factors=(1, 2, 4, 8, 16), 414 | bottom_s=4, 415 | num_classes=1, 416 | add_output=True, 417 | ): 418 | super().__init__() 419 | self.num_classes = num_classes 420 | assert len(down_filter_factors) == len(up_filter_factors) 421 | assert down_filter_factors[-1] == up_filter_factors[-1] 422 | down_filter_sizes = [filters_base * s for s in down_filter_factors] 423 | up_filter_sizes = [filters_base * s for s in up_filter_factors] 424 | self.down, self.up = nn.ModuleList(), nn.ModuleList() 425 | self.down.append(self.module(input_channels, down_filter_sizes[0])) 426 | for prev_i, nf in enumerate(down_filter_sizes[1:]): 427 | self.down.append(self.module(down_filter_sizes[prev_i], nf)) 428 | for prev_i, nf in enumerate(up_filter_sizes[1:]): 429 | self.up.append( 430 | self.module(down_filter_sizes[prev_i] + nf, up_filter_sizes[prev_i]) 431 | ) 432 | pool = nn.MaxPool2d(2, 2) 433 | pool_bottom = nn.MaxPool2d(bottom_s, bottom_s) 434 | upsample = nn.Upsample(scale_factor=2) 435 | upsample_bottom = nn.Upsample(scale_factor=bottom_s) 436 | self.downsamplers = [None] + [pool] * (len(self.down) - 1) 437 | self.downsamplers[-1] = pool_bottom 438 | self.upsamplers = [upsample] * len(self.up) 439 | self.upsamplers[-1] = upsample_bottom 440 | self.add_output = add_output 441 | if add_output: 442 | self.conv_final1 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 443 | self.conv_final2 = nn.Conv2d(up_filter_sizes[0], num_classes, 1) 444 | self.conv_final3 = nn.Conv2d(up_filter_sizes[0], 1, 1) 445 | 446 | def forward(self, x): 447 | xs = [] 448 | for downsample, down in zip(self.downsamplers, self.down): 449 | x_in = x if downsample is None else downsample(xs[-1]) 450 | x_out = down(x_in) 451 | xs.append(x_out) 452 | 453 | x_out = xs[-1] 454 | for x_skip, upsample, up in reversed( 455 | list(zip(xs[:-1], self.upsamplers, self.up)) 456 | ): 457 | x_out = upsample(x_out) 458 | x_out = up(torch.cat([x_out, x_skip], 1)) 459 | 460 | if self.add_output: 461 | x_out1 = self.conv_final1(x_out) 462 | x_out2 = self.conv_final2(x_out) 463 | x_out3 = self.conv_final3(x_out) 464 | if self.num_classes > 1: 465 | x_out1 = F.log_softmax(x_out1, dim=1) 466 | x_out2 = F.log_softmax(x_out2, dim=1) 467 | x_out3 = torch.sigmoid(x_out3) 468 | # x_out3 = torch.tanh(x_out3) 469 | 470 | # return x_out,x_out1,x_out2,x_out3 471 | return [x_out1, x_out2, x_out3] 472 | -------------------------------------------------------------------------------- /smp_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | import copy 6 | from segmentation_models_pytorch.encoders import get_encoder 7 | from segmentation_models_pytorch.unet.decoder import UnetDecoder, CenterBlock, DecoderBlock 8 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead 9 | import segmentation_models_pytorch.base.initialization as init 10 | from segmentation_models_pytorch.encoders.resnet import ResNetEncoder, resnet_encoders 11 | from segmentation_models_pytorch.encoders.timm_resnest import ResNestEncoder, timm_resnest_encoders 12 | from segmentation_models_pytorch.encoders.vgg import VGGEncoder, vgg_encoders 13 | from segmentation_models_pytorch.encoders.timm_sknet import SkNetEncoder, timm_sknet_encoders 14 | from segmentation_models_pytorch.encoders.timm_res2net import Res2NetEncoder, timm_res2net_encoders 15 | 16 | 17 | class MyResNetEncoder(ResNetEncoder): 18 | 19 | def __init__(self, out_channels, depth, decoder_channels, **kwargs): 20 | super().__init__(out_channels, depth, **kwargs) 21 | self.my_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 22 | self.my_fc = nn.Linear(512 * kwargs["block"].expansion, kwargs["num_classes"]) 23 | 24 | self.cat_convs = nn.Sequential( 25 | nn.Conv2d(out_channels[1] + decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 26 | nn.Conv2d(out_channels[2] + decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 27 | nn.Conv2d(out_channels[3] + decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 28 | nn.Conv2d(out_channels[4] + decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 29 | ) 30 | self._initialize() 31 | 32 | def _initialize(self): 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 36 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 37 | nn.init.constant_(m.weight, 1) 38 | nn.init.constant_(m.bias, 0) 39 | 40 | def forward(self, x, decoder_features): 41 | 42 | decoder_features = decoder_features[::-1] 43 | '''for j, fe in enumerate(decoder_features): 44 | print(j, fe.shape) 45 | print()''' 46 | 47 | stages = self.get_stages() 48 | 49 | features = [] 50 | for i in range(self._depth + 1): 51 | # print(stages[i]) 52 | x = stages[i](x) 53 | # print(i, x.shape) 54 | if i > 0 and i < 5: 55 | skip = decoder_features[i] 56 | # print(skip.shape) 57 | x = torch.cat([x, skip], dim=1) 58 | x = self.cat_convs[i - 1](x) 59 | # print(x.shape) 60 | # print() 61 | features.append(x) 62 | 63 | x = self.my_avgpool(x) 64 | x = torch.flatten(x, 1) 65 | x = self.my_fc(x) 66 | 67 | return x 68 | 69 | 70 | class MyResNetEncoderMulti(ResNetEncoder): 71 | 72 | def __init__(self, out_channels, depth, decoder_channels, **kwargs): 73 | super().__init__(out_channels, depth, **kwargs) 74 | self.my_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 75 | self.my_fc = nn.Linear(512 * kwargs["block"].expansion, kwargs["num_classes"]) 76 | 77 | self.cat_convs = nn.Sequential( 78 | nn.Conv2d(out_channels[1] + 3*decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 79 | nn.Conv2d(out_channels[2] + 3*decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 80 | nn.Conv2d(out_channels[3] + 3*decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 81 | nn.Conv2d(out_channels[4] + 3*decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 82 | ) 83 | self._initialize() 84 | 85 | def _initialize(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | def forward(self, x, decoder1_features, decoder2_features, decoder3_features): 94 | 95 | decoder1_features = decoder1_features[::-1] 96 | decoder2_features = decoder2_features[::-1] 97 | decoder3_features = decoder3_features[::-1] 98 | '''for j, fe in enumerate(decoder_features): 99 | print(j, fe.shape) 100 | print()''' 101 | 102 | stages = self.get_stages() 103 | 104 | features = [] 105 | for i in range(self._depth + 1): 106 | # print(stages[i]) 107 | x = stages[i](x) 108 | # print(i, x.shape) 109 | if i > 0 and i < 5: 110 | skip1 = decoder1_features[i] 111 | skip2 = decoder2_features[i] 112 | skip3 = decoder3_features[i] 113 | # print(skip.shape) 114 | x = torch.cat([x, skip1, skip2, skip3], dim=1) 115 | x = self.cat_convs[i - 1](x) 116 | # print(x.shape) 117 | # print() 118 | features.append(x) 119 | 120 | x = self.my_avgpool(x) 121 | x = torch.flatten(x, 1) 122 | x = self.my_fc(x) 123 | 124 | return x 125 | 126 | 127 | class MyRes2NetEncoderMulti(Res2NetEncoder): 128 | 129 | def __init__(self, out_channels, depth, decoder_channels, **kwargs): 130 | super().__init__(out_channels, depth, **kwargs) 131 | self.my_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 132 | self.my_fc = nn.Linear(512 * kwargs["block"].expansion, kwargs["num_classes"]) 133 | 134 | self.cat_convs = nn.Sequential( 135 | nn.Conv2d(out_channels[1] + 3*decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 136 | nn.Conv2d(out_channels[2] + 3*decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 137 | nn.Conv2d(out_channels[3] + 3*decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 138 | nn.Conv2d(out_channels[4] + 3*decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 139 | ) 140 | self._initialize() 141 | 142 | def _initialize(self): 143 | for n, m in self.named_modules(): 144 | if isinstance(m, nn.Conv2d): 145 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 146 | elif isinstance(m, nn.BatchNorm2d): 147 | nn.init.constant_(m.weight, 1.) 148 | nn.init.constant_(m.bias, 0.) 149 | 150 | def forward(self, x, decoder1_features, decoder2_features, decoder3_features): 151 | 152 | decoder1_features = decoder1_features[::-1] 153 | decoder2_features = decoder2_features[::-1] 154 | decoder3_features = decoder3_features[::-1] 155 | '''for j, fe in enumerate(decoder_features): 156 | print(j, fe.shape) 157 | print()''' 158 | 159 | stages = self.get_stages() 160 | 161 | features = [] 162 | for i in range(self._depth + 1): 163 | # print(stages[i]) 164 | x = stages[i](x) 165 | # print(i, x.shape) 166 | if i > 0 and i < 5: 167 | skip1 = decoder1_features[i] 168 | skip2 = decoder2_features[i] 169 | skip3 = decoder3_features[i] 170 | # print(skip.shape) 171 | x = torch.cat([x, skip1, skip2, skip3], dim=1) 172 | x = self.cat_convs[i - 1](x) 173 | # print(x.shape) 174 | # print() 175 | features.append(x) 176 | 177 | x = self.my_avgpool(x) 178 | x = torch.flatten(x, 1) 179 | x = self.my_fc(x) 180 | 181 | return x 182 | 183 | 184 | class MyResNestEncoderMulti(ResNestEncoder): 185 | 186 | def __init__(self, out_channels, depth, decoder_channels, **kwargs): 187 | super().__init__(out_channels, depth, **kwargs) 188 | self.my_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 189 | self.my_fc = nn.Linear(512 * kwargs["block"].expansion, kwargs["num_classes"]) 190 | 191 | self.cat_convs = nn.Sequential( 192 | nn.Conv2d(out_channels[1] + 3*decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 193 | nn.Conv2d(out_channels[2] + 3*decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 194 | nn.Conv2d(out_channels[3] + 3*decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 195 | nn.Conv2d(out_channels[4] + 3*decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 196 | ) 197 | self._initialize() 198 | 199 | def _initialize(self): 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 203 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 204 | nn.init.constant_(m.weight, 1) 205 | nn.init.constant_(m.bias, 0) 206 | 207 | def forward(self, x, decoder1_features, decoder2_features, decoder3_features): 208 | 209 | decoder1_features = decoder1_features[::-1] 210 | decoder2_features = decoder2_features[::-1] 211 | decoder3_features = decoder3_features[::-1] 212 | '''for j, fe in enumerate(decoder_features): 213 | print(j, fe.shape) 214 | print()''' 215 | 216 | stages = self.get_stages() 217 | 218 | features = [] 219 | for i in range(self._depth + 1): 220 | # print(stages[i]) 221 | x = stages[i](x) 222 | # print(i, x.shape) 223 | if i > 0 and i < 5: 224 | skip1 = decoder1_features[i] 225 | skip2 = decoder2_features[i] 226 | skip3 = decoder3_features[i] 227 | # print(skip.shape) 228 | x = torch.cat([x, skip1, skip2, skip3], dim=1) 229 | x = self.cat_convs[i - 1](x) 230 | # print(x.shape) 231 | # print() 232 | features.append(x) 233 | 234 | x = self.my_avgpool(x) 235 | x = torch.flatten(x, 1) 236 | x = self.my_fc(x) 237 | 238 | return x 239 | 240 | 241 | class MyVggEncoder(VGGEncoder): 242 | 243 | def __init__(self, out_channels, config, decoder_channels, batch_norm=False, depth=5, **kwargs): 244 | super().__init__(out_channels, config, batch_norm, depth, **kwargs) 245 | self._out_channels = out_channels 246 | self._depth = depth 247 | self._in_channels = 3 248 | 249 | self.my_avgpool = nn.AdaptiveAvgPool2d((7, 7)) 250 | self.my_classifier = nn.Sequential( 251 | nn.Linear(512 * 7 * 7, 4096), 252 | nn.ReLU(True), 253 | nn.Dropout(), 254 | nn.Linear(4096, 4096), 255 | nn.ReLU(True), 256 | nn.Dropout(), 257 | nn.Linear(4096, kwargs["num_classes"]), 258 | ) 259 | 260 | self.cat_convs = nn.Sequential( 261 | nn.Conv2d(out_channels[1] + decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 262 | nn.Conv2d(out_channels[2] + decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 263 | nn.Conv2d(out_channels[3] + decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 264 | nn.Conv2d(out_channels[4] + decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 265 | ) 266 | self._initialize() 267 | 268 | def _initialize(self): 269 | for m in self.modules(): 270 | if isinstance(m, nn.Conv2d): 271 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 272 | if m.bias is not None: 273 | nn.init.constant_(m.bias, 0) 274 | elif isinstance(m, nn.BatchNorm2d): 275 | nn.init.constant_(m.weight, 1) 276 | nn.init.constant_(m.bias, 0) 277 | elif isinstance(m, nn.Linear): 278 | nn.init.normal_(m.weight, 0, 0.01) 279 | nn.init.constant_(m.bias, 0) 280 | 281 | def forward(self, x, decoder1_features): 282 | 283 | decoder1_features = decoder1_features[::-1] 284 | # decoder2_features = decoder2_features[::-1] 285 | # decoder3_features = decoder3_features[::-1] 286 | '''for j, fe in enumerate(decoder_features): 287 | print(j, fe.shape) 288 | print()''' 289 | 290 | stages = self.get_stages() 291 | 292 | features = [] 293 | for i in range(self._depth + 1): 294 | # print(stages[i]) 295 | x = stages[i](x) 296 | # print(i, x.shape) 297 | if i > 0 and i < 5: 298 | skip1 = decoder1_features[i] 299 | # skip2 = decoder2_features[i] 300 | # skip3 = decoder3_features[i] 301 | # print(skip.shape) 302 | x = torch.cat([x, skip1], dim=1) 303 | x = self.cat_convs[i - 1](x) 304 | # print(x.shape) 305 | # print() 306 | features.append(x) 307 | 308 | x = self.my_avgpool(x) 309 | x = torch.flatten(x, 1) 310 | x = self.my_classifier(x) 311 | 312 | return x 313 | 314 | 315 | class MyVggEncoderMulti(VGGEncoder): 316 | 317 | def __init__(self, out_channels, config, decoder_channels, batch_norm=False, depth=5, **kwargs): 318 | super().__init__(out_channels, config, batch_norm, depth, **kwargs) 319 | self._out_channels = out_channels 320 | self._depth = depth 321 | self._in_channels = 3 322 | 323 | self.my_avgpool = nn.AdaptiveAvgPool2d((7, 7)) 324 | self.my_classifier = nn.Sequential( 325 | nn.Linear(512 * 7 * 7, 4096), 326 | nn.ReLU(True), 327 | nn.Dropout(), 328 | nn.Linear(4096, 4096), 329 | nn.ReLU(True), 330 | nn.Dropout(), 331 | nn.Linear(4096, kwargs["num_classes"]), 332 | ) 333 | 334 | self.cat_convs = nn.Sequential( 335 | nn.Conv2d(out_channels[1] + 3*decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 336 | nn.Conv2d(out_channels[2] + 3*decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 337 | nn.Conv2d(out_channels[3] + 3*decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 338 | nn.Conv2d(out_channels[4] + 3*decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 339 | ) 340 | self._initialize() 341 | 342 | def _initialize(self): 343 | for m in self.modules(): 344 | if isinstance(m, nn.Conv2d): 345 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 346 | if m.bias is not None: 347 | nn.init.constant_(m.bias, 0) 348 | elif isinstance(m, nn.BatchNorm2d): 349 | nn.init.constant_(m.weight, 1) 350 | nn.init.constant_(m.bias, 0) 351 | elif isinstance(m, nn.Linear): 352 | nn.init.normal_(m.weight, 0, 0.01) 353 | nn.init.constant_(m.bias, 0) 354 | 355 | def forward(self, x, decoder1_features, decoder2_features, decoder3_features): 356 | 357 | decoder1_features = decoder1_features[::-1] 358 | decoder2_features = decoder2_features[::-1] 359 | decoder3_features = decoder3_features[::-1] 360 | '''for j, fe in enumerate(decoder_features): 361 | print(j, fe.shape) 362 | print()''' 363 | 364 | stages = self.get_stages() 365 | 366 | features = [] 367 | for i in range(self._depth + 1): 368 | # print(stages[i]) 369 | x = stages[i](x) 370 | # print(i, x.shape) 371 | if i > 0 and i < 5: 372 | skip1 = decoder1_features[i] 373 | skip2 = decoder2_features[i] 374 | skip3 = decoder3_features[i] 375 | # print(skip.shape) 376 | x = torch.cat([x, skip1, skip2, skip3], dim=1) 377 | x = self.cat_convs[i - 1](x) 378 | # print(x.shape) 379 | # print() 380 | features.append(x) 381 | 382 | x = self.my_avgpool(x) 383 | x = torch.flatten(x, 1) 384 | x = self.my_classifier(x) 385 | 386 | return x 387 | 388 | 389 | class MySKNetEncoderMulti(SkNetEncoder): 390 | 391 | def __init__(self, out_channels, depth, decoder_channels, **kwargs): 392 | super().__init__(out_channels, depth, **kwargs) 393 | self.my_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 394 | self.my_fc = nn.Linear(512 * kwargs["block"].expansion, kwargs["num_classes"]) 395 | 396 | self.cat_convs = nn.Sequential( 397 | nn.Conv2d(out_channels[1] + 3*decoder_channels[3], out_channels[1], kernel_size=1, stride=1, bias=False), 398 | nn.Conv2d(out_channels[2] + 3*decoder_channels[2], out_channels[2], kernel_size=1, stride=1, bias=False), 399 | nn.Conv2d(out_channels[3] + 3*decoder_channels[1], out_channels[3], kernel_size=1, stride=1, bias=False), 400 | nn.Conv2d(out_channels[4] + 3*decoder_channels[0], out_channels[4], kernel_size=1, stride=1, bias=False) 401 | ) 402 | self._initialize() 403 | 404 | def _initialize(self): 405 | for m in self.modules(): 406 | if isinstance(m, nn.Conv2d): 407 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 408 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 409 | nn.init.constant_(m.weight, 1) 410 | nn.init.constant_(m.bias, 0) 411 | 412 | def forward(self, x, decoder1_features, decoder2_features, decoder3_features): 413 | 414 | decoder1_features = decoder1_features[::-1] 415 | decoder2_features = decoder2_features[::-1] 416 | decoder3_features = decoder3_features[::-1] 417 | '''for j, fe in enumerate(decoder_features): 418 | print(j, fe.shape) 419 | print()''' 420 | 421 | stages = self.get_stages() 422 | 423 | features = [] 424 | for i in range(self._depth + 1): 425 | # print(stages[i]) 426 | x = stages[i](x) 427 | # print(i, x.shape) 428 | if i > 0 and i < 5: 429 | skip1 = decoder1_features[i] 430 | skip2 = decoder2_features[i] 431 | skip3 = decoder3_features[i] 432 | # print(skip.shape) 433 | x = torch.cat([x, skip1, skip2, skip3], dim=1) 434 | x = self.cat_convs[i - 1](x) 435 | # print(x.shape) 436 | # print() 437 | features.append(x) 438 | 439 | x = self.my_avgpool(x) 440 | x = torch.flatten(x, 1) 441 | x = self.my_fc(x) 442 | 443 | return x 444 | 445 | 446 | encoders = {} 447 | my_resnet_encoders = copy.deepcopy(resnet_encoders) 448 | my_resnest_encoders = copy.deepcopy(timm_resnest_encoders) 449 | my_vgg_encoders = copy.deepcopy(vgg_encoders) 450 | my_res2net_encoders = copy.deepcopy(timm_res2net_encoders) 451 | my_sknet_encoders = copy.deepcopy(timm_sknet_encoders) 452 | for name in my_resnet_encoders: 453 | my_resnet_encoders[name]["encoder"] = MyResNetEncoderMulti 454 | 455 | for name in my_resnest_encoders: 456 | my_resnest_encoders[name]["encoder"] = MyResNestEncoderMulti 457 | 458 | for name in my_vgg_encoders: 459 | my_vgg_encoders[name]["encoder"] = MyVggEncoderMulti 460 | 461 | for name in my_res2net_encoders: 462 | my_res2net_encoders[name]["encoder"] = MyRes2NetEncoderMulti 463 | 464 | for name in my_sknet_encoders: 465 | my_sknet_encoders[name]["encoder"] = MySKNetEncoderMulti 466 | 467 | encoders.update(my_resnet_encoders) 468 | encoders.update(my_resnest_encoders) 469 | encoders.update(my_vgg_encoders) 470 | encoders.update(my_res2net_encoders) 471 | encoders.update(my_sknet_encoders) 472 | 473 | 474 | def my_get_encoder(name, in_channels=3, depth=5, weights=None, decoder_channels=(256, 128, 64, 32, 16), num_classes=1): 475 | Encoder = encoders[name]["encoder"] 476 | params = encoders[name]["params"] 477 | params.update(depth=depth) 478 | params["decoder_channels"] = decoder_channels 479 | params["num_classes"] = num_classes 480 | encoder = Encoder(**params) 481 | 482 | if weights is not None: 483 | settings = encoders[name]["pretrained_settings"][weights] 484 | state_dict = encoder.state_dict() 485 | # for param in state_dict: 486 | # print(param, '\t', state_dict[param].size()) 487 | pretrain_state_dict = model_zoo.load_url(settings["url"]) 488 | # for param in pretrain_state_dict: 489 | # print(param, '\t', pretrain_state_dict[param].size()) 490 | state_dict.update(pretrain_state_dict) 491 | 492 | encoder.load_state_dict(state_dict) 493 | 494 | encoder.set_in_channels(in_channels) 495 | 496 | return encoder 497 | 498 | 499 | class MyUnetDecoder(UnetDecoder): 500 | 501 | def __init__( 502 | self, 503 | encoder_channels, 504 | decoder_channels, 505 | n_blocks=5, 506 | use_batchnorm=True, 507 | attention_type=None, 508 | center=False, 509 | ): 510 | super().__init__(encoder_channels, decoder_channels, n_blocks, use_batchnorm, attention_type, center) 511 | 512 | def forward(self, *features): 513 | 514 | features = features[1:] # remove first skip with same spatial resolution 515 | features = features[::-1] # reverse channels to start from head of encoder 516 | 517 | head = features[0] 518 | skips = features[1:] 519 | 520 | decoder_features = [] 521 | 522 | x = self.center(head) 523 | for i, decoder_block in enumerate(self.blocks): 524 | skip = skips[i] if i < len(skips) else None 525 | '''if i < len(skips): 526 | print(x.shape, skip.shape) 527 | else: 528 | print(x.shape, skip)''' 529 | x = decoder_block(x, skip) 530 | # print(decoder_block, x.shape) 531 | decoder_features.append(x) 532 | 533 | return [x, decoder_features] 534 | 535 | 536 | class MyUnetDecoder_withfirstconnect(nn.Module): 537 | 538 | def __init__( 539 | self, 540 | encoder_channels, 541 | decoder_channels, 542 | n_blocks=5, 543 | use_batchnorm=True, 544 | attention_type=None, 545 | center=False, 546 | ): 547 | super().__init__() 548 | 549 | if n_blocks != len(decoder_channels): 550 | raise ValueError( 551 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 552 | n_blocks, len(decoder_channels) 553 | ) 554 | ) 555 | 556 | encoder_channels = encoder_channels[0:] # remove first skip with same spatial resolution 557 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 558 | 559 | head_channels = encoder_channels[0] 560 | in_channels = [head_channels] + list(decoder_channels[:-1]) 561 | skip_channels = list(encoder_channels[1:]) + [0] 562 | out_channels = decoder_channels 563 | 564 | if center: 565 | self.center = CenterBlock( 566 | head_channels, head_channels, use_batchnorm=use_batchnorm 567 | ) 568 | else: 569 | self.center = nn.Identity() 570 | 571 | # combine decoder keyword arguments 572 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 573 | blocks = [ 574 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 575 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 576 | ] 577 | self.blocks = nn.ModuleList(blocks) 578 | 579 | def forward(self, *features): 580 | 581 | features = features[0:] # remove first skip with same spatial resolution 582 | features = features[::-1] # reverse channels to start from head of encoder 583 | 584 | head = features[0] 585 | skips = features[1:] 586 | 587 | decoder_features = [] 588 | 589 | x = self.center(head) 590 | for i, decoder_block in enumerate(self.blocks): 591 | skip = skips[i] if i < len(skips) else None 592 | '''if i < len(skips): 593 | print(x.shape, skip.shape) 594 | else: 595 | print(x.shape, skip)''' 596 | x = decoder_block(x, skip) 597 | # print(decoder_block, x.shape) 598 | decoder_features.append(x) 599 | 600 | return [x, decoder_features] 601 | 602 | 603 | class MyUnetModel(SegmentationModel): 604 | 605 | def __init__( 606 | self, 607 | encoder_name="resnet34", 608 | encoder_depth=5, 609 | encoder_weights="imagenet", 610 | decoder_use_batchnorm=True, 611 | decoder_channels=(256, 128, 64, 32, 16), 612 | decoder_attention_type=None, 613 | in_channels=3, 614 | classes=1, 615 | activation=None, 616 | aux_params=None, 617 | ): 618 | super().__init__() 619 | 620 | self.encoder = get_encoder( 621 | encoder_name, 622 | in_channels=in_channels, 623 | depth=encoder_depth, 624 | weights=encoder_weights, 625 | ) 626 | 627 | self.decoder1 = MyUnetDecoder( 628 | encoder_channels=self.encoder.out_channels, 629 | decoder_channels=decoder_channels, 630 | n_blocks=encoder_depth, 631 | use_batchnorm=decoder_use_batchnorm, 632 | center=True if encoder_name.startswith("vgg") else False, 633 | attention_type=decoder_attention_type, 634 | ) 635 | 636 | '''self.decoder2 = MyUnetDecoder( 637 | encoder_channels=self.encoder.out_channels, 638 | decoder_channels=decoder_channels, 639 | n_blocks=encoder_depth, 640 | use_batchnorm=decoder_use_batchnorm, 641 | center=True if encoder_name.startswith("vgg") else False, 642 | attention_type=decoder_attention_type, 643 | ) 644 | 645 | self.decoder3= MyUnetDecoder( 646 | encoder_channels=self.encoder.out_channels, 647 | decoder_channels=decoder_channels, 648 | n_blocks=encoder_depth, 649 | use_batchnorm=decoder_use_batchnorm, 650 | center=True if encoder_name.startswith("vgg") else False, 651 | attention_type=decoder_attention_type, 652 | )''' 653 | 654 | self.segmentation_head1 = SegmentationHead( 655 | in_channels=decoder_channels[-1], 656 | out_channels=classes, 657 | activation=activation, 658 | kernel_size=3, 659 | ) 660 | 661 | '''self.segmentation_head2 = SegmentationHead( 662 | in_channels=decoder_channels[-1], 663 | out_channels=classes, 664 | activation=activation, 665 | kernel_size=3, 666 | ) 667 | 668 | self.segmentation_head3 = SegmentationHead( 669 | in_channels=decoder_channels[-1], 670 | out_channels=classes, 671 | activation=activation, 672 | kernel_size=3, 673 | )''' 674 | self.name = "u-{}".format(encoder_name) 675 | self.initialize() 676 | 677 | def initialize(self): 678 | init.initialize_decoder(self.decoder1) 679 | init.initialize_head(self.segmentation_head1) 680 | '''init.initialize_decoder(self.decoder2) 681 | init.initialize_head(self.segmentation_head2) 682 | init.initialize_decoder(self.decoder3) 683 | init.initialize_head(self.segmentation_head3)''' 684 | 685 | def forward(self, x): 686 | features = self.encoder(x) 687 | '''for f in features: 688 | print(f.shape)''' 689 | decoder1_output, decoder1_features = self.decoder1(*features) 690 | # decoder2_output = self.decoder2(*features) 691 | # decoder3_output = self.decoder3(*features) 692 | 693 | mask1 = self.segmentation_head1(decoder1_output) 694 | # mask2 = self.segmentation_head2(decoder2_output) 695 | # mask3 = self.segmentation_head3(decoder3_output) 696 | 697 | # return [mask1, mask2, mask3] 698 | return [mask1, decoder1_features] 699 | 700 | 701 | class MyMultibranchModel(SegmentationModel): 702 | 703 | def __init__( 704 | self, 705 | encoder_name="resnet34", 706 | encoder_depth=5, 707 | encoder_weights="imagenet", 708 | decoder_use_batchnorm=True, 709 | decoder_channels=(256, 128, 64, 32, 16), 710 | decoder_attention_type=None, 711 | in_channels=3, 712 | classes=1, 713 | activation=None, 714 | aux_params=None, 715 | ): 716 | super().__init__() 717 | 718 | self.encoder = get_encoder( 719 | encoder_name, 720 | in_channels=in_channels, 721 | depth=encoder_depth, 722 | weights=encoder_weights, 723 | ) 724 | 725 | self.decoder1 = MyUnetDecoder( 726 | encoder_channels=self.encoder.out_channels, 727 | decoder_channels=decoder_channels, 728 | n_blocks=encoder_depth, 729 | use_batchnorm=decoder_use_batchnorm, 730 | center=True if encoder_name.startswith("vgg") else False, 731 | attention_type=decoder_attention_type, 732 | ) 733 | 734 | self.decoder2 = MyUnetDecoder( 735 | encoder_channels=self.encoder.out_channels, 736 | decoder_channels=decoder_channels, 737 | n_blocks=encoder_depth, 738 | use_batchnorm=decoder_use_batchnorm, 739 | center=True if encoder_name.startswith("vgg") else False, 740 | attention_type=decoder_attention_type, # mind here!!! 741 | ) 742 | 743 | self.decoder3 = MyUnetDecoder( 744 | encoder_channels=self.encoder.out_channels, 745 | decoder_channels=decoder_channels, 746 | n_blocks=encoder_depth, 747 | use_batchnorm=decoder_use_batchnorm, 748 | center=True if encoder_name.startswith("vgg") else False, 749 | attention_type=decoder_attention_type, # mind here!!! 750 | ) 751 | 752 | self.cat_conv = nn.Conv2d(2 * decoder_channels[-1], decoder_channels[-1], kernel_size=3, padding=3 // 2) 753 | 754 | self.segmentation_head1 = SegmentationHead( 755 | in_channels=decoder_channels[-1], 756 | out_channels=classes, 757 | activation=activation, 758 | kernel_size=3, 759 | ) 760 | 761 | self.segmentation_head2 = SegmentationHead( 762 | in_channels=decoder_channels[-1], 763 | out_channels=1, # mind here!!! 764 | activation=None, # mind here!!! 765 | kernel_size=3, 766 | ) 767 | 768 | self.segmentation_head3 = SegmentationHead( 769 | in_channels=decoder_channels[-1], 770 | out_channels=classes, 771 | activation=None, # mind here!!! 772 | kernel_size=3, 773 | ) 774 | self.name = "u-{}".format(encoder_name) 775 | self.initialize() 776 | 777 | def initialize(self): 778 | init.initialize_decoder(self.decoder1) 779 | init.initialize_head(self.segmentation_head1) 780 | init.initialize_decoder(self.decoder2) 781 | init.initialize_head(self.segmentation_head2) 782 | init.initialize_decoder(self.decoder3) 783 | init.initialize_head(self.segmentation_head3) 784 | 785 | def forward(self, x): 786 | features = self.encoder(x) 787 | '''for f in features: 788 | print(f.shape)''' 789 | decoder1_output, decoder1_features = self.decoder1(*features) 790 | decoder2_output, decoder2_features = self.decoder2(*features) 791 | decoder3_output, decoder3_features = self.decoder3(*features) 792 | 793 | # mask = self.segmentation_head1(decoder1_output) 794 | # cat_outputs = torch.cat([decoder1_output, decoder2_output, decoder3_output], dim=1) 795 | cat_outputs = torch.cat([decoder1_output, decoder2_output], dim=1) 796 | cat_inputs = self.cat_conv(cat_outputs) 797 | mask = self.segmentation_head1(cat_inputs) # mind here 798 | boundary = self.segmentation_head2(decoder2_output) 799 | dist = self.segmentation_head3(decoder3_output) 800 | 801 | # return [mask1, mask2, mask3] 802 | return [mask, boundary, dist, decoder1_features, decoder2_features, decoder3_features] 803 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import logging 5 | import random 6 | import glob 7 | import numpy as np 8 | import torch.nn as nn 9 | import torchvision.models as models 10 | import segmentation_models_pytorch as smp 11 | from tqdm import tqdm 12 | from tensorboardX import SummaryWriter 13 | from utils import create_train_arg_parser, AverageMeter, generate_dataset 14 | from sklearn.metrics import cohen_kappa_score 15 | from resnest.torch import resnest50 16 | 17 | 18 | net_config = { 19 | 'vgg16': models.vgg16_bn, 20 | 'resnet50': models.resnet50, 21 | 'resnext50': models.resnext50_32x4d, 22 | # 'resnest50': torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True), 23 | 'resnest50': resnest50(pretrained=True), 24 | 'args': {} 25 | } 26 | 27 | 28 | def make_layers(cfg, batch_norm=False): 29 | layers = [] 30 | in_channels = 1 31 | for v in cfg: 32 | if v == 'M': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 34 | else: 35 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 36 | if batch_norm: 37 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 38 | else: 39 | layers += [conv2d, nn.ReLU(inplace=True)] 40 | in_channels = v 41 | return nn.Sequential(*layers) 42 | 43 | 44 | cfgs = { 45 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 46 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 47 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 48 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 49 | } 50 | 51 | 52 | class CustomizedModel(nn.Module): 53 | def __init__(self, name, backbone, num_classes, pretrained=False, **kwargs): 54 | super(CustomizedModel, self).__init__() 55 | 56 | if 'resnest' in name: 57 | net = resnest50(pretrained=True) 58 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 59 | bias=False) 60 | net.fc = nn.Linear(net.fc.in_features, num_classes) 61 | else: 62 | net = backbone(pretrained=pretrained, **kwargs) 63 | if 'resnet' in name or 'resnext' in name or 'shufflenet' in name: 64 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 65 | bias=False) 66 | net.fc = nn.Linear(net.fc.in_features, num_classes) 67 | elif 'densenet' in name: 68 | net.classifier = nn.Linear(net.classifier.in_features, num_classes) 69 | elif 'vgg' in name: 70 | net.features = make_layers(cfgs['D'], batch_norm=True) 71 | net.classifier = nn.Sequential( 72 | nn.Linear(512 * 7 * 7, 4096), 73 | nn.ReLU(True), 74 | nn.Dropout(), 75 | nn.Linear(4096, 4096), 76 | nn.ReLU(True), 77 | nn.Dropout(), 78 | nn.Linear(4096, num_classes), 79 | ) 80 | elif 'mobilenet' in name: 81 | net.classifier = nn.Sequential( 82 | nn.Dropout(0.2), 83 | nn.Linear(net.last_channel, num_classes), 84 | ) 85 | elif 'squeezenet' in name: 86 | net.classifier = nn.Sequential( 87 | nn.Dropout(p=0.5), 88 | nn.Conv2d(512, num_classes, kernel_size=1), 89 | nn.ReLU(inplace=True), 90 | nn.AdaptiveAvgPool2d((1, 1)) 91 | ) 92 | elif 'resnest' in name: 93 | pass 94 | else: 95 | raise NotImplementedError('Not implemented network.') 96 | self.net = net 97 | 98 | def forward(self, x): 99 | x = self.net(x) 100 | return x 101 | 102 | 103 | def generate_model(network, out_features, net_config, device, pretrained=False, checkpoint=None): 104 | if pretrained: 105 | print('Loading weight from pretrained') 106 | if checkpoint: 107 | model = torch.load(checkpoint).to(device) 108 | print('Load weights form {}'.format(checkpoint)) 109 | else: 110 | if network not in net_config.keys(): 111 | raise NotImplementedError('Not implemented network.') 112 | 113 | model = CustomizedModel( 114 | network, 115 | net_config[network], 116 | out_features, 117 | pretrained, 118 | **net_config['args'] 119 | ).to(device) 120 | 121 | if device == 'cuda' and torch.cuda.device_count() > 1: 122 | model = torch.nn.DataParallel(model) 123 | 124 | return model 125 | 126 | 127 | def set_random_seed(seed): 128 | random.seed(seed) 129 | np.random.seed(seed) 130 | torch.manual_seed(seed) 131 | torch.cuda.manual_seed(seed) 132 | torch.backends.cudnn.deterministic = True 133 | 134 | 135 | def train(model, data_loader, optimizer, criterion, device, training=False): 136 | losses = AverageMeter("Cls_Loss", ".16f") 137 | accs = AverageMeter("Accuracy", ".8f") 138 | kappas = AverageMeter("Kappa", ".8f") 139 | 140 | if training: 141 | model.train() 142 | torch.set_grad_enabled(True) 143 | else: 144 | model.eval() 145 | torch.set_grad_enabled(False) 146 | 147 | process = tqdm(data_loader) 148 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(process): 149 | inputs = inputs.to(device) 150 | targets1, targets2 = targets1.to(device), targets2.to(device) 151 | targets3, targets4 = targets3.to(device), targets4.to(device) 152 | targets = [targets1, targets2, targets3, targets4] 153 | 154 | if training: 155 | optimizer.zero_grad() 156 | 157 | outputs = model(inputs) 158 | 159 | labels = torch.argmax(targets[3], dim=2).squeeze(1) 160 | preds = torch.argmax(outputs, dim=1) 161 | 162 | loss = criterion(outputs, labels) 163 | predictions = preds.detach().cpu().numpy() 164 | target = labels.detach().cpu().numpy() 165 | acc = (predictions == target).sum() / len(predictions) 166 | kappa = cohen_kappa_score(predictions, target) 167 | 168 | if training: 169 | loss.backward() 170 | optimizer.step() 171 | 172 | losses.update(loss.item(), inputs.size(0)) 173 | accs.update(acc.item(), inputs.size(0)) 174 | kappas.update(kappa.item(), inputs.size(0)) 175 | 176 | process.set_description('Loss: ' + str(round(losses.avg, 4))) 177 | 178 | epoch_loss = losses.avg 179 | epoch_acc = accs.avg 180 | epoch_kappa = kappas.avg 181 | 182 | return epoch_loss, epoch_acc, epoch_kappa 183 | 184 | 185 | def main(): 186 | seed = 1234 187 | set_random_seed(seed) 188 | 189 | args = create_train_arg_parser().parse_args() 190 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 191 | 192 | if args.pretrain == 'True': 193 | pretrain = True 194 | else: 195 | pretrain = False 196 | 197 | log_path = os.path.join(args.save_path, "summary/") 198 | writer = SummaryWriter(log_dir=log_path) 199 | rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 200 | log_name = os.path.join(log_path, str(rq) + '.log') 201 | logging.basicConfig( 202 | filename=log_name, 203 | filemode="a", 204 | format="%(asctime)s %(levelname)s %(message)s", 205 | datefmt="%Y-%m-%d %H:%M", 206 | level=logging.INFO, 207 | ) 208 | logging.info(args) 209 | print(args) 210 | 211 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 212 | random.shuffle(train_file_names) 213 | val_file_names = glob.glob(os.path.join(args.val_path, "*.png")) 214 | 215 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 216 | 217 | model = generate_model(args.model_type, args.classnum, net_config, device, pretrained=pretrain, checkpoint=None) 218 | logging.info(model) 219 | model = model.to(device) 220 | 221 | train_loader, valid_loader = generate_dataset(train_file_names, val_file_names, args.batch_size, args.batch_size, args.distance_type, args.clahe) 222 | optimizer = torch.optim.Adam([ 223 | dict(params=model.parameters(), lr=args.LR_seg) 224 | ]) 225 | 226 | criterion = smp.utils.losses.CrossEntropyLoss() 227 | 228 | max_acc = 0 229 | epoch_start = 0 230 | 231 | for epoch in range(epoch_start + 1, epoch_start + 1 + args.num_epochs): 232 | 233 | print('\nEpoch: {}'.format(epoch)) 234 | 235 | train_loss, train_acc, train_kappa = train(model, train_loader, optimizer, criterion, device, training=True) 236 | val_loss, val_acc, val_kappa = train(model, valid_loader, optimizer, criterion, device, training=False) 237 | 238 | epoch_info = "Epoch: {}".format(epoch) 239 | train_info = "Training Loss: {:.4f}, Training Acc: {:.4f}, Training Kappa: {:.4f}".format(train_loss, train_acc, train_kappa) 240 | val_info = "Validation Loss: {:.4f}, Validation Acc: {:.4f}, Validation Kappa: {:.4f}".format(val_loss, val_acc, val_kappa) 241 | print(train_info) 242 | print(val_info) 243 | logging.info(epoch_info) 244 | logging.info(train_info) 245 | logging.info(val_info) 246 | writer.add_scalar("train_loss", train_loss, epoch) 247 | writer.add_scalar("train_acc", train_acc, epoch) 248 | writer.add_scalar("train_kappa", train_kappa, epoch) 249 | writer.add_scalar("val_loss", val_loss, epoch) 250 | writer.add_scalar("val_acc", val_acc, epoch) 251 | writer.add_scalar("val_kappa", val_kappa, epoch) 252 | 253 | best_name = os.path.join(args.save_path, "best_acc_" + str(round(val_acc, 4)) + "_kappa_" + str(round(val_kappa, 4)) + ".pt") 254 | save_name = os.path.join(args.save_path, str(epoch) + "_acc_" + str(round(val_acc, 4)) + "_kappa_" + str(round(val_kappa, 4)) + ".pt") 255 | 256 | if max_acc < val_acc: 257 | max_acc = val_acc 258 | if max_acc > 0.3: 259 | if torch.cuda.device_count() > 1: 260 | torch.save(model.module.state_dict(), best_name) 261 | else: 262 | torch.save(model.state_dict(), best_name) 263 | print('Best model saved!') 264 | logging.warning('Best model saved!') 265 | if epoch % 50 == 0: 266 | if torch.cuda.device_count() > 1: 267 | torch.save(model.module.state_dict(), save_name) 268 | print('Epoch {} model saved!'.format(epoch)) 269 | else: 270 | torch.save(model.state_dict(), save_name) 271 | print('Epoch {} model saved!'.format(epoch)) 272 | 273 | 274 | if __name__ == "__main__": 275 | main() 276 | -------------------------------------------------------------------------------- /train_seg_clf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import torch 4 | import os 5 | import glob 6 | from torch.optim import Adam 7 | from tqdm import tqdm 8 | import logging 9 | from torch import nn 10 | import random 11 | from tensorboardX import SummaryWriter 12 | from utils import create_train_arg_parser, define_loss, generate_dataset 13 | from losses import My_multiLoss 14 | import segmentation_models_pytorch as smp 15 | import numpy as np 16 | from sklearn.metrics import cohen_kappa_score 17 | from smp_model import MyUnetModel, my_get_encoder, MyMultibranchModel 18 | 19 | # os.environ["CUDA_VISIBLE_DEVICES"] = '2' 20 | # torch.backends.cuda.matmul.allow_tf32 = True 21 | # torch.backends.cudnn.benchmark = True 22 | # torch.backends.cudnn.deterministic = False 23 | # torch.backends.cudnn.allow_tf32 = True 24 | 25 | IN_MODELS = ['unet_smp', 'unet++', 'manet', 'linknet', 'fpn', 'pspnet', 'pan', 'deeplabv3', 'deeplabv3+'] 26 | 27 | 28 | def set_seed(seed): 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 35 | 36 | 37 | def segmentation_iteration(model, optimizer, model_type, criterion, data_loader, device, writer, training=False): 38 | running_loss = 0.0 39 | total_size = 0 40 | 41 | if training: 42 | model.train() 43 | torch.set_grad_enabled(True) 44 | else: 45 | model.eval() 46 | torch.set_grad_enabled(False) 47 | 48 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(tqdm(data_loader)): 49 | inputs = inputs.to(device) 50 | targets1, targets2 = targets1.to(device), targets2.to(device) 51 | targets3, targets4 = targets3.to(device), targets4.to(device) 52 | targets = [targets1, targets2, targets3, targets4] 53 | 54 | if training: 55 | optimizer.zero_grad() 56 | 57 | outputs = model(inputs) 58 | 59 | if model_type in IN_MODELS + ["unet"]: 60 | if not isinstance(outputs, list): 61 | outputs = [outputs] 62 | loss = criterion(outputs[0], targets[0]) 63 | dsc_loss = smp.utils.losses.DiceLoss() 64 | # 只写了一个例子 65 | # preds = torch.argmax(outputs[0].exp(), dim=1) 66 | preds = torch.argmax(torch.sigmoid(outputs[0]), dim=1) 67 | dsc = 1 - dsc_loss(preds, targets[0].squeeze(1)) 68 | 69 | elif model_type == "dcan": 70 | loss = criterion(outputs[0], outputs[1], targets[0], targets[1]) 71 | 72 | elif model_type == "dmtn": 73 | loss = criterion(outputs[0], outputs[1], targets[0], targets[2]) 74 | 75 | elif model_type in ["psinet", "convmcd"]: 76 | loss = criterion( 77 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2] 78 | ) 79 | else: 80 | raise ValueError('error') 81 | 82 | if training: 83 | # with amp.scale_loss(loss, optimizer) as scaled_loss: 84 | # scaled_loss.backward() 85 | loss.backward() 86 | optimizer.step() 87 | # scheduler.step() 88 | 89 | running_loss += loss.item() * inputs.size(0) 90 | total_size += inputs.size(0) 91 | 92 | epoch_loss = running_loss / total_size 93 | # print("total size:", total_size, training) 94 | 95 | return epoch_loss, dsc 96 | 97 | 98 | class AverageMeter(object): 99 | """Computes and stores the average and current value""" 100 | def __init__(self, name, fmt=':f'): 101 | self.name = name 102 | self.fmt = fmt 103 | self.reset() 104 | 105 | def reset(self): 106 | self.val = 0 107 | self.avg = 0 108 | self.sum = 0 109 | self.count = 0 110 | 111 | def update(self, val, n=1): 112 | self.val = val 113 | self.sum += val * n 114 | self.count += n 115 | self.avg = self.sum / self.count 116 | 117 | def __str__(self): 118 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 119 | return fmtstr.format(**self.__dict__) 120 | 121 | 122 | def seg_clf_iteration(epoch, model, optimizer, criterion, data_loader, device, writer, loss_weights, startpoint, training=False): 123 | seg_losses = AverageMeter("Loss", ".16f") 124 | multi_losses = AverageMeter("multiLoss", ".16f") 125 | seg_dices = AverageMeter("Dice", ".8f") 126 | seg_jaccards = AverageMeter("Jaccard", ".8f") 127 | clf_losses = AverageMeter("Loss", ".16f") 128 | clf_accs = AverageMeter("Acc", ".8f") 129 | clf_kappas = AverageMeter("Kappa", ".8f") 130 | 131 | if training: 132 | model.train() 133 | torch.set_grad_enabled(True) 134 | else: 135 | model.eval() 136 | torch.set_grad_enabled(False) 137 | 138 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(tqdm(data_loader)): 139 | inputs = inputs.to(device) 140 | targets1, targets2 = targets1.to(device), targets2.to(device) 141 | targets3, targets4 = targets3.to(device), targets4.to(device) 142 | targets = [targets1, targets2, targets3, targets4] 143 | 144 | if training: 145 | optimizer.zero_grad() 146 | 147 | seg_outputs = model.seg_forward(inputs) 148 | # seg_preds = torch.argmax(seg_outputs[0].exp(), dim=1) 149 | if not isinstance(seg_outputs, list): 150 | seg_outputs = [seg_outputs] 151 | 152 | seg_preds = torch.round(seg_outputs[0]) 153 | # clf_outputs = model.clf_forward(inputs, seg_outputs[1]) 154 | clf_outputs = model.clf_forward(inputs, seg_outputs[3], seg_outputs[4], seg_outputs[5]) 155 | 156 | # print(seg_criterion)j 157 | # print(seg_outputs[0].shape, targets[0].shape) 158 | seg_criterion, dice_criterion, jaccard_criterion, clf_criterion = criterion[0], criterion[1], criterion[2], criterion[3] 159 | seg_loss = seg_criterion(seg_outputs[0], targets[0].to(torch.float32)) 160 | multi_criterion = My_multiLoss(loss_weights) 161 | multi_loss = multi_criterion( 162 | seg_outputs[0], seg_outputs[1], seg_outputs[2], targets[0].to(torch.float32), targets[1], targets[2] 163 | ) 164 | # print(seg_outputs[1].shape, targets[1].shape) 165 | seg_dice = 1 - dice_criterion(seg_preds.squeeze(1), targets[0].squeeze(1)) 166 | seg_jaccard = 1 - jaccard_criterion(seg_preds.squeeze(1), targets[0].squeeze(1)) 167 | # seg_iou = smp.utils.metrics.IoU(threshold=0.5) 168 | 169 | # print(clf_outputs.shape, targets[3].shape) 170 | clf_labels = torch.argmax(targets[3], dim=2).squeeze(1) 171 | clf_preds = torch.argmax(clf_outputs, dim=1) 172 | clf_loss = clf_criterion(clf_outputs, clf_labels) 173 | kappa = cohen_kappa_score(clf_labels.detach().cpu().numpy(), clf_preds.detach().cpu().numpy()) 174 | # print(targets[3]) 175 | # print(clf_labels) 176 | # print(clf_preds) 177 | acc = np.mean(clf_labels.detach().cpu().numpy() == clf_preds.detach().cpu().numpy()) 178 | 179 | if training: 180 | if epoch <= startpoint: 181 | # loss = seg_loss 182 | loss = multi_loss 183 | else: 184 | # loss = (seg_loss + clf_loss) 185 | loss = (multi_loss + clf_loss) 186 | loss.backward() 187 | # with amp.scale_loss(loss, optimizer) as scaled_loss: 188 | # scaled_loss.backward() 189 | optimizer.step() 190 | # scheduler.step() 191 | 192 | seg_losses.update(seg_loss.item(), inputs.size(0)) 193 | multi_losses.update(multi_loss.item(), inputs.size(0)) 194 | seg_dices.update(seg_dice.item(), inputs.size(0)) 195 | seg_jaccards.update(seg_jaccard.item(), inputs.size(0)) 196 | clf_losses.update(clf_loss.item(), inputs.size(0)) 197 | clf_accs.update(acc, inputs.size(0)) 198 | clf_kappas.update(kappa, inputs.size(0)) 199 | 200 | seg_epoch_loss = seg_losses.avg 201 | multi_epoch_loss = multi_losses.avg 202 | seg_epoch_dice = seg_dices.avg 203 | seg_epoch_jaccard = seg_jaccards.avg 204 | clf_epoch_loss = clf_losses.avg 205 | clf_epoch_acc = clf_accs.avg 206 | clf_epoch_kappa = clf_kappas.avg 207 | # clf_epoch_loss = clf_running_loss / total_size 208 | # print("total size:", total_size, training, seg_epoch_loss) 209 | 210 | return seg_epoch_loss, multi_epoch_loss, seg_epoch_dice, seg_epoch_jaccard, clf_epoch_loss, clf_epoch_acc, clf_epoch_kappa 211 | 212 | 213 | class CotrainingModel(nn.Module): 214 | 215 | def __init__(self, encoder, pretrain, classnum): 216 | super().__init__() 217 | self.seg_model = MyUnetModel( 218 | encoder_name=encoder, encoder_depth=5, encoder_weights=pretrain, decoder_use_batchnorm=True, 219 | decoder_channels=(256, 128, 64, 32, 16), decoder_attention_type=None, in_channels=1, classes=1, 220 | activation='sigmoid', aux_params=None 221 | ) 222 | self.clf_model = my_get_encoder(encoder, in_channels=1, depth=5, weights=pretrain, num_classes=classnum) 223 | 224 | def seg_forward(self, x): 225 | return self.seg_model(x) 226 | 227 | def clf_forward(self, x, decoder_features): 228 | return self.clf_model(x, decoder_features) 229 | 230 | 231 | class CotrainingModelMulti(nn.Module): 232 | 233 | def __init__(self, encoder, pretrain, usenorm, attention_type, classnum): 234 | super().__init__() 235 | self.seg_model = MyMultibranchModel( 236 | encoder_name=encoder, encoder_depth=5, encoder_weights=pretrain, decoder_use_batchnorm=usenorm, 237 | decoder_channels=(256, 128, 64, 32, 16), 238 | # decoder_channels=(512, 256, 128, 64, 32), 239 | decoder_attention_type=attention_type, in_channels=1, classes=1, 240 | activation='sigmoid', aux_params=None 241 | ) 242 | self.clf_model = my_get_encoder(encoder, in_channels=1, depth=5, weights=pretrain, decoder_channels=(256, 128, 64, 32, 16), num_classes=classnum) 243 | 244 | def seg_forward(self, x): 245 | return self.seg_model(x) 246 | 247 | def clf_forward(self, x, decoder1_features, decoder2_features, decoder3_features): 248 | return self.clf_model(x, decoder1_features, decoder2_features, decoder3_features) 249 | 250 | 251 | def main(): 252 | with torch.backends.cudnn.flags(enabled=True, benchmark=True, deterministic=False, allow_tf32=False): 253 | torch.set_num_threads(4) 254 | set_seed(2021) 255 | 256 | args = create_train_arg_parser().parse_args() 257 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 258 | print("cuda_count:", torch.cuda.device_count()) 259 | 260 | log_path = os.path.join(args.save_path, "summary/") 261 | writer = SummaryWriter(log_dir=log_path) 262 | rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 263 | log_name = os.path.join(log_path, str(rq) + '.log') 264 | logging.basicConfig( 265 | filename=log_name, 266 | filemode="a", 267 | format="%(asctime)s %(levelname)s %(message)s", 268 | datefmt="%Y-%m-%d %H:%M", 269 | level=logging.INFO, 270 | ) 271 | logging.info(args) 272 | 273 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 274 | 275 | encoder = args.encoder 276 | usenorm = args.usenorm 277 | attention_type = args.attention 278 | if args.pretrain in ['imagenet', 'ssl', 'swsl', 'instagram']: 279 | pretrain = args.pretrain 280 | # preprocess_input = get_preprocessing_fn(encoder, pretrain) 281 | else: 282 | pretrain = None 283 | # preprocess_input = get_preprocessing_fn(encoder) 284 | # model = CotrainingModel(encoder, pretrain).to(device) 285 | model = CotrainingModelMulti(encoder, pretrain, usenorm, attention_type, args.classnum).to(device) 286 | logging.info(model) 287 | # seg_criterion = smp.utils.losses.DiceLoss() 288 | # seg_dice_criterion = smp.utils.losses.DiceLoss() 289 | # clf_criterion = smp.utils.losses.CrossEntropyLoss() 290 | criterion = [ 291 | define_loss(args.loss_type), 292 | smp.utils.losses.DiceLoss(), 293 | smp.utils.losses.JaccardLoss(), 294 | smp.utils.losses.CrossEntropyLoss() 295 | ] 296 | 297 | optimizer = Adam([ 298 | {"params": model.seg_model.parameters(), "lr": args.LR_seg}, 299 | {"params": model.clf_model.parameters(), "lr": args.LR_clf} 300 | ]) 301 | 302 | # model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 303 | 304 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 305 | # random.shuffle(train_file_names) 306 | val_file_names = glob.glob(os.path.join(args.val_path, "*.png")) 307 | random.shuffle(val_file_names) 308 | 309 | # train_dataset = DatasetImageMaskContourDist(train_file_names, args.distance_type) 310 | # valid_dataset = DatasetImageMaskContourDist(val_file_names, args.distance_type) 311 | # trainLoader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8) 312 | # devLoader = DataLoader(valid_dataset, batch_size=args.val_batch_size, num_workers=4) 313 | trainLoader, devLoader = generate_dataset(train_file_names, val_file_names, args.batch_size, args.val_batch_size, args.distance_type, args.clahe) 314 | 315 | epoch_start = 0 316 | max_dice = 0.8 317 | max_acc = 0.6 318 | loss_weights = [3, 1, 1] 319 | logging.info(loss_weights) 320 | startpoint = args.startpoint 321 | 322 | for epoch in range(epoch_start + 1, epoch_start + 1 + args.num_epochs): 323 | 324 | print('\nEpoch: {}'.format(epoch)) 325 | training_seg_loss, training_multi_loss, training_seg_dice, training_seg_jaccard, training_clf_loss, training_clf_acc, training_clf_kappa = seg_clf_iteration(epoch, model, optimizer, criterion, trainLoader, device, writer, loss_weights, startpoint, training=True) 326 | dev_seg_loss, dev_multi_loss, dev_seg_dice, dev_seg_jaccard, dev_clf_loss, dev_clf_acc, dev_clf_kappa = seg_clf_iteration(epoch, model, optimizer, criterion, devLoader, device, writer, loss_weights, startpoint, training=False) 327 | 328 | epoch_info = "Epoch: {}".format(epoch) 329 | train_info = "TrainSeg Loss:{:.7f}, TrMutiLoss:{:.7f}, Dice: {:.7f}, Jaccard: {:.7f}, TrainClf Loss:{:.7f}, Acc: {:.7f}, Kappa:{:.7f}".format(training_seg_loss, training_multi_loss, training_seg_dice, training_seg_jaccard, training_clf_loss, training_clf_acc, training_clf_kappa) 330 | val_info = "ValSeg Loss:{:.7f}, VaMutiLoss:{:.7f}, Dice: {:.7f}, Jaccard: {:.7f}, ValClf Loss:{:.7f}, Acc: {:.7f}, Kappa:{:.7f}:".format(dev_seg_loss, dev_multi_loss, dev_seg_dice, dev_seg_jaccard, dev_clf_loss, dev_clf_acc, dev_clf_kappa) 331 | print(train_info) 332 | print(val_info) 333 | logging.info(epoch_info) 334 | logging.info(train_info) 335 | logging.info(val_info) 336 | writer.add_scalar("trainseg_loss", training_seg_loss, epoch) 337 | writer.add_scalar("trainmulti_loss", training_multi_loss, epoch) 338 | writer.add_scalar("trainseg_dice", training_seg_dice, epoch) 339 | writer.add_scalar("trainseg_jaccard", training_seg_jaccard, epoch) 340 | writer.add_scalar("traincls_loss", training_clf_loss, epoch) 341 | writer.add_scalar("traincls_acc", training_clf_acc, epoch) 342 | writer.add_scalar("traincls_kappa", training_clf_kappa, epoch) 343 | 344 | writer.add_scalar("valseg_loss", dev_seg_loss, epoch) 345 | writer.add_scalar("valmulti_loss", dev_multi_loss, epoch) 346 | writer.add_scalar("valseg_dice", dev_seg_dice, epoch) 347 | writer.add_scalar("valseg_jaccard", dev_seg_jaccard, epoch) 348 | writer.add_scalar("valcls_loss", dev_clf_loss, epoch) 349 | writer.add_scalar("valcls_acc", dev_clf_acc, epoch) 350 | writer.add_scalar("valcls_kappa", dev_clf_kappa, epoch) 351 | 352 | best_name = os.path.join(args.save_path, "dice_" + str(round(dev_seg_dice, 5)) + "_jaccard_" + str(round(dev_seg_jaccard, 5)) + "_acc_" + str(round(dev_clf_acc, 4)) + "_kap_" + str(round(dev_clf_kappa, 4)) + ".pt") 353 | save_name = os.path.join(args.save_path, str(epoch) + "_dice_" + str(round(dev_seg_dice, 5)) + "_jaccard_" + str(round(dev_seg_jaccard, 5)) + "_acc_" + str(round(dev_clf_acc, 4)) + "_kap_" + str(round(dev_clf_kappa, 4)) + ".pt") 354 | 355 | if max_dice <= dev_seg_dice: 356 | max_dice = dev_seg_dice 357 | # if epoch > 10: 358 | if torch.cuda.device_count() > 1: 359 | torch.save(model.module.state_dict(), best_name) 360 | else: 361 | torch.save(model.state_dict(), best_name) 362 | print('Best seg model saved!') 363 | logging.warning('Best seg model saved!') 364 | if max_acc <= dev_clf_acc: 365 | max_acc = dev_clf_acc 366 | # if epoch > 10: 367 | if torch.cuda.device_count() > 1: 368 | torch.save(model.module.state_dict(), best_name) 369 | else: 370 | torch.save(model.state_dict(), best_name) 371 | print('Best clf model saved!') 372 | logging.warning('Best clf model saved!') 373 | 374 | if epoch % 50 == 0: 375 | if torch.cuda.device_count() > 1: 376 | torch.save(model.module.state_dict(), save_name) 377 | print('Epoch {} model saved!'.format(epoch)) 378 | else: 379 | torch.save(model.state_dict(), save_name) 380 | print('Epoch {} model saved!'.format(epoch)) 381 | 382 | 383 | if __name__ == "__main__": 384 | main() 385 | -------------------------------------------------------------------------------- /train_smp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | import glob 5 | from torch.optim import Adam 6 | from tqdm import tqdm 7 | import logging 8 | from torch import nn 9 | import random 10 | from tensorboardX import SummaryWriter 11 | from utils import visualize, evaluate, create_train_arg_parser, evaluate_modi 12 | from losses import LossUNet, LossDCAN, LossDMTN, LossPsiNet 13 | from models import UNet, UNet_DCAN, UNet_DMTN, PsiNet, UNet_ConvMCD 14 | from dataset import DatasetImageMaskContourDist 15 | import segmentation_models_pytorch as smp 16 | from segmentation_models_pytorch.encoders import get_preprocessing_fn 17 | from apex import amp 18 | 19 | 20 | # os.environ["CUDA_VISIBLE_DEVICES"] = '2' 21 | # torch.backends.cuda.matmul.allow_tf32 = True 22 | # torch.backends.cudnn.benchmark = True 23 | # torch.backends.cudnn.deterministic = False 24 | # torch.backends.cudnn.allow_tf32 = True 25 | 26 | # SEED = 2021 27 | # utils.set_global_seed(SEED) 28 | # utils.prepare_cudnn(deterministic=True) 29 | 30 | def build_model(model_type, encoder, pretrain): 31 | 32 | if model_type == "unet": 33 | model = UNet(num_classes=2) 34 | print(model) 35 | if model_type == "dcan": 36 | model = UNet_DCAN(num_classes=2) 37 | print(model) 38 | if model_type == "dmtn": 39 | model = UNet_DMTN(num_classes=2) 40 | print(model) 41 | if model_type == "psinet": 42 | model = PsiNet(num_classes=2) 43 | print(model) 44 | if model_type == "convmcd": 45 | model = UNet_ConvMCD(num_classes=2) 46 | print(model) 47 | if model_type == "unet_smp": 48 | model = smp.Unet( 49 | encoder_name=encoder, 50 | encoder_depth=5, 51 | encoder_weights=pretrain, 52 | decoder_use_batchnorm=True, 53 | decoder_channels=(256, 128, 64, 32, 16), 54 | decoder_attention_type=None, 55 | in_channels=1, 56 | classes=1, 57 | activation=None, 58 | aux_params=None 59 | ) 60 | print(model) 61 | if model_type == "unet++": 62 | model = smp.UnetPlusPlus( 63 | encoder_name=encoder, 64 | encoder_depth=5, 65 | encoder_weights=pretrain, 66 | decoder_use_batchnorm=True, 67 | decoder_channels=(256, 128, 64, 32, 16), 68 | decoder_attention_type=None, 69 | in_channels=1, 70 | classes=1, 71 | activation=None, 72 | aux_params=None 73 | ) 74 | print(model) 75 | if model_type == "manet": 76 | model = smp.MAnet( 77 | encoder_name=encoder, 78 | encoder_depth=5, 79 | encoder_weights=pretrain, 80 | decoder_use_batchnorm=True, 81 | decoder_channels=(256, 128, 64, 32, 16), 82 | decoder_pab_channels=64, 83 | in_channels=1, 84 | classes=1, 85 | activation=None, 86 | aux_params=None 87 | ) 88 | print(model) 89 | if model_type == "linknet": 90 | model = smp.Linknet( 91 | encoder_name=encoder, 92 | encoder_depth=5, 93 | encoder_weights=pretrain, 94 | decoder_use_batchnorm=True, 95 | in_channels=1, 96 | classes=1, 97 | activation=None, 98 | aux_params=None) 99 | print(model) 100 | if model_type == "fpn": 101 | model = smp.FPN( 102 | encoder_name=encoder, 103 | encoder_depth=5, 104 | encoder_weights=pretrain, 105 | decoder_pyramid_channels=256, 106 | decoder_segmentation_channels=128, 107 | decoder_merge_policy='add', 108 | decoder_dropout=0.2, 109 | in_channels=1, 110 | classes=1, 111 | activation=None, 112 | upsampling=4, 113 | aux_params=None 114 | ) 115 | print(model) 116 | if model_type == "pspnet": 117 | model = smp.PSPNet( 118 | encoder_name=encoder, 119 | encoder_weights=pretrain, 120 | encoder_depth=3, 121 | psp_out_channels=512, 122 | psp_use_batchnorm=True, 123 | psp_dropout=0.2, 124 | in_channels=1, 125 | classes=1, 126 | activation=None, 127 | upsampling=8, 128 | aux_params=None 129 | ) 130 | print(model) 131 | if model_type == "pan": 132 | model = smp.PAN( 133 | encoder_name=encoder, 134 | encoder_weights=pretrain, 135 | encoder_dilation=True, 136 | decoder_channels=32, 137 | in_channels=1, 138 | classes=1, 139 | activation=None, 140 | upsampling=4, 141 | aux_params=None 142 | ) 143 | print(model) 144 | if model_type == "deeplabv3": 145 | model = smp.DeepLabV3( 146 | encoder_name=encoder, 147 | encoder_depth=5, 148 | encoder_weights=pretrain, 149 | decoder_channels=256, 150 | in_channels=1, 151 | classes=1, 152 | activation=None, 153 | upsampling=8, 154 | aux_params=None 155 | ) 156 | print(model) 157 | if model_type == "deeplabv3+": 158 | model = smp.DeepLabV3Plus( 159 | encoder_name=encoder, 160 | encoder_depth=5, 161 | encoder_weights=pretrain, 162 | encoder_output_stride=16, 163 | decoder_channels=256, 164 | decoder_atrous_rates=(12, 24, 36), 165 | in_channels=1, 166 | classes=1, 167 | activation=None, 168 | upsampling=4, 169 | aux_params=None 170 | ) 171 | print(model) 172 | return model 173 | 174 | 175 | def define_loss(loss_type, weights=[3, 1, 2]): 176 | 177 | if loss_type == "jaccard": 178 | criterion = smp.utils.losses.JaccardLoss() 179 | if loss_type == "dice": 180 | criterion = smp.utils.losses.DiceLoss() 181 | if loss_type == "ce": 182 | criterion = smp.utils.losses.CrossEntropyLoss() 183 | if loss_type == "bcewithlogit": 184 | criterion = smp.utils.losses.BCEWithLogitsLoss() 185 | if loss_type == "unet": 186 | criterion = LossUNet(weights) 187 | if loss_type == "dcan": 188 | criterion = LossDCAN(weights) 189 | if loss_type == "dmtn": 190 | criterion = LossDMTN(weights) 191 | if loss_type == "psinet" or loss_type == "convmcd": 192 | # Both psinet and convmcd uses same mask,contour and distance loss function 193 | criterion = LossPsiNet(weights) 194 | 195 | return criterion 196 | 197 | 198 | def train_model(model, targets, model_type, criterion, optimizer): 199 | in_models = ['unet_smp', 'unet++', 'manet', 'linknet', 'fpn', 'pspnet', 'pan', 'deeplabv3', 'deeplabv3+'] 200 | 201 | if model_type in in_models: 202 | 203 | optimizer.zero_grad() 204 | outputs = model(inputs) 205 | if not isinstance(outputs, list): 206 | outputs = [outputs] 207 | # print('\n', np.array(outputs).shape) 208 | # print('\n', np.array(targets).shape) 209 | loss = criterion(outputs[0], targets[0]) 210 | with amp.scale_loss(loss, optimizer) as scaled_loss: 211 | scaled_loss.backward() 212 | 213 | optimizer.step() 214 | # scheduler.step() 215 | 216 | elif model_type == "unet": 217 | optimizer.zero_grad() 218 | outputs = model(inputs) 219 | if not isinstance(outputs, list): 220 | outputs = [outputs] 221 | 222 | loss = criterion(outputs[0], targets[0]) 223 | with amp.scale_loss(loss, optimizer) as scaled_loss: 224 | scaled_loss.backward() 225 | 226 | optimizer.step() 227 | # scheduler.step() 228 | 229 | elif model_type == "dcan": 230 | 231 | optimizer.zero_grad() 232 | outputs = model(inputs) 233 | loss = criterion(outputs[0], outputs[1], targets[0], targets[1]) 234 | with amp.scale_loss(loss, optimizer) as scaled_loss: 235 | scaled_loss.backward() 236 | 237 | optimizer.step() 238 | # scheduler.step() 239 | 240 | elif model_type == "dmtn": 241 | 242 | optimizer.zero_grad() 243 | outputs = model(inputs) 244 | loss = criterion(outputs[0], outputs[1], targets[0], targets[2]) 245 | with amp.scale_loss(loss, optimizer) as scaled_loss: 246 | scaled_loss.backward() 247 | 248 | optimizer.step() 249 | # scheduler.step() 250 | 251 | elif model_type == "psinet" or model_type == "convmcd": 252 | 253 | optimizer.zero_grad() 254 | outputs = model(inputs) 255 | loss = criterion( 256 | outputs[0], outputs[1], outputs[2], targets[0], targets[1], targets[2] 257 | ) 258 | with amp.scale_loss(loss, optimizer) as scaled_loss: 259 | scaled_loss.backward() 260 | 261 | optimizer.step() 262 | # scheduler.step() 263 | 264 | else: 265 | print('error') 266 | 267 | return loss 268 | 269 | # ==================================================== 270 | 271 | 272 | if __name__ == "__main__": 273 | 274 | args = create_train_arg_parser().parse_args() 275 | encoder = args.encoder 276 | if args.pretrain in ['imagenet', 'ssl', 'swsl']: 277 | pretrain = args.pretrain 278 | preprocess_input = get_preprocessing_fn(encoder, pretrain) 279 | else: 280 | pretrain = None 281 | # preprocess_input = get_preprocessing_fn(encoder) ## 282 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 283 | log_path = args.save_path + "/summary" 284 | writer = SummaryWriter(log_dir=log_path) 285 | 286 | logging.basicConfig( 287 | filename=''.format(args.object_type), 288 | filemode="a", 289 | format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", 290 | datefmt="%Y-%m-%d %H:%M", 291 | level=logging.INFO, 292 | ) 293 | logging.info("") 294 | 295 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 296 | random.shuffle(train_file_names) 297 | val_file_names = glob.glob(os.path.join(args.val_path, "*.png")) 298 | 299 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 300 | 301 | if args.pretrain in ['imagenet', 'ssl', 'swsl']: 302 | model = build_model(args.model_type, args.encoder, pretrain) 303 | else: 304 | pretrain = None 305 | model = build_model(args.model_type, args.encoder, pretrain) 306 | 307 | model = model.to(device) 308 | 309 | optimizer = Adam(model.parameters(), args.LR) 310 | 311 | criterion = define_loss(args.loss_type) 312 | 313 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 314 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250, 400], gamma=0.2) ### 315 | if args.use_scheduler is True: 316 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True) 317 | 318 | # model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 319 | 320 | if torch.cuda.device_count() > 1: 321 | print("Let's use", torch.cuda.device_count(), "GPUs!") 322 | model = nn.DataParallel(model) 323 | 324 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 325 | in_models = ['unet_smp', 'unet++', 'manet', 'linknet', 'fpn', 'pspnet', 'pan', 'deeplabv3', 'deeplabv3+'] 326 | 327 | epoch_start = "0" 328 | if args.use_pretrained: 329 | print("Loading Model {}".format(os.path.basename(args.pretrained_model_path))) 330 | model.load_state_dict(torch.load(args.pretrained_model_path)) 331 | epoch_start = os.path.basename(args.pretrained_model_path).split(".")[0] 332 | print(epoch_start) 333 | 334 | torch.set_num_threads(2) 335 | 336 | trainLoader = DataLoader( 337 | DatasetImageMaskContourDist(train_file_names, args.distance_type), 338 | batch_size=args.batch_size, num_workers=4 339 | ) 340 | devLoader = DataLoader( 341 | DatasetImageMaskContourDist(val_file_names, args.distance_type), num_workers=4 342 | ) 343 | displayLoader = DataLoader( 344 | DatasetImageMaskContourDist(val_file_names, args.distance_type), 345 | batch_size=args.val_batch_size, num_workers=4 346 | ) 347 | 348 | for epoch in range(int(epoch_start) + 1, int(epoch_start) + 1 + args.num_epochs): 349 | 350 | global_step = epoch * len(trainLoader) 351 | running_loss = 0.0 352 | 353 | model.train() 354 | 355 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate( 356 | tqdm(trainLoader) 357 | ): 358 | 359 | inputs = inputs.to(device) 360 | targets1 = targets1.to(device) 361 | targets2 = targets2.to(device) 362 | targets3 = targets3.to(device) 363 | targets4 = targets4.to(device) 364 | 365 | targets = [targets1, targets2, targets3, targets4] 366 | 367 | loss = train_model(model, targets, args.model_type, criterion, optimizer) 368 | 369 | writer.add_scalar("loss", loss.item(), epoch) 370 | 371 | running_loss += loss.item() * inputs.size(0) 372 | 373 | epoch_loss = running_loss / len(train_file_names) 374 | 375 | if epoch % 1 == 0: 376 | if args.model_type not in in_models: 377 | dev_loss, dsc, dev_time = evaluate(device, epoch, model, devLoader, writer) 378 | else: 379 | dev_loss, dsc, dev_time = evaluate_modi(device, epoch, model, devLoader, writer, criterion, args.model_type) 380 | writer.add_scalar("loss_valid", dev_loss, epoch) 381 | visualize(device, epoch, model, displayLoader, writer, args.val_batch_size) 382 | print("Global Loss:{} Val Loss:{} dsc:{}".format(epoch_loss, dev_loss, dsc)) 383 | else: 384 | print("Global Loss:{} ".format(epoch_loss)) 385 | 386 | if args.use_scheduler is True: 387 | scheduler.step(dev_loss) 388 | 389 | logging.info("epoch:{} train_loss:{}".format(epoch, epoch_loss)) 390 | 391 | if epoch % 25 == 0: 392 | if torch.cuda.device_count() > 1: 393 | torch.save( 394 | model.module.state_dict(), 395 | os.path.join(args.save_path, str(epoch) + "_val_" + str(round(dev_loss, 5)) + "_dsc_" + str(round(dsc, 5)) + ".pt") 396 | ) 397 | else: 398 | torch.save( 399 | model.state_dict(), os.path.join(args.save_path, str(epoch) + "_val_" + str(round(dev_loss, 5)) + "_dsc_" + str(round(dsc, 5)) + ".pt") 400 | ) 401 | -------------------------------------------------------------------------------- /train_smp_y.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import logging 5 | import random 6 | import glob 7 | import segmentation_models_pytorch as smp 8 | from tqdm import tqdm 9 | from tensorboardX import SummaryWriter 10 | from utils import create_train_arg_parser, build_model, define_loss, AverageMeter, generate_dataset 11 | from sklearn.metrics import cohen_kappa_score 12 | 13 | 14 | def train(epoch, model, data_loader, optimizer, criterion, device, training=False): 15 | seg_losses = AverageMeter("Seg_Loss", ".16f") 16 | dices = AverageMeter("Dice", ".8f") 17 | jaccards = AverageMeter("Jaccard", ".8f") 18 | clas_losses = AverageMeter("Clas_Loss", ".16f") 19 | accs = AverageMeter("Accuracy", ".8f") 20 | kappas = AverageMeter("Kappa", ".8f") 21 | 22 | if training: 23 | model.train() 24 | torch.set_grad_enabled(True) 25 | else: 26 | model.eval() 27 | torch.set_grad_enabled(False) 28 | 29 | process = tqdm(data_loader) 30 | for i, (img_file_name, inputs, targets1, targets2, targets3, targets4) in enumerate(process): 31 | inputs = inputs.to(device) 32 | targets1, targets2 = targets1.to(device), targets2.to(device) 33 | targets3, targets4 = targets3.to(device), targets4.to(device) 34 | targets = [targets1, targets2, targets3, targets4] 35 | 36 | if training: 37 | optimizer.zero_grad() 38 | 39 | outputs, label = model(inputs) 40 | if not isinstance(outputs, list): 41 | outputs = [outputs] 42 | preds = torch.round(outputs[0]) 43 | 44 | seg_criterion, dice_criterion, jaccard_criterion, clas_criterion = criterion[0], criterion[1], criterion[2], criterion[3] 45 | 46 | seg_loss = seg_criterion(outputs[0], targets[0].to(torch.float32)) 47 | dice = 1 - dice_criterion(preds.squeeze(1), targets[0].squeeze(1)) 48 | jaccard = 1 - jaccard_criterion(preds.squeeze(1), targets[0].squeeze(1)) 49 | 50 | clf_labels = torch.argmax(targets[3], dim=2).squeeze(1) 51 | clf_preds = torch.argmax(label, dim=1) 52 | clas_loss = clas_criterion(label, clf_labels) 53 | predictions = clf_preds.detach().cpu().numpy() 54 | target = clf_labels.detach().cpu().numpy() 55 | acc = (predictions == target).sum() / len(predictions) 56 | kappa = cohen_kappa_score(predictions, target) 57 | 58 | if training: 59 | if epoch < 30: 60 | total_loss = seg_loss 61 | else: 62 | total_loss = seg_loss + clas_loss 63 | total_loss.backward() 64 | optimizer.step() 65 | 66 | seg_losses.update(seg_loss.item(), inputs.size(0)) 67 | dices.update(dice.item(), inputs.size(0)) 68 | jaccards.update(jaccard.item(), inputs.size(0)) 69 | clas_losses.update(clas_loss.item(), inputs.size(0)) 70 | accs.update(acc.item(), inputs.size(0)) 71 | kappas.update(kappa.item(), inputs.size(0)) 72 | 73 | process.set_description('Seg Loss: ' + str(round(seg_losses.avg, 4)) + ' Clas Loss: ' + str(round(clas_losses.avg, 4))) 74 | 75 | epoch_seg_loss = seg_losses.avg 76 | epoch_dice = dices.avg 77 | epoch_jaccard = jaccards.avg 78 | epoch_clas_loss = clas_losses.avg 79 | epoch_acc = accs.avg 80 | epoch_kappa = kappas.avg 81 | 82 | return epoch_seg_loss, epoch_dice, epoch_jaccard, epoch_clas_loss, epoch_acc, epoch_kappa 83 | 84 | 85 | def main(): 86 | 87 | args = create_train_arg_parser().parse_args() 88 | CUDA_SELECT = "cuda:{}".format(args.cuda_no) 89 | 90 | log_path = os.path.join(args.save_path, "summary/") 91 | writer = SummaryWriter(log_dir=log_path) 92 | rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 93 | log_name = os.path.join(log_path, str(rq) + '.log') 94 | logging.basicConfig( 95 | filename=log_name, 96 | filemode="a", 97 | format="%(asctime)s %(levelname)s %(message)s", 98 | datefmt="%Y-%m-%d %H:%M", 99 | level=logging.INFO, 100 | ) 101 | logging.info(args) 102 | print(args) 103 | 104 | train_file_names = glob.glob(os.path.join(args.train_path, "*.png")) 105 | random.shuffle(train_file_names) 106 | val_file_names = glob.glob(os.path.join(args.val_path, "*.png")) 107 | 108 | device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu") 109 | 110 | if args.pretrain in ['imagenet', 'ssl', 'swsl']: 111 | pretrain = args.pretrain 112 | model = build_model(args.model_type, args.encoder, pretrain, aux=True) 113 | else: 114 | pretrain = None 115 | model = build_model(args.model_type, args.encoder, pretrain, aux=True) 116 | logging.info(model) 117 | model = model.to(device) 118 | 119 | train_loader, valid_loader = generate_dataset(train_file_names, val_file_names, args.batch_size, args.batch_size, args.distance_type, args.clahe) 120 | 121 | optimizer = torch.optim.Adam([ 122 | dict(params=model.parameters(), lr=args.LR_seg) 123 | ]) 124 | 125 | criterion = [ 126 | define_loss(args.loss_type), 127 | smp.utils.losses.DiceLoss(), 128 | smp.utils.losses.JaccardLoss(), 129 | smp.utils.losses.CrossEntropyLoss() 130 | ] 131 | 132 | max_dice = 0 133 | max_acc = 0 134 | epoch_start = 0 135 | 136 | for epoch in range(epoch_start + 1, epoch_start + 1 + args.num_epochs): 137 | 138 | print('\nEpoch: {}'.format(epoch)) 139 | 140 | train_seg_loss, train_dice, train_jaccard, train_clas_loss, train_acc, train_kappa = train(epoch, model, train_loader, optimizer, criterion, device, training=True) 141 | val_seg_loss, val_dice, val_jaccard, val_clas_loss, val_acc, val_kappa = train(epoch, model, valid_loader, optimizer, criterion, device, training=False) 142 | 143 | epoch_info = "Epoch: {}".format(epoch) 144 | train_seg_info = "Training Seg Loss: {:.4f}, Training Dice: {:.4f}, Training Jaccard: {:.4f}".format(train_seg_loss, train_dice, train_jaccard) 145 | train_clas_info = "Training Clas Loss: {:.4f}, Training Acc: {:.4f}, Training Kappa: {:.4f}".format(train_clas_loss, train_acc, train_kappa) 146 | val_seg_info = "Validation Seg Loss: {:.4f}, Validation Dice: {:.4f}, Validation Jaccard: {:.4f}".format(val_seg_loss, val_dice, val_jaccard) 147 | val_clas_info = "Validation Clas Loss: {:.4f}, Validation Acc: {:.4f}, Validation Kappa: {:.4f}".format(val_clas_loss, val_acc, val_kappa) 148 | print(train_seg_info) 149 | print(train_clas_info) 150 | print(val_seg_info) 151 | print(val_clas_info) 152 | logging.info(epoch_info) 153 | logging.info(train_seg_info) 154 | logging.info(train_clas_info) 155 | logging.info(val_seg_info) 156 | logging.info(val_clas_info) 157 | writer.add_scalar("train_seg_loss", train_seg_loss, epoch) 158 | writer.add_scalar("train_dice", train_dice, epoch) 159 | writer.add_scalar("train_jaccard", train_jaccard, epoch) 160 | writer.add_scalar("train_clas_loss", train_clas_loss, epoch) 161 | writer.add_scalar("train_acc", train_acc, epoch) 162 | writer.add_scalar("train_kappa", train_kappa, epoch) 163 | writer.add_scalar("val_seg_loss", val_seg_loss, epoch) 164 | writer.add_scalar("val_dice", val_dice, epoch) 165 | writer.add_scalar("val_jaccard", val_jaccard, epoch) 166 | writer.add_scalar("val_clas_loss", val_clas_loss, epoch) 167 | writer.add_scalar("val_acc", val_acc, epoch) 168 | writer.add_scalar("val_kappa", val_kappa, epoch) 169 | 170 | best_name = os.path.join(args.save_path, "best_dice_" + str(round(val_dice, 4)) + "_jaccard_" + str(round(val_jaccard, 4)) + "_acc_" + str(round(val_acc, 4)) + "_kappa_" + str(round(val_kappa, 4)) + ".pt") 171 | save_name = os.path.join(args.save_path, str(epoch) + "_dice_" + str(round(val_dice, 4)) + "_jaccard_" + str(round(val_jaccard, 4)) + "_acc_" + str(round(val_acc, 4)) + "_kappa_" + str(round(val_kappa, 4)) + ".pt") 172 | 173 | if max_dice < val_dice: 174 | max_dice = val_dice 175 | if max_dice > 0.5: 176 | if torch.cuda.device_count() > 1: 177 | torch.save(model.module.state_dict(), best_name) 178 | else: 179 | torch.save(model.state_dict(), best_name) 180 | print('Best seg model saved!') 181 | logging.warning('Best seg model saved!') 182 | if max_acc < val_acc: 183 | max_acc = val_acc 184 | if max_acc > 0.4: 185 | if torch.cuda.device_count() > 1: 186 | torch.save(model.module.state_dict(), best_name) 187 | else: 188 | torch.save(model.state_dict(), best_name) 189 | print('Best clas model saved!') 190 | logging.warning('Best clas model saved!') 191 | if epoch % 50 == 0: 192 | if torch.cuda.device_count() > 1: 193 | torch.save(model.module.state_dict(), save_name) 194 | print('Epoch {} model saved!'.format(epoch)) 195 | else: 196 | torch.save(model.state_dict(), save_name) 197 | print('Epoch {} model saved!'.format(epoch)) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | import time 5 | import argparse 6 | import segmentation_models_pytorch as smp 7 | from tqdm import tqdm 8 | from typing import List 9 | from torch import Tensor, einsum 10 | from torch.nn import functional as F 11 | from dataset import DatasetImageMaskContourDist, mean_and_std 12 | from dataset import DatasetCornea, distancedStainingImage 13 | from losses import LossUNet, LossDCAN, LossDMTN, LossPsiNet 14 | from models import UNet, UNet_DCAN, UNet_DMTN, PsiNet, UNet_ConvMCD 15 | from torch.utils.data import DataLoader 16 | from sklearn.model_selection import KFold, train_test_split 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self, name, fmt=':f'): 22 | self.name = name 23 | self.fmt = fmt 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def __str__(self): 39 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 40 | return fmtstr.format(**self.__dict__) 41 | 42 | 43 | def build_model(model_type, encoder, pretrain, aux=False): 44 | 45 | aux_params = dict( 46 | pooling='avg', # one of 'avg', 'max' 47 | dropout=0.5, # dropout ratio, default is None 48 | activation='sigmoid', # activation function, default is None 49 | classes=3, # define number of output labels 50 | ) 51 | 52 | if model_type == "unet": 53 | model = UNet(num_classes=2) 54 | if model_type == "dcan": 55 | model = UNet_DCAN(num_classes=2) 56 | if model_type == "dmtn": 57 | model = UNet_DMTN(num_classes=2) 58 | if model_type == "psinet": 59 | model = PsiNet(num_classes=2) 60 | if model_type == "convmcd": 61 | model = UNet_ConvMCD(num_classes=2) 62 | if model_type == "unet_smp": 63 | model = smp.Unet( 64 | encoder_name=encoder, 65 | encoder_depth=5, 66 | encoder_weights=pretrain, 67 | decoder_use_batchnorm=True, 68 | decoder_channels=(256, 128, 64, 32, 16), 69 | decoder_attention_type=None, 70 | in_channels=3, 71 | classes=1, 72 | activation='sigmoid', 73 | aux_params=None if not aux else aux_params 74 | ) 75 | if model_type == "unet++": 76 | model = smp.UnetPlusPlus( 77 | encoder_name=encoder, 78 | encoder_depth=5, 79 | encoder_weights=pretrain, 80 | decoder_use_batchnorm=True, 81 | decoder_channels=(256, 128, 64, 32, 16), 82 | decoder_attention_type=None, 83 | in_channels=3, 84 | classes=1, 85 | activation='sigmoid', 86 | aux_params=None if not aux else aux_params 87 | ) 88 | if model_type == "manet": 89 | model = smp.MAnet( 90 | encoder_name=encoder, 91 | encoder_depth=5, 92 | encoder_weights=pretrain, 93 | decoder_use_batchnorm=True, 94 | decoder_channels=(256, 128, 64, 32, 16), 95 | decoder_pab_channels=64, 96 | in_channels=1, 97 | classes=1, 98 | activation='sigmoid', 99 | aux_params=None if not aux else aux_params 100 | ) 101 | if model_type == "linknet": 102 | model = smp.Linknet( 103 | encoder_name=encoder, 104 | encoder_depth=5, 105 | encoder_weights=pretrain, 106 | decoder_use_batchnorm=True, 107 | in_channels=1, 108 | classes=1, 109 | activation='sigmoid', 110 | aux_params=None if not aux else aux_params 111 | ) 112 | if model_type == "fpn": 113 | model = smp.FPN( 114 | encoder_name=encoder, 115 | encoder_depth=5, 116 | encoder_weights=pretrain, 117 | decoder_pyramid_channels=256, 118 | decoder_segmentation_channels=128, 119 | decoder_merge_policy='add', 120 | decoder_dropout=0.2, 121 | in_channels=3, 122 | classes=1, 123 | activation='sigmoid', 124 | upsampling=4, 125 | aux_params=None if not aux else aux_params 126 | ) 127 | if model_type == "pspnet": 128 | model = smp.PSPNet( 129 | encoder_name=encoder, 130 | encoder_weights=pretrain, 131 | encoder_depth=3, 132 | psp_out_channels=512, 133 | psp_use_batchnorm=True, 134 | psp_dropout=0.2, 135 | in_channels=3, 136 | classes=1, 137 | activation='sigmoid', 138 | upsampling=8, 139 | aux_params=None if not aux else aux_params 140 | ) 141 | if model_type == "pan": 142 | model = smp.PAN( 143 | encoder_name=encoder, 144 | encoder_weights=pretrain, 145 | encoder_dilation=True, 146 | decoder_channels=32, 147 | in_channels=1, 148 | classes=1, 149 | activation='sigmoid', 150 | upsampling=4, 151 | aux_params=None if not aux else aux_params 152 | ) 153 | if model_type == "deeplabv3": 154 | model = smp.DeepLabV3( 155 | encoder_name=encoder, 156 | encoder_depth=5, 157 | encoder_weights=pretrain, 158 | decoder_channels=256, 159 | in_channels=3, 160 | classes=1, 161 | activation='sigmoid', 162 | upsampling=8, 163 | aux_params=None if not aux else aux_params 164 | ) 165 | if model_type == "deeplabv3+": 166 | model = smp.DeepLabV3Plus( 167 | encoder_name=encoder, 168 | encoder_depth=5, 169 | encoder_weights=pretrain, 170 | encoder_output_stride=16, 171 | decoder_channels=256, 172 | decoder_atrous_rates=(12, 24, 36), 173 | in_channels=1, 174 | classes=1, 175 | activation='sigmoid', 176 | upsampling=4, 177 | aux_params=None if not aux else aux_params 178 | ) 179 | # print(model) 180 | 181 | return model 182 | 183 | 184 | def define_loss(loss_type, weights=[3, 1, 2]): 185 | 186 | if loss_type == "jaccard": 187 | criterion = smp.utils.losses.JaccardLoss() 188 | if loss_type == "dice": 189 | criterion = smp.utils.losses.DiceLoss() 190 | if loss_type == "ce": 191 | criterion = smp.utils.losses.CrossEntropyLoss() 192 | if loss_type == "bcewithlogit": 193 | criterion = smp.utils.losses.BCEWithLogitsLoss() 194 | if loss_type == "unet": 195 | criterion = LossUNet(weights) 196 | if loss_type == "dcan": 197 | criterion = LossDCAN(weights) 198 | if loss_type == "dmtn": 199 | criterion = LossDMTN(weights) 200 | if loss_type == "psinet" or loss_type == "convmcd": 201 | # Both psinet and convmcd uses same mask,contour and distance loss function 202 | criterion = LossPsiNet(weights) 203 | 204 | return criterion 205 | 206 | 207 | def evaluate(device, epoch, model, data_loader, writer): 208 | model.eval() 209 | losses = [] 210 | dsces = [] 211 | start = time.perf_counter() 212 | with torch.no_grad(): 213 | 214 | for iter, data in enumerate(tqdm(data_loader)): 215 | 216 | _, inputs, targets, _, _, _ = data 217 | inputs = inputs.to(device) 218 | targets = targets.to(device) 219 | outputs = model(inputs) 220 | loss = F.nll_loss(outputs[0], targets.squeeze(1)) 221 | dsc_loss = smp.utils.losses.DiceLoss() 222 | 223 | preds = torch.argmax(outputs[0].exp(), dim=1) 224 | dsc = 1 - dsc_loss(preds, targets.squeeze(1)) 225 | 226 | losses.append(loss.item()) 227 | dsces.append(dsc.item()) 228 | 229 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch) 230 | 231 | return np.mean(losses), np.mean(dsces), time.perf_counter() - start 232 | 233 | 234 | def evaluate_modi(device, epoch, model, data_loader, writer, criterion, model_type): 235 | model.eval() 236 | losses = [] 237 | dsces = [] 238 | start = time.perf_counter() 239 | 240 | with torch.no_grad(): 241 | 242 | for iter, data in enumerate(tqdm(data_loader)): 243 | 244 | _, inputs, targets, _, _, _ = data 245 | inputs = inputs.to(device) 246 | targets = targets.to(device) 247 | outputs = model(inputs) 248 | if not isinstance(outputs, list): 249 | outputs = [outputs] 250 | preds = torch.argmax(outputs[0].exp(), dim=1) 251 | loss = criterion(preds, targets.squeeze(1)) 252 | 253 | dsc_loss = smp.utils.losses.DiceLoss() 254 | 255 | dsc = 1 - dsc_loss(preds, targets.squeeze(1)) 256 | losses.append(loss.item()) 257 | dsces.append(dsc.item()) 258 | 259 | writer.add_scalar("Dev_Loss", np.mean(losses), epoch) 260 | 261 | return np.mean(losses), np.mean(dsces), time.perf_counter() - start 262 | 263 | 264 | def visualize(device, epoch, model, data_loader, writer, val_batch_size, train=False): 265 | def save_image(image, tag, val_batch_size): 266 | image -= image.min() 267 | image /= image.max() 268 | grid = torchvision.utils.make_grid( 269 | image, nrow=int(np.sqrt(val_batch_size)), pad_value=0, padding=25 270 | ) 271 | writer.add_image(tag, grid, epoch) 272 | 273 | model.eval() 274 | with torch.no_grad(): 275 | for iter, data in enumerate(tqdm(data_loader)): 276 | _, inputs, targets, _, _, _ = data 277 | 278 | inputs = inputs.to(device) 279 | 280 | targets = targets.to(device) 281 | outputs = model(inputs) 282 | 283 | output_mask = outputs[0].detach().cpu().numpy() 284 | output_final = np.argmax(output_mask, axis=1).astype(float) 285 | 286 | output_final = torch.from_numpy(output_final).unsqueeze(1) 287 | 288 | if train == "True": 289 | save_image(targets.float(), "Target_train", val_batch_size) 290 | save_image(output_final, "Prediction_train", val_batch_size) 291 | else: 292 | save_image(targets.float(), "Target", val_batch_size) 293 | save_image(output_final, "Prediction", val_batch_size) 294 | 295 | break 296 | 297 | 298 | def generate_dataset(train_file_names, val_file_names, batch_size, val_batch_size, distance_type, do_clahe): 299 | train_mean, train_std = mean_and_std(train_file_names) 300 | 301 | train_dataset = DatasetImageMaskContourDist(train_file_names, distance_type, train_mean, train_std, do_clahe) 302 | valid_dataset = DatasetImageMaskContourDist(val_file_names, distance_type, train_mean, train_std, do_clahe) 303 | train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True) 304 | valid_loader = DataLoader(valid_dataset, batch_size=val_batch_size, num_workers=4, shuffle=True) 305 | 306 | return train_loader, valid_loader 307 | 308 | 309 | def create_train_arg_parser(): 310 | 311 | parser = argparse.ArgumentParser(description="train setup for segmentation") 312 | parser.add_argument("--train_path", type=str, help="path to training img jpg files") 313 | parser.add_argument("--val_path", type=str, help="path to validation img jpg files") 314 | parser.add_argument("--test_path", type=str, help="path to test img jpg files") 315 | parser.add_argument( 316 | "--train_type", 317 | type=str, 318 | default="cotraining", 319 | help="Select training type, including single classification, segmentation, cotraining and multitask. ") 320 | parser.add_argument( 321 | "--model_type", 322 | type=str, 323 | help="select model type: unet,dcan,dmtn,psinet,convmcd", 324 | ) 325 | parser.add_argument("--object_type", type=str, help="Dataset.") 326 | parser.add_argument( 327 | "--distance_type", 328 | type=str, 329 | default="dist_mask", 330 | help="select distance transform type - dist_mask,dist_contour,dist_signed", 331 | ) 332 | parser.add_argument("--batch_size", type=int, default=64, help="train batch size") 333 | parser.add_argument( 334 | "--val_batch_size", type=int, default=64, help="validation batch size" 335 | ) 336 | parser.add_argument("--num_epochs", type=int, default=500, help="number of epochs") 337 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 338 | parser.add_argument( 339 | "--use_pretrained", type=bool, default=False, help="Load pretrained checkpoint." 340 | ) 341 | parser.add_argument( 342 | "--pretrained_model_path", 343 | type=str, 344 | default=None, 345 | help="If use_pretrained is true, provide checkpoint.", 346 | ) 347 | parser.add_argument("--save_path", type=str, help="Model save path.") 348 | parser.add_argument("--encoder", type=str, default=None, help="encoder.") 349 | parser.add_argument("--pretrain", type=str, default=None, help="choose pretrain.") 350 | parser.add_argument("--loss_type", type=str, default=None, help="loss type.") 351 | parser.add_argument("--local_rank", default=0, type=int, help='node rank for distributed training') 352 | parser.add_argument("--LR_seg", default=1e-4, type=float, help='learning rate.') 353 | parser.add_argument("--LR_clf", default=1e-4, type=float, help='learning rate.') 354 | parser.add_argument("--use_scheduler", type=bool, default=False, help="use_scheduler.") 355 | parser.add_argument("--aux", type=bool, default=False, help="choose to do classification") 356 | parser.add_argument("--attention", type=str, default=None, help="decoder_attention_type.") 357 | parser.add_argument("--usenorm", type=bool, default=True, help="encoder use bn") 358 | parser.add_argument("--startpoint", type=int, default=60, help="start cotraining point.") 359 | parser.add_argument("--clahe", type=bool, default=False, help="do clahe.") 360 | parser.add_argument("--classnum", type=int, default=3, help="clf class number.") 361 | parser.add_argument("--fold", type=str, default=0, help="Fold for training.") 362 | return parser 363 | 364 | 365 | def create_validation_arg_parser(): 366 | 367 | parser = argparse.ArgumentParser(description="train setup for segmentation") 368 | parser.add_argument( 369 | "--model_type", 370 | type=str, 371 | help="select model type: unet,dcan,dmtn,psinet,convmcd", 372 | ) 373 | parser.add_argument( 374 | "--distance_type", 375 | type=str, 376 | default="dist_signed", 377 | help="select distance transform type - dist_mask,dist_contour,dist_signed", 378 | ) 379 | parser.add_argument("--train_path", type=str, help="path to train img jpg files") 380 | parser.add_argument("--val_path", type=str, help="path to validation img jpg files") 381 | parser.add_argument("--test_path", type=str, help="path to test img jpg files") 382 | parser.add_argument("--model_file", type=str, help="model_file") 383 | parser.add_argument("--save_path", type=str, help="results save path.") 384 | parser.add_argument("--cuda_no", type=int, default=0, help="cuda number") 385 | parser.add_argument("--encoder", type=str, default=None, help="encoder.") 386 | parser.add_argument("--pretrain", type=str, default=None, help="choose pretrain.") 387 | parser.add_argument("--attention", type=str, default=None, help="decoder_attention_type.") 388 | parser.add_argument("--val_batch_size", type=int, default=32, help="validation batch size") 389 | parser.add_argument("--usenorm", type=bool, default=True, help="encoder use bn") 390 | parser.add_argument("--clahe", type=bool, default=False, help="do clahe.") 391 | parser.add_argument("--classnum", type=int, default=3, help="clf class number.") 392 | return parser 393 | --------------------------------------------------------------------------------