├── README.md ├── checkpoints └── exp0 │ └── .gitkeep ├── datasets ├── Polyp.py └── __init__.py ├── models ├── LDNet.py ├── LDNet_ResNet34.py ├── __init__.py ├── modules.py └── res2net.py ├── opt.py ├── test.py ├── train.py └── utils ├── __init__.py ├── comm.py ├── loss.py ├── metrics.py ├── transform.py └── transform_multi.py /README.md: -------------------------------------------------------------------------------- 1 | # Lesion-Aware Dynamic Kernel for Polyp Segmentation 2 | 3 | ## Introduction 4 | 5 | This repository contains the PyTorch implementation of: 6 | 7 | Lesion-Aware Dynamic Kernel for Polyp Segmentation, MICCAI 2022. 8 | 9 | ## Requirements 10 | 11 | * torch 12 | * torchvision 13 | * tqdm 14 | * opencv 15 | * scipy 16 | * skimage 17 | * PIL 18 | * numpy 19 | 20 | ## Usage 21 | 22 | #### 1. Training 23 | 24 | ```bash 25 | python train.py --root /path-to-project --mode train 26 | --train_data_dir /path-to-train_data --valid_data_dir /path-to-valid_data 27 | ``` 28 | 29 | 30 | 31 | #### 2. Inference 32 | 33 | ```bash 34 | python test.py --root /path-to-project --mode test --load_ckpt checkpoint 35 | --test_data_dir /path-to-test_data 36 | ``` 37 | 38 | 39 | 40 | ## Citation 41 | 42 | If you feel this work is helpful, please cite our paper 43 | 44 | ``` 45 | @inproceedings{zhang2022lesion, 46 | title={Lesion-Aware Dynamic Kernel for Polyp Segmentation}, 47 | author={Zhang, Ruifei and Lai, Peiwen and Wan, Xiang and Fan, De-Jun and Gao, Feng and Wu, Xiao-Jian and Li, Guanbin}, 48 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 49 | pages={99--109}, 50 | year={2022}, 51 | organization={Springer} 52 | } 53 | ``` 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /checkpoints/exp0/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReaFly/LDNet/b673453a26b6e5219677b09292e441de2b676172/checkpoints/exp0/.gitkeep -------------------------------------------------------------------------------- /datasets/Polyp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from utils.transform_multi import * 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | 8 | class PolypDataset(Dataset): 9 | def __init__(self, root, data_dir, mode='train', transform=None): 10 | super(PolypDataset, self).__init__() 11 | data_path = osp.join(root, data_dir) 12 | self.imglist = [] 13 | self.gtlist = [] 14 | 15 | datalist = os.listdir(osp.join(data_path, 'images')) 16 | for data in datalist: 17 | name = os.path.splitext(data)[0] 18 | self.imglist.append(osp.join(data_path+'/images', data)) 19 | self.gtlist.append(osp.join(data_path+'/masks', name+'.png')) 20 | 21 | if transform is None: 22 | if mode == 'train': 23 | transform = transforms.Compose([ 24 | Resize((256, 256)), 25 | RandomHorizontalFlip(), 26 | RandomVerticalFlip(), 27 | RandomRotation(90), 28 | RandomZoom((0.9, 1.1)), 29 | RandomCrop((224, 224)), 30 | ToTensor(), 31 | ]) 32 | elif mode == 'valid' or mode == 'test': 33 | transform = transforms.Compose([ 34 | Resize((224, 224)), 35 | ToTensor(), 36 | ]) 37 | self.transform = transform 38 | 39 | def __getitem__(self, index): 40 | img_path = self.imglist[index] 41 | gt_path = self.gtlist[index] 42 | name = img_path.split('/')[-1].split('.')[0] 43 | img = Image.open(img_path).convert('RGB') 44 | gt = Image.open(gt_path).convert('L') 45 | data = {'image': img, 'label': gt} 46 | if self.transform: 47 | data = self.transform(data) 48 | data['name'] = name 49 | return data 50 | 51 | def __len__(self): 52 | return len(self.imglist) 53 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .Polyp import PolypDataset 2 | -------------------------------------------------------------------------------- /models/LDNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.modules import LCA_blcok, ESA_blcok 6 | from models.res2net import res2net50_v1b_26w_4s 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 11 | super(ConvBlock, self).__init__() 12 | self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size,stride=stride,padding=padding) 13 | self.bn = nn.BatchNorm2d(out_channels) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | x = self.relu(x) 20 | return x 21 | 22 | 23 | class DecoderBlock(nn.Module): 24 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 25 | super(DecoderBlock, self).__init__() 26 | 27 | self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size, stride=stride, padding=padding) 28 | 29 | self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 30 | 31 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 32 | 33 | def forward(self, x): 34 | x = self.conv1(x) 35 | x = self.conv2(x) 36 | x = self.upsample(x) 37 | return x 38 | 39 | 40 | class HeadUpdator(nn.Module): 41 | def __init__(self, in_channels=64, feat_channels=64, out_channels=None, conv_kernel_size=1): 42 | super(HeadUpdator, self).__init__() 43 | 44 | self.conv_kernel_size = conv_kernel_size 45 | 46 | # C == feat 47 | self.in_channels = in_channels 48 | self.feat_channels = feat_channels 49 | self.out_channels = out_channels if out_channels else in_channels 50 | # feat == in == out 51 | self.num_in = self.feat_channels 52 | self.num_out = self.feat_channels 53 | 54 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 55 | 56 | self.pred_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out) 57 | self.head_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out, 1) 58 | 59 | self.pred_gate = nn.Linear(self.num_in, self.feat_channels, 1) 60 | self.head_gate = nn.Linear(self.num_in, self.feat_channels, 1) 61 | 62 | self.pred_norm_in = nn.LayerNorm(self.feat_channels) 63 | self.head_norm_in = nn.LayerNorm(self.feat_channels) 64 | self.pred_norm_out = nn.LayerNorm(self.feat_channels) 65 | self.head_norm_out = nn.LayerNorm(self.feat_channels) 66 | 67 | self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) 68 | self.fc_norm = nn.LayerNorm(self.feat_channels) 69 | self.activation = nn.ReLU(inplace=True) 70 | 71 | 72 | def forward(self, feat, head, pred): 73 | 74 | bs, num_classes = head.shape[:2] 75 | # C, H, W = feat.shape[-3:] 76 | 77 | pred = self.upsample(pred) 78 | pred = torch.sigmoid(pred) 79 | 80 | """ 81 | Head feature assemble 82 | - use prediction to assemble head-aware feature 83 | """ 84 | 85 | # [B, N, C] 86 | assemble_feat = torch.einsum('bnhw,bchw->bnc', pred, feat) 87 | 88 | # [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] 89 | head = head.reshape(bs, num_classes, self.in_channels, -1).permute(0, 1, 3, 2) 90 | 91 | """ 92 | Update head 93 | - assemble_feat, head -> linear transform -> pred_feat, head_feat 94 | - both split into two parts: xxx_in & xxx_out 95 | - gate_feat = head_feat_in * pred_feat_in 96 | - gate_feat -> linear transform -> pred_gate, head_gate 97 | - update_head = pred_gate * pred_feat_out + head_gate * head_feat_out 98 | """ 99 | # [B, N, C] -> [B*N, C] 100 | assemble_feat = assemble_feat.reshape(-1, self.in_channels) 101 | bs_num = assemble_feat.size(0) 102 | 103 | # [B*N, C] -> [B*N, in+out] 104 | pred_feat = self.pred_transform_layer(assemble_feat) 105 | 106 | # [B*N, in] 107 | pred_feat_in = pred_feat[:, :self.num_in].view(-1, self.feat_channels) 108 | # [B*N, out] 109 | pred_feat_out = pred_feat[:, -self.num_out:].view(-1, self.feat_channels) 110 | 111 | # [B, N, K*K, C] -> [B*N, K*K, C] -> [B*N, K*K, in+out] 112 | head_feat = self.head_transform_layer( 113 | head.reshape(bs_num, -1, self.in_channels)) 114 | 115 | # [B*N, K*K, in] 116 | head_feat_in = head_feat[..., :self.num_in] 117 | # [B*N, K*K, out] 118 | head_feat_out = head_feat[..., -self.num_out:] 119 | 120 | # [B*N, K*K, in] * [B*N, 1, in] -> [B*N, K*K, in] 121 | gate_feat = head_feat_in * pred_feat_in.unsqueeze(-2) 122 | 123 | # [B*N, K*K, feat] 124 | head_gate = self.head_norm_in(self.head_gate(gate_feat)) 125 | pred_gate = self.pred_norm_in(self.pred_gate(gate_feat)) 126 | 127 | head_gate = torch.sigmoid(head_gate) 128 | pred_gate = torch.sigmoid(pred_gate) 129 | 130 | # [B*N, K*K, out] 131 | head_feat_out = self.head_norm_out(head_feat_out) 132 | # [B*N, out] 133 | pred_feat_out = self.pred_norm_out(pred_feat_out) 134 | 135 | # [B*N, K*K, feat] or [B*N, K*K, C] 136 | update_head = pred_gate * pred_feat_out.unsqueeze(-2) + head_gate * head_feat_out 137 | 138 | update_head = self.fc_layer(update_head) 139 | update_head = self.fc_norm(update_head) 140 | update_head = self.activation(update_head) 141 | 142 | # [B*N, K*K, C] -> [B, N, K*K, C] 143 | update_head = update_head.reshape(bs, num_classes, -1, self.feat_channels) 144 | # [B, N, K*K, C] -> [B, N, C, K*K] -> [B, N, C, K, K] 145 | update_head = update_head.permute(0, 1, 3, 2).reshape(bs, num_classes, self.feat_channels, self.conv_kernel_size, self.conv_kernel_size) 146 | 147 | return update_head 148 | 149 | 150 | class LDNet(nn.Module): 151 | def __init__(self, num_classes=1, unified_channels=64, conv_kernel_size=1): 152 | super(LDNet, self).__init__() 153 | self.num_classes = num_classes 154 | self.conv_kernel_size = conv_kernel_size 155 | self.unified_channels = unified_channels 156 | 157 | res2net = res2net50_v1b_26w_4s(pretrained=True) 158 | 159 | # Encoder 160 | self.encoder1_conv = res2net.conv1 161 | self.encoder1_bn = res2net.bn1 162 | self.encoder1_relu = res2net.relu 163 | self.maxpool = res2net.maxpool 164 | self.encoder2 = res2net.layer1 165 | self.encoder3 = res2net.layer2 166 | self.encoder4 = res2net.layer3 167 | self.encoder5 = res2net.layer4 168 | 169 | self.reduce2 = nn.Conv2d(256, 64, 1) 170 | self.reduce3 = nn.Conv2d(512, 128, 1) 171 | self.reduce4 = nn.Conv2d(1024, 256, 1) 172 | self.reduce5 = nn.Conv2d(2048, 512, 1) 173 | # Decoder 174 | self.decoder5 = DecoderBlock(in_channels=512, out_channels=512) 175 | self.decoder4 = DecoderBlock(in_channels=512+256, out_channels=256) 176 | self.decoder3 = DecoderBlock(in_channels=256+128, out_channels=128) 177 | self.decoder2 = DecoderBlock(in_channels=128+64, out_channels=64) 178 | self.decoder1 = DecoderBlock(in_channels=64+64, out_channels=64) 179 | 180 | # self.outconv = nn.Sequential( 181 | # ConvBlock(64, 32, kernel_size=3, stride=1, padding=1), 182 | # nn.Dropout2d(0.1), 183 | # nn.Conv2d(32, num_classes, 1) 184 | # ) 185 | 186 | self.gobal_average_pool = nn.Sequential( 187 | nn.GroupNorm(16, 512), 188 | nn.ReLU(inplace=True), 189 | nn.AdaptiveAvgPool2d(1), 190 | ) 191 | #self.gobal_average_pool = nn.AdaptiveAvgPool2d(1) 192 | self.generate_head = nn.Linear(512, self.num_classes*self.unified_channels*self.conv_kernel_size*self.conv_kernel_size) 193 | 194 | # self.pred_head = nn.Conv2d(64, self.num_classes, self.conv_kernel_size) 195 | 196 | self.headUpdators = nn.ModuleList() 197 | for i in range(4): 198 | self.headUpdators.append(HeadUpdator()) 199 | 200 | # Unified channel 201 | self.unify1 = nn.Conv2d(64, 64, 1) 202 | self.unify2 = nn.Conv2d(64, 64, 1) 203 | self.unify3 = nn.Conv2d(128, 64, 1) 204 | self.unify4 = nn.Conv2d(256, 64, 1) 205 | self.unify5 = nn.Conv2d(512, 64, 1) 206 | 207 | # Efficient self-attention block 208 | self.esa1 = ESA_blcok(dim=64) 209 | self.esa2 = ESA_blcok(dim=64) 210 | self.esa3 = ESA_blcok(dim=128) 211 | self.esa4 = ESA_blcok(dim=256) 212 | #self.esa5 = ESA_blcok(dim=512) 213 | # Lesion-aware cross-attention block 214 | self.lca1 = LCA_blcok(dim=64) 215 | self.lca2 = LCA_blcok(dim=128) 216 | self.lca3 = LCA_blcok(dim=256) 217 | self.lca4 = LCA_blcok(dim=512) 218 | 219 | self.decoderList = nn.ModuleList([self.decoder4, self.decoder3, self.decoder2, self.decoder1]) 220 | self.unifyList = nn.ModuleList([self.unify4, self.unify3, self.unify2, self.unify1]) 221 | self.esaList = nn.ModuleList([self.esa4, self.esa3, self.esa2, self.esa1]) 222 | self.lcaList = nn.ModuleList([self.lca4, self.lca3, self.lca2, self.lca1]) 223 | 224 | 225 | def forward(self, x): 226 | # x = H*W*3 227 | bs = x.shape[0] 228 | e1_ = self.encoder1_conv(x) # H/2*W/2*64 229 | e1_ = self.encoder1_bn(e1_) 230 | e1_ = self.encoder1_relu(e1_) 231 | e1_pool_ = self.maxpool(e1_) # H/4*W/4*64 232 | e2_ = self.encoder2(e1_pool_) # H/4*W/4*64 233 | e3_ = self.encoder3(e2_) # H/8*W/8*128 234 | e4_ = self.encoder4(e3_) # H/16*W/16*256 235 | e5_ = self.encoder5(e4_) # H/32*W/32*512 236 | 237 | e1 = e1_ 238 | e2 = self.reduce2(e2_) 239 | e3 = self.reduce3(e3_) 240 | e4 = self.reduce4(e4_) 241 | e5 = self.reduce5(e5_) 242 | 243 | #e5 = self.esa5(e5) 244 | d5 = self.decoder5(e5) # H/16*W/16*512 245 | 246 | feat5 = self.unify5(d5) 247 | 248 | decoder_out = [d5] 249 | encoder_out = [e4, e3, e2, e1] 250 | 251 | """ 252 | B = batch size (bs) 253 | N = number of classes (num_classes) 254 | C = feature channels 255 | K = conv kernel size 256 | """ 257 | # [B, 512, 1, 1] -> [B, 512] 258 | gobal_context = self.gobal_average_pool(e5) 259 | gobal_context = gobal_context.reshape(bs, -1) 260 | 261 | # [B, N*C*K*K] -> [B, N, C, K, K] 262 | head = self.generate_head(gobal_context) 263 | head = head.reshape(bs, self.num_classes, self.unified_channels, self.conv_kernel_size, self.conv_kernel_size) 264 | 265 | pred = [] 266 | for t in range(bs): 267 | pred.append(F.conv2d( 268 | feat5[t:t+1], 269 | head[t], 270 | padding=int(self.conv_kernel_size // 2))) 271 | pred = torch.cat(pred, dim=0) 272 | H, W = feat5.shape[-2:] 273 | # [B, N, H, W] 274 | pred = pred.reshape(bs, self.num_classes, H, W) 275 | stage_out = [pred] 276 | 277 | # feat size: [B, C, H, W] 278 | # feats = [feat4, feat3, feat2, feat1] 279 | feats = [] 280 | 281 | for i in range(4): 282 | esa_out = self.esaList[i](encoder_out[i]) 283 | lca_out = self.lcaList[i](decoder_out[-1], stage_out[-1]) 284 | comb = torch.cat([lca_out, esa_out], dim=1) 285 | 286 | d = self.decoderList[i](comb) 287 | decoder_out.append(d) 288 | 289 | feat = self.unifyList[i](d) 290 | feats.append(feat) 291 | 292 | head = self.headUpdators[i](feats[i], head, pred) 293 | pred = [] 294 | 295 | for j in range(bs): 296 | pred.append(F.conv2d( 297 | feats[i][j:j+1], 298 | head[j], 299 | padding=int(self.conv_kernel_size // 2))) 300 | pred = torch.cat(pred, dim=0) 301 | H, W = feats[i].shape[-2:] 302 | pred = pred.reshape(bs, self.num_classes, H, W) 303 | stage_out.append(pred) 304 | 305 | stage_out.reverse() 306 | #return stage_out[0], stage_out[1], stage_out[2], stage_out[3], stage_out[4] 307 | return torch.sigmoid(stage_out[0]), torch.sigmoid(stage_out[1]), torch.sigmoid(stage_out[2]), \ 308 | torch.sigmoid(stage_out[3]), torch.sigmoid(stage_out[4]) 309 | -------------------------------------------------------------------------------- /models/LDNet_ResNet34.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as torchmodels 5 | from models.modules import LCA_blcok, ESA_blcok 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 10 | super(ConvBlock, self).__init__() 11 | self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size,stride=stride,padding=padding) 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | x = self.relu(x) 19 | return x 20 | 21 | 22 | class DecoderBlock(nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 24 | super(DecoderBlock, self).__init__() 25 | 26 | self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size, stride=stride, padding=padding) 27 | 28 | self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 29 | 30 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 31 | 32 | def forward(self, x): 33 | x = self.conv1(x) 34 | x = self.conv2(x) 35 | x = self.upsample(x) 36 | return x 37 | 38 | 39 | class HeadUpdator(nn.Module): 40 | def __init__(self, in_channels=64, feat_channels=64, out_channels=None, conv_kernel_size=1): 41 | super(HeadUpdator, self).__init__() 42 | 43 | self.conv_kernel_size = conv_kernel_size 44 | 45 | # C == feat 46 | self.in_channels = in_channels 47 | self.feat_channels = feat_channels 48 | self.out_channels = out_channels if out_channels else in_channels 49 | # feat == in == out 50 | self.num_in = self.feat_channels 51 | self.num_out = self.feat_channels 52 | 53 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 54 | 55 | self.pred_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out) 56 | self.head_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out, 1) 57 | 58 | self.pred_gate = nn.Linear(self.num_in, self.feat_channels, 1) 59 | self.head_gate = nn.Linear(self.num_in, self.feat_channels, 1) 60 | 61 | self.pred_norm_in = nn.LayerNorm(self.feat_channels) 62 | self.head_norm_in = nn.LayerNorm(self.feat_channels) 63 | self.pred_norm_out = nn.LayerNorm(self.feat_channels) 64 | self.head_norm_out = nn.LayerNorm(self.feat_channels) 65 | 66 | self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) 67 | self.fc_norm = nn.LayerNorm(self.feat_channels) 68 | self.activation = nn.ReLU(inplace=True) 69 | 70 | 71 | def forward(self, feat, head, pred): 72 | 73 | bs, num_classes = head.shape[:2] 74 | # C, H, W = feat.shape[-3:] 75 | 76 | pred = self.upsample(pred) 77 | pred = torch.sigmoid(pred) 78 | 79 | """ 80 | Head feature assemble 81 | - use prediction to assemble head-aware feature 82 | """ 83 | 84 | # [B, N, C] 85 | assemble_feat = torch.einsum('bnhw,bchw->bnc', pred, feat) 86 | 87 | # [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] 88 | head = head.reshape(bs, num_classes, self.in_channels, -1).permute(0, 1, 3, 2) 89 | 90 | """ 91 | Update head 92 | - assemble_feat, head -> linear transform -> pred_feat, head_feat 93 | - both split into two parts: xxx_in & xxx_out 94 | - gate_feat = head_feat_in * pred_feat_in 95 | - gate_feat -> linear transform -> pred_gate, head_gate 96 | - update_head = pred_gate * pred_feat_out + head_gate * head_feat_out 97 | """ 98 | # [B, N, C] -> [B*N, C] 99 | assemble_feat = assemble_feat.reshape(-1, self.in_channels) 100 | bs_num = assemble_feat.size(0) 101 | 102 | # [B*N, C] -> [B*N, in+out] 103 | pred_feat = self.pred_transform_layer(assemble_feat) 104 | 105 | # [B*N, in] 106 | pred_feat_in = pred_feat[:, :self.num_in].view(-1, self.feat_channels) 107 | # [B*N, out] 108 | pred_feat_out = pred_feat[:, -self.num_out:].view(-1, self.feat_channels) 109 | 110 | # [B, N, K*K, C] -> [B*N, K*K, C] -> [B*N, K*K, in+out] 111 | head_feat = self.head_transform_layer( 112 | head.reshape(bs_num, -1, self.in_channels)) 113 | 114 | # [B*N, K*K, in] 115 | head_feat_in = head_feat[..., :self.num_in] 116 | # [B*N, K*K, out] 117 | head_feat_out = head_feat[..., -self.num_out:] 118 | 119 | # [B*N, K*K, in] * [B*N, 1, in] -> [B*N, K*K, in] 120 | gate_feat = head_feat_in * pred_feat_in.unsqueeze(-2) 121 | 122 | # [B*N, K*K, feat] 123 | head_gate = self.head_norm_in(self.head_gate(gate_feat)) 124 | pred_gate = self.pred_norm_in(self.pred_gate(gate_feat)) 125 | 126 | head_gate = torch.sigmoid(head_gate) 127 | pred_gate = torch.sigmoid(pred_gate) 128 | 129 | # [B*N, K*K, out] 130 | head_feat_out = self.head_norm_out(head_feat_out) 131 | # [B*N, out] 132 | pred_feat_out = self.pred_norm_out(pred_feat_out) 133 | 134 | # [B*N, K*K, feat] or [B*N, K*K, C] 135 | update_head = pred_gate * pred_feat_out.unsqueeze(-2) + head_gate * head_feat_out 136 | 137 | update_head = self.fc_layer(update_head) 138 | update_head = self.fc_norm(update_head) 139 | update_head = self.activation(update_head) 140 | 141 | # [B*N, K*K, C] -> [B, N, K*K, C] 142 | update_head = update_head.reshape(bs, num_classes, -1, self.feat_channels) 143 | # [B, N, K*K, C] -> [B, N, C, K*K] -> [B, N, C, K, K] 144 | update_head = update_head.permute(0, 1, 3, 2).reshape(bs, num_classes, self.feat_channels, self.conv_kernel_size, self.conv_kernel_size) 145 | 146 | return update_head 147 | 148 | 149 | class LDNet_ResNet34(nn.Module): 150 | def __init__(self, num_classes=1, unified_channels=64, conv_kernel_size=1): 151 | super(LDNet_ResNet34, self).__init__() 152 | self.num_classes = num_classes 153 | self.conv_kernel_size = conv_kernel_size 154 | self.unified_channels = unified_channels 155 | 156 | resnet = torchmodels.resnet34(pretrained=True) 157 | 158 | # Encoder 159 | self.encoder1_conv = resnet.conv1 160 | self.encoder1_bn = resnet.bn1 161 | self.encoder1_relu = resnet.relu 162 | self.maxpool = resnet.maxpool 163 | self.encoder2 = resnet.layer1 164 | self.encoder3 = resnet.layer2 165 | self.encoder4 = resnet.layer3 166 | self.encoder5 = resnet.layer4 167 | 168 | # Decoder 169 | self.decoder5 = DecoderBlock(in_channels=512, out_channels=512) 170 | self.decoder4 = DecoderBlock(in_channels=512+256, out_channels=256) 171 | self.decoder3 = DecoderBlock(in_channels=256+128, out_channels=128) 172 | self.decoder2 = DecoderBlock(in_channels=128+64, out_channels=64) 173 | self.decoder1 = DecoderBlock(in_channels=64+64, out_channels=64) 174 | 175 | # self.outconv = nn.Sequential( 176 | # ConvBlock(64, 32, kernel_size=3, stride=1, padding=1), 177 | # nn.Dropout2d(0.1), 178 | # nn.Conv2d(32, num_classes, 1) 179 | # ) 180 | 181 | self.gobal_average_pool = nn.Sequential( 182 | nn.GroupNorm(16, 512), 183 | nn.ReLU(inplace=True), 184 | nn.AdaptiveAvgPool2d(1), 185 | ) 186 | 187 | self.generate_head = nn.Linear(512, self.num_classes*self.unified_channels*self.conv_kernel_size*self.conv_kernel_size) 188 | 189 | # self.pred_head = nn.Conv2d(64, self.num_classes, self.conv_kernel_size) 190 | 191 | self.headUpdators = nn.ModuleList() 192 | for i in range(4): 193 | self.headUpdators.append(HeadUpdator()) 194 | 195 | # Unified channel 196 | self.unify1 = nn.Conv2d(64, 64, 1) 197 | self.unify2 = nn.Conv2d(64, 64, 1) 198 | self.unify3 = nn.Conv2d(128, 64, 1) 199 | self.unify4 = nn.Conv2d(256, 64, 1) 200 | self.unify5 = nn.Conv2d(512, 64, 1) 201 | 202 | # Efficient self-attention block 203 | self.esa1 = ESA_blcok(dim=64) 204 | self.esa2 = ESA_blcok(dim=64) 205 | self.esa3 = ESA_blcok(dim=128) 206 | self.esa4 = ESA_blcok(dim=256) 207 | 208 | # Lesion-aware cross-attention block 209 | self.lca1 = LCA_blcok(dim=64) 210 | self.lca2 = LCA_blcok(dim=128) 211 | self.lca3 = LCA_blcok(dim=256) 212 | self.lca4 = LCA_blcok(dim=512) 213 | 214 | self.decoderList = nn.ModuleList([self.decoder4, self.decoder3, self.decoder2, self.decoder1]) 215 | self.unifyList = nn.ModuleList([self.unify4, self.unify3, self.unify2, self.unify1]) 216 | self.esaList = nn.ModuleList([self.esa4, self.esa3, self.esa2, self.esa1]) 217 | self.lcaList = nn.ModuleList([self.lca4, self.lca3, self.lca2, self.lca1]) 218 | 219 | 220 | def forward(self, x): 221 | # x = H*W*3 222 | bs = x.shape[0] 223 | e1 = self.encoder1_conv(x) # H/2*W/2*64 224 | e1 = self.encoder1_bn(e1) 225 | e1 = self.encoder1_relu(e1) 226 | e1_pool = self.maxpool(e1) # H/4*W/4*64 227 | e2 = self.encoder2(e1_pool) # H/4*W/4*64 228 | e3 = self.encoder3(e2) # H/8*W/8*128 229 | e4 = self.encoder4(e3) # H/16*W/16*256 230 | e5 = self.encoder5(e4) # H/32*W/32*512 231 | 232 | d5 = self.decoder5(e5) # H/16*W/16*512 233 | 234 | feat5 = self.unify5(d5) 235 | 236 | decoder_out = [d5] 237 | encoder_out = [e4, e3, e2, e1] 238 | 239 | """ 240 | B = batch size (bs) 241 | N = number of classes (num_classes) 242 | C = feature channels 243 | K = conv kernel size 244 | """ 245 | # [B, 512, 1, 1] -> [B, 512] 246 | gobal_context = self.gobal_average_pool(e5) 247 | gobal_context = gobal_context.reshape(bs, -1) 248 | 249 | # [B, N*C*K*K] -> [B, N, C, K, K] 250 | head = self.generate_head(gobal_context) 251 | head = head.reshape(bs, self.num_classes, self.unified_channels, self.conv_kernel_size, self.conv_kernel_size) 252 | 253 | pred = [] 254 | for t in range(bs): 255 | pred.append(F.conv2d( 256 | feat5[t:t+1], 257 | head[t], 258 | padding=int(self.conv_kernel_size // 2))) 259 | pred = torch.cat(pred, dim=0) 260 | H, W = feat5.shape[-2:] 261 | # [B, N, H, W] 262 | pred = pred.reshape(bs, self.num_classes, H, W) 263 | stage_out = [pred] 264 | 265 | # feat size: [B, C, H, W] 266 | # feats = [feat4, feat3, feat2, feat1] 267 | feats = [] 268 | 269 | for i in range(4): 270 | esa_out = self.esaList[i](encoder_out[i]) 271 | lca_out = self.lcaList[i](decoder_out[-1],stage_out[-1]) 272 | comb = torch.cat([lca_out, esa_out], dim=1) 273 | 274 | d = self.decoderList[i](comb) 275 | decoder_out.append(d) 276 | 277 | feat = self.unifyList[i](d) 278 | feats.append(feat) 279 | 280 | head = self.headUpdators[i](feats[i], head, pred) 281 | pred = [] 282 | 283 | for j in range(bs): 284 | pred.append(F.conv2d( 285 | feats[i][j:j+1], 286 | head[j], 287 | padding=int(self.conv_kernel_size // 2))) 288 | pred = torch.cat(pred, dim=0) 289 | H, W = feats[i].shape[-2:] 290 | pred = pred.reshape(bs, self.num_classes, H, W) 291 | stage_out.append(pred) 292 | 293 | stage_out.reverse() 294 | 295 | return torch.sigmoid(stage_out[0]), torch.sigmoid(stage_out[1]), torch.sigmoid(stage_out[2]), \ 296 | torch.sigmoid(stage_out[3]), torch.sigmoid(stage_out[4]) 297 | 298 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .LDNet import LDNet 2 | from .LDNet_ResNet34 import LDNet_ResNet34 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | import torch.nn.functional as F 6 | 7 | class PreNorm(nn.Module): 8 | def __init__(self, dim, fn): 9 | super().__init__() 10 | self.norm = nn.LayerNorm(dim) 11 | self.fn = fn 12 | 13 | def forward(self, x, **kwargs): 14 | return self.fn(self.norm(x), **kwargs) 15 | 16 | 17 | class FeedForward(nn.Module): 18 | def __init__(self, dim, hidden_dim, dropout = 0.): 19 | super().__init__() 20 | self.net = nn.Sequential( 21 | nn.Linear(dim, hidden_dim), 22 | nn.GELU(), 23 | nn.Dropout(dropout), 24 | nn.Linear(hidden_dim, dim), 25 | nn.Dropout(dropout) 26 | ) 27 | def forward(self, x): 28 | return self.net(x) 29 | 30 | 31 | class PPM(nn.Module): 32 | def __init__(self, pooling_sizes=(1, 3, 5)): 33 | super().__init__() 34 | self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size,size)) for size in pooling_sizes]) 35 | 36 | def forward(self, feat): 37 | b, c, h, w = feat.shape 38 | output = [layer(feat).view(b, c, -1) for layer in self.layer] 39 | output = torch.cat(output, dim=-1) 40 | return output 41 | 42 | 43 | # Efficient self attention 44 | class ESA_layer(nn.Module): 45 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | project_out = not (heads == 1 and dim_head == dim) 49 | 50 | self.heads = heads 51 | self.scale = dim_head ** -0.5 52 | 53 | self.attend = nn.Softmax(dim=-1) 54 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False) 55 | self.ppm = PPM(pooling_sizes=(1, 3, 5)) 56 | self.to_out = nn.Sequential( 57 | nn.Linear(inner_dim, dim), 58 | nn.Dropout(dropout) 59 | ) if project_out else nn.Identity() 60 | 61 | def forward(self, x): 62 | # input x (b, c, h, w) 63 | b, c, h, w = x.shape 64 | q, k, v = self.to_qkv(x).chunk(3, dim=1) # q/k/v shape: (b, inner_dim, h, w) 65 | q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads) # q shape: (b, head, n_q, d) 66 | 67 | k, v = self.ppm(k), self.ppm(v) # k/v shape: (b, inner_dim, n_kv) 68 | k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads) # k shape: (b, head, n_kv, d) 69 | v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads) # v shape: (b, head, n_kv, d) 70 | 71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # shape: (b, head, n_q, n_kv) 72 | 73 | attn = self.attend(dots) 74 | 75 | out = torch.matmul(attn, v) # shape: (b, head, n_q, d) 76 | out = rearrange(out, 'b head n d -> b n (head d)') 77 | return self.to_out(out) 78 | 79 | 80 | class ESA_blcok(nn.Module): 81 | def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout = 0.): 82 | super().__init__() 83 | self.ESAlayer = ESA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout) 84 | self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 85 | 86 | 87 | def forward(self, x): 88 | b, c, h, w = x.shape 89 | out = rearrange(x, 'b c h w -> b (h w) c') 90 | out = self.ESAlayer(x) + out 91 | out = self.ff(out) + out 92 | out = rearrange(out, 'b (h w) c -> b c h w', h=h) 93 | 94 | return out 95 | 96 | 97 | def MaskAveragePooling(x, mask): 98 | mask = torch.sigmoid(mask) 99 | b, c, h, w = x.shape 100 | eps = 0.0005 101 | x_mask = x * mask 102 | h, w = x.shape[2], x.shape[3] 103 | area = F.avg_pool2d(mask, (h, w)) * h * w + eps 104 | x_feat = F.avg_pool2d(x_mask, (h, w)) * h * w / area 105 | x_feat = x_feat.view(b, c, -1) 106 | return x_feat 107 | 108 | 109 | # Lesion-aware Cross Attention 110 | class LCA_layer(nn.Module): 111 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 112 | super().__init__() 113 | inner_dim = dim_head * heads 114 | project_out = not (heads == 1 and dim_head == dim) 115 | self.heads = heads 116 | self.scale = dim_head ** -0.5 117 | 118 | self.attend = nn.Softmax(dim=-1) 119 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False) 120 | self.to_out = nn.Sequential( 121 | nn.Linear(inner_dim, dim), 122 | nn.Dropout(dropout) 123 | ) if project_out else nn.Identity() 124 | 125 | def forward(self, x, mask): 126 | # input x (b, c, h, w) 127 | b, c, h, w = x.shape 128 | q, k, v = self.to_qkv(x).chunk(3, dim=1) # q/k/v shape: (b, inner_dim, h, w) 129 | q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads) # q shape: (b, head, n_q, d) 130 | 131 | k, v = MaskAveragePooling(k, mask), MaskAveragePooling(v, mask) # k/v shape: (b, inner_dim, 1) 132 | k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads) # k shape: (b, head, 1, d) 133 | v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads) # v shape: (b, head, 1, d) 134 | 135 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # shape: (b, head, n_q, n_kv) 136 | 137 | attn = self.attend(dots) 138 | 139 | out = torch.matmul(attn, v) # shape: (b, head, n_q, d) 140 | out = rearrange(out, 'b head n d -> b n (head d)') 141 | return self.to_out(out) 142 | 143 | 144 | class LCA_blcok(nn.Module): 145 | def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout = 0.): 146 | super().__init__() 147 | self.LCAlayer = LCA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout) 148 | self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 149 | 150 | def forward(self, x, mask): 151 | b, c, h, w = x.shape 152 | out = rearrange(x, 'b c h w -> b (h w) c') 153 | out = self.LCAlayer(x, mask) + out 154 | out = self.ff(out) + out 155 | out = rearrange(out, 'b (h w) c -> b c h w', h=h) 156 | 157 | return out 158 | 159 | 160 | 161 | # test 162 | if __name__ == '__main__': 163 | x = torch.rand((4, 3, 320, 320)) 164 | mask = torch.rand(4, 1, 320, 320) 165 | lca = LCA_blcok(dim=3) 166 | esa = ESA_blcok(dim=3) 167 | print(lca(x, mask).shape) 168 | print(esa(x).shape) 169 | 170 | -------------------------------------------------------------------------------- /models/res2net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 8 | 9 | model_urls = { 10 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 11 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 12 | } 13 | 14 | 15 | class Bottle2neck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 19 | """ Constructor 20 | Args: 21 | inplanes: input channel dimensionality 22 | planes: output channel dimensionality 23 | stride: conv stride. Replaces pooling layer. 24 | downsample: None when stride = 1 25 | baseWidth: basic width of conv3x3 26 | scale: number of scale. 27 | type: 'normal': normal set. 'stage': first block of a new stage. 28 | """ 29 | super(Bottle2neck, self).__init__() 30 | 31 | width = int(math.floor(planes * (baseWidth / 64.0))) 32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(width * scale) 34 | 35 | if scale == 1: 36 | self.nums = 1 37 | else: 38 | self.nums = scale - 1 39 | if stype == 'stage': 40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 41 | convs = [] 42 | bns = [] 43 | for i in range(self.nums): 44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 45 | bns.append(nn.BatchNorm2d(width)) 46 | self.convs = nn.ModuleList(convs) 47 | self.bns = nn.ModuleList(bns) 48 | 49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stype = stype 55 | self.scale = scale 56 | self.width = width 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | spx = torch.split(out, self.width, 1) 66 | for i in range(self.nums): 67 | if i == 0 or self.stype == 'stage': 68 | sp = spx[i] 69 | else: 70 | sp = sp + spx[i] 71 | sp = self.convs[i](sp) 72 | sp = self.relu(self.bns[i](sp)) 73 | if i == 0: 74 | out = sp 75 | else: 76 | out = torch.cat((out, sp), 1) 77 | if self.scale != 1 and self.stype == 'normal': 78 | out = torch.cat((out, spx[self.nums]), 1) 79 | elif self.scale != 1 and self.stype == 'stage': 80 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b lib. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | 177 | def res2net101_v1b(pretrained=False, **kwargs): 178 | """Constructs a Res2Net-50_v1b_26w_4s lib. 179 | Args: 180 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 181 | """ 182 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 185 | return model 186 | 187 | 188 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 189 | """Constructs a Res2Net-50_v1b_26w_4s lib. 190 | Args: 191 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 192 | """ 193 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 194 | if pretrained: 195 | model_state = torch.load('/data2/zhangruifei/backbone/res2net50_v1b_26w_4s-3cf99910.pth') 196 | model.load_state_dict(model_state) 197 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 198 | return model 199 | 200 | 201 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 202 | """Constructs a Res2Net-50_v1b_26w_4s lib. 203 | Args: 204 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 205 | """ 206 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 209 | return model 210 | 211 | 212 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 213 | """Constructs a Res2Net-50_v1b_26w_4s lib. 214 | Args: 215 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 216 | """ 217 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 220 | return model 221 | 222 | 223 | if __name__ == '__main__': 224 | images = torch.rand(1, 3, 224, 224).cuda(0) 225 | model = res2net50_v1b_26w_4s(pretrained=True) 226 | model = model.cuda(0) 227 | print(model(images).size()) -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | parse = argparse.ArgumentParser(description='PyTorch Polyp Segmentation') 6 | 7 | "-------------------data option--------------------------" 8 | parse.add_argument('--root', type=str, default='/data2/zhangruifei/polypseg') 9 | parse.add_argument('--dataset', type=str, default='PolypDataset') 10 | parse.add_argument('--train_data_dir', type=str, default='data/Kvasir_CVC-ClinicDB/train') 11 | parse.add_argument('--valid_data_dir', type=str, default='data/Kvasir_CVC-ClinicDB/valid') 12 | parse.add_argument('--test_data_dir', type=str, default='data/CVC-ColonDB') 13 | 14 | # Test set: 15 | # Kvasir/test 16 | # CVC-ClinicDB/test 17 | # CVC-ColonDB 18 | # ETIS-LaribPolypDB 19 | 20 | 21 | 22 | "-------------------training option-----------------------" 23 | parse.add_argument('--mode', type=str, default='train') 24 | parse.add_argument('--nEpoch', type=int, default=80) 25 | parse.add_argument('--batch_size', type=float, default=16) 26 | parse.add_argument('--num_workers', type=int, default=2) 27 | parse.add_argument('--use_gpu', type=bool, default=True) 28 | parse.add_argument('--gpu', type=str, default='0') 29 | parse.add_argument('--load_ckpt', type=str, default=None) 30 | parse.add_argument('--model', type=str, default='LDNet') 31 | parse.add_argument('--expID', type=int, default=0) 32 | parse.add_argument('--ckpt_period', type=int, default=5) 33 | 34 | "-------------------optimizer option-----------------------" 35 | parse.add_argument('--lr', type=float, default=1e-3) 36 | parse.add_argument('--weight_decay', type=float, default=1e-5) 37 | parse.add_argument('--mt', type=float, default=0.9) 38 | parse.add_argument('--power', type=float, default=0.9) 39 | 40 | parse.add_argument('--nclasses', type=int, default=1) 41 | parse.add_argument('--save_img', type=bool, default=False) 42 | 43 | opt = parse.parse_args() 44 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from tqdm import tqdm 5 | from opt import opt 6 | from utils.metrics import evaluate 7 | import datasets 8 | from torch.utils.data import DataLoader 9 | from utils.comm import generate_model 10 | from utils.metrics import Metrics 11 | 12 | 13 | def test(model, test_data_dir): 14 | test_data_name = test_data_dir.split("/")[1] 15 | 16 | print('Loading data......') 17 | test_data = getattr(datasets, opt.dataset)(opt.root, test_data_dir, mode='test') 18 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=opt.num_workers) 19 | total_batch = int(len(test_data) / 1) 20 | 21 | model.eval() 22 | 23 | # metrics_logger initialization 24 | metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2', 'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean']) 25 | 26 | print('Start testing') 27 | with torch.no_grad(): 28 | bar = tqdm(enumerate(test_dataloader), total=total_batch) 29 | for i, data in bar: 30 | img, gt, name = data['image'], data['label'], data['name'] 31 | 32 | if opt.use_gpu: 33 | img = img.cuda() 34 | gt = gt.cuda() 35 | 36 | output = model(img) 37 | _recall, _specificity, _precision, _F1, _F2, _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt, name, test_data_name) 38 | 39 | 40 | metrics.update(recall= _recall, specificity= _specificity, precision= _precision, F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly, IoU_bg= _IoU_bg, IoU_mean= _IoU_mean) 41 | 42 | metrics_result = metrics.mean(total_batch) 43 | 44 | results = open('./checkpoints/exp' + str(opt.expID) + "/testResults.txt", "a+") 45 | 46 | print("\n%s Test Result:" % test_data_name, file=results) 47 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean']), file=results) 48 | print("\n%s Test Result:" % test_data_name) 49 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])) 50 | 51 | results.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu 56 | 57 | print('Loading model......') 58 | model = generate_model(opt) 59 | 60 | #test_data_list = ["Kvasir/test", "CVC-ClinicDB/test", "CVC-ColonDB", "ETIS-LaribPolypDB"] 61 | 62 | if opt.mode == 'test': 63 | print('--- PolypSeg Test---') 64 | test(model, opt.test_data_dir) 65 | 66 | # you could also utilize the following loop operation 67 | # to directly evaluate the performance on all testsets 68 | 69 | # for test_data_dir in test_data_list: 70 | # test(model, "data/" + test_data_dir) 71 | 72 | print('Done') 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import LambdaLR 5 | from tqdm import tqdm 6 | import datasets 7 | from utils.metrics import evaluate 8 | from opt import opt 9 | from utils.comm import generate_model 10 | from utils.loss import DeepSupervisionLoss 11 | from utils.metrics import Metrics 12 | import os 13 | import torch.nn.functional as F 14 | 15 | 16 | def valid(model, valid_dataloader, total_batch): 17 | 18 | model.eval() 19 | 20 | # Metrics_logger initialization 21 | metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2', 'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean']) 22 | 23 | with torch.no_grad(): 24 | bar = tqdm(enumerate(valid_dataloader), total=total_batch) 25 | for i, data in bar: 26 | img, gt = data['image'], data['label'] 27 | 28 | if opt.use_gpu: 29 | img = img.cuda() 30 | gt = gt.cuda() 31 | 32 | output = model(img) 33 | 34 | _recall, _specificity, _precision, _F1, _F2, _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt) 35 | 36 | metrics.update(recall= _recall, specificity= _specificity, precision= _precision, F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly, IoU_bg= _IoU_bg, IoU_mean= _IoU_mean) 37 | 38 | metrics_result = metrics.mean(total_batch) 39 | 40 | return metrics_result 41 | 42 | 43 | def train(): 44 | 45 | # load model 46 | print('Loading model......') 47 | model = generate_model(opt) 48 | print('Load model:', opt.model) 49 | 50 | # load data 51 | print('Loading data......') 52 | train_data = getattr(datasets, opt.dataset)(opt.root, opt.train_data_dir, mode='train') 53 | train_dataloader = DataLoader(train_data, int(opt.batch_size), shuffle=True, num_workers=opt.num_workers) 54 | valid_data = getattr(datasets, opt.dataset)(opt.root, opt.valid_data_dir, mode='valid') 55 | valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=opt.num_workers) 56 | val_total_batch = int(len(valid_data) / 1) 57 | 58 | # load optimizer and scheduler 59 | optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.mt, weight_decay=opt.weight_decay) 60 | 61 | lr_lambda = lambda epoch: pow(1.0 - epoch / opt.nEpoch, opt.power) 62 | scheduler = LambdaLR(optimizer, lr_lambda) 63 | 64 | # train 65 | print('Start training') 66 | print('---------------------------------\n') 67 | 68 | results = open('./checkpoints/exp' + str(opt.expID) + "/validResults.txt", "a+") 69 | best_mIoU = 0 70 | best_idx = 0 71 | for epoch in range(opt.nEpoch): 72 | print('------ Epoch', epoch + 1 + 0) 73 | model.train() 74 | total_batch = int(len(train_data) / opt.batch_size) 75 | bar = tqdm(enumerate(train_dataloader), total=total_batch) 76 | 77 | for i, data in bar: 78 | img = data['image'] 79 | gt = data['label'] 80 | 81 | if opt.use_gpu: 82 | img = img.cuda() 83 | gt = gt.cuda() 84 | 85 | optimizer.zero_grad() 86 | output = model(img) 87 | loss = DeepSupervisionLoss(output, gt) 88 | loss.backward() 89 | 90 | optimizer.step() 91 | bar.set_postfix_str('loss: %.5s' % loss.item()) 92 | 93 | scheduler.step() 94 | 95 | metrics_result = valid(model, valid_dataloader, val_total_batch) 96 | 97 | print("\nValid Result of epoch %d:" % (epoch + 1 + 0), file=results) 98 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean']), file=results) 99 | print("\nValid Result of epoch %d:" % (epoch + 1 + 0)) 100 | print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])) 101 | 102 | if ((epoch + 1 + 0) % opt.ckpt_period == 0): 103 | torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1 + 0)) 104 | 105 | if metrics_result['IoU_mean'] > best_mIoU: 106 | best_idx = epoch + 1 + 0 107 | best_mIoU = metrics_result['IoU_mean'] 108 | torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1 + 0)) 109 | print("Epoch %d with best mIoU: %.4f" % (best_idx, best_mIoU)) 110 | print("\nEpoch %d with best mIoU: %.4f" % (best_idx, best_mIoU), file=results) 111 | 112 | results.close() 113 | 114 | 115 | if __name__ == '__main__': 116 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu 117 | 118 | if opt.mode == 'train': 119 | print('---PolypSeg Train---') 120 | train() 121 | 122 | print('Done') 123 | 124 | 125 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReaFly/LDNet/b673453a26b6e5219677b09292e441de2b676172/utils/__init__.py -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | import models 2 | import torch 3 | import cv2 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | def generate_model(opt): 10 | model = getattr(models, opt.model)() 11 | if opt.use_gpu: 12 | model.cuda() 13 | torch.backends.cudnn.benchmark = True 14 | 15 | if opt.load_ckpt is not None: 16 | model_dict = model.state_dict() 17 | load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth') 18 | assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' 19 | print('Loading checkpoint......') 20 | checkpoint = torch.load(load_ckpt_path) 21 | new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()} 22 | model_dict.update(new_dict) 23 | model.load_state_dict(model_dict) 24 | 25 | print('Done') 26 | 27 | return model 28 | 29 | 30 | def save_binary_img(x, testset, name): 31 | x = x.cpu().data.numpy() 32 | x = np.squeeze(x) # batch_size == 1 33 | x *= 255 34 | img_save_dir = './pred/'+testset 35 | 36 | im = Image.fromarray(x) 37 | if not os.path.exists(img_save_dir): 38 | os.makedirs(img_save_dir) 39 | 40 | if im.mode == 'F': 41 | im = im.convert('L') 42 | im.save(os.path.join(img_save_dir, name[0] + '.png')) 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | """BCE loss""" 7 | class BCELoss(nn.Module): 8 | def __init__(self, weight=None, size_average=True): 9 | super(BCELoss, self).__init__() 10 | self.bceloss = nn.BCELoss(weight=weight, size_average=size_average) 11 | 12 | def forward(self, pred, target): 13 | size = pred.size(0) 14 | pred_flat = pred.view(size, -1) 15 | target_flat = target.view(size, -1) 16 | 17 | loss = self.bceloss(pred_flat, target_flat) 18 | 19 | return loss 20 | 21 | 22 | """Dice loss""" 23 | 24 | 25 | class DiceLoss(nn.Module): 26 | def __init__(self): 27 | super(DiceLoss, self).__init__() 28 | 29 | def forward(self, pred, target): 30 | smooth = 1 31 | 32 | size = pred.size(0) 33 | 34 | pred_flat = pred.view(size, -1) 35 | target_flat = target.view(size, -1) 36 | 37 | intersection = pred_flat * target_flat 38 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_flat.sum(1) + target_flat.sum(1) + smooth) 39 | dice_loss = 1 - dice_score.sum()/size 40 | 41 | return dice_loss 42 | 43 | 44 | """BCE + DICE Loss""" 45 | 46 | 47 | class BceDiceLoss(nn.Module): 48 | def __init__(self, weight=None, size_average=True): 49 | super(BceDiceLoss, self).__init__() 50 | self.bce = BCELoss(weight, size_average) 51 | self.dice = DiceLoss() 52 | 53 | def forward(self, pred, target): 54 | #pred = torch.sigmoid(pred) 55 | bceloss = self.bce(pred, target) 56 | diceloss = self.dice(pred, target) 57 | 58 | loss = diceloss + bceloss 59 | 60 | return loss 61 | 62 | 63 | """ Deep Supervision Loss""" 64 | 65 | 66 | def DeepSupervisionLoss(pred, gt): 67 | 68 | d0, d1, d2, d3, d4 = pred[:] 69 | 70 | criterion = BceDiceLoss() 71 | loss0 = criterion(d0, gt) 72 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 73 | loss1 = criterion(d1, gt) 74 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 75 | loss2 = criterion(d2, gt) 76 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 77 | loss3 = criterion(d3, gt) 78 | gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) 79 | loss4 = criterion(d4, gt) 80 | 81 | return loss0 + loss1 + loss2 + loss3 + loss4 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .comm import save_binary_img 3 | from opt import opt 4 | 5 | 6 | def evaluate(pred, gt, name=None, testset=None): 7 | if isinstance(pred, (list, tuple)): 8 | pred = pred[0] 9 | #pred = torch.sigmoid(pred) 10 | pred_binary = (pred >= 0.5).float() 11 | pred_binary_inverse = (pred_binary == 0).float() 12 | 13 | gt_binary = (gt >= 0.5).float() 14 | gt_binary_inverse = (gt_binary == 0).float() 15 | 16 | if opt.save_img == True: 17 | save_binary_img(pred_binary, testset, name) 18 | 19 | TP = pred_binary.mul(gt_binary).sum() 20 | FP = pred_binary.mul(gt_binary_inverse).sum() 21 | TN = pred_binary_inverse.mul(gt_binary_inverse).sum() 22 | FN = pred_binary_inverse.mul(gt_binary).sum() 23 | 24 | if TP.item() == 0: 25 | TP = torch.Tensor([1]).cuda() 26 | 27 | # recall 28 | Recall = TP / (TP + FN) 29 | 30 | # Specificity or true negative rate 31 | Specificity = TN / (TN + FP) 32 | 33 | # Precision or positive predictive value 34 | Precision = TP / (TP + FP) 35 | 36 | # F1 score = Dice 37 | F1 = 2 * Precision * Recall / (Precision + Recall) 38 | 39 | # F2 score 40 | F2 = 5 * Precision * Recall / (4 * Precision + Recall) 41 | 42 | # Overall accuracy 43 | ACC_overall = (TP + TN) / (TP + FP + FN + TN) 44 | 45 | # IoU for poly 46 | IoU_poly = TP / (TP + FP + FN) 47 | 48 | # IoU for background 49 | IoU_bg = TN / (TN + FP + FN) 50 | 51 | # mean IoU 52 | IoU_mean = (IoU_poly + IoU_bg) / 2.0 53 | 54 | return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean 55 | 56 | 57 | class Metrics(object): 58 | def __init__(self, metrics_list): 59 | self.metrics = {} 60 | for metric in metrics_list: 61 | self.metrics[metric] = 0 62 | 63 | def update(self, **kwargs): 64 | for k, v in kwargs.items(): 65 | assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k) 66 | if isinstance(v, torch.Tensor): 67 | v = v.item() 68 | 69 | self.metrics[k] += v 70 | 71 | def mean(self, total): 72 | mean_metrics = {} 73 | for k, v in self.metrics.items(): 74 | mean_metrics[k] = v / total 75 | return mean_metrics 76 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import scipy.ndimage 4 | import random 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | from skimage import transform as tf 9 | import numbers 10 | 11 | class ToTensor(object): 12 | 13 | def __call__(self, data): 14 | image, label = data['image'], data['label'] 15 | return {'image': F.to_tensor(image), 'label': F.to_tensor(label)} 16 | 17 | 18 | class Resize(object): 19 | 20 | def __init__(self, size): 21 | self.size = size 22 | 23 | def __call__(self, data): 24 | image, label = data['image'], data['label'] 25 | 26 | return {'image': F.resize(image, self.size), 'label': F.resize(label, self.size)} 27 | 28 | 29 | class RandomHorizontalFlip(object): 30 | def __init__(self, p=0.5): 31 | self.p = p 32 | 33 | def __call__(self, data): 34 | image, label = data['image'], data['label'] 35 | 36 | if random.random() < self.p: 37 | return {'image': F.hflip(image), 'label': F.hflip(label)} 38 | 39 | return {'image': image, 'label': label} 40 | 41 | 42 | class RandomVerticalFlip(object): 43 | def __init__(self, p=0.5): 44 | self.p = p 45 | 46 | def __call__(self, data): 47 | image, label = data['image'], data['label'] 48 | 49 | if random.random() < self.p: 50 | return {'image': F.vflip(image), 'label': F.vflip(label)} 51 | 52 | return {'image': image, 'label': label} 53 | 54 | 55 | class RandomRotation(object): 56 | 57 | def __init__(self, degrees, resample=False, expand=False, center=None): 58 | if isinstance(degrees,numbers.Number): 59 | if degrees < 0: 60 | raise ValueError("If degrees is a single number, it must be positive.") 61 | self.degrees = (-degrees, degrees) 62 | else: 63 | if len(degrees) != 2: 64 | raise ValueError("If degrees is a sequence, it must be of len 2.") 65 | self.degrees = degrees 66 | self.resample = resample 67 | self.expand = expand 68 | self.center = center 69 | 70 | @staticmethod 71 | def get_params(degrees): 72 | """Get parameters for ``rotate`` for a random rotation. 73 | 74 | Returns: 75 | sequence: params to be passed to ``rotate`` for random rotation. 76 | """ 77 | angle = random.uniform(degrees[0], degrees[1]) 78 | 79 | return angle 80 | 81 | def __call__(self, data): 82 | 83 | """ 84 | img (PIL Image): Image to be rotated. 85 | 86 | Returns: 87 | PIL Image: Rotated image. 88 | """ 89 | image, label = data['image'], data['label'] 90 | 91 | if random.random() < 0.5: 92 | angle = self.get_params(self.degrees) 93 | return {'image': F.rotate(image, angle, self.resample, self.expand, self.center), 94 | 'label': F.rotate(label, angle, self.resample, self.expand, self.center)} 95 | 96 | return {'image': image, 'label': label} 97 | 98 | 99 | class RandomZoom(object): 100 | def __init__(self, zoom=(0.8, 1.2)): 101 | self.min, self.max = zoom[0], zoom[1] 102 | 103 | def __call__(self, data): 104 | image, label = data['image'], data['label'] 105 | 106 | if random.random() < 0.5: 107 | image = np.array(image) 108 | label = np.array(label) 109 | 110 | zoom = random.uniform(self.min, self.max) 111 | zoom_image = clipped_zoom(image, zoom) 112 | zoom_label = clipped_zoom(label, zoom) 113 | 114 | zoom_image = Image.fromarray(zoom_image.astype('uint8'), 'RGB') 115 | zoom_label = Image.fromarray(zoom_label.astype('uint8'), 'L') 116 | return {'image': zoom_image, 'label': zoom_label} 117 | 118 | return {'image': image, 'label': label} 119 | 120 | 121 | def clipped_zoom(img, zoom_factor, **kwargs): 122 | h, w = img.shape[:2] 123 | 124 | # For multichannel images we don't want to apply the zoom factor to the RGB 125 | # dimension, so instead we create a tuple of zoom factors, one per array 126 | # dimension, with 1's for any trailing dimensions after the width and height. 127 | zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2) 128 | 129 | # Zooming out 130 | if zoom_factor < 1: 131 | 132 | # Bounding box of the zoomed-out image within the output array 133 | zh = int(np.round(h * zoom_factor)) 134 | zw = int(np.round(w * zoom_factor)) 135 | top = (h - zh) // 2 136 | left = (w - zw) // 2 137 | 138 | # Zero-padding 139 | out = np.zeros_like(img) 140 | out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs) 141 | 142 | # Zooming in 143 | elif zoom_factor > 1: 144 | 145 | # Bounding box of the zoomed-in region within the input array 146 | zh = int(np.round(h / zoom_factor)) 147 | zw = int(np.round(w / zoom_factor)) 148 | top = (h - zh) // 2 149 | left = (w - zw) // 2 150 | 151 | zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs) 152 | 153 | # `zoom_in` might still be slightly different with `img` due to rounding, so 154 | # trim off any extra pixels at the edges or zero-padding 155 | 156 | if zoom_in.shape[0] >= h: 157 | zoom_top = (zoom_in.shape[0] - h) // 2 158 | sh = h 159 | out_top = 0 160 | oh = h 161 | else: 162 | zoom_top = 0 163 | sh = zoom_in.shape[0] 164 | out_top = (h - zoom_in.shape[0]) // 2 165 | oh = zoom_in.shape[0] 166 | if zoom_in.shape[1] >= w: 167 | zoom_left = (zoom_in.shape[1] - w) // 2 168 | sw = w 169 | out_left = 0 170 | ow = w 171 | else: 172 | zoom_left = 0 173 | sw = zoom_in.shape[1] 174 | out_left = (w - zoom_in.shape[1]) // 2 175 | ow = zoom_in.shape[1] 176 | 177 | out = np.zeros_like(img) 178 | out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw] 179 | 180 | # If zoom_factor == 1, just return the input array 181 | else: 182 | out = img 183 | return out 184 | 185 | 186 | class Translation(object): 187 | def __init__(self, translation): 188 | self.translation = translation 189 | 190 | def __call__(self, data): 191 | image, label = data['image'], data['label'] 192 | 193 | if random.random() < 0.5: 194 | image = np.array(image) 195 | label = np.array(label) 196 | rows, cols, ch = image.shape 197 | 198 | translation = random.uniform(0, self.translation) 199 | tr_x = translation / 2 200 | tr_y = translation / 2 201 | Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]]) 202 | 203 | translate_image = cv2.warpAffine(image, Trans_M, (cols, rows)) 204 | translate_label = cv2.warpAffine(label, Trans_M, (cols, rows)) 205 | 206 | translate_image = Image.fromarray(translate_image.astype('uint8'), 'RGB') 207 | translate_label = Image.fromarray(translate_label.astype('uint8'), 'L') 208 | 209 | return {'image': translate_image, 'label': translate_label} 210 | 211 | return {'image': image, 'label': label} 212 | 213 | 214 | class RandomCrop(object): 215 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 216 | if isinstance(size, numbers.Number): 217 | self.size = (int(size), int(size)) 218 | else: 219 | self.size = size 220 | self.padding = padding 221 | self.pad_if_needed = pad_if_needed 222 | self.fill = fill 223 | self.padding_mode = padding_mode 224 | 225 | @staticmethod 226 | def get_params(img, output_size): 227 | """Get parameters for ``crop`` for a random crop. 228 | Args: 229 | img (PIL Image): Image to be cropped. 230 | output_size (tuple): Expected output size of the crop. 231 | Returns: 232 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 233 | """ 234 | w, h = img.size 235 | th, tw = output_size 236 | if w == tw and h == th: 237 | return 0, 0, h, w 238 | 239 | #i = torch.randint(0, h - th + 1, size=(1, )).item() 240 | #j = torch.randint(0, w - tw + 1, size=(1, )).item() 241 | i = random.randint(0, h - th) 242 | j = random.randint(0, w - tw) 243 | return i, j, th, tw 244 | 245 | 246 | def __call__(self, data): 247 | """ 248 | Args: 249 | img (PIL Image): Image to be cropped. 250 | Returns: 251 | PIL Image: Cropped image. 252 | """ 253 | img, label = data['image'], data['label'] 254 | if self.padding is not None: 255 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 256 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 257 | # pad the width if needed 258 | if self.pad_if_needed and img.size[0] < self.size[1]: 259 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 260 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 261 | 262 | # pad the height if needed 263 | if self.pad_if_needed and img.size[1] < self.size[0]: 264 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 265 | label = F.pad(label, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 266 | i, j, h, w = self.get_params(img, self.size) 267 | img = F.crop(img, i, j ,h ,w) 268 | label = F.crop(label, i, j, h, w) 269 | return {"image": img, "label": label} 270 | 271 | 272 | class Normalization(object): 273 | 274 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 275 | self.mean = mean 276 | self.std = std 277 | 278 | def __call__(self, sample): 279 | image, label = sample['image'], sample['label'] 280 | image = F.normalize(image, self.mean, self.std) 281 | return {'image': image, 'label': label} 282 | 283 | -------------------------------------------------------------------------------- /utils/transform_multi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import scipy.ndimage 4 | import random 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | from skimage import transform as tf 9 | import numbers 10 | 11 | class ToTensor(object): 12 | 13 | def __call__(self, data): 14 | trans_data = dict() 15 | for k, v in data.items(): 16 | trans_data[k] = F.to_tensor(v) 17 | 18 | return trans_data 19 | 20 | class ToPILImage(object): 21 | 22 | def __call__(self, data): 23 | trans_data = dict() 24 | for k, v in data.items(): 25 | trans_data[k] = F.to_pil_image(v) 26 | 27 | return trans_data 28 | 29 | 30 | class Resize(object): 31 | 32 | def __init__(self, size): 33 | self.size = size 34 | 35 | def __call__(self, data): 36 | trans_data = dict() 37 | for k, v in data.items(): 38 | trans_data[k] = F.resize(v, self.size) 39 | 40 | return trans_data 41 | 42 | 43 | class RandomHorizontalFlip(object): 44 | def __init__(self, p=0.5): 45 | self.p = p 46 | 47 | def __call__(self, data): 48 | trans_data = dict() 49 | 50 | if random.random() < self.p: 51 | for k, v in data.items(): 52 | trans_data[k] = F.hflip(v) 53 | else: 54 | trans_data = data 55 | 56 | return trans_data 57 | 58 | 59 | class RandomVerticalFlip(object): 60 | def __init__(self, p=0.5): 61 | self.p = p 62 | 63 | def __call__(self, data): 64 | trans_data = dict() 65 | 66 | if random.random() < self.p: 67 | for k, v in data.items(): 68 | trans_data[k] = F.vflip(v) 69 | else: 70 | trans_data = data 71 | 72 | return trans_data 73 | 74 | 75 | class RandomRotation(object): 76 | 77 | def __init__(self, degrees=90, resample=None, expand=False, center=None): 78 | if isinstance(degrees,numbers.Number): 79 | if degrees < 0: 80 | raise ValueError("If degrees is a single number, it must be positive.") 81 | self.degrees = (-degrees, degrees) 82 | else: 83 | if len(degrees) != 2: 84 | raise ValueError("If degrees is a sequence, it must be of len 2.") 85 | self.degrees = degrees 86 | self.resample = resample 87 | self.expand = expand 88 | self.center = center 89 | 90 | @staticmethod 91 | def get_params(degrees): 92 | """Get parameters for ``rotate`` for a random rotation. 93 | 94 | Returns: 95 | sequence: params to be passed to ``rotate`` for random rotation. 96 | """ 97 | angle = random.uniform(degrees[0], degrees[1]) 98 | 99 | return angle 100 | 101 | def __call__(self, data): 102 | 103 | """ 104 | img (PIL Image): Image to be rotated. 105 | 106 | Returns: 107 | PIL Image: Rotated image. 108 | """ 109 | trans_data = dict() 110 | 111 | if random.random() < 0.5: 112 | angle = self.get_params(self.degrees) 113 | for k, v in data.items(): 114 | trans_data[k] = F.rotate(v, angle=angle, resample=self.resample, expand=self.expand, center=self.center) 115 | else: 116 | trans_data = data 117 | 118 | return trans_data 119 | 120 | 121 | class RandomZoom(object): 122 | def __init__(self, zoom=(0.8, 1.2)): 123 | self.min, self.max = zoom[0], zoom[1] 124 | 125 | def __call__(self, data): 126 | trans_data = dict() 127 | 128 | if random.random() < 0.5: 129 | zoom = random.uniform(self.min, self.max) 130 | for k, v in data.items(): 131 | mode = v.mode 132 | v = np.array(v) 133 | zoom_v = clipped_zoom(v, zoom) 134 | zoom_v = Image.fromarray(zoom_v.astype('uint8'), mode) 135 | trans_data[k] = zoom_v 136 | else: 137 | trans_data = data 138 | 139 | return trans_data 140 | 141 | 142 | def clipped_zoom(img, zoom_factor, **kwargs): 143 | h, w = img.shape[:2] 144 | 145 | # For multichannel images we don't want to apply the zoom factor to the RGB 146 | # dimension, so instead we create a tuple of zoom factors, one per array 147 | # dimension, with 1's for any trailing dimensions after the width and height. 148 | zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2) 149 | 150 | # Zooming out 151 | if zoom_factor < 1: 152 | 153 | # Bounding box of the zoomed-out image within the output array 154 | zh = int(np.round(h * zoom_factor)) 155 | zw = int(np.round(w * zoom_factor)) 156 | top = (h - zh) // 2 157 | left = (w - zw) // 2 158 | 159 | # Zero-padding 160 | out = np.zeros_like(img) 161 | out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs) 162 | 163 | # Zooming in 164 | elif zoom_factor > 1: 165 | 166 | # Bounding box of the zoomed-in region within the input array 167 | zh = int(np.round(h / zoom_factor)) 168 | zw = int(np.round(w / zoom_factor)) 169 | top = (h - zh) // 2 170 | left = (w - zw) // 2 171 | 172 | zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs) 173 | 174 | # `zoom_in` might still be slightly different with `img` due to rounding, so 175 | # trim off any extra pixels at the edges or zero-padding 176 | 177 | if zoom_in.shape[0] >= h: 178 | zoom_top = (zoom_in.shape[0] - h) // 2 179 | sh = h 180 | out_top = 0 181 | oh = h 182 | else: 183 | zoom_top = 0 184 | sh = zoom_in.shape[0] 185 | out_top = (h - zoom_in.shape[0]) // 2 186 | oh = zoom_in.shape[0] 187 | if zoom_in.shape[1] >= w: 188 | zoom_left = (zoom_in.shape[1] - w) // 2 189 | sw = w 190 | out_left = 0 191 | ow = w 192 | else: 193 | zoom_left = 0 194 | sw = zoom_in.shape[1] 195 | out_left = (w - zoom_in.shape[1]) // 2 196 | ow = zoom_in.shape[1] 197 | 198 | out = np.zeros_like(img) 199 | out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw] 200 | 201 | # If zoom_factor == 1, just return the input array 202 | else: 203 | out = img 204 | return out 205 | 206 | 207 | 208 | 209 | class RandomCrop(object): 210 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 211 | if isinstance(size, numbers.Number): 212 | self.size = (int(size), int(size)) 213 | else: 214 | self.size = size 215 | self.padding = padding 216 | self.pad_if_needed = pad_if_needed 217 | self.fill = fill 218 | self.padding_mode = padding_mode 219 | 220 | @staticmethod 221 | def get_params(img, output_size): 222 | """Get parameters for ``crop`` for a random crop. 223 | Args: 224 | img (PIL Image): Image to be cropped. 225 | output_size (tuple): Expected output size of the crop. 226 | Returns: 227 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 228 | """ 229 | w, h = img.size 230 | 231 | th, tw = output_size 232 | if w == tw and h == th: 233 | return 0, 0, h, w 234 | 235 | #i = torch.randint(0, h - th + 1, size=(1, )).item() 236 | #j = torch.randint(0, w - tw + 1, size=(1, )).item() 237 | i = random.randint(0, h - th) 238 | j = random.randint(0, w - tw) 239 | return i, j, th, tw 240 | 241 | 242 | def __call__(self, data): 243 | """ 244 | Args: 245 | img (PIL Image): Image to be cropped. 246 | Returns: 247 | PIL Image: Cropped image. 248 | """ 249 | trans_data = dict() 250 | i, j, h, w = self.get_params(data['image'], self.size) 251 | for k, v in data.items(): 252 | if self.padding is not None: 253 | v = F.pad(v, self.padding, self.fill, self.padding_mode) 254 | # pad the width if needed 255 | if self.pad_if_needed and v.size[0] < self.size[1]: 256 | v = F.pad(v, (self.size[1] - v.size[0], 0), self.fill, self.padding_mode) 257 | # pad the height if needed 258 | if self.pad_if_needed and v.size[1] < self.size[0]: 259 | v = F.pad(v, (0, self.size[0] - v.size[1]), self.fill, self.padding_mode) 260 | v = F.crop(v, i, j ,h ,w) 261 | trans_data[k] = v 262 | return trans_data 263 | 264 | 265 | class Normalize(object): 266 | 267 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 268 | self.mean = mean 269 | self.std = std 270 | 271 | def __call__(self, data): 272 | trans_data = data 273 | trans_data['image'] = F.normalize(trans_data['image'], self.mean, self.std) 274 | return trans_data 275 | 276 | --------------------------------------------------------------------------------