├── README.md ├── data ├── build_dataset.py ├── optic.py ├── polyp.py └── skin.py ├── main.py ├── models ├── __init__.py ├── build_model.py ├── module.py └── mymodel.py ├── opt.py └── utils ├── evaluate.py ├── loss.py ├── mytransforms.py └── save_img.py /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Correction Learning for Semi-Supervised Biomedical Image Segmentation 2 | 3 | ## Introduction 4 | 5 | This repository contains the PyTorch implementation of: 6 | 7 | Self-Supervised Correction Learning for Semi-Supervised Biomedical Image Segmentation, MICCAI 2021. 8 | 9 | ## Requirements 10 | 11 | * torch 12 | * torchvision 13 | * tqdm 14 | * opencv 15 | * scipy 16 | * skimage 17 | * PIL 18 | * numpy 19 | 20 | ## Usage 21 | 22 | #### 1. Training 23 | 24 | ```bash 25 | python main.py --mode train --manner semi --ratio 2 26 | --root {project_path} --dataset polyp --polyp {data_path} 27 | 28 | ``` 29 | 30 | 31 | 32 | #### 2. Inference 33 | 34 | ```bash 35 | python main.py --mode test --manner test --load_ckpt checkpoint 36 | --root {project_path} --dataset polyp --polyp {data_path} 37 | ``` 38 | 39 | 40 | 41 | ## Citation 42 | 43 | If you feel this work is helpful, please cite our paper 44 | 45 | ``` 46 | @inproceedings{zhang2021self, 47 | title={Self-supervised Correction Learning for Semi-supervised Biomedical Image Segmentation}, 48 | author={Zhang, Ruifei and Liu, Sishuo and Yu, Yizhou and Li, Guanbin}, 49 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 50 | pages={134--144}, 51 | year={2021}, 52 | organization={Springer} 53 | } 54 | ``` 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /data/build_dataset.py: -------------------------------------------------------------------------------- 1 | from .polyp import PolypDataSet 2 | from .skin import SkinDataSet 3 | from .optic import OpticDataSet 4 | 5 | 6 | def build_dataset(args): 7 | if args.manner == 'test': 8 | if args.dataset == 'polyp': 9 | test_data = PolypDataSet(args.root, args.polyp, mode='test') 10 | elif args.dataset == 'skin': 11 | test_data = SkinDataSet(args.root, args.skin, mode='test') 12 | elif args.dataset == 'optic': 13 | test_data = OpticDataSet(args.root, args.optic, mode='test') 14 | return test_data 15 | else: 16 | if args.dataset == 'polyp': 17 | train_data = PolypDataSet(args.root, args.polyp, mode='train', ratio=args.ratio, sign='label') 18 | valid_data = PolypDataSet(args.root, args.polyp, mode='valid') 19 | test_data = PolypDataSet(args.root, args.polyp, mode='test') 20 | train_u_data = None 21 | if args.manner == 'semi': 22 | train_u_data = PolypDataSet(args.root, args.polyp, mode='train', ratio=args.ratio, sign='unlabel') 23 | elif args.dataset == 'skin': 24 | train_data = SkinDataSet(args.root, args.skin, mode='train', ratio=args.ratio, sign='label') 25 | valid_data = None 26 | test_data = SkinDataSet(args.root, args.skin, mode='test') 27 | train_u_data = None 28 | if args.manner == 'semi': 29 | train_u_data = SkinDataSet(args.root, args.skin, mode='train', ratio=args.ratio, sign='unlabel') 30 | elif args.dataset == 'optic': 31 | train_data = OpticDataSet(args.root, args.optic, mode='train', ratio=args.ratio, sign='label') 32 | valid_data = None 33 | test_data = OpticDataSet(args.root, args.optic, mode='test') 34 | train_u_data = None 35 | if args.manner == 'semi': 36 | train_u_data = OpticDataSet(args.root, args.optic, mode='train', ratio=args.ratio, sign='unlabel') 37 | 38 | return train_data, train_u_data, valid_data, test_data 39 | 40 | -------------------------------------------------------------------------------- /data/optic.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | import os 4 | from utils.mytransforms import * 5 | import json 6 | 7 | 8 | class OpticDataSet(Dataset): 9 | def __init__(self, root, data_dir, mode='train', ratio=10, sign='label', transform=None): 10 | super(OpticDataSet, self).__init__() 11 | self.mode = mode 12 | self.sign = sign 13 | root_path = os.path.join(root, data_dir) 14 | imgfile = os.path.join(root_path, mode + '_optic.json') 15 | with open(imgfile, 'r') as f: 16 | imglist = json.load(f) 17 | 18 | imglist = [os.path.join(root_path, mode, img) for img in imglist] 19 | 20 | if mode == 'train' and ratio < 10: 21 | split_file_path = os.path.join(root_path, "train_split_10.json") 22 | with open(split_file_path, 'r') as f: 23 | split_10_list = json.load(f) 24 | 25 | if sign == 'label': 26 | Limglist = [] 27 | for i in range(ratio): 28 | Limglist += split_10_list[str(i)] 29 | self.imglist = [imglist[index] for index in Limglist] 30 | if sign == 'unlabel': 31 | Uimglist = [] 32 | for i in range(ratio, 10): 33 | Uimglist += split_10_list[str(i)] 34 | self.imglist = [imglist[index] for index in Uimglist] 35 | else: 36 | self.imglist = imglist 37 | 38 | if transform is None: 39 | if mode == 'train' and sign == 'label': 40 | transform = transforms.Compose([ 41 | Resize((320, 320)), 42 | RandomHorizontalFlip(), 43 | RandomVerticalFlip(), 44 | RandomRotation(90), 45 | RandomZoom((0.9, 1.1)), 46 | RandomCrop((256, 256)), 47 | ToTensor() 48 | ]) 49 | elif mode == 'train' and sign == 'unlabel': 50 | transform = transforms.Compose([ 51 | transforms.Resize((320, 320)), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.RandomVerticalFlip(), 54 | transforms.RandomRotation(90), 55 | transforms.RandomCrop((256, 256)), 56 | transforms.ToTensor() 57 | ]) 58 | elif mode == 'valid' or mode == 'test': 59 | transform = transforms.Compose([ 60 | Resize((320, 320)), 61 | ToTensor() 62 | ]) 63 | self.transform = transform 64 | 65 | def __getitem__(self, index): 66 | if self.mode == 'train' and self.sign == 'unlabel': 67 | img_path = self.imglist[index] 68 | img = Image.open(img_path).convert('RGB') 69 | if self.transform: 70 | return self.transform(img) 71 | else: 72 | img_path = self.imglist[index] 73 | gt_path = img_path.replace('images', 'masks') 74 | gt_path = gt_path.split('.')[0] + "-exp1.bmp" 75 | img = Image.open(img_path).convert('RGB') 76 | gt = Image.open(gt_path).convert('L') 77 | data = {'image': img, 'label': gt} 78 | if self.transform: 79 | data = self.transform(data) 80 | 81 | return data 82 | 83 | def __len__(self): 84 | return len(self.imglist) 85 | -------------------------------------------------------------------------------- /data/polyp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | import os 4 | from utils.mytransforms import * 5 | import json 6 | 7 | 8 | class PolypDataSet(Dataset): 9 | def __init__(self, root, data_dir, mode='train', ratio=10, sign='label', transform=None): 10 | super(PolypDataSet, self).__init__() 11 | self.mode = mode 12 | self.sign = sign 13 | root_path = os.path.join(root, data_dir) 14 | imgfile = os.path.join(root_path, mode + '_polyp.json') 15 | with open(imgfile, 'r') as f: 16 | imglist = json.load(f) 17 | 18 | imglist = [os.path.join(root_path, mode, img) for img in imglist] 19 | 20 | if mode == 'train' and ratio < 10: 21 | split_file_path = os.path.join(root_path, "train_split_10.json") 22 | with open(split_file_path, 'r') as f: 23 | split_10_list = json.load(f) 24 | 25 | if sign == 'label': 26 | Limglist = [] 27 | for i in range(ratio): 28 | Limglist += split_10_list[str(i)] 29 | self.imglist = [imglist[index] for index in Limglist] 30 | if sign == 'unlabel': 31 | Uimglist = [] 32 | for i in range(ratio, 10): 33 | Uimglist += split_10_list[str(i)] 34 | self.imglist = [imglist[index] for index in Uimglist] 35 | else: 36 | self.imglist = imglist 37 | 38 | if transform is None: 39 | if mode == 'train' and sign == 'label': 40 | transform = transforms.Compose([ 41 | Resize((320, 320)), 42 | RandomHorizontalFlip(), 43 | RandomVerticalFlip(), 44 | RandomRotation(90), 45 | RandomZoom((0.9, 1.1)), 46 | RandomCrop((256, 256)), 47 | ToTensor() 48 | ]) 49 | elif mode == 'train' and sign == 'unlabel': 50 | transform = transforms.Compose([ 51 | transforms.Resize((320, 320)), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.RandomVerticalFlip(), 54 | transforms.RandomRotation(90), 55 | transforms.RandomCrop((256, 256)), 56 | transforms.ToTensor() 57 | ]) 58 | elif mode == 'valid' or mode == 'test': 59 | transform = transforms.Compose([ 60 | Resize((320, 320)), 61 | ToTensor() 62 | ]) 63 | self.transform = transform 64 | 65 | def __getitem__(self, index): 66 | if self.mode == 'train' and self.sign == 'unlabel': 67 | img_path = self.imglist[index] 68 | img = Image.open(img_path).convert('RGB') 69 | if self.transform: 70 | return self.transform(img) 71 | else: 72 | img_path = self.imglist[index] 73 | gt_path = img_path.replace('images', 'masks') 74 | img = Image.open(img_path).convert('RGB') 75 | gt = Image.open(gt_path).convert('L') 76 | data = {'image': img, 'label': gt} 77 | if self.transform: 78 | data = self.transform(data) 79 | 80 | return data 81 | 82 | def __len__(self): 83 | return len(self.imglist) 84 | -------------------------------------------------------------------------------- /data/skin.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | import os 4 | from utils.mytransforms import * 5 | import json 6 | 7 | 8 | class SkinDataSet(Dataset): 9 | def __init__(self, root, data_dir, mode='train', ratio=10, sign='label', transform=None): 10 | super(SkinDataSet, self).__init__() 11 | self.mode = mode 12 | self.sign = sign 13 | root_path = os.path.join(root, data_dir) 14 | imgfile = os.path.join(root_path, mode + '_skin.json') 15 | with open(imgfile, 'r') as f: 16 | imglist = json.load(f) 17 | 18 | imglist = [os.path.join(root_path, mode, img) for img in imglist] 19 | 20 | if mode == 'train' and ratio < 10: 21 | split_file_path = os.path.join(root_path, "train_split_10.json") 22 | with open(split_file_path, 'r') as f: 23 | split_10_list = json.load(f) 24 | 25 | if sign == 'label': 26 | Limglist = [] 27 | for i in range(ratio): 28 | Limglist += split_10_list[str(i)] 29 | self.imglist = [imglist[index] for index in Limglist] 30 | if sign == 'unlabel': 31 | Uimglist = [] 32 | for i in range(ratio, 10): 33 | Uimglist += split_10_list[str(i)] 34 | self.imglist = [imglist[index] for index in Uimglist] 35 | else: 36 | self.imglist = imglist 37 | 38 | if transform is None: 39 | if mode == 'train' and sign == 'label': 40 | transform = transforms.Compose([ 41 | Resize((320, 320)), 42 | RandomHorizontalFlip(), 43 | RandomVerticalFlip(), 44 | RandomRotation(90), 45 | RandomZoom((0.9, 1.1)), 46 | RandomCrop((256, 256)), 47 | ToTensor() 48 | ]) 49 | elif mode == 'train' and sign == 'unlabel': 50 | transform = transforms.Compose([ 51 | transforms.Resize((320, 320)), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.RandomVerticalFlip(), 54 | transforms.RandomRotation(90), 55 | transforms.RandomCrop((256, 256)), 56 | transforms.ToTensor() 57 | ]) 58 | elif mode == 'valid' or mode == 'test': 59 | transform = transforms.Compose([ 60 | Resize((320, 320)), 61 | ToTensor() 62 | ]) 63 | self.transform = transform 64 | 65 | def __getitem__(self, index): 66 | if self.mode == 'train' and self.sign == 'unlabel': 67 | img_path = self.imglist[index] 68 | img = Image.open(img_path).convert('RGB') 69 | if self.transform: 70 | return self.transform(img) 71 | else: 72 | img_path = self.imglist[index] 73 | gt_path = img_path.replace('images', 'masks') 74 | gt_path = gt_path.split(".")[0] + "_Segmentation.png" 75 | img = Image.open(img_path).convert('RGB') 76 | gt = Image.open(gt_path).convert('L') 77 | data = {'image': img, 'label': gt} 78 | if self.transform: 79 | data = self.transform(data) 80 | 81 | return data 82 | 83 | def __len__(self): 84 | return len(self.imglist) 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import DataLoader 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | from itertools import cycle 8 | from data.build_dataset import build_dataset 9 | from models.build_model import build_model 10 | from utils.evaluate import evaluate 11 | from opt import args 12 | from utils.loss import BceDiceLoss 13 | import math 14 | 15 | 16 | def DeepSupSeg(pred, gt): 17 | 18 | d0, d1, d2, d3, d4 = pred ## 19 | 20 | criterion = BceDiceLoss() 21 | 22 | loss0 = criterion(d0, gt) 23 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) ## 24 | loss1 = criterion(d1, gt) 25 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 26 | loss2 = criterion(d2, gt) 27 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 28 | loss3 = criterion(d3, gt) 29 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 30 | loss4 = criterion(d4, gt) 31 | 32 | return loss0 + loss1 + loss2 + loss3 + loss4 33 | 34 | 35 | def DeepSupInp(pred, gt, mask): 36 | 37 | criterion = nn.L1Loss() 38 | loss = 0 39 | for i in range(len(pred)): 40 | select_pred = torch.masked_select(pred[i], mask[i]>0.5) 41 | select_target = torch.masked_select(gt, mask[i]>0.5) 42 | loss += criterion(select_pred, select_target) 43 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 44 | return loss 45 | 46 | 47 | def SupInp(pred, gt, mask): 48 | 49 | criterion = nn.L1Loss() 50 | 51 | select_pred = torch.masked_select(pred, mask>0.5) 52 | select_target = torch.masked_select(gt, mask>0.5) 53 | loss = criterion(select_pred, select_target) 54 | 55 | return loss 56 | 57 | 58 | def lr_poly(base_lr, iter, max_iter, power): 59 | return base_lr * ((1-float(iter)/max_iter)**power) 60 | 61 | 62 | def adjust_lr_rate(argsimizer, iter, total_batch): 63 | lr = lr_poly(args.lr, iter, args.nEpoch*total_batch, args.power) 64 | argsimizer.param_groups[0]['lr'] = lr 65 | return lr 66 | 67 | 68 | def train(): 69 | 70 | 71 | """load data""" 72 | train_l_data, train_u_data, valid_data, test_data = build_dataset(args) 73 | 74 | 75 | train_l_dataloader = DataLoader(train_l_data, args.batch_size, shuffle=True, num_workers=args.num_workers) 76 | #train_u_dataloader = DataLoader(train_u_data, args.batch_size, shuffle=True, num_workers=args.num_workers) 77 | valid_sign = False 78 | if valid_data is not None: 79 | valid_sign = True 80 | valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=args.num_workers) 81 | val_total_batch = int(len(valid_data) / 1) 82 | 83 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.num_workers) 84 | test_total_batch = int(len(test_data) / 1) 85 | """load model""" 86 | model = build_model(args) 87 | 88 | optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mt, weight_decay=args.weight_decay) 89 | 90 | # train 91 | print('\n---------------------------------') 92 | print('Start training') 93 | print('---------------------------------\n') 94 | 95 | F1_best, F1_test_best = 0, 0 96 | 97 | for epoch in range(args.nEpoch): 98 | model.train() 99 | 100 | print("Epoch: {}".format(epoch)) 101 | total_batch = math.ceil(len(train_l_data) / args.batch_size) 102 | bar = tqdm(enumerate(train_l_dataloader), total=total_batch) 103 | for batch_id, data_l in bar: 104 | #data_l, data_u = next(loader) 105 | 106 | #total_batch = len(train_u_dataloader) 107 | #total_batch = len(train_l_dataloader) 108 | itr = total_batch * epoch + batch_id 109 | 110 | img_l, gt = data_l['image'], data_l['label'] 111 | 112 | if args.GPUs: 113 | img_l = img_l.cuda() 114 | gt = gt.cuda() 115 | 116 | 117 | optim.zero_grad() 118 | 119 | pred_l = model(img_l) 120 | mask_l = pred_l[:5] 121 | inp_l = pred_l[5:] 122 | loss_l_seg = DeepSupSeg(mask_l, gt) 123 | 124 | loss_l = loss_l_seg 125 | 126 | 127 | loss = loss_l 128 | loss.backward() 129 | 130 | optim.step() 131 | 132 | 133 | adjust_lr_rate(optim, itr, total_batch) 134 | 135 | if valid_sign == True: 136 | recall, specificity, precision, F1, F2, \ 137 | ACC_overall, IoU_poly, IoU_bg, IoU_mean = evaluate(model, valid_dataloader, val_total_batch, args) 138 | 139 | print("Valid Result:") 140 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' \ 141 | % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean)) 142 | 143 | if (F1 > F1_best): 144 | F1_best = F1 145 | torch.save(model.state_dict(), args.root + "checkpoint/exp" + str(args.expID) + "/ck_%.4f.pth" % F1) 146 | 147 | else: 148 | recall_test, specificity_test, precision_test, F1_test, F2_test, \ 149 | ACC_overall_test, IoU_poly_test, IoU_bg_test, IoU_mean_test = evaluate(model, test_dataloader, test_total_batch, args) 150 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' \ 151 | % (recall_test, specificity_test, precision_test, F1_test, F2_test, ACC_overall_test, IoU_poly_test, IoU_bg_test, IoU_mean_test)) 152 | 153 | if (F1_test > F1_test_best): 154 | F1_test_best = F1_test 155 | torch.save(model.state_dict(), args.root + "checkpoint/exp" + str(args.expID) + "/ck_%.4f.pth" % F1_test) 156 | 157 | 158 | 159 | 160 | 161 | def train_semi(): 162 | 163 | 164 | """load data""" 165 | train_l_data, train_u_data, valid_data, test_data = build_dataset(args) 166 | 167 | 168 | train_l_dataloader = DataLoader(train_l_data, args.batch_size, shuffle=True, num_workers=args.num_workers) 169 | train_u_dataloader = DataLoader(train_u_data, args.batch_size, shuffle=True, num_workers=args.num_workers) 170 | valid_sign = False 171 | if valid_data is not None: 172 | valid_sign = True 173 | valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=args.num_workers) 174 | val_total_batch = int(len(valid_data) / 1) 175 | 176 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.num_workers) 177 | test_total_batch = int(len(test_data) / 1) 178 | """load model""" 179 | model = build_model(args) 180 | 181 | optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mt, weight_decay=args.weight_decay) 182 | 183 | # train 184 | print('\n---------------------------------') 185 | print('Start training') 186 | print('---------------------------------\n') 187 | 188 | F1_best, F1_test_best = 0, 0 189 | 190 | for epoch in range(args.nEpoch): 191 | model.train() 192 | 193 | print("Epoch: {}".format(epoch)) 194 | 195 | loader = iter(zip(cycle(train_l_dataloader), train_u_dataloader)) 196 | #loader = iter(zip(train_l_dataloader, cycle(train_u_dataloader))) 197 | bar = tqdm(range(len(train_u_dataloader))) 198 | #bar = tqdm(range(len(train_l_dataloader))) 199 | for batch_id in bar: 200 | data_l, data_u = next(loader) 201 | 202 | total_batch = len(train_u_dataloader) 203 | #total_batch = len(train_l_dataloader) 204 | itr = total_batch * epoch + batch_id 205 | 206 | img_l, gt = data_l['image'], data_l['label'] 207 | img_u = data_u 208 | 209 | if args.GPUs: 210 | img_l = img_l.cuda() 211 | gt = gt.cuda() 212 | img_u = img_u.cuda() 213 | 214 | optim.zero_grad() 215 | 216 | pred_l = model(img_l) 217 | mask_l = pred_l[:5] 218 | inp_l = pred_l[5:] 219 | loss_l_seg = DeepSupSeg(mask_l, gt) 220 | 221 | loss_l = loss_l_seg 222 | 223 | pred_u = model(img_u) 224 | mask_u = pred_u[:5] 225 | inp_u = pred_u[5:] 226 | loss_u = DeepSupInp(inp_u, img_u.detach(), [m.detach() for m in mask_u]) 227 | 228 | 229 | loss = 2 * loss_l + loss_u 230 | loss.backward() 231 | 232 | optim.step() 233 | 234 | 235 | adjust_lr_rate(optim, itr, total_batch) 236 | 237 | if valid_sign == True: 238 | recall, specificity, precision, F1, F2, \ 239 | ACC_overall, IoU_poly, IoU_bg, IoU_mean = evaluate(model, valid_dataloader, val_total_batch, args) 240 | 241 | print("Valid Result:") 242 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' \ 243 | % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean)) 244 | 245 | if (F1 > F1_best): 246 | F1_best = F1 247 | torch.save(model.state_dict(), args.root + "checkpoint/exp" + str(args.expID) + "/ck_%.4f.pth" % F1) 248 | 249 | else: 250 | recall_test, specificity_test, precision_test, F1_test, F2_test, \ 251 | ACC_overall_test, IoU_poly_test, IoU_bg_test, IoU_mean_test = evaluate(model, test_dataloader, test_total_batch, args) 252 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' \ 253 | % (recall_test, specificity_test, precision_test, F1_test, F2_test, ACC_overall_test, IoU_poly_test, IoU_bg_test, IoU_mean_test)) 254 | 255 | if (F1_test > F1_test_best): 256 | F1_test_best = F1_test 257 | torch.save(model.state_dict(), args.root + "checkpoint/exp" + str(args.expID) + "/ck_%.4f.pth" % F1_test) 258 | 259 | 260 | 261 | def test(): 262 | 263 | print('loading data......') 264 | test_data = build_dataset(args) 265 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.num_workers) 266 | total_batch = int(len(test_data) / 1) 267 | model = build_model(args) 268 | 269 | model.eval() 270 | 271 | recall, specificity, precision, F1, F2, \ 272 | ACC_overall, IoU_poly, IoU_bg, IoU_mean = evaluate(model, test_dataloader, total_batch, args) 273 | 274 | print("Test Result:") 275 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' \ 276 | % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean)) 277 | 278 | return recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean 279 | 280 | 281 | if __name__ == '__main__': 282 | 283 | os.environ['CUDA_VISIBLE_DEVICES'] = args.GPUs 284 | if args.manner == 'full': 285 | print('---{}-Seg Train---'.format(args.dataset)) 286 | train() 287 | elif args.manner =='semi': 288 | print('---{}-seg Semi-Train--'.format(args.dataset)) 289 | train_semi() 290 | elif args.manner == 'test': 291 | print('---{}-Seg Test---'.format(args.dataset)) 292 | test() 293 | print('Done') 294 | 295 | 296 | 297 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . mymodel import MyModel 2 | -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models 3 | import os 4 | 5 | 6 | def build_model(args): 7 | model = getattr(models, args.model)(args.nclasses) 8 | # model = nn.DataParallel(model) 9 | if args.GPUs: 10 | model.cuda() 11 | torch.backends.cudnn.benchmark = True 12 | 13 | if args.load_ckpt is not None: 14 | 15 | model_dict = model.state_dict() 16 | load_ckpt_path = os.path.join(args.root, "checkpoint/exp" + str(args.expID), args.load_ckpt + '.pth') 17 | assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' 18 | print('Loading checkpoint......') 19 | checkpoint = torch.load(load_ckpt_path) 20 | new_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()} 21 | model_dict.update(new_dict) 22 | model.load_state_dict(model_dict) 23 | print('Done') 24 | 25 | return model 26 | -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GFF(nn.Module): 6 | 7 | def __init__(self, in_channels, kernel_size=3, padding=1): 8 | super().__init__() 9 | 10 | self.reset_gate = nn.Conv2d(in_channels * 2, in_channels, kernel_size=kernel_size, padding=padding) 11 | self.select_gate = nn.Conv2d(in_channels * 2, in_channels, kernel_size=kernel_size, padding=padding) 12 | self.outconv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=kernel_size, padding=padding) 13 | self.activate = nn.ReLU(inplace=True) 14 | 15 | nn.init.orthogonal_(self.reset_gate.weight) 16 | nn.init.orthogonal_(self.select_gate.weight) 17 | nn.init.orthogonal_(self.outconv.weight) 18 | nn.init.constant_(self.reset_gate.bias, 0.) 19 | nn.init.constant_(self.select_gate.bias, 0.) 20 | nn.init.constant_(self.outconv.bias, 0.) 21 | 22 | def forward(self, seg_feat, inp_feat): 23 | 24 | concat_inputs = torch.cat([seg_feat, inp_feat], dim=1) 25 | reset = torch.sigmoid(self.reset_gate(concat_inputs)) 26 | select = torch.sigmoid(self.select_gate(concat_inputs)) 27 | out_inputs = self.activate(self.outconv(torch.cat([seg_feat, inp_feat * reset], dim=1))) 28 | out_final = seg_feat * (1 - select) + out_inputs * select 29 | 30 | return out_final 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/mymodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | from .module import GFF 6 | 7 | 8 | def cat(x1, x2, x3=None, dim=1): 9 | if x3 == None: 10 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 11 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 12 | 13 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 14 | diffY // 2, diffY - diffY // 2]) 15 | 16 | x = torch.cat([x1, x2], dim) 17 | return x 18 | else: 19 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 20 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 21 | 22 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 23 | diffY // 2, diffY - diffY // 2]) 24 | 25 | x = torch.cat([x1, x2], dim) 26 | diffY = torch.tensor([x.size()[2] - x3.size()[2]]) 27 | diffX = torch.tensor([x.size()[3] - x3.size()[3]]) 28 | x3 = F.pad(x3, [diffX // 2, diffX - diffX // 2, 29 | diffY // 2, diffY - diffY // 2]) 30 | x = torch.cat([x, x3], dim=1) 31 | return x 32 | 33 | 34 | class ConvBlock(nn.Module): 35 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 36 | super(ConvBlock, self).__init__() 37 | self.conv = nn.Conv2d(in_channels, out_channels, 38 | kernel_size=kernel_size, 39 | stride=stride, 40 | padding=padding) 41 | self.bn = nn.BatchNorm2d(out_channels) 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.bn(x) 47 | x = self.relu(x) 48 | return x 49 | 50 | 51 | class DecoderBlock(nn.Module): 52 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, transpose=False): 53 | super(DecoderBlock, self).__init__() 54 | 55 | self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size, 56 | stride=stride, padding=padding) 57 | 58 | self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size, 59 | stride=stride, padding=padding) 60 | 61 | if transpose: 62 | self.upsample = nn.Sequential( 63 | nn.ConvTranspose2d(in_channels // 4, 64 | in_channels // 4, 65 | kernel_size=3, 66 | stride=2, 67 | padding=1, 68 | output_padding=1, 69 | bias=False), 70 | nn.BatchNorm2d(in_channels // 4), 71 | nn.ReLU(inplace=True) 72 | ) 73 | else: 74 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 75 | 76 | def forward(self, x): 77 | x = self.conv1(x) 78 | x = self.conv2(x) 79 | x = self.upsample(x) 80 | 81 | return x 82 | 83 | 84 | class SideoutBlock(nn.Module): 85 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 86 | super(SideoutBlock, self).__init__() 87 | 88 | self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size, 89 | stride=stride, padding=padding) 90 | 91 | self.dropout = nn.Dropout2d(0.1) 92 | 93 | self.conv2 = nn.Conv2d(in_channels // 4, out_channels, 1) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.dropout(x) 98 | x = self.conv2(x) 99 | 100 | return x 101 | 102 | 103 | class Encoder(nn.Module): 104 | def __init__(self, in_channels): 105 | super(Encoder, self).__init__() 106 | resnet = models.resnet34(pretrained=False) 107 | resnet.load_state_dict(torch.load("../../backbone/resnet34.pth")) 108 | 109 | if in_channels == 3: 110 | self.encoder1_conv = resnet.conv1 111 | else: 112 | self.encoder1_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 113 | 114 | self.encoder1_bn = resnet.bn1 115 | self.encoder1_relu = resnet.relu 116 | self.maxpool = resnet.maxpool 117 | self.encoder2 = resnet.layer1 118 | self.encoder3 = resnet.layer2 119 | self.encoder4 = resnet.layer3 120 | self.encoder5 = resnet.layer4 121 | 122 | def forward(self, x): 123 | e1 = self.encoder1_conv(x) 124 | e1 = self.encoder1_bn(e1) 125 | e1 = self.encoder1_relu(e1) 126 | e1_maxpool = self.maxpool(e1) 127 | 128 | e2 = self.encoder2(e1_maxpool) 129 | e3 = self.encoder3(e2) 130 | e4 = self.encoder4(e3) 131 | e5 = self.encoder5(e4) 132 | return e1, e2, e3, e4, e5 133 | 134 | 135 | class MyModel(nn.Module): 136 | def __init__(self, 137 | num_classes=1, 138 | in_channels=3 139 | ): 140 | super().__init__() 141 | 142 | self.encoder = Encoder(in_channels=in_channels) 143 | 144 | # seg-Decoder 145 | self.segDecoder5 = DecoderBlock(512, 512) 146 | 147 | self.segDecoder4 = DecoderBlock(512 + 256, 256) 148 | 149 | self.segDecoder3 = DecoderBlock(256 + 128, 128) 150 | 151 | self.segDecoder2 = DecoderBlock(128 + 64, 64) 152 | 153 | self.segDecoder1 = DecoderBlock(64 + 64, 64) 154 | 155 | self.segSideout5 = SideoutBlock(512, 1) 156 | self.segSideout4 = SideoutBlock(256, 1) 157 | self.segSideout3 = SideoutBlock(128, 1) 158 | self.segSideout2 = SideoutBlock(64, 1) 159 | 160 | self.segconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1), 161 | nn.Dropout2d(0.1), 162 | nn.Conv2d(32, num_classes, 1)) 163 | 164 | # inpaint-Decoder 165 | 166 | self.inpDecoder5 = DecoderBlock(512, 512, transpose=True) 167 | self.inpDecoder4 = DecoderBlock(512 + 256, 256, transpose=True) 168 | self.inpDecoder3 = DecoderBlock(256 + 128, 128, transpose=True) 169 | self.inpDecoder2 = DecoderBlock(128 + 64, 64, transpose=True) 170 | self.inpDecoder1 = DecoderBlock(64 + 64, 64, transpose=True) 171 | 172 | self.inpSideout5 = SideoutBlock(512, 3) 173 | self.inpSideout4 = SideoutBlock(256, 3) 174 | self.inpSideout3 = SideoutBlock(128, 3) 175 | self.inpSideout2 = SideoutBlock(64, 3) 176 | 177 | self.gate1 = GFF(64) 178 | self.gate2 = GFF(64) 179 | self.gate3 = GFF(128) 180 | self.gate4 = GFF(256) 181 | self.gate5 = GFF(512) 182 | 183 | self.inpconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1), 184 | nn.Conv2d(32, in_channels, 1)) 185 | 186 | def forward(self, x): 187 | ori = x 188 | bs, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 189 | """Seg-branch""" 190 | e1, e2, e3, e4, e5 = self.encoder(x) 191 | d5 = self.segDecoder5(e5) 192 | d4 = self.segDecoder4(cat(d5, e4)) 193 | d3 = self.segDecoder3(cat(d4, e3)) 194 | d2 = self.segDecoder2(cat(d3, e2)) 195 | d1 = self.segDecoder1(cat(d2, e1)) 196 | 197 | mask = self.segconv(d1) 198 | mask = torch.sigmoid(mask) 199 | #sidecopy = sidemask.detach().clone() 200 | mask_binary = (mask > 0.5).float() 201 | mask_rbinary = (mask < 0.5).float() 202 | cut_ori = ori * mask_rbinary 203 | inpe1, inpe2, inpe3, inpe4, inpe5 = self.encoder(cut_ori) 204 | e1, e2, e3, e4, e5 = self.encoder(x) 205 | 206 | ge1 = self.gate1(e1, inpe1) 207 | ge2 = self.gate2(e2, inpe2) 208 | ge3 = self.gate3(e3, inpe3) 209 | ge4 = self.gate4(e4, inpe4) 210 | ge5 = self.gate5(e5, inpe5) 211 | 212 | d5 = self.segDecoder5(ge5) 213 | sidemask5 = self.segSideout5(d5) 214 | sidemask5 = torch.sigmoid(sidemask5) 215 | 216 | d4 = self.segDecoder4(cat(d5, ge4)) 217 | sidemask4 = self.segSideout4(d4) 218 | sidemask4 = torch.sigmoid(sidemask4) 219 | 220 | d3 = self.segDecoder3(cat(d4, ge3)) 221 | sidemask3 = self.segSideout3(d3) 222 | sidemask3 = torch.sigmoid(sidemask3) 223 | 224 | d2 = self.segDecoder2(cat(d3, ge2)) 225 | sidemask2 = self.segSideout2(d2) 226 | sidemask2 = torch.sigmoid(sidemask2) 227 | 228 | d1 = self.segDecoder1(cat(d2, ge1)) 229 | 230 | mask = self.segconv(d1) 231 | #maskh, maskw = mask.shape[2], mask.shape[3] 232 | #if maskh != H or maskw != W: 233 | # mask = F.interpolate(mask, size=(H, W), mode='bilinear') 234 | mask = torch.sigmoid(mask) 235 | 236 | 237 | inpd5 = self.inpDecoder5(ge5) 238 | inpimg5 = self.inpSideout5(inpd5) 239 | inpd4 = self.inpDecoder4(cat(inpd5, ge4)) 240 | inpimg4 = self.inpSideout4(inpd4) 241 | inpd3 = self.inpDecoder3(cat(inpd4, ge3)) 242 | inpimg3 = self.inpSideout3(inpd3) 243 | inpd2 = self.inpDecoder2(cat(inpd3, ge2)) 244 | inpimg2 = self.inpSideout2(inpd2) 245 | inpd1 = self.inpDecoder1(cat(inpd2, ge1)) 246 | inpimg = self.inpconv(inpd1) 247 | mask_binary = (mask > 0.5).float() 248 | mask_rbinary = (mask < 0.5).float() 249 | inpimg = inpimg * mask_binary + ori * mask_rbinary 250 | 251 | return mask, sidemask2, sidemask3, sidemask4, sidemask5, inpimg, inpimg2, inpimg3, inpimg4, inpimg5 252 | 253 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | parse = argparse.ArgumentParser(description='PyTorch Semi-Medical-Seg Implement') 6 | 7 | "-------------------GPU option----------------------------" 8 | parse.add_argument('--GPUs', type=str, default='0') 9 | 10 | "-------------------data option--------------------------" 11 | parse.add_argument('--root', type=str, default='/data/ruifei/SemiMedSeg/') 12 | parse.add_argument('--dataset', type=str, default='polyp', choices=['polyp', 'skin', 'optic']) 13 | parse.add_argument('--ratio', type=int, default=10) 14 | parse.add_argument('--polyp', type=str, default='data_polyp') 15 | parse.add_argument('--skin', type=str, default='data_skin') 16 | parse.add_argument('--optic', type=str, default='data_optic') 17 | 18 | 19 | 20 | "-------------------training option-----------------------" 21 | parse.add_argument('--manner', type=str, default='full', choices=['full', 'semi', 'test']) 22 | parse.add_argument('--mode', type=str, default='train') 23 | parse.add_argument('--nEpoch', type=int, default=80) 24 | parse.add_argument('--batch_size', type=float, default=4) 25 | parse.add_argument('--num_workers', type=int, default=2) 26 | parse.add_argument('--load_ckpt', type=str, default=None) 27 | parse.add_argument('--model', type=str, default='MyModel') 28 | parse.add_argument('--expID', type=int, default=0) 29 | 30 | "-------------------optimizer option-----------------------" 31 | parse.add_argument('--lr', type=float, default=1e-3) 32 | parse.add_argument('--power',type=float, default=0.9) 33 | parse.add_argument('--betas', default=(0.9, 0.999)) 34 | parse.add_argument('--weight_decay', type=float, default=1e-5) 35 | parse.add_argument('--eps', type=float, default=1e-8) 36 | parse.add_argument('--mt', type=float, default=0.9) 37 | 38 | parse.add_argument('--nclasses', type=int, default=1) 39 | parse.add_argument('--save_pred', default=False, action='store_true') 40 | 41 | args = parse.parse_args() 42 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .save_img import save_binary_img, save_img 4 | from tqdm import tqdm 5 | 6 | 7 | def evaluate(model, dataloader, total_batch, args): 8 | 9 | model.eval() 10 | 11 | recall = 0 12 | specificity = 0 13 | precision = 0 14 | F1 = 0 15 | F2 = 0 16 | ACC_overall = 0 17 | IoU_poly = 0 18 | IoU_bg = 0 19 | IoU_mean = 0 20 | 21 | with torch.no_grad(): 22 | bar = tqdm(enumerate(dataloader), total=total_batch) 23 | for i, data in bar: 24 | img, gt = data['image'], data['label'] 25 | inp = img.clone().detach() 26 | target = gt.clone().detach() 27 | if args.GPUs: 28 | inp = inp.cuda() 29 | target = target.cuda() 30 | 31 | output = model(inp) 32 | _recall, _specificity, _precision, _F1, _F2, \ 33 | _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate_batch(output, target, args, i, inp) 34 | recall += _recall.item() 35 | specificity += _specificity.item() 36 | precision += _precision.item() 37 | F1 += _F1.item() 38 | F2 += _F2.item() 39 | ACC_overall += _ACC_overall.item() 40 | IoU_poly += _IoU_poly.item() 41 | IoU_bg += _IoU_bg.item() 42 | IoU_mean += _IoU_mean.item() 43 | recall /= total_batch 44 | specificity /= total_batch 45 | precision /= total_batch 46 | F1 /= total_batch 47 | F2 /= total_batch 48 | ACC_overall /= total_batch 49 | IoU_poly /= total_batch 50 | IoU_bg /= total_batch 51 | IoU_mean /= total_batch 52 | return recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean 53 | 54 | 55 | def evaluate_batch(output, gt, args, iid=None, img=None): 56 | pred = output[0] 57 | 58 | pred_binary = (pred >= 0.5).float() 59 | 60 | pred_binary_inverse = (pred_binary == 0).float() 61 | 62 | gt_binary = (gt >= 0.5).float() 63 | 64 | gt_binary_inverse = (gt_binary == 0).float() 65 | 66 | if args.save_pred == True: 67 | inpimg = output[-5] 68 | #cut_ori = output[-1] 69 | save_img(img, iid, 'img') 70 | save_img(inpimg, iid, 'inp') 71 | #save_img(cut_ori, iid, 'cut') 72 | save_binary_img(pred_binary, iid, 'pred') 73 | save_binary_img(gt_binary, iid, 'gt') 74 | 75 | 76 | TP = pred_binary.mul(gt_binary).sum() 77 | FP = pred_binary.mul(gt_binary_inverse).sum() 78 | TN = pred_binary_inverse.mul(gt_binary_inverse).sum() 79 | FN = pred_binary_inverse.mul(gt_binary).sum() 80 | 81 | if TP.item() == 0: 82 | TP = torch.Tensor([1]).cuda() 83 | 84 | # recall 85 | Recall = TP / (TP + FN) 86 | 87 | # Specificity or true negative rate 88 | Specificity = TN / (TN + FP) 89 | 90 | # Precision or positive predictive value 91 | Precision = TP / (TP + FP) 92 | 93 | # F1 score = Dice 94 | F1 = 2 * Precision * Recall / (Precision + Recall) 95 | 96 | # F2 score 97 | F2 = 5 * Precision * Recall / (4 * Precision + Recall) 98 | 99 | # Overall accuracy 100 | ACC_overall = (TP + TN) / (TP + FP + FN + TN) 101 | 102 | # IoU for poly 103 | IoU_poly = TP / (TP + FP + FN) 104 | 105 | # IoU for background 106 | IoU_bg = TN / (TN + FP + FN) 107 | 108 | # mean IoU 109 | IoU_mean = (IoU_poly + IoU_bg) / 2.0 110 | 111 | return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean 112 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # BCE loss 5 | class BCELoss(nn.Module): 6 | def __init__(self, weight=None, size_average=True): 7 | super(BCELoss, self).__init__() 8 | self.bceloss = nn.BCELoss(weight, size_average) 9 | 10 | def forward(self, pred, target): 11 | sizep = pred.size(0) 12 | sizet = target.size(0) 13 | pred_flat = pred.view(sizep, -1) 14 | target_flat = target.view(sizet, -1) 15 | 16 | loss = self.bceloss(pred_flat, target_flat) 17 | 18 | return loss 19 | 20 | #Dice loss 21 | 22 | class DiceLoss(nn.Module): 23 | def __init__(self,size_average=True): 24 | super(DiceLoss, self).__init__() 25 | 26 | def forward(self, pred, target): 27 | smooth = 1 28 | 29 | size = target.size(0) 30 | 31 | pred_flat = pred.view(size, -1) 32 | target_flat = target.view(size, -1) 33 | 34 | intersection = pred_flat * target_flat 35 | dice_score = (2 * intersection.sum(1) + smooth)/ (pred_flat.sum(1) + target_flat.sum(1) + smooth) 36 | dice_loss = 1 - dice_score.sum()/size 37 | 38 | return dice_loss 39 | 40 | # BCE + Dice loss 41 | class BceDiceLoss(nn.Module): 42 | def __init__(self, weight=None, size_average=True): 43 | super(BceDiceLoss, self).__init__() 44 | self.bce = BCELoss(weight, size_average) 45 | self.dice = DiceLoss(size_average) 46 | 47 | def forward(self, pred, target): 48 | bceloss = self.bce(pred,target) 49 | diceloss = self.dice(pred,target) 50 | 51 | #w = 0.6 52 | loss = diceloss + bceloss 53 | 54 | return loss 55 | 56 | -------------------------------------------------------------------------------- /utils/mytransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import scipy.ndimage 4 | import random 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import numbers 9 | 10 | 11 | class ToTensor(object): 12 | 13 | def __call__(self, data): 14 | image, label = data['image'], data['label'] 15 | return {'image': F.to_tensor(image), 'label': F.to_tensor(label)} 16 | 17 | 18 | class Resize(object): 19 | 20 | def __init__(self, size): 21 | self.size = size 22 | 23 | def __call__(self, data): 24 | image, label = data['image'], data['label'] 25 | 26 | return {'image': F.resize(image, self.size), 'label': F.resize(label, self.size)} 27 | 28 | 29 | class RandomHorizontalFlip(object): 30 | def __init__(self, p=0.5): 31 | self.p = p 32 | 33 | def __call__(self, data): 34 | image, label = data['image'], data['label'] 35 | 36 | if random.random() < self.p: 37 | return {'image': F.hflip(image), 'label': F.hflip(label)} 38 | 39 | return {'image': image, 'label': label} 40 | 41 | 42 | class RandomVerticalFlip(object): 43 | def __init__(self, p=0.5): 44 | self.p = p 45 | 46 | def __call__(self, data): 47 | image, label = data['image'], data['label'] 48 | 49 | if random.random() < self.p: 50 | return {'image': F.vflip(image), 'label': F.vflip(label)} 51 | 52 | return {'image': image, 'label': label} 53 | 54 | 55 | class RandomRotation(object): 56 | 57 | def __init__(self, degrees, resample=False, expand=False, center=None): 58 | if isinstance(degrees, numbers.Number): 59 | if degrees < 0: 60 | raise ValueError("If degrees is a single number, it must be positive.") 61 | self.degrees = (-degrees, degrees) 62 | else: 63 | if len(degrees) != 2: 64 | raise ValueError("If degrees is a sequence, it must be of len 2.") 65 | self.degrees = degrees 66 | self.resample = resample 67 | self.expand = expand 68 | self.center = center 69 | 70 | @staticmethod 71 | def get_params(degrees): 72 | """Get parameters for ``rotate`` for a random rotation. 73 | 74 | Returns: 75 | sequence: params to be passed to ``rotate`` for random rotation. 76 | """ 77 | angle = random.uniform(degrees[0], degrees[1]) 78 | 79 | return angle 80 | 81 | def __call__(self, data): 82 | 83 | """ 84 | img (PIL Image): Image to be rotated. 85 | 86 | Returns: 87 | PIL Image: Rotated image. 88 | """ 89 | image, label = data['image'], data['label'] 90 | 91 | if random.random() < 0.5: 92 | angle = self.get_params(self.degrees) 93 | return {'image': F.rotate(image, angle, self.resample, self.expand, self.center), 94 | 'label': F.rotate(label, angle, self.resample, self.expand, self.center)} 95 | 96 | return {'image': image, 'label': label} 97 | 98 | 99 | class RandomZoom(object): 100 | def __init__(self, zoom=(0.8, 1.2)): 101 | self.min, self.max = zoom[0], zoom[1] 102 | 103 | def __call__(self, data): 104 | image, label = data['image'], data['label'] 105 | 106 | if random.random() < 0.5: 107 | image = np.array(image) 108 | label = np.array(label) 109 | 110 | zoom = random.uniform(self.min, self.max) 111 | zoom_image = clipped_zoom(image, zoom) 112 | zoom_label = clipped_zoom(label, zoom) 113 | 114 | zoom_image = Image.fromarray(zoom_image.astype('uint8'), 'RGB') 115 | zoom_label = Image.fromarray(zoom_label.astype('uint8'), 'L') 116 | return {'image': zoom_image, 'label': zoom_label} 117 | 118 | return {'image': image, 'label': label} 119 | 120 | 121 | def clipped_zoom(img, zoom_factor, **kwargs): 122 | h, w = img.shape[:2] 123 | 124 | # For multichannel images we don't want to apply the zoom factor to the RGB 125 | # dimension, so instead we create a tuple of zoom factors, one per array 126 | # dimension, with 1's for any trailing dimensions after the width and height. 127 | zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2) 128 | 129 | # Zooming out 130 | if zoom_factor < 1: 131 | 132 | # Bounding box of the zoomed-out image within the output array 133 | zh = int(np.round(h * zoom_factor)) 134 | zw = int(np.round(w * zoom_factor)) 135 | top = (h - zh) // 2 136 | left = (w - zw) // 2 137 | 138 | # Zero-padding 139 | out = np.zeros_like(img) 140 | out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs) 141 | 142 | # Zooming in 143 | elif zoom_factor > 1: 144 | 145 | # Bounding box of the zoomed-in region within the input array 146 | zh = int(np.round(h / zoom_factor)) 147 | zw = int(np.round(w / zoom_factor)) 148 | top = (h - zh) // 2 149 | left = (w - zw) // 2 150 | 151 | zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs) 152 | 153 | # `zoom_in` might still be slightly different with `img` due to rounding, so 154 | # trim off any extra pixels at the edges or zero-padding 155 | 156 | if zoom_in.shape[0] >= h: 157 | zoom_top = (zoom_in.shape[0] - h) // 2 158 | sh = h 159 | out_top = 0 160 | oh = h 161 | else: 162 | zoom_top = 0 163 | sh = zoom_in.shape[0] 164 | out_top = (h - zoom_in.shape[0]) // 2 165 | oh = zoom_in.shape[0] 166 | if zoom_in.shape[1] >= w: 167 | zoom_left = (zoom_in.shape[1] - w) // 2 168 | sw = w 169 | out_left = 0 170 | ow = w 171 | else: 172 | zoom_left = 0 173 | sw = zoom_in.shape[1] 174 | out_left = (w - zoom_in.shape[1]) // 2 175 | ow = zoom_in.shape[1] 176 | 177 | out = np.zeros_like(img) 178 | out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw] 179 | 180 | # If zoom_factor == 1, just return the input array 181 | else: 182 | out = img 183 | return out 184 | 185 | 186 | class Translation(object): 187 | def __init__(self, translation): 188 | self.translation = translation 189 | 190 | def __call__(self, data): 191 | image, label = data['image'], data['label'] 192 | 193 | if random.random() < 0.5: 194 | image = np.array(image) 195 | label = np.array(label) 196 | rows, cols, ch = image.shape 197 | 198 | translation = random.uniform(0, self.translation) 199 | tr_x = translation / 2 200 | tr_y = translation / 2 201 | Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]]) 202 | 203 | translate_image = cv2.warpAffine(image, Trans_M, (cols, rows)) 204 | translate_label = cv2.warpAffine(label, Trans_M, (cols, rows)) 205 | 206 | translate_image = Image.fromarray(translate_image.astype('uint8'), 'RGB') 207 | translate_label = Image.fromarray(translate_label.astype('uint8'), 'L') 208 | 209 | return {'image': translate_image, 'label': translate_label} 210 | 211 | return {'image': image, 'label': label} 212 | 213 | 214 | class RandomCrop(object): 215 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 216 | if isinstance(size, numbers.Number): 217 | self.size = (int(size), int(size)) 218 | else: 219 | self.size = size 220 | self.padding = padding 221 | self.pad_if_needed = pad_if_needed 222 | self.fill = fill 223 | self.padding_mode = padding_mode 224 | 225 | @staticmethod 226 | def get_params(img, output_size): 227 | """Get parameters for ``crop`` for a random crop. 228 | Args: 229 | img (PIL Image): Image to be cropped. 230 | output_size (tuple): Expected output size of the crop. 231 | Returns: 232 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 233 | """ 234 | w, h = img.size 235 | th, tw = output_size 236 | if w == tw and h == th: 237 | return 0, 0, h, w 238 | 239 | # i = torch.randint(0, h - th + 1, size=(1, )).item() 240 | # j = torch.randint(0, w - tw + 1, size=(1, )).item() 241 | i = random.randint(0, h - th) 242 | j = random.randint(0, w - tw) 243 | return i, j, th, tw 244 | 245 | def __call__(self, data): 246 | """ 247 | Args: 248 | img (PIL Image): Image to be cropped. 249 | Returns: 250 | PIL Image: Cropped image. 251 | """ 252 | img, label = data['image'], data['label'] 253 | if self.padding is not None: 254 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 255 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 256 | # pad the width if needed 257 | if self.pad_if_needed and img.size[0] < self.size[1]: 258 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 259 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 260 | 261 | # pad the height if needed 262 | if self.pad_if_needed and img.size[1] < self.size[0]: 263 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 264 | label = F.pad(label, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 265 | i, j, h, w = self.get_params(img, self.size) 266 | img = F.crop(img, i, j, h, w) 267 | label = F.crop(label, i, j, h, w) 268 | return {"image": img, "label": label} 269 | 270 | 271 | class Normalization(object): 272 | 273 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 274 | self.mean = mean 275 | self.std = std 276 | 277 | def __call__(self, sample): 278 | image, label = sample['image'], sample['label'] 279 | image = F.normalize(image, self.mean, self.std) 280 | return {'image': image, 'label': label} 281 | 282 | -------------------------------------------------------------------------------- /utils/save_img.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import time 6 | from torchvision import transforms 7 | 8 | 9 | def save_binary_img(x, iid, suffix): 10 | x = x.cpu().data.numpy() 11 | # print(x.shape) 12 | x = np.squeeze(x) 13 | # print(x.shape) 14 | img_save_dir = './img' 15 | x *= 255 16 | # timesign= time.strftime('%m%d_%H%M%S') 17 | im = Image.fromarray(x) 18 | if not os.path.exists(img_save_dir): 19 | os.mkdir(img_save_dir) 20 | 21 | if (im.mode == 'F'): 22 | im = im.convert('RGB') 23 | im.save(os.path.join(img_save_dir, str(iid) + '_' + suffix + '.jpg')) 24 | 25 | 26 | def save_img(x, iid, suffix): 27 | img = x.cpu().clone() 28 | # print(x.shape) 29 | img = img.squeeze(0) 30 | # print(img.shape) 31 | img = transforms.ToPILImage()(img) 32 | img_save_dir = './img' 33 | # x *= 255 34 | # timesign= time.strftime('%m%d_%H%M%S') 35 | # im = Image.fromarray(x) 36 | if not os.path.exists(img_save_dir): 37 | os.mkdir(img_save_dir) 38 | 39 | # if(im.mode == 'F'): 40 | # im = im.convert('RGB') 41 | img.save(os.path.join(img_save_dir, str(iid) + '_' + suffix + '.jpg')) 42 | 43 | --------------------------------------------------------------------------------