├── DuAT.py ├── Fig ├── fig1.png ├── fig2.png ├── fig3.png ├── fig4.png ├── fig5.png └── fig6.png ├── README.md ├── Test.py ├── Train.py ├── lib ├── DuAT.py └── pvtv2.py ├── test.sh ├── train.sh └── utils ├── Readme.md ├── dataloader.py ├── format_conversion.py └── utils.py /DuAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.pvtv2 import pvt_v2_b2 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mmcv.cnn import ConvModule 10 | from torch.nn import Conv2d, UpsamplingBilinear2d 11 | import warnings 12 | import torch 13 | from mmcv.cnn import constant_init, kaiming_init 14 | from torch import nn 15 | from torchvision.transforms.functional import normalize 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | class BasicConv2d(nn.Module): 20 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 21 | super(BasicConv2d, self).__init__() 22 | 23 | self.conv = nn.Conv2d(in_planes, out_planes, 24 | kernel_size=kernel_size, stride=stride, 25 | padding=padding, dilation=dilation, bias=False) 26 | self.bn = nn.BatchNorm2d(out_planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | x = self.relu(x) 33 | return x 34 | 35 | class Block(nn.Sequential): 36 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d): 37 | super(Block, self).__init__() 38 | if bn_start: 39 | self.add_module('norm1', norm_layer(input_num)), 40 | 41 | self.add_module('relu1', nn.ReLU(inplace=True)), 42 | self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)), 43 | 44 | self.add_module('norm2', norm_layer(num1)), 45 | self.add_module('relu2', nn.ReLU(inplace=True)), 46 | self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 47 | dilation=dilation_rate, padding=dilation_rate)), 48 | self.drop_rate = drop_out 49 | 50 | def forward(self, _input): 51 | feature = super(Block, self).forward(_input) 52 | if self.drop_rate > 0: 53 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training) 54 | return feature 55 | 56 | 57 | def Upsample(x, size, align_corners = False): 58 | """ 59 | Wrapper Around the Upsample Call 60 | """ 61 | return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners) 62 | 63 | 64 | def last_zero_init(m): 65 | if isinstance(m, nn.Sequential): 66 | constant_init(m[-1], val=0) 67 | else: 68 | constant_init(m, val=0) 69 | 70 | 71 | class ContextBlock(nn.Module): 72 | 73 | def __init__(self, 74 | inplanes, 75 | ratio, 76 | pooling_type='att', 77 | fusion_types=('channel_mul', )): 78 | super(ContextBlock, self).__init__() 79 | assert pooling_type in ['avg', 'att'] 80 | assert isinstance(fusion_types, (list, tuple)) 81 | valid_fusion_types = ['channel_add', 'channel_mul'] 82 | assert all([f in valid_fusion_types for f in fusion_types]) 83 | assert len(fusion_types) > 0, 'at least one fusion should be used' 84 | self.inplanes = inplanes 85 | self.ratio = ratio 86 | self.planes = int(inplanes * ratio) 87 | self.pooling_type = pooling_type 88 | self.fusion_types = fusion_types 89 | if pooling_type == 'att': 90 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 91 | self.softmax = nn.Softmax(dim=2) 92 | else: 93 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 94 | if 'channel_add' in fusion_types: 95 | self.channel_add_conv = nn.Sequential( 96 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 97 | nn.LayerNorm([self.planes, 1, 1]), 98 | nn.ReLU(inplace=True), # yapf: disable 99 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 100 | else: 101 | self.channel_add_conv = None 102 | if 'channel_mul' in fusion_types: 103 | self.channel_mul_conv = nn.Sequential( 104 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 105 | nn.LayerNorm([self.planes, 1, 1]), 106 | nn.ReLU(inplace=True), # yapf: disable 107 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 108 | else: 109 | self.channel_mul_conv = None 110 | self.reset_parameters() 111 | 112 | def reset_parameters(self): 113 | if self.pooling_type == 'att': 114 | kaiming_init(self.conv_mask, mode='fan_in') 115 | self.conv_mask.inited = True 116 | 117 | if self.channel_add_conv is not None: 118 | last_zero_init(self.channel_add_conv) 119 | if self.channel_mul_conv is not None: 120 | last_zero_init(self.channel_mul_conv) 121 | 122 | def spatial_pool(self, x): 123 | batch, channel, height, width = x.size() 124 | if self.pooling_type == 'att': 125 | input_x = x 126 | # [N, C, H * W] 127 | input_x = input_x.view(batch, channel, height * width) 128 | # [N, 1, C, H * W] 129 | input_x = input_x.unsqueeze(1) 130 | # [N, 1, H, W] 131 | context_mask = self.conv_mask(x) 132 | # [N, 1, H * W] 133 | context_mask = context_mask.view(batch, 1, height * width) 134 | # [N, 1, H * W] 135 | context_mask = self.softmax(context_mask) 136 | # [N, 1, H * W, 1] 137 | context_mask = context_mask.unsqueeze(-1) 138 | # [N, 1, C, 1] 139 | context = torch.matmul(input_x, context_mask) 140 | # [N, C, 1, 1] 141 | context = context.view(batch, channel, 1, 1) 142 | else: 143 | # [N, C, 1, 1] 144 | context = self.avg_pool(x) 145 | 146 | return context 147 | 148 | def forward(self, x): 149 | # [N, C, 1, 1] 150 | context = self.spatial_pool(x) 151 | 152 | out = x 153 | if self.channel_mul_conv is not None: 154 | # [N, C, 1, 1] 155 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 156 | out = out + out * channel_mul_term 157 | if self.channel_add_conv is not None: 158 | # [N, C, 1, 1] 159 | channel_add_term = self.channel_add_conv(context) 160 | out = out + channel_add_term 161 | 162 | return out 163 | 164 | 165 | 166 | class ChannelAttention(nn.Module): 167 | def __init__(self, in_planes, ratio=16): 168 | super(ChannelAttention, self).__init__() 169 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 170 | self.max_pool = nn.AdaptiveMaxPool2d(1) 171 | 172 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 173 | self.relu1 = nn.ReLU() 174 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 175 | 176 | self.sigmoid = nn.Sigmoid() 177 | 178 | def forward(self, x): 179 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 180 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 181 | out = avg_out + max_out 182 | return self.sigmoid(out) 183 | 184 | 185 | class SpatialAttention(nn.Module): 186 | def __init__(self, kernel_size=7): 187 | super(SpatialAttention, self).__init__() 188 | 189 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 190 | padding = 3 if kernel_size == 7 else 1 191 | 192 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 193 | self.sigmoid = nn.Sigmoid() 194 | 195 | def forward(self, x): 196 | avg_out = torch.mean(x, dim=1, keepdim=True) 197 | max_out, _ = torch.max(x, dim=1, keepdim=True) 198 | x = torch.cat([avg_out, max_out], dim=1) 199 | x = self.conv1(x) 200 | return self.sigmoid(x) 201 | 202 | 203 | class ConvBranch(nn.Module): 204 | def __init__(self, in_features, hidden_features = None, out_features = None): 205 | super().__init__() 206 | hidden_features = hidden_features or in_features 207 | out_features = out_features or in_features 208 | self.conv1 = nn.Sequential( 209 | nn.Conv2d(in_features, hidden_features, 1, bias=False), 210 | nn.BatchNorm2d(hidden_features), 211 | nn.ReLU(inplace=True) 212 | ) 213 | self.conv2 = nn.Sequential( 214 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 215 | nn.BatchNorm2d(hidden_features), 216 | nn.ReLU(inplace=True) 217 | ) 218 | self.conv3 = nn.Sequential( 219 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 220 | nn.BatchNorm2d(hidden_features), 221 | nn.ReLU(inplace=True) 222 | ) 223 | self.conv4 = nn.Sequential( 224 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 225 | nn.BatchNorm2d(hidden_features), 226 | nn.ReLU(inplace=True) 227 | ) 228 | self.conv5 = nn.Sequential( 229 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 230 | nn.BatchNorm2d(hidden_features), 231 | nn.SiLU(inplace=True) 232 | ) 233 | self.conv6 = nn.Sequential( 234 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 235 | nn.BatchNorm2d(hidden_features), 236 | nn.ReLU(inplace=True) 237 | ) 238 | self.conv7 = nn.Sequential( 239 | nn.Conv2d(hidden_features, out_features, 1, bias=False), 240 | nn.ReLU(inplace=True) 241 | ) 242 | self.ca = ChannelAttention(64) 243 | self.sa = SpatialAttention() 244 | self.sigmoid_spatial = nn.Sigmoid() 245 | 246 | def forward(self, x): 247 | res1 = x 248 | res2 = x 249 | x = self.conv1(x) 250 | x = x + self.conv2(x) 251 | x = self.conv3(x) 252 | x = x + self.conv4(x) 253 | x = self.conv5(x) 254 | x = x + self.conv6(x) 255 | x = self.conv7(x) 256 | x_mask = self.sigmoid_spatial(x) 257 | res1 = res1 * x_mask 258 | return res2 + res1 259 | 260 | 261 | class GLSA(nn.Module): 262 | 263 | def __init__(self, input_dim=512, embed_dim=32, k_s=3): 264 | super().__init__() 265 | 266 | self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1) 267 | self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1) 268 | self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1) 269 | self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1) 270 | self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2) 271 | self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim) 272 | 273 | def forward(self, x): 274 | b, c, h, w = x.size() 275 | x_0, x_1 = x.chunk(2,dim = 1) 276 | 277 | # local block 278 | local = self.local(self.local_11conv(x_0)) 279 | 280 | # Globel block 281 | Globel = self.GlobelBlock(self.global_11conv(x_1)) 282 | 283 | # concat Globel + local 284 | x = torch.cat([local,Globel], dim=1) 285 | x = self.conv1_1(x) 286 | 287 | return x 288 | 289 | class SBA(nn.Module): 290 | 291 | def __init__(self,input_dim = 64): 292 | super().__init__() 293 | 294 | self.input_dim = input_dim 295 | 296 | self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1) 297 | self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1) 298 | 299 | 300 | self.conv = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False)) 301 | self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False) 302 | self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False) 303 | 304 | self.Sigmoid = nn.Sigmoid() 305 | 306 | def forward(self, H_feature, L_feature): 307 | 308 | L_feature = self.fc1(L_feature) 309 | H_feature = self.fc2(H_feature) 310 | 311 | g_L_feature = self.Sigmoid(L_feature) 312 | g_H_feature = self.Sigmoid(H_feature) 313 | 314 | L_feature = self.d_in1(L_feature) 315 | H_feature = self.d_in2(H_feature) 316 | 317 | 318 | L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False) 319 | H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False) 320 | 321 | H_feature = Upsample(H_feature, size = L_feature.size()[2:]) 322 | out = self.conv(torch.cat([H_feature,L_feature], dim=1)) 323 | return out 324 | 325 | 326 | class DuAT(nn.Module): 327 | def __init__(self, dim=32, dims= [64, 128, 320, 512]): 328 | super(DuAT, self).__init__() 329 | 330 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512] 331 | path = './pretrained_pth/pvt_v2_b2.pth' 332 | save_model = torch.load(path) 333 | model_dict = self.backbone.state_dict() 334 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 335 | model_dict.update(state_dict) 336 | self.backbone.load_state_dict(model_dict) 337 | 338 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 339 | 340 | self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim) 341 | self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim) 342 | self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim) 343 | self.L_feature = BasicConv2d(c1_in_channels,dim, 3,1,1) 344 | 345 | self.SBA = SBA(input_dim = dim) 346 | self.fuse = BasicConv2d(dim * 2, dim, 1) 347 | self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.Conv2d(dim, 1, kernel_size=1, bias=False)) 348 | 349 | 350 | def forward(self, x): 351 | # backbone 352 | pvt = self.backbone(x) 353 | c1, c2, c3, c4 = pvt 354 | n, _, h, w = c4.shape 355 | _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11] 356 | _c4 = Upsample(_c4, c3.size()[2:]) 357 | _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22] 358 | _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44] 359 | 360 | output = self.fuse2(torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)) 361 | 362 | L_feature = self.L_feature(c1) # [1, 64, 88, 88] 363 | H_feature = self.fuse(torch.cat([_c4, _c3], dim=1)) 364 | H_feature = Upsample(H_feature,c2.size()[2:]) 365 | 366 | output2 = self.SBA(H_feature,L_feature) 367 | 368 | output = F.interpolate(output, scale_factor=8, mode='bilinear') 369 | output2 = F.interpolate(output2, scale_factor=4, mode='bilinear') 370 | 371 | return output, output2 372 | 373 | 374 | 375 | if __name__ == '__main__': 376 | 377 | model = DuAT().to('cuda') 378 | from torchinfo import summary 379 | # summary(model, (1, 3, 352, 352)) 380 | from thop import profile 381 | import torch 382 | input = torch.randn(1, 3, 352, 352).to('cuda') 383 | macs, params = profile(model, inputs=(input,)) 384 | print('macs:', macs / 1000000000) 385 | print('params:', params / 1000000) 386 | 387 | # import time 388 | ## net = model() 389 | # model.eval() 390 | # time_count = 0.0 391 | # for i in range(1000): 392 | # image = torch.randn(1, 3, 352, 352).cuda() 393 | # torch.cuda.synchronize() 394 | # start_time = time.time() 395 | # pred_semantic = model(image) 396 | # torch.cuda.synchronize() 397 | # print(time.time() - start_time) 398 | # if i >= 100 and i <= 900: 399 | # time_count = time_count + time.time() - start_time 400 | # print("FPS:", 800 / time_count) 401 | 402 | 403 | 404 | 405 | 406 | -------------------------------------------------------------------------------- /Fig/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig1.png -------------------------------------------------------------------------------- /Fig/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig2.png -------------------------------------------------------------------------------- /Fig/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig3.png -------------------------------------------------------------------------------- /Fig/fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig4.png -------------------------------------------------------------------------------- /Fig/fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig5.png -------------------------------------------------------------------------------- /Fig/fig6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Barrett-python/DuAT/a075fe21fabbb65352f8a730b9396b45ca00d41b/Fig/fig6.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DuAT 2 | Feilong Tang, Qiming Huang, Jinfeng Wang, Xianxu Hou, Jionglong Su, and Jingxin Liu 3 | 4 | This repo is the official implementation of ["DuAT: Dual-Aggregation Transformer Network for Medical Image Segmentation"](https://arxiv.org/abs/2212.11677). 5 | 6 | 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/medical-image-segmentation-on-2018-data)](https://paperswithcode.com/sota/medical-image-segmentation-on-2018-data?p=duat-dual-aggregation-transformer-network-for) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/medical-image-segmentation-on-cvc-clinicdb)](https://paperswithcode.com/sota/medical-image-segmentation-on-cvc-clinicdb?p=duat-dual-aggregation-transformer-network-for) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/medical-image-segmentation-on-etis)](https://paperswithcode.com/sota/medical-image-segmentation-on-etis?p=duat-dual-aggregation-transformer-network-for) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/lesion-segmentation-on-isic-2018)](https://paperswithcode.com/sota/lesion-segmentation-on-isic-2018?p=duat-dual-aggregation-transformer-network-for) 12 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/medical-image-segmentation-on-cvc-colondb)](https://paperswithcode.com/sota/medical-image-segmentation-on-cvc-colondb?p=duat-dual-aggregation-transformer-network-for) 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/duat-dual-aggregation-transformer-network-for/medical-image-segmentation-on-kvasir-seg)](https://paperswithcode.com/sota/medical-image-segmentation-on-kvasir-seg?p=duat-dual-aggregation-transformer-network-for) 14 | 15 | 16 | ## 1. Introduction 17 | **DuAT** is initially described in [PRCV](https://arxiv.org/pdf/2212.11677.pdf). 18 | 19 | Transformer-based models have been widely demon- strated to be successful in computer vision tasks by mod- elling long-range dependencies and capturing global rep- resentations. However, they are often dominated by fea- tures of large patterns leading to the loss of local details (e.g., boundaries and small objects), which are critical in medical image segmentation. To alleviate this problem, we propose a Dual-Aggregation Transformer Network called DuAT, which is characterized by two innovative designs, namely, the Global-to-Local Spatial Aggregation (GLSA) and Selective Boundary Aggregation (SBA) modules. The GLSA has the ability to aggregate and represent both global and local spatial features, which are beneficial for locat- ing large and small objects, respectively. The SBA mod- ule is used to aggregate the boundary characteristic from low-level features and semantic information from high-level features for better preserving boundary details and locat- ing the re-calibration objects. Extensive experiments in six benchmark datasets demonstrate that our proposed model outperforms state-of-the-art methods in the segmentation of skin lesion images, and polyps in colonoscopy images. In addition, our approach is more robust than existing meth- ods in various challenging situations such as small object segmentation and ambiguous object boundaries. 20 | 21 | 22 | ## 2. Framework Overview 23 | ![](https://github.com/Barrett-python/DuAT/blob/main/Fig/fig1.png) 24 | 25 | ## 3. Results 26 | ### 3.1 Image-level Polyp Segmentation 27 | ![](https://github.com/Barrett-python/DuAT/blob/main/Fig/fig2.png) 28 | The polyp Segmentation prediction results in [here](https://drive.google.com/drive/folders/14IDwewAb12HWlxgOFtFB46aMJyqPaKpz?usp=sharing). 29 | 30 | ## 4. Usage: 31 | ### 4.1 Recommended environment: 32 | ``` 33 | Python 3.8 34 | Pytorch 1.7.1 35 | torchvision 0.8.2 36 | ``` 37 | ### 4.2 Data preparation: 38 | Downloading training and testing datasets and move them into ./dataset/, which can be found in this [Google Drive](https://drive.google.com/file/d/1pFxb9NbM8mj_rlSawTlcXG1OdVGAbRQC/view?usp=sharing)/[Baidu Drive](https://pan.baidu.com/s/1OBVivLJAs9ZpnB5I2s3lNg) [code:dr1h]. 39 | 40 | 41 | ### 4.3 Pretrained model: 42 | You should download the pretrained model from [Google Drive](https://drive.google.com/drive/folders/1Eu8v9vMRvt-dyCH0XSV2i77lAd62nPXV?usp=sharing)/[Baidu Drive](https://pan.baidu.com/s/1Vez7iT2v_g7VYsDxRGE8HA) [code:w4vk], and then put it in the './pretrained_pth' folder for initialization. 43 | 44 | ### 4.4 Training: 45 | Clone the repository: 46 | ``` 47 | git clone https://github.com/Barrett-python/DuAT.git 48 | cd DuAT 49 | bash train.sh 50 | ``` 51 | 52 | ### 4.5 Testing: 53 | ``` 54 | cd DuAT 55 | bash test.sh 56 | ``` 57 | 58 | 59 | ### 4.6 Evaluating your trained model: 60 | 61 | Matlab: Please refer to the work of MICCAI2020 ([link](https://github.com/DengPingFan/PraNet)). 62 | 63 | Python: Please refer to the work of ACMMM2021 ([link](https://github.com/plemeri/UACANet)). 64 | 65 | Please note that we use the Matlab version to evaluate in our paper. 66 | 67 | 68 | ### 4.7 Well trained model: 69 | You could download the trained model from [Google Drive](https://drive.google.com/drive/folders/14IDwewAb12HWlxgOFtFB46aMJyqPaKpz) and put the model in directory './model_pth'. 70 | 71 | 72 | Citation If you find this code or idea useful, please cite our work: 73 | ## Citation: 74 | ``` 75 | @inproceedings{tang2023duat, 76 | title={DuAT: Dual-aggregation transformer network for medical image segmentation}, 77 | author={Tang, Feilong and Xu, Zhongxing and Huang, Qiming and Wang, Jinfeng and Hou, Xianxu and Su, Jionglong and Liu, Jingxin}, 78 | booktitle={Chinese Conference on Pattern Recognition and Computer Vision (PRCV)}, 79 | pages={343--356}, 80 | year={2023}, 81 | organization={Springer} 82 | } 83 | ``` 84 | 85 | 86 | ## 6. Acknowledgement 87 | We are very grateful for these excellent works [PraNet](https://github.com/DengPingFan/PraNet), [Polyp-PVT](https://github.com/DengPingFan/Polyp-PVT) and [SSformer](https://github.com/Qiming-Huang/ssformer), which have provided the basis for our framework. 88 | 89 | 91 | 92 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os, argparse 5 | from scipy import misc 6 | from lib.pvt import PolypPVT 7 | from utils.dataloader import test_dataset 8 | import cv2 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 13 | parser.add_argument('--pth_path', type=str, default='./model_pth/PolypPVT.pth') 14 | opt = parser.parse_args() 15 | model = PolypPVT() 16 | model.load_state_dict(torch.load(opt.pth_path)) 17 | model.cuda() 18 | model.eval() 19 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 20 | 21 | ##### put data_path here ##### 22 | data_path = './dataset/TestDataset/{}'.format(_data_name) 23 | ##### save_path ##### 24 | save_path = './result_map/PolypPVT/{}/'.format(_data_name) 25 | 26 | if not os.path.exists(save_path): 27 | os.makedirs(save_path) 28 | image_root = '{}/images/'.format(data_path) 29 | gt_root = '{}/masks/'.format(data_path) 30 | num1 = len(os.listdir(gt_root)) 31 | test_loader = test_dataset(image_root, gt_root, 352) 32 | for i in range(num1): 33 | image, gt, name = test_loader.load_data() 34 | gt = np.asarray(gt, np.float32) 35 | gt /= (gt.max() + 1e-8) 36 | image = image.cuda() 37 | P1,P2 = model(image) 38 | res = F.upsample(P1+P2, size=gt.shape, mode='bilinear', align_corners=False) 39 | res = res.sigmoid().data.cpu().numpy().squeeze() 40 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 41 | cv2.imwrite(save_path+name, res*255) 42 | print(_data_name, 'Finish!') 43 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import os 4 | import argparse 5 | from datetime import datetime 6 | from lib.pvt import DuAT 7 | from utils.dataloader import get_loader, test_dataset 8 | from utils.utils import clip_gradient, adjust_lr, AvgMeter 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import logging 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | def structure_loss(pred, mask): 16 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 17 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 18 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 19 | 20 | pred = torch.sigmoid(pred) 21 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 22 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 23 | wiou = 1 - (inter + 1) / (union - inter + 1) 24 | 25 | return (wbce + wiou).mean() 26 | 27 | 28 | def test(model, path, dataset): 29 | 30 | data_path = os.path.join(path, dataset) 31 | image_root = '{}/images/'.format(data_path) 32 | gt_root = '{}/masks/'.format(data_path) 33 | model.eval() 34 | num1 = len(os.listdir(gt_root)) 35 | test_loader = test_dataset(image_root, gt_root, 352) 36 | DSC = 0.0 37 | for i in range(num1): 38 | image, gt, name = test_loader.load_data() 39 | gt = np.asarray(gt, np.float32) 40 | gt /= (gt.max() + 1e-8) 41 | image = image.cuda() 42 | 43 | res, res1 = model(image) 44 | # eval Dice 45 | res = F.upsample(res + res1 , size=gt.shape, mode='bilinear', align_corners=False) 46 | res = res.sigmoid().data.cpu().numpy().squeeze() 47 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 48 | input = res 49 | target = np.array(gt) 50 | N = gt.shape 51 | smooth = 1 52 | input_flat = np.reshape(input, (-1)) 53 | target_flat = np.reshape(target, (-1)) 54 | intersection = (input_flat * target_flat) 55 | dice = (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth) 56 | dice = '{:.4f}'.format(dice) 57 | dice = float(dice) 58 | DSC = DSC + dice 59 | 60 | return DSC / num1 61 | 62 | 63 | 64 | def train(train_loader, model, optimizer, epoch, test_path): 65 | model.train() 66 | global best 67 | size_rates = [0.75, 1, 1.25] 68 | loss_P2_record = AvgMeter() 69 | for i, pack in enumerate(train_loader, start=1): 70 | for rate in size_rates: 71 | optimizer.zero_grad() 72 | # ---- data prepare ---- 73 | images, gts = pack 74 | images = Variable(images).cuda() 75 | gts = Variable(gts).cuda() 76 | # ---- rescale ---- 77 | trainsize = int(round(opt.trainsize * rate / 32) * 32) 78 | if rate != 1: 79 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 80 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 81 | # ---- forward ---- 82 | P1, P2= model(images) 83 | # ---- loss function ---- 84 | loss_P1 = structure_loss(P1, gts) 85 | loss_P2 = structure_loss(P2, gts) 86 | loss = loss_P1 + loss_P2 87 | # ---- backward ---- 88 | loss.backward() 89 | clip_gradient(optimizer, opt.clip) 90 | optimizer.step() 91 | # ---- recording loss ---- 92 | if rate == 1: 93 | loss_P2_record.update(loss_P2.data, opt.batchsize) 94 | # ---- train visualization ---- 95 | if i % 20 == 0 or i == total_step: 96 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], ' 97 | ' lateral-5: {:0.4f}]'. 98 | format(datetime.now(), epoch, opt.epoch, i, total_step, 99 | loss_P2_record.show())) 100 | # save model 101 | save_path = (opt.train_save) 102 | if not os.path.exists(save_path): 103 | os.makedirs(save_path) 104 | torch.save(model.state_dict(), save_path +str(epoch)+ 'DuAT.pth') 105 | # choose the best model 106 | 107 | global dict_plot 108 | 109 | test1path = './dataset/TestDataset/' 110 | if (epoch + 1) % 1 == 0: 111 | for dataset in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 112 | dataset_dice = test(model, test1path, dataset) 113 | logging.info('epoch: {}, dataset: {}, dice: {}'.format(epoch, dataset, dataset_dice)) 114 | print(dataset, ': ', dataset_dice) 115 | dict_plot[dataset].append(dataset_dice) 116 | meandice = test(model, test_path, 'test') 117 | dict_plot['test'].append(meandice) 118 | if meandice > best: 119 | best = meandice 120 | torch.save(model.state_dict(), save_path + 'DuAT.pth') 121 | torch.save(model.state_dict(), save_path +str(epoch)+ 'DuAT-best.pth') 122 | print('##############################################################################best', best) 123 | logging.info('##############################################################################best:{}'.format(best)) 124 | 125 | 126 | def plot_train(dict_plot=None, name = None): 127 | color = ['red', 'lawngreen', 'lime', 'gold', 'm', 'plum', 'blue'] 128 | line = ['-', "--"] 129 | for i in range(len(name)): 130 | plt.plot(dict_plot[name[i]], label=name[i], color=color[i], linestyle=line[(i + 1) % 2]) 131 | transfuse = {'CVC-300': 0.902, 'CVC-ClinicDB': 0.918, 'Kvasir': 0.918, 'CVC-ColonDB': 0.773,'ETIS-LaribPolypDB': 0.733, 'test':0.83} 132 | plt.axhline(y=transfuse[name[i]], color=color[i], linestyle='-') 133 | plt.xlabel("epoch") 134 | plt.ylabel("dice") 135 | plt.title('Train') 136 | plt.legend() 137 | plt.savefig('eval.png') 138 | # plt.show() 139 | 140 | 141 | if __name__ == '__main__': 142 | dict_plot = {'CVC-300':[], 'CVC-ClinicDB':[], 'Kvasir':[], 'CVC-ColonDB':[], 'ETIS-LaribPolypDB':[], 'test':[]} 143 | name = ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB', 'test'] 144 | ##################model_name############################# 145 | model_name = 'DuAT' 146 | ############################################### 147 | parser = argparse.ArgumentParser() 148 | 149 | parser.add_argument('--epoch', type=int, 150 | default=100, help='epoch number') 151 | 152 | parser.add_argument('--lr', type=float, 153 | default=1e-4, help='learning rate') 154 | 155 | parser.add_argument('--optimizer', type=str, 156 | default='AdamW', help='choosing optimizer AdamW or SGD') 157 | 158 | parser.add_argument('--augmentation', 159 | default=False, help='choose to do random flip rotation') 160 | 161 | parser.add_argument('--batchsize', type=int, 162 | default=16, help='training batch size') 163 | 164 | parser.add_argument('--trainsize', type=int, 165 | default=352, help='training dataset size') 166 | 167 | parser.add_argument('--clip', type=float, 168 | default=0.5, help='gradient clipping margin') 169 | 170 | parser.add_argument('--decay_rate', type=float, 171 | default=0.1, help='decay rate of learning rate') 172 | 173 | parser.add_argument('--decay_epoch', type=int, 174 | default=50, help='every n epochs decay learning rate') 175 | 176 | parser.add_argument('--train_path', type=str, 177 | default='./dataset/TrainDataset/', 178 | help='path to train dataset') 179 | 180 | parser.add_argument('--test_path', type=str, 181 | default='./dataset/TestDataset/', 182 | help='path to testing Kvasir dataset') 183 | 184 | parser.add_argument('--train_save', type=str, 185 | default='./model_pth/'+model_name+'/') 186 | 187 | opt = parser.parse_args() 188 | logging.basicConfig(filename='train_log.log', 189 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 190 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 191 | 192 | # ---- build models ---- 193 | # torch.cuda.set_device(0) # set your gpu device 194 | model = DuAT().cuda() 195 | 196 | best = 0 197 | 198 | params = model.parameters() 199 | 200 | if opt.optimizer == 'AdamW': 201 | optimizer = torch.optim.AdamW(params, opt.lr, weight_decay=1e-4) 202 | else: 203 | optimizer = torch.optim.SGD(params, opt.lr, weight_decay=1e-4, momentum=0.9) 204 | 205 | print(optimizer) 206 | image_root = '{}/images/'.format(opt.train_path) 207 | gt_root = '{}/masks/'.format(opt.train_path) 208 | 209 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize, 210 | augmentation=opt.augmentation) 211 | total_step = len(train_loader) 212 | 213 | print("#" * 20, "Start Training", "#" * 20) 214 | 215 | for epoch in range(1, opt.epoch): 216 | adjust_lr(optimizer, opt.lr, epoch, 0.1, 200) 217 | train(train_loader, model, optimizer, epoch, opt.test_path) 218 | 219 | # plot the eval.png in the training stage 220 | # plot_train(dict_plot, name) 221 | -------------------------------------------------------------------------------- /lib/DuAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.pvtv2 import pvt_v2_b2 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mmcv.cnn import ConvModule 10 | from torch.nn import Conv2d, UpsamplingBilinear2d 11 | import warnings 12 | import torch 13 | from mmcv.cnn import constant_init, kaiming_init 14 | from torch import nn 15 | from torchvision.transforms.functional import normalize 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | class BasicConv2d(nn.Module): 20 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 21 | super(BasicConv2d, self).__init__() 22 | 23 | self.conv = nn.Conv2d(in_planes, out_planes, 24 | kernel_size=kernel_size, stride=stride, 25 | padding=padding, dilation=dilation, bias=False) 26 | self.bn = nn.BatchNorm2d(out_planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | x = self.relu(x) 33 | return x 34 | 35 | class Block(nn.Sequential): 36 | def __init__(self, input_num, num1, num2, dilation_rate, drop_out, bn_start=True, norm_layer=nn.BatchNorm2d): 37 | super(Block, self).__init__() 38 | if bn_start: 39 | self.add_module('norm1', norm_layer(input_num)), 40 | 41 | self.add_module('relu1', nn.ReLU(inplace=True)), 42 | self.add_module('conv1', nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)), 43 | 44 | self.add_module('norm2', norm_layer(num1)), 45 | self.add_module('relu2', nn.ReLU(inplace=True)), 46 | self.add_module('conv2', nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 47 | dilation=dilation_rate, padding=dilation_rate)), 48 | self.drop_rate = drop_out 49 | 50 | def forward(self, _input): 51 | feature = super(Block, self).forward(_input) 52 | if self.drop_rate > 0: 53 | feature = F.dropout2d(feature, p=self.drop_rate, training=self.training) 54 | return feature 55 | 56 | 57 | def Upsample(x, size, align_corners = False): 58 | """ 59 | Wrapper Around the Upsample Call 60 | """ 61 | return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners) 62 | 63 | 64 | def last_zero_init(m): 65 | if isinstance(m, nn.Sequential): 66 | constant_init(m[-1], val=0) 67 | else: 68 | constant_init(m, val=0) 69 | 70 | 71 | class ContextBlock(nn.Module): 72 | 73 | def __init__(self, 74 | inplanes, 75 | ratio, 76 | pooling_type='att', 77 | fusion_types=('channel_mul', )): 78 | super(ContextBlock, self).__init__() 79 | assert pooling_type in ['avg', 'att'] 80 | assert isinstance(fusion_types, (list, tuple)) 81 | valid_fusion_types = ['channel_add', 'channel_mul'] 82 | assert all([f in valid_fusion_types for f in fusion_types]) 83 | assert len(fusion_types) > 0, 'at least one fusion should be used' 84 | self.inplanes = inplanes 85 | self.ratio = ratio 86 | self.planes = int(inplanes * ratio) 87 | self.pooling_type = pooling_type 88 | self.fusion_types = fusion_types 89 | if pooling_type == 'att': 90 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 91 | self.softmax = nn.Softmax(dim=2) 92 | else: 93 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 94 | if 'channel_add' in fusion_types: 95 | self.channel_add_conv = nn.Sequential( 96 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 97 | nn.LayerNorm([self.planes, 1, 1]), 98 | nn.ReLU(inplace=True), # yapf: disable 99 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 100 | else: 101 | self.channel_add_conv = None 102 | if 'channel_mul' in fusion_types: 103 | self.channel_mul_conv = nn.Sequential( 104 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 105 | nn.LayerNorm([self.planes, 1, 1]), 106 | nn.ReLU(inplace=True), # yapf: disable 107 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 108 | else: 109 | self.channel_mul_conv = None 110 | self.reset_parameters() 111 | 112 | def reset_parameters(self): 113 | if self.pooling_type == 'att': 114 | kaiming_init(self.conv_mask, mode='fan_in') 115 | self.conv_mask.inited = True 116 | 117 | if self.channel_add_conv is not None: 118 | last_zero_init(self.channel_add_conv) 119 | if self.channel_mul_conv is not None: 120 | last_zero_init(self.channel_mul_conv) 121 | 122 | def spatial_pool(self, x): 123 | batch, channel, height, width = x.size() 124 | if self.pooling_type == 'att': 125 | input_x = x 126 | # [N, C, H * W] 127 | input_x = input_x.view(batch, channel, height * width) 128 | # [N, 1, C, H * W] 129 | input_x = input_x.unsqueeze(1) 130 | # [N, 1, H, W] 131 | context_mask = self.conv_mask(x) 132 | # [N, 1, H * W] 133 | context_mask = context_mask.view(batch, 1, height * width) 134 | # [N, 1, H * W] 135 | context_mask = self.softmax(context_mask) 136 | # [N, 1, H * W, 1] 137 | context_mask = context_mask.unsqueeze(-1) 138 | # [N, 1, C, 1] 139 | context = torch.matmul(input_x, context_mask) 140 | # [N, C, 1, 1] 141 | context = context.view(batch, channel, 1, 1) 142 | else: 143 | # [N, C, 1, 1] 144 | context = self.avg_pool(x) 145 | 146 | return context 147 | 148 | def forward(self, x): 149 | # [N, C, 1, 1] 150 | context = self.spatial_pool(x) 151 | 152 | out = x 153 | if self.channel_mul_conv is not None: 154 | # [N, C, 1, 1] 155 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 156 | out = out + out * channel_mul_term 157 | if self.channel_add_conv is not None: 158 | # [N, C, 1, 1] 159 | channel_add_term = self.channel_add_conv(context) 160 | out = out + channel_add_term 161 | 162 | return out 163 | 164 | 165 | 166 | class ChannelAttention(nn.Module): 167 | def __init__(self, in_planes, ratio=16): 168 | super(ChannelAttention, self).__init__() 169 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 170 | self.max_pool = nn.AdaptiveMaxPool2d(1) 171 | 172 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 173 | self.relu1 = nn.ReLU() 174 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 175 | 176 | self.sigmoid = nn.Sigmoid() 177 | 178 | def forward(self, x): 179 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 180 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 181 | out = avg_out + max_out 182 | return self.sigmoid(out) 183 | 184 | 185 | class SpatialAttention(nn.Module): 186 | def __init__(self, kernel_size=7): 187 | super(SpatialAttention, self).__init__() 188 | 189 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 190 | padding = 3 if kernel_size == 7 else 1 191 | 192 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 193 | self.sigmoid = nn.Sigmoid() 194 | 195 | def forward(self, x): 196 | avg_out = torch.mean(x, dim=1, keepdim=True) 197 | max_out, _ = torch.max(x, dim=1, keepdim=True) 198 | x = torch.cat([avg_out, max_out], dim=1) 199 | x = self.conv1(x) 200 | return self.sigmoid(x) 201 | 202 | 203 | class ConvBranch(nn.Module): 204 | def __init__(self, in_features, hidden_features = None, out_features = None): 205 | super().__init__() 206 | hidden_features = hidden_features or in_features 207 | out_features = out_features or in_features 208 | self.conv1 = nn.Sequential( 209 | nn.Conv2d(in_features, hidden_features, 1, bias=False), 210 | nn.BatchNorm2d(hidden_features), 211 | nn.ReLU(inplace=True) 212 | ) 213 | self.conv2 = nn.Sequential( 214 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 215 | nn.BatchNorm2d(hidden_features), 216 | nn.ReLU(inplace=True) 217 | ) 218 | self.conv3 = nn.Sequential( 219 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 220 | nn.BatchNorm2d(hidden_features), 221 | nn.ReLU(inplace=True) 222 | ) 223 | self.conv4 = nn.Sequential( 224 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 225 | nn.BatchNorm2d(hidden_features), 226 | nn.ReLU(inplace=True) 227 | ) 228 | self.conv5 = nn.Sequential( 229 | nn.Conv2d(hidden_features, hidden_features, 1, bias=False), 230 | nn.BatchNorm2d(hidden_features), 231 | nn.SiLU(inplace=True) 232 | ) 233 | self.conv6 = nn.Sequential( 234 | nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False), 235 | nn.BatchNorm2d(hidden_features), 236 | nn.ReLU(inplace=True) 237 | ) 238 | self.conv7 = nn.Sequential( 239 | nn.Conv2d(hidden_features, out_features, 1, bias=False), 240 | nn.ReLU(inplace=True) 241 | ) 242 | self.ca = ChannelAttention(64) 243 | self.sa = SpatialAttention() 244 | self.sigmoid_spatial = nn.Sigmoid() 245 | 246 | def forward(self, x): 247 | res1 = x 248 | res2 = x 249 | x = self.conv1(x) 250 | x = x + self.conv2(x) 251 | x = self.conv3(x) 252 | x = x + self.conv4(x) 253 | x = self.conv5(x) 254 | x = x + self.conv6(x) 255 | x = self.conv7(x) 256 | x_mask = self.sigmoid_spatial(x) 257 | res1 = res1 * x_mask 258 | return res2 + res1 259 | 260 | 261 | class GLSA(nn.Module): 262 | 263 | def __init__(self, input_dim=512, embed_dim=32, k_s=3): 264 | super().__init__() 265 | 266 | self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1) 267 | self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1) 268 | self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1) 269 | self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1) 270 | self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2) 271 | self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim) 272 | 273 | def forward(self, x): 274 | b, c, h, w = x.size() 275 | x_0, x_1 = x.chunk(2,dim = 1) 276 | 277 | # local block 278 | local = self.local(self.local_11conv(x_0)) 279 | 280 | # Globel block 281 | Globel = self.GlobelBlock(self.global_11conv(x_1)) 282 | 283 | # concat Globel + local 284 | x = torch.cat([local,Globel], dim=1) 285 | x = self.conv1_1(x) 286 | 287 | return x 288 | 289 | class SBA(nn.Module): 290 | 291 | def __init__(self,input_dim = 64): 292 | super().__init__() 293 | 294 | self.input_dim = input_dim 295 | 296 | self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1) 297 | self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1) 298 | 299 | 300 | self.conv = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False)) 301 | self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False) 302 | self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False) 303 | 304 | self.Sigmoid = nn.Sigmoid() 305 | 306 | def forward(self, H_feature, L_feature): 307 | 308 | L_feature = self.fc1(L_feature) 309 | H_feature = self.fc2(H_feature) 310 | 311 | g_L_feature = self.Sigmoid(L_feature) 312 | g_H_feature = self.Sigmoid(H_feature) 313 | 314 | L_feature = self.d_in1(L_feature) 315 | H_feature = self.d_in2(H_feature) 316 | 317 | 318 | L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False) 319 | H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False) 320 | 321 | H_feature = Upsample(H_feature, size = L_feature.size()[2:]) 322 | out = self.conv(torch.cat([H_feature,L_feature], dim=1)) 323 | return out 324 | 325 | 326 | class DuAT(nn.Module): 327 | def __init__(self, dim=32, dims= [64, 128, 320, 512]): 328 | super(DuAT, self).__init__() 329 | 330 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512] 331 | path = './pretrained_pth/pvt_v2_b2.pth' 332 | save_model = torch.load(path) 333 | model_dict = self.backbone.state_dict() 334 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 335 | model_dict.update(state_dict) 336 | self.backbone.load_state_dict(model_dict) 337 | 338 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3] 339 | 340 | self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim) 341 | self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim) 342 | self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim) 343 | self.L_feature = BasicConv2d(c1_in_channels,dim, 3,1,1) 344 | 345 | self.SBA = SBA(input_dim = dim) 346 | self.fuse = BasicConv2d(dim * 2, dim, 1) 347 | self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.Conv2d(dim, 1, kernel_size=1, bias=False)) 348 | 349 | 350 | def forward(self, x): 351 | # backbone 352 | pvt = self.backbone(x) 353 | c1, c2, c3, c4 = pvt 354 | n, _, h, w = c4.shape 355 | _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11] 356 | _c4 = Upsample(_c4, c3.size()[2:]) 357 | _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22] 358 | _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44] 359 | 360 | output = self.fuse2(torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)) 361 | 362 | L_feature = self.L_feature(c1) # [1, 64, 88, 88] 363 | H_feature = self.fuse(torch.cat([_c4, _c3], dim=1)) 364 | H_feature = Upsample(H_feature,c2.size()[2:]) 365 | 366 | output2 = self.SBA(H_feature,L_feature) 367 | 368 | output = F.interpolate(output, scale_factor=8, mode='bilinear') 369 | output2 = F.interpolate(output2, scale_factor=4, mode='bilinear') 370 | 371 | return output, output2 372 | 373 | 374 | 375 | if __name__ == '__main__': 376 | 377 | model = DuAT().to('cuda') 378 | from torchinfo import summary 379 | # summary(model, (1, 3, 352, 352)) 380 | from thop import profile 381 | import torch 382 | input = torch.randn(1, 3, 352, 352).to('cuda') 383 | macs, params = profile(model, inputs=(input,)) 384 | print('macs:', macs / 1000000000) 385 | print('params:', params / 1000000) 386 | 387 | # import time 388 | ## net = model() 389 | # model.eval() 390 | # time_count = 0.0 391 | # for i in range(1000): 392 | # image = torch.randn(1, 3, 352, 352).cuda() 393 | # torch.cuda.synchronize() 394 | # start_time = time.time() 395 | # pred_semantic = model(image) 396 | # torch.cuda.synchronize() 397 | # print(time.time() - start_time) 398 | # if i >= 100 and i <= 900: 399 | # time_count = time_count + time.time() - start_time 400 | # print("FPS:", 800 / time_count) 401 | 402 | 403 | 404 | 405 | 406 | -------------------------------------------------------------------------------- /lib/pvtv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from timm.models.registry import register_model 10 | 11 | import math 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.dwconv = DWConv(hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | x = self.dwconv(x, H, W) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 54 | super().__init__() 55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | attn = (q @ k.transpose(-2, -1)) * self.scale 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | 111 | return x 112 | 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, 122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | self.norm2 = norm_layer(dim) 127 | mlp_hidden_dim = int(dim * mlp_ratio) 128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 129 | 130 | self.apply(self._init_weights) 131 | 132 | def _init_weights(self, m): 133 | if isinstance(m, nn.Linear): 134 | trunc_normal_(m.weight, std=.02) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | elif isinstance(m, nn.Conv2d): 141 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | fan_out //= m.groups 143 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 144 | if m.bias is not None: 145 | m.bias.data.zero_() 146 | 147 | def forward(self, x, H, W): 148 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 149 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 150 | 151 | return x 152 | 153 | 154 | class OverlapPatchEmbed(nn.Module): 155 | """ Image to Patch Embedding 156 | """ 157 | 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 166 | self.num_patches = self.H * self.W 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 168 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 169 | self.norm = nn.LayerNorm(embed_dim) 170 | 171 | self.apply(self._init_weights) 172 | 173 | def _init_weights(self, m): 174 | if isinstance(m, nn.Linear): 175 | trunc_normal_(m.weight, std=.02) 176 | if isinstance(m, nn.Linear) and m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | elif isinstance(m, nn.LayerNorm): 179 | nn.init.constant_(m.bias, 0) 180 | nn.init.constant_(m.weight, 1.0) 181 | elif isinstance(m, nn.Conv2d): 182 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | fan_out //= m.groups 184 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 185 | if m.bias is not None: 186 | m.bias.data.zero_() 187 | 188 | def forward(self, x): 189 | x = self.proj(x) 190 | _, _, H, W = x.shape 191 | x = x.flatten(2).transpose(1, 2) 192 | x = self.norm(x) 193 | 194 | return x, H, W 195 | 196 | 197 | class PyramidVisionTransformerImpr(nn.Module): 198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 199 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 200 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 201 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 202 | super().__init__() 203 | self.num_classes = num_classes 204 | self.depths = depths 205 | 206 | # patch_embed 207 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 208 | embed_dim=embed_dims[0]) 209 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 210 | embed_dim=embed_dims[1]) 211 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 212 | embed_dim=embed_dims[2]) 213 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 214 | embed_dim=embed_dims[3]) 215 | 216 | # transformer encoder 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 218 | cur = 0 219 | self.block1 = nn.ModuleList([Block( 220 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 222 | sr_ratio=sr_ratios[0]) 223 | for i in range(depths[0])]) 224 | self.norm1 = norm_layer(embed_dims[0]) 225 | 226 | cur += depths[0] 227 | self.block2 = nn.ModuleList([Block( 228 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 229 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 230 | sr_ratio=sr_ratios[1]) 231 | for i in range(depths[1])]) 232 | self.norm2 = norm_layer(embed_dims[1]) 233 | 234 | cur += depths[1] 235 | self.block3 = nn.ModuleList([Block( 236 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 238 | sr_ratio=sr_ratios[2]) 239 | for i in range(depths[2])]) 240 | self.norm3 = norm_layer(embed_dims[2]) 241 | 242 | cur += depths[2] 243 | self.block4 = nn.ModuleList([Block( 244 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 246 | sr_ratio=sr_ratios[3]) 247 | for i in range(depths[3])]) 248 | self.norm4 = norm_layer(embed_dims[3]) 249 | 250 | # classification head 251 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 252 | 253 | self.apply(self._init_weights) 254 | 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | elif isinstance(m, nn.Conv2d): 264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | fan_out //= m.groups 266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | 270 | def init_weights(self, pretrained=None): 271 | if isinstance(pretrained, str): 272 | logger = 1 273 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 274 | 275 | def reset_drop_path(self, drop_path_rate): 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 277 | cur = 0 278 | for i in range(self.depths[0]): 279 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 280 | 281 | cur += self.depths[0] 282 | for i in range(self.depths[1]): 283 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 284 | 285 | cur += self.depths[1] 286 | for i in range(self.depths[2]): 287 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 288 | 289 | cur += self.depths[2] 290 | for i in range(self.depths[3]): 291 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 292 | 293 | def freeze_patch_emb(self): 294 | self.patch_embed1.requires_grad = False 295 | 296 | @torch.jit.ignore 297 | def no_weight_decay(self): 298 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 299 | 300 | def get_classifier(self): 301 | return self.head 302 | 303 | def reset_classifier(self, num_classes, global_pool=''): 304 | self.num_classes = num_classes 305 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 306 | 307 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W): 308 | # if H * W == self.patch_embed1.num_patches: 309 | # return pos_embed 310 | # else: 311 | # return F.interpolate( 312 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 313 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 314 | 315 | def forward_features(self, x): 316 | B = x.shape[0] 317 | outs = [] 318 | 319 | # stage 1 320 | x, H, W = self.patch_embed1(x) 321 | for i, blk in enumerate(self.block1): 322 | x = blk(x, H, W) 323 | x = self.norm1(x) 324 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 325 | outs.append(x) 326 | 327 | # stage 2 328 | x, H, W = self.patch_embed2(x) 329 | for i, blk in enumerate(self.block2): 330 | x = blk(x, H, W) 331 | x = self.norm2(x) 332 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 333 | outs.append(x) 334 | 335 | # stage 3 336 | x, H, W = self.patch_embed3(x) 337 | for i, blk in enumerate(self.block3): 338 | x = blk(x, H, W) 339 | x = self.norm3(x) 340 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 341 | outs.append(x) 342 | 343 | # stage 4 344 | x, H, W = self.patch_embed4(x) 345 | for i, blk in enumerate(self.block4): 346 | x = blk(x, H, W) 347 | x = self.norm4(x) 348 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 349 | outs.append(x) 350 | 351 | return outs 352 | 353 | # return x.mean(dim=1) 354 | 355 | def forward(self, x): 356 | x = self.forward_features(x) 357 | # x = self.head(x) 358 | 359 | return x 360 | 361 | 362 | class DWConv(nn.Module): 363 | def __init__(self, dim=768): 364 | super(DWConv, self).__init__() 365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 366 | 367 | def forward(self, x, H, W): 368 | B, N, C = x.shape 369 | x = x.transpose(1, 2).view(B, C, H, W) 370 | x = self.dwconv(x) 371 | x = x.flatten(2).transpose(1, 2) 372 | 373 | return x 374 | 375 | 376 | def _conv_filter(state_dict, patch_size=16): 377 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 378 | out_dict = {} 379 | for k, v in state_dict.items(): 380 | if 'patch_embed.proj.weight' in k: 381 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 382 | out_dict[k] = v 383 | 384 | return out_dict 385 | 386 | 387 | @register_model 388 | class pvt_v2_b0(PyramidVisionTransformerImpr): 389 | def __init__(self, **kwargs): 390 | super(pvt_v2_b0, self).__init__( 391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 393 | drop_rate=0.0, drop_path_rate=0.1) 394 | 395 | 396 | 397 | @register_model 398 | class pvt_v2_b1(PyramidVisionTransformerImpr): 399 | def __init__(self, **kwargs): 400 | super(pvt_v2_b1, self).__init__( 401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 402 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 403 | drop_rate=0.0, drop_path_rate=0.1) 404 | 405 | @register_model 406 | class pvt_v2_b2(PyramidVisionTransformerImpr): 407 | def __init__(self, **kwargs): 408 | super(pvt_v2_b2, self).__init__( 409 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 410 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 411 | drop_rate=0.0, drop_path_rate=0.1) 412 | 413 | @register_model 414 | class pvt_v2_b3(PyramidVisionTransformerImpr): 415 | def __init__(self, **kwargs): 416 | super(pvt_v2_b3, self).__init__( 417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 419 | drop_rate=0.0, drop_path_rate=0.1) 420 | 421 | @register_model 422 | class pvt_v2_b4(PyramidVisionTransformerImpr): 423 | def __init__(self, **kwargs): 424 | super(pvt_v2_b4, self).__init__( 425 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 426 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 427 | drop_rate=0.0, drop_path_rate=0.1) 428 | 429 | 430 | @register_model 431 | class pvt_v2_b5(PyramidVisionTransformerImpr): 432 | def __init__(self, **kwargs): 433 | super(pvt_v2_b5, self).__init__( 434 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 435 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 436 | drop_rate=0.0, drop_path_rate=0.1) 437 | 438 | 439 | 440 | 441 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python -W ignore Test.py -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -W ignore Train.py -------------------------------------------------------------------------------- /utils/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | import cv2 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | 12 | class PolypDataset(data.Dataset): 13 | def __init__(self, image_root, gt_root, trainsize, augmentations): 14 | self.image_root = image_root 15 | self.gt_root = gt_root 16 | self.samples = [name for name in os.listdir(image_root) if name[0]!="."] 17 | self.transform = A.Compose([ 18 | A.Normalize(), 19 | A.Resize(352, 352, interpolation=cv2.INTER_NEAREST), 20 | A.HorizontalFlip(p=0.2), 21 | A.VerticalFlip(p=0.2), 22 | # A.RandomRotate90(p=0.2), 23 | ToTensorV2() 24 | ]) 25 | 26 | self.color1, self.color2 = [], [] 27 | for name in self.samples: 28 | if name[:-4].isdigit(): 29 | self.color1.append(name) 30 | else: 31 | self.color2.append(name) 32 | 33 | def __getitem__(self, idx): 34 | name = self.samples[idx] 35 | image = cv2.imread(self.image_root+'/'+name) 36 | image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) 37 | 38 | name2 = self.color1[idx%len(self.color1)] if np.random.rand()<0.7 else self.color2[idx%len(self.color2)] 39 | image2 = cv2.imread(self.image_root+'/'+name2) 40 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2LAB) 41 | 42 | mean , std = image.mean(axis=(0,1), keepdims=True), image.std(axis=(0,1), keepdims=True) 43 | mean2, std2 = image2.mean(axis=(0,1), keepdims=True), image2.std(axis=(0,1), keepdims=True) 44 | image = np.uint8((image-mean)/std*std2+mean2) 45 | image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB) 46 | mask = cv2.imread(self.gt_root+'/'+name, cv2.IMREAD_GRAYSCALE)/255.0 47 | pair = self.transform(image=image, mask=mask) 48 | 49 | return pair['image'], pair['mask'] 50 | 51 | def __len__(self): 52 | return len(self.samples) 53 | 54 | 55 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True, augmentation=False): 56 | 57 | dataset = PolypDataset(image_root, gt_root, trainsize, augmentation) 58 | data_loader = data.DataLoader(dataset=dataset, 59 | batch_size=batchsize, 60 | shuffle=shuffle, 61 | num_workers=num_workers, 62 | pin_memory=pin_memory) 63 | return data_loader 64 | 65 | 66 | class test_dataset: 67 | def __init__(self, image_root, gt_root, testsize): 68 | self.testsize = testsize 69 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 70 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 71 | self.images = sorted(self.images) 72 | self.gts = sorted(self.gts) 73 | self.transform = transforms.Compose([ 74 | transforms.Resize((self.testsize, self.testsize)), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], 77 | [0.229, 0.224, 0.225]) 78 | ]) 79 | self.gt_transform = transforms.ToTensor() 80 | self.size = len(self.images) 81 | self.index = 0 82 | 83 | def load_data(self): 84 | image = self.rgb_loader(self.images[self.index]) 85 | image = self.transform(image).unsqueeze(0) 86 | gt = self.binary_loader(self.gts[self.index]) 87 | name = self.images[self.index].split('/')[-1] 88 | if name.endswith('.jpg'): 89 | name = name.split('.jpg')[0] + '.png' 90 | self.index += 1 91 | return image, gt, name 92 | 93 | def rgb_loader(self, path): 94 | with open(path, 'rb') as f: 95 | img = Image.open(f) 96 | return img.convert('RGB') 97 | 98 | def binary_loader(self, path): 99 | with open(path, 'rb') as f: 100 | img = Image.open(f) 101 | return img.convert('L') 102 | -------------------------------------------------------------------------------- /utils/format_conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from libtiff import TIFF # pip install libtiff 4 | from scipy import misc 5 | import random 6 | 7 | 8 | def tif2png(_src_path, _dst_path): 9 | """ 10 | Usage: 11 | formatting `tif/tiff` files to `jpg/png` files 12 | :param _src_path: 13 | :param _dst_path: 14 | :return: 15 | """ 16 | tif = TIFF.open(_src_path, mode='r') 17 | image = tif.read_image() 18 | misc.imsave(_dst_path, image) 19 | 20 | 21 | def data_split(src_list): 22 | """ 23 | Usage: 24 | randomly spliting dataset 25 | :param src_list: 26 | :return: 27 | """ 28 | counter_list = random.sample(range(0, len(src_list)), 550) 29 | 30 | return counter_list 31 | 32 | 33 | if __name__ == '__main__': 34 | src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif' 35 | dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks' 36 | 37 | os.makedirs(dst_dir, exist_ok=True) 38 | for img_name in os.listdir(src_dir): 39 | tif2png(os.path.join(src_dir, img_name), 40 | os.path.join(dst_dir, img_name.replace('.tif', '.png'))) 41 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | 6 | 7 | def clip_gradient(optimizer, grad_clip): 8 | """ 9 | For calibrating misalignment gradient via cliping gradient technique 10 | :param optimizer: 11 | :param grad_clip: 12 | :return: 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 21 | decay = decay_rate ** (epoch // decay_epoch) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] *= decay 24 | 25 | 26 | class AvgMeter(object): 27 | def __init__(self, num=40): 28 | self.num = num 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | self.losses = [] 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | self.losses.append(val) 44 | 45 | def show(self): 46 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) 47 | 48 | 49 | def CalParams(model, input_tensor): 50 | """ 51 | Usage: 52 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) 53 | Necessarity: 54 | from thop import profile 55 | from thop import clever_format 56 | :param model: 57 | :param input_tensor: 58 | :return: 59 | """ 60 | flops, params = profile(model, inputs=(input_tensor,)) 61 | flops, params = clever_format([flops, params], "%.3f") 62 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) --------------------------------------------------------------------------------