├── LICENSE ├── README.md ├── __pycache__ ├── archs.cpython-36.pyc ├── dataset.cpython-36.pyc ├── losses.cpython-36.pyc ├── metrics.cpython-36.pyc └── utils.cpython-36.pyc ├── archs.py ├── config.py ├── dataset.py ├── environment.yml ├── imgs ├── readme.md └── unext.png ├── losses.py ├── metrics.py ├── post_process.py ├── train.py ├── utils.py └── val.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jeya Maria Jose 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNeXt 2 | 3 | Official Pytorch Code base for [UNeXt: MLP-based Rapid Medical Image Segmentation Network](https://arxiv.org/abs/2203.04967), MICCAI 2022 4 | 5 | [Paper](https://arxiv.org/abs/2203.04967) | [Project](https://jeya-maria-jose.github.io/UNext-web/) 6 | 7 | ## Introduction 8 | 9 | UNet and its latest extensions like TransUNet have been the leading medical image segmentation methods in recent years. However, these networks cannot be effectively adopted for rapid image segmentation in point-of-care applications as they are parameter-heavy, computationally complex and slow to use. To this end, we propose UNeXt which is a Convolutional multilayer perceptron (MLP) based network for image segmentation. We design UNeXt in an effective way with an early convolutional stage and a MLP stage in the latent stage. We propose a tokenized MLP block where we efficiently tokenize and project the convolutional features and use MLPs to model the representation. To further boost the performance, we propose shifting the channels of the inputs while feeding in to MLPs so as to focus on learning local dependencies. Using tokenized MLPs in latent space reduces the number of parameters and computational complexity while being able to result in a better representation to help segmentation. The network also consists of skip connections between various levels of encoder and decoder. We test UNeXt on multiple medical image segmentation datasets and show that we reduce the number of parameters by 72x, decrease the computational complexity by 68x, and improve the inference speed by 10x while also obtaining better segmentation performance over the state-of-the-art medical image segmentation architectures. 10 | 11 |

12 | 13 |

14 | 15 | 16 | ## Using the code: 17 | 18 | The code is stable while using Python 3.6.13, CUDA >=10.1 19 | 20 | - Clone this repository: 21 | ```bash 22 | git clone https://github.com/jeya-maria-jose/UNeXt-pytorch 23 | cd UNeXt-pytorch 24 | ``` 25 | 26 | To install all the dependencies using conda: 27 | 28 | ```bash 29 | conda env create -f environment.yml 30 | conda activate unext 31 | ``` 32 | 33 | If you prefer pip, install following versions: 34 | 35 | ```bash 36 | timm==0.3.2 37 | mmcv-full==1.2.7 38 | torch==1.7.1 39 | torchvision==0.8.2 40 | opencv-python==4.5.1.48 41 | ``` 42 | 43 | ## Datasets 44 | 45 | 1) ISIC 2018 - [Link](https://challenge.isic-archive.com/data/) 46 | 2) BUSI - [Link](https://www.kaggle.com/aryashah2k/breast-ultrasound-images-dataset) 47 | 48 | ## Data Format 49 | 50 | Make sure to put the files as the following structure (e.g. the number of classes is 2): 51 | 52 | ``` 53 | inputs 54 | └── 55 | ├── images 56 | | ├── 001.png 57 | │ ├── 002.png 58 | │ ├── 003.png 59 | │ ├── ... 60 | | 61 | └── masks 62 | ├── 0 63 | | ├── 001.png 64 | | ├── 002.png 65 | | ├── 003.png 66 | | ├── ... 67 | | 68 | └── 1 69 | ├── 001.png 70 | ├── 002.png 71 | ├── 003.png 72 | ├── ... 73 | ``` 74 | 75 | For binary segmentation problems, just use folder 0. 76 | 77 | ## Training and Validation 78 | 79 | 1. Train the model. 80 | ``` 81 | python train.py --dataset --arch UNext --name --img_ext .png --mask_ext .png --lr 0.0001 --epochs 500 --input_w 512 --input_h 512 --b 8 82 | ``` 83 | 2. Evaluate. 84 | ``` 85 | python val.py --name 86 | ``` 87 | 88 | ### Acknowledgements: 89 | 90 | This code-base uses certain code-blocks and helper functions from [UNet++](https://github.com/4uiiurz1/pytorch-nested-unet), [Segformer](https://github.com/NVlabs/SegFormer), and [AS-MLP](https://github.com/svip-lab/AS-MLP). Naming credits to [Poojan](https://scholar.google.co.in/citations?user=9dhBHuAAAAAJ&hl=en). 91 | 92 | ### Citation: 93 | ``` 94 | @article{valanarasu2022unext, 95 | title={UNeXt: MLP-based Rapid Medical Image Segmentation Network}, 96 | author={Valanarasu, Jeya Maria Jose and Patel, Vishal M}, 97 | journal={arXiv preprint arXiv:2203.04967}, 98 | year={2022} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /__pycache__/archs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/archs.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch 4 | import torchvision 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from torchvision.utils import save_image 10 | import torch.nn.functional as F 11 | import os 12 | import matplotlib.pyplot as plt 13 | from utils import * 14 | __all__ = ['UNext'] 15 | 16 | import timm 17 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 18 | import types 19 | import math 20 | from abc import ABCMeta, abstractmethod 21 | from mmcv.cnn import ConvModule 22 | import pdb 23 | 24 | 25 | 26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 29 | 30 | 31 | def shift(dim): 32 | x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] 33 | x_cat = torch.cat(x_shift, 1) 34 | x_cat = torch.narrow(x_cat, 2, self.pad, H) 35 | x_cat = torch.narrow(x_cat, 3, self.pad, W) 36 | return x_cat 37 | 38 | class shiftmlp(nn.Module): 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | self.dim = in_features 44 | self.fc1 = nn.Linear(in_features, hidden_features) 45 | self.dwconv = DWConv(hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.drop = nn.Dropout(drop) 49 | 50 | self.shift_size = shift_size 51 | self.pad = shift_size // 2 52 | 53 | 54 | self.apply(self._init_weights) 55 | 56 | def _init_weights(self, m): 57 | if isinstance(m, nn.Linear): 58 | trunc_normal_(m.weight, std=.02) 59 | if isinstance(m, nn.Linear) and m.bias is not None: 60 | nn.init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.LayerNorm): 62 | nn.init.constant_(m.bias, 0) 63 | nn.init.constant_(m.weight, 1.0) 64 | elif isinstance(m, nn.Conv2d): 65 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | fan_out //= m.groups 67 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | 71 | # def shift(x, dim): 72 | # x = F.pad(x, "constant", 0) 73 | # x = torch.chunk(x, shift_size, 1) 74 | # x = [ torch.roll(x_c, shift, dim) for x_s, shift in zip(x, range(-pad, pad+1))] 75 | # x = torch.cat(x, 1) 76 | # return x[:, :, pad:-pad, pad:-pad] 77 | 78 | def forward(self, x, H, W): 79 | # pdb.set_trace() 80 | B, N, C = x.shape 81 | 82 | xn = x.transpose(1, 2).view(B, C, H, W).contiguous() 83 | xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0) 84 | xs = torch.chunk(xn, self.shift_size, 1) 85 | x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] 86 | x_cat = torch.cat(x_shift, 1) 87 | x_cat = torch.narrow(x_cat, 2, self.pad, H) 88 | x_s = torch.narrow(x_cat, 3, self.pad, W) 89 | 90 | 91 | x_s = x_s.reshape(B,C,H*W).contiguous() 92 | x_shift_r = x_s.transpose(1,2) 93 | 94 | 95 | x = self.fc1(x_shift_r) 96 | 97 | x = self.dwconv(x, H, W) 98 | x = self.act(x) 99 | x = self.drop(x) 100 | 101 | xn = x.transpose(1, 2).view(B, C, H, W).contiguous() 102 | xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0) 103 | xs = torch.chunk(xn, self.shift_size, 1) 104 | x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] 105 | x_cat = torch.cat(x_shift, 1) 106 | x_cat = torch.narrow(x_cat, 2, self.pad, H) 107 | x_s = torch.narrow(x_cat, 3, self.pad, W) 108 | x_s = x_s.reshape(B,C,H*W).contiguous() 109 | x_shift_c = x_s.transpose(1,2) 110 | 111 | x = self.fc2(x_shift_c) 112 | x = self.drop(x) 113 | return x 114 | 115 | 116 | 117 | class shiftedBlock(nn.Module): 118 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 119 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 120 | super().__init__() 121 | 122 | 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | self.apply(self._init_weights) 128 | 129 | def _init_weights(self, m): 130 | if isinstance(m, nn.Linear): 131 | trunc_normal_(m.weight, std=.02) 132 | if isinstance(m, nn.Linear) and m.bias is not None: 133 | nn.init.constant_(m.bias, 0) 134 | elif isinstance(m, nn.LayerNorm): 135 | nn.init.constant_(m.bias, 0) 136 | nn.init.constant_(m.weight, 1.0) 137 | elif isinstance(m, nn.Conv2d): 138 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 139 | fan_out //= m.groups 140 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 141 | if m.bias is not None: 142 | m.bias.data.zero_() 143 | 144 | def forward(self, x, H, W): 145 | 146 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 147 | return x 148 | 149 | 150 | class DWConv(nn.Module): 151 | def __init__(self, dim=768): 152 | super(DWConv, self).__init__() 153 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 154 | 155 | def forward(self, x, H, W): 156 | B, N, C = x.shape 157 | x = x.transpose(1, 2).view(B, C, H, W) 158 | x = self.dwconv(x) 159 | x = x.flatten(2).transpose(1, 2) 160 | 161 | return x 162 | 163 | class OverlapPatchEmbed(nn.Module): 164 | """ Image to Patch Embedding 165 | """ 166 | 167 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 168 | super().__init__() 169 | img_size = to_2tuple(img_size) 170 | patch_size = to_2tuple(patch_size) 171 | 172 | self.img_size = img_size 173 | self.patch_size = patch_size 174 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 175 | self.num_patches = self.H * self.W 176 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 177 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 178 | self.norm = nn.LayerNorm(embed_dim) 179 | 180 | self.apply(self._init_weights) 181 | 182 | def _init_weights(self, m): 183 | if isinstance(m, nn.Linear): 184 | trunc_normal_(m.weight, std=.02) 185 | if isinstance(m, nn.Linear) and m.bias is not None: 186 | nn.init.constant_(m.bias, 0) 187 | elif isinstance(m, nn.LayerNorm): 188 | nn.init.constant_(m.bias, 0) 189 | nn.init.constant_(m.weight, 1.0) 190 | elif isinstance(m, nn.Conv2d): 191 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | fan_out //= m.groups 193 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 194 | if m.bias is not None: 195 | m.bias.data.zero_() 196 | 197 | def forward(self, x): 198 | x = self.proj(x) 199 | _, _, H, W = x.shape 200 | x = x.flatten(2).transpose(1, 2) 201 | x = self.norm(x) 202 | 203 | return x, H, W 204 | 205 | 206 | class UNext(nn.Module): 207 | 208 | ## Conv 3 + MLP 2 + shifted MLP 209 | 210 | def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3, embed_dims=[ 128, 160, 256], 211 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 212 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 213 | depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs): 214 | super().__init__() 215 | 216 | self.encoder1 = nn.Conv2d(3, 16, 3, stride=1, padding=1) 217 | self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1) 218 | self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1) 219 | 220 | self.ebn1 = nn.BatchNorm2d(16) 221 | self.ebn2 = nn.BatchNorm2d(32) 222 | self.ebn3 = nn.BatchNorm2d(128) 223 | 224 | self.norm3 = norm_layer(embed_dims[1]) 225 | self.norm4 = norm_layer(embed_dims[2]) 226 | 227 | self.dnorm3 = norm_layer(160) 228 | self.dnorm4 = norm_layer(128) 229 | 230 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 231 | 232 | self.block1 = nn.ModuleList([shiftedBlock( 233 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 234 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, 235 | sr_ratio=sr_ratios[0])]) 236 | 237 | self.block2 = nn.ModuleList([shiftedBlock( 238 | dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, 240 | sr_ratio=sr_ratios[0])]) 241 | 242 | self.dblock1 = nn.ModuleList([shiftedBlock( 243 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 244 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, 245 | sr_ratio=sr_ratios[0])]) 246 | 247 | self.dblock2 = nn.ModuleList([shiftedBlock( 248 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 249 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, 250 | sr_ratio=sr_ratios[0])]) 251 | 252 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 253 | embed_dim=embed_dims[1]) 254 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 255 | embed_dim=embed_dims[2]) 256 | 257 | self.decoder1 = nn.Conv2d(256, 160, 3, stride=1,padding=1) 258 | self.decoder2 = nn.Conv2d(160, 128, 3, stride=1, padding=1) 259 | self.decoder3 = nn.Conv2d(128, 32, 3, stride=1, padding=1) 260 | self.decoder4 = nn.Conv2d(32, 16, 3, stride=1, padding=1) 261 | self.decoder5 = nn.Conv2d(16, 16, 3, stride=1, padding=1) 262 | 263 | self.dbn1 = nn.BatchNorm2d(160) 264 | self.dbn2 = nn.BatchNorm2d(128) 265 | self.dbn3 = nn.BatchNorm2d(32) 266 | self.dbn4 = nn.BatchNorm2d(16) 267 | 268 | self.final = nn.Conv2d(16, num_classes, kernel_size=1) 269 | 270 | self.soft = nn.Softmax(dim =1) 271 | 272 | def forward(self, x): 273 | 274 | B = x.shape[0] 275 | ### Encoder 276 | ### Conv Stage 277 | 278 | ### Stage 1 279 | out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) 280 | t1 = out 281 | ### Stage 2 282 | out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) 283 | t2 = out 284 | ### Stage 3 285 | out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) 286 | t3 = out 287 | 288 | ### Tokenized MLP Stage 289 | ### Stage 4 290 | 291 | out,H,W = self.patch_embed3(out) 292 | for i, blk in enumerate(self.block1): 293 | out = blk(out, H, W) 294 | out = self.norm3(out) 295 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 296 | t4 = out 297 | 298 | ### Bottleneck 299 | 300 | out ,H,W= self.patch_embed4(out) 301 | for i, blk in enumerate(self.block2): 302 | out = blk(out, H, W) 303 | out = self.norm4(out) 304 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 305 | 306 | ### Stage 4 307 | 308 | out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear')) 309 | 310 | out = torch.add(out,t4) 311 | _,_,H,W = out.shape 312 | out = out.flatten(2).transpose(1,2) 313 | for i, blk in enumerate(self.dblock1): 314 | out = blk(out, H, W) 315 | 316 | ### Stage 3 317 | 318 | out = self.dnorm3(out) 319 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 320 | out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear')) 321 | out = torch.add(out,t3) 322 | _,_,H,W = out.shape 323 | out = out.flatten(2).transpose(1,2) 324 | 325 | for i, blk in enumerate(self.dblock2): 326 | out = blk(out, H, W) 327 | 328 | out = self.dnorm4(out) 329 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 330 | 331 | out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear')) 332 | out = torch.add(out,t2) 333 | out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear')) 334 | out = torch.add(out,t1) 335 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear')) 336 | 337 | return self.final(out) 338 | 339 | 340 | class UNext_S(nn.Module): 341 | 342 | ## Conv 3 + MLP 2 + shifted MLP w less parameters 343 | 344 | def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3, embed_dims=[32, 64, 128, 512], 345 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 346 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 347 | depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs): 348 | super().__init__() 349 | 350 | self.encoder1 = nn.Conv2d(3, 8, 3, stride=1, padding=1) 351 | self.encoder2 = nn.Conv2d(8, 16, 3, stride=1, padding=1) 352 | self.encoder3 = nn.Conv2d(16, 32, 3, stride=1, padding=1) 353 | 354 | self.ebn1 = nn.BatchNorm2d(8) 355 | self.ebn2 = nn.BatchNorm2d(16) 356 | self.ebn3 = nn.BatchNorm2d(32) 357 | 358 | self.norm3 = norm_layer(embed_dims[1]) 359 | self.norm4 = norm_layer(embed_dims[2]) 360 | 361 | self.dnorm3 = norm_layer(64) 362 | self.dnorm4 = norm_layer(32) 363 | 364 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 365 | 366 | self.block1 = nn.ModuleList([shiftedBlock( 367 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 368 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, 369 | sr_ratio=sr_ratios[0])]) 370 | 371 | self.block2 = nn.ModuleList([shiftedBlock( 372 | dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 373 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, 374 | sr_ratio=sr_ratios[0])]) 375 | 376 | self.dblock1 = nn.ModuleList([shiftedBlock( 377 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 378 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, 379 | sr_ratio=sr_ratios[0])]) 380 | 381 | self.dblock2 = nn.ModuleList([shiftedBlock( 382 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, 383 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, 384 | sr_ratio=sr_ratios[0])]) 385 | 386 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 387 | embed_dim=embed_dims[1]) 388 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 389 | embed_dim=embed_dims[2]) 390 | 391 | self.decoder1 = nn.Conv2d(128, 64, 3, stride=1,padding=1) 392 | self.decoder2 = nn.Conv2d(64, 32, 3, stride=1, padding=1) 393 | self.decoder3 = nn.Conv2d(32, 16, 3, stride=1, padding=1) 394 | self.decoder4 = nn.Conv2d(16, 8, 3, stride=1, padding=1) 395 | self.decoder5 = nn.Conv2d(8, 8, 3, stride=1, padding=1) 396 | 397 | self.dbn1 = nn.BatchNorm2d(64) 398 | self.dbn2 = nn.BatchNorm2d(32) 399 | self.dbn3 = nn.BatchNorm2d(16) 400 | self.dbn4 = nn.BatchNorm2d(8) 401 | 402 | self.final = nn.Conv2d(8, num_classes, kernel_size=1) 403 | 404 | self.soft = nn.Softmax(dim =1) 405 | 406 | def forward(self, x): 407 | 408 | B = x.shape[0] 409 | ### Encoder 410 | ### Conv Stage 411 | 412 | ### Stage 1 413 | out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) 414 | t1 = out 415 | ### Stage 2 416 | out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) 417 | t2 = out 418 | ### Stage 3 419 | out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) 420 | t3 = out 421 | 422 | ### Tokenized MLP Stage 423 | ### Stage 4 424 | 425 | out,H,W = self.patch_embed3(out) 426 | for i, blk in enumerate(self.block1): 427 | out = blk(out, H, W) 428 | out = self.norm3(out) 429 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 430 | t4 = out 431 | 432 | ### Bottleneck 433 | 434 | out ,H,W= self.patch_embed4(out) 435 | for i, blk in enumerate(self.block2): 436 | out = blk(out, H, W) 437 | out = self.norm4(out) 438 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 439 | 440 | ### Stage 4 441 | 442 | out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear')) 443 | 444 | out = torch.add(out,t4) 445 | _,_,H,W = out.shape 446 | out = out.flatten(2).transpose(1,2) 447 | for i, blk in enumerate(self.dblock1): 448 | out = blk(out, H, W) 449 | 450 | ### Stage 3 451 | 452 | out = self.dnorm3(out) 453 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 454 | out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear')) 455 | out = torch.add(out,t3) 456 | _,_,H,W = out.shape 457 | out = out.flatten(2).transpose(1,2) 458 | 459 | for i, blk in enumerate(self.dblock2): 460 | out = blk(out, H, W) 461 | 462 | out = self.dnorm4(out) 463 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 464 | 465 | out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear')) 466 | out = torch.add(out,t2) 467 | out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear')) 468 | out = torch.add(out,t1) 469 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear')) 470 | 471 | return self.final(out) 472 | 473 | 474 | #EOF 475 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 1 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 256 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 4 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = False 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = False 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | # update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | 9 | class Dataset(torch.utils.data.Dataset): 10 | def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None): 11 | """ 12 | Args: 13 | img_ids (list): Image ids. 14 | img_dir: Image file directory. 15 | mask_dir: Mask file directory. 16 | img_ext (str): Image file extension. 17 | mask_ext (str): Mask file extension. 18 | num_classes (int): Number of classes. 19 | transform (Compose, optional): Compose transforms of albumentations. Defaults to None. 20 | 21 | Note: 22 | Make sure to put the files as the following structure: 23 | 24 | ├── images 25 | | ├── 0a7e06.jpg 26 | │ ├── 0aab0a.jpg 27 | │ ├── 0b1761.jpg 28 | │ ├── ... 29 | | 30 | └── masks 31 | ├── 0 32 | | ├── 0a7e06.png 33 | | ├── 0aab0a.png 34 | | ├── 0b1761.png 35 | | ├── ... 36 | | 37 | ├── 1 38 | | ├── 0a7e06.png 39 | | ├── 0aab0a.png 40 | | ├── 0b1761.png 41 | | ├── ... 42 | ... 43 | """ 44 | self.img_ids = img_ids 45 | self.img_dir = img_dir 46 | self.mask_dir = mask_dir 47 | self.img_ext = img_ext 48 | self.mask_ext = mask_ext 49 | self.num_classes = num_classes 50 | self.transform = transform 51 | 52 | def __len__(self): 53 | return len(self.img_ids) 54 | 55 | def __getitem__(self, idx): 56 | img_id = self.img_ids[idx] 57 | 58 | img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext)) 59 | 60 | mask = [] 61 | for i in range(self.num_classes): 62 | mask.append(cv2.imread(os.path.join(self.mask_dir, str(i), 63 | img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None]) 64 | mask = np.dstack(mask) 65 | 66 | if self.transform is not None: 67 | augmented = self.transform(image=img, mask=mask) 68 | img = augmented['image'] 69 | mask = augmented['mask'] 70 | 71 | img = img.astype('float32') / 255 72 | img = img.transpose(2, 0, 1) 73 | mask = mask.astype('float32') / 255 74 | mask = mask.transpose(2, 0, 1) 75 | 76 | return img, mask, {'img_id': img_id} 77 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: unext 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - ca-certificates=2021.10.26=h06a4308_2 8 | - certifi=2021.5.30=py36h06a4308_0 9 | - ld_impl_linux-64=2.35.1=h7274673_9 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.3.0=h5101ec6_17 12 | - libgomp=9.3.0=h5101ec6_17 13 | - libstdcxx-ng=9.3.0=hd4cf53a_17 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1l=h7f8727e_0 16 | - pip=21.2.2=py36h06a4308_0 17 | - python=3.6.13=h12debd9_1 18 | - readline=8.1=h27cfd23_0 19 | - setuptools=58.0.4=py36h06a4308_0 20 | - sqlite=3.36.0=hc218d9a_0 21 | - tk=8.6.11=h1ccaba5_0 22 | - wheel=0.37.0=pyhd3eb1b0_1 23 | - xz=5.2.5=h7b6447c_0 24 | - zlib=1.2.11=h7b6447c_3 25 | - pip: 26 | - addict==2.4.0 27 | - dataclasses==0.8 28 | - mmcv-full==1.2.7 29 | - numpy==1.19.5 30 | - opencv-python==4.5.1.48 31 | - perceptual==0.1 32 | - pillow==8.4.0 33 | - scikit-image==0.17.2 34 | - scipy==1.5.4 35 | - tifffile==2020.9.3 36 | - timm==0.3.2 37 | - torch==1.7.1 38 | - torchvision==0.8.2 39 | - typing-extensions==4.0.0 40 | - yapf==0.31.0 41 | prefix: /home/jeyamariajose/anaconda3/envs/transweather 42 | 43 | -------------------------------------------------------------------------------- /imgs/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /imgs/unext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/imgs/unext.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge 7 | except ImportError: 8 | pass 9 | 10 | __all__ = ['BCEDiceLoss', 'LovaszHingeLoss'] 11 | 12 | 13 | class BCEDiceLoss(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, input, target): 18 | bce = F.binary_cross_entropy_with_logits(input, target) 19 | smooth = 1e-5 20 | input = torch.sigmoid(input) 21 | num = target.size(0) 22 | input = input.view(num, -1) 23 | target = target.view(num, -1) 24 | intersection = (input * target) 25 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) 26 | dice = 1 - dice.sum() / num 27 | return 0.5 * bce + dice 28 | 29 | 30 | class LovaszHingeLoss(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | 34 | def forward(self, input, target): 35 | input = input.squeeze(1) 36 | target = target.squeeze(1) 37 | loss = lovasz_hinge(input, target, per_image=True) 38 | 39 | return loss 40 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def iou_score(output, target): 7 | smooth = 1e-5 8 | 9 | if torch.is_tensor(output): 10 | output = torch.sigmoid(output).data.cpu().numpy() 11 | if torch.is_tensor(target): 12 | target = target.data.cpu().numpy() 13 | output_ = output > 0.5 14 | target_ = target > 0.5 15 | intersection = (output_ & target_).sum() 16 | union = (output_ | target_).sum() 17 | iou = (intersection + smooth) / (union + smooth) 18 | dice = (2* iou) / (iou+1) 19 | return iou, dice 20 | 21 | 22 | def dice_coef(output, target): 23 | smooth = 1e-5 24 | 25 | output = torch.sigmoid(output).view(-1).data.cpu().numpy() 26 | target = target.view(-1).data.cpu().numpy() 27 | intersection = (output * target).sum() 28 | 29 | return (2. * intersection + smooth) / \ 30 | (output.sum() + target.sum() + smooth) 31 | -------------------------------------------------------------------------------- /post_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from glob import glob 4 | 5 | import cv2 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import yaml 9 | from albumentations.augmentations import transforms 10 | from albumentations.core.composition import Compose 11 | from sklearn.model_selection import train_test_split 12 | from tqdm import tqdm 13 | 14 | import archs 15 | from dataset import Dataset 16 | from metrics import iou_score 17 | from utils import AverageMeter 18 | from albumentations import RandomRotate90,Resize 19 | import time 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--name', default=None, 25 | help='model name') 26 | 27 | args = parser.parse_args() 28 | 29 | return args 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | 35 | with open('models/%s/config.yml' % args.name, 'r') as f: 36 | config = yaml.load(f, Loader=yaml.FullLoader) 37 | 38 | print('-'*20) 39 | for key in config.keys(): 40 | print('%s: %s' % (key, str(config[key]))) 41 | print('-'*20) 42 | 43 | cudnn.benchmark = True 44 | 45 | # create model 46 | print("=> creating model %s" % config['arch']) 47 | model = archs.__dict__[config['arch']](config['num_classes'], 48 | config['input_channels'], 49 | config['deep_supervision']) 50 | 51 | model = model.cuda() 52 | 53 | # Data loading code 54 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext'])) 55 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] 56 | 57 | _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41) 58 | 59 | model.load_state_dict(torch.load('models/%s/model.pth' % 60 | config['name'])) 61 | model.eval() 62 | 63 | val_transform = Compose([ 64 | Resize(config['input_h'], config['input_w']), 65 | transforms.Normalize(), 66 | ]) 67 | 68 | val_dataset = Dataset( 69 | img_ids=val_img_ids, 70 | img_dir=os.path.join('inputs', config['dataset'], 'images'), 71 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'), 72 | img_ext=config['img_ext'], 73 | mask_ext=config['mask_ext'], 74 | num_classes=config['num_classes'], 75 | transform=val_transform) 76 | val_loader = torch.utils.data.DataLoader( 77 | val_dataset, 78 | batch_size=config['batch_size'], 79 | shuffle=False, 80 | num_workers=config['num_workers'], 81 | drop_last=False) 82 | 83 | iou_avg_meter = AverageMeter() 84 | dice_avg_meter = AverageMeter() 85 | gput = AverageMeter() 86 | cput = AverageMeter() 87 | 88 | count = 0 89 | for c in range(config['num_classes']): 90 | os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True) 91 | with torch.no_grad(): 92 | for input, target, meta in tqdm(val_loader, total=len(val_loader)): 93 | input = input.cuda() 94 | target = target.cuda() 95 | model = model.cuda() 96 | # compute output 97 | 98 | if count<=5: 99 | start = time.time() 100 | if config['deep_supervision']: 101 | output = model(input)[-1] 102 | else: 103 | output = model(input) 104 | stop = time.time() 105 | 106 | gput.update(stop-start, input.size(0)) 107 | 108 | start = time.time() 109 | model = model.cpu() 110 | input = input.cpu() 111 | output = model(input) 112 | stop = time.time() 113 | 114 | cput.update(stop-start, input.size(0)) 115 | count=count+1 116 | 117 | iou,dice = iou_score(output, target) 118 | iou_avg_meter.update(iou, input.size(0)) 119 | dice_avg_meter.update(dice, input.size(0)) 120 | 121 | output = torch.sigmoid(output).cpu().numpy() 122 | 123 | for i in range(len(output)): 124 | for c in range(config['num_classes']): 125 | cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'), 126 | (output[i, c] * 255).astype('uint8')) 127 | 128 | print('IoU: %.4f' % iou_avg_meter.avg) 129 | print('Dice: %.4f' % dice_avg_meter.avg) 130 | 131 | print('CPU: %.4f' %cput.avg) 132 | print('GPU: %.4f' %gput.avg) 133 | 134 | torch.cuda.empty_cache() 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import OrderedDict 4 | from glob import glob 5 | 6 | import pandas as pd 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import yaml 12 | from albumentations.augmentations import transforms 13 | from albumentations.core.composition import Compose, OneOf 14 | from sklearn.model_selection import train_test_split 15 | from torch.optim import lr_scheduler 16 | from tqdm import tqdm 17 | from albumentations import RandomRotate90,Resize 18 | import archs 19 | import losses 20 | from dataset import Dataset 21 | from metrics import iou_score 22 | from utils import AverageMeter, str2bool 23 | from archs import UNext 24 | 25 | 26 | ARCH_NAMES = archs.__all__ 27 | LOSS_NAMES = losses.__all__ 28 | LOSS_NAMES.append('BCEWithLogitsLoss') 29 | 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument('--name', default=None, 36 | help='model name: (default: arch+timestamp)') 37 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('-b', '--batch_size', default=16, type=int, 40 | metavar='N', help='mini-batch size (default: 16)') 41 | 42 | # model 43 | parser.add_argument('--arch', '-a', metavar='ARCH', default='UNext') 44 | parser.add_argument('--deep_supervision', default=False, type=str2bool) 45 | parser.add_argument('--input_channels', default=3, type=int, 46 | help='input channels') 47 | parser.add_argument('--num_classes', default=1, type=int, 48 | help='number of classes') 49 | parser.add_argument('--input_w', default=256, type=int, 50 | help='image width') 51 | parser.add_argument('--input_h', default=256, type=int, 52 | help='image height') 53 | 54 | # loss 55 | parser.add_argument('--loss', default='BCEDiceLoss', 56 | choices=LOSS_NAMES, 57 | help='loss: ' + 58 | ' | '.join(LOSS_NAMES) + 59 | ' (default: BCEDiceLoss)') 60 | 61 | # dataset 62 | parser.add_argument('--dataset', default='isic', 63 | help='dataset name') 64 | parser.add_argument('--img_ext', default='.png', 65 | help='image file extension') 66 | parser.add_argument('--mask_ext', default='.png', 67 | help='mask file extension') 68 | 69 | # optimizer 70 | parser.add_argument('--optimizer', default='Adam', 71 | choices=['Adam', 'SGD'], 72 | help='loss: ' + 73 | ' | '.join(['Adam', 'SGD']) + 74 | ' (default: Adam)') 75 | parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float, 76 | metavar='LR', help='initial learning rate') 77 | parser.add_argument('--momentum', default=0.9, type=float, 78 | help='momentum') 79 | parser.add_argument('--weight_decay', default=1e-4, type=float, 80 | help='weight decay') 81 | parser.add_argument('--nesterov', default=False, type=str2bool, 82 | help='nesterov') 83 | 84 | # scheduler 85 | parser.add_argument('--scheduler', default='CosineAnnealingLR', 86 | choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR']) 87 | parser.add_argument('--min_lr', default=1e-5, type=float, 88 | help='minimum learning rate') 89 | parser.add_argument('--factor', default=0.1, type=float) 90 | parser.add_argument('--patience', default=2, type=int) 91 | parser.add_argument('--milestones', default='1,2', type=str) 92 | parser.add_argument('--gamma', default=2/3, type=float) 93 | parser.add_argument('--early_stopping', default=-1, type=int, 94 | metavar='N', help='early stopping (default: -1)') 95 | parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', ) 96 | 97 | parser.add_argument('--num_workers', default=4, type=int) 98 | 99 | config = parser.parse_args() 100 | 101 | return config 102 | 103 | # args = parser.parse_args() 104 | def train(config, train_loader, model, criterion, optimizer): 105 | avg_meters = {'loss': AverageMeter(), 106 | 'iou': AverageMeter()} 107 | 108 | model.train() 109 | 110 | pbar = tqdm(total=len(train_loader)) 111 | for input, target, _ in train_loader: 112 | input = input.cuda() 113 | target = target.cuda() 114 | 115 | # compute output 116 | if config['deep_supervision']: 117 | outputs = model(input) 118 | loss = 0 119 | for output in outputs: 120 | loss += criterion(output, target) 121 | loss /= len(outputs) 122 | iou,dice = iou_score(outputs[-1], target) 123 | else: 124 | output = model(input) 125 | loss = criterion(output, target) 126 | iou,dice = iou_score(output, target) 127 | 128 | # compute gradient and do optimizing step 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | avg_meters['loss'].update(loss.item(), input.size(0)) 134 | avg_meters['iou'].update(iou, input.size(0)) 135 | 136 | postfix = OrderedDict([ 137 | ('loss', avg_meters['loss'].avg), 138 | ('iou', avg_meters['iou'].avg), 139 | ]) 140 | pbar.set_postfix(postfix) 141 | pbar.update(1) 142 | pbar.close() 143 | 144 | return OrderedDict([('loss', avg_meters['loss'].avg), 145 | ('iou', avg_meters['iou'].avg)]) 146 | 147 | 148 | def validate(config, val_loader, model, criterion): 149 | avg_meters = {'loss': AverageMeter(), 150 | 'iou': AverageMeter(), 151 | 'dice': AverageMeter()} 152 | 153 | # switch to evaluate mode 154 | model.eval() 155 | 156 | with torch.no_grad(): 157 | pbar = tqdm(total=len(val_loader)) 158 | for input, target, _ in val_loader: 159 | input = input.cuda() 160 | target = target.cuda() 161 | 162 | # compute output 163 | if config['deep_supervision']: 164 | outputs = model(input) 165 | loss = 0 166 | for output in outputs: 167 | loss += criterion(output, target) 168 | loss /= len(outputs) 169 | iou,dice = iou_score(outputs[-1], target) 170 | else: 171 | output = model(input) 172 | loss = criterion(output, target) 173 | iou,dice = iou_score(output, target) 174 | 175 | avg_meters['loss'].update(loss.item(), input.size(0)) 176 | avg_meters['iou'].update(iou, input.size(0)) 177 | avg_meters['dice'].update(dice, input.size(0)) 178 | 179 | postfix = OrderedDict([ 180 | ('loss', avg_meters['loss'].avg), 181 | ('iou', avg_meters['iou'].avg), 182 | ('dice', avg_meters['dice'].avg) 183 | ]) 184 | pbar.set_postfix(postfix) 185 | pbar.update(1) 186 | pbar.close() 187 | 188 | return OrderedDict([('loss', avg_meters['loss'].avg), 189 | ('iou', avg_meters['iou'].avg), 190 | ('dice', avg_meters['dice'].avg)]) 191 | 192 | 193 | def main(): 194 | config = vars(parse_args()) 195 | 196 | if config['name'] is None: 197 | if config['deep_supervision']: 198 | config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch']) 199 | else: 200 | config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch']) 201 | 202 | os.makedirs('models/%s' % config['name'], exist_ok=True) 203 | 204 | print('-' * 20) 205 | for key in config: 206 | print('%s: %s' % (key, config[key])) 207 | print('-' * 20) 208 | 209 | with open('models/%s/config.yml' % config['name'], 'w') as f: 210 | yaml.dump(config, f) 211 | 212 | # define loss function (criterion) 213 | if config['loss'] == 'BCEWithLogitsLoss': 214 | criterion = nn.BCEWithLogitsLoss().cuda() 215 | else: 216 | criterion = losses.__dict__[config['loss']]().cuda() 217 | 218 | cudnn.benchmark = True 219 | 220 | # create model 221 | model = archs.__dict__[config['arch']](config['num_classes'], 222 | config['input_channels'], 223 | config['deep_supervision']) 224 | 225 | model = model.cuda() 226 | 227 | params = filter(lambda p: p.requires_grad, model.parameters()) 228 | if config['optimizer'] == 'Adam': 229 | optimizer = optim.Adam( 230 | params, lr=config['lr'], weight_decay=config['weight_decay']) 231 | elif config['optimizer'] == 'SGD': 232 | optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], 233 | nesterov=config['nesterov'], weight_decay=config['weight_decay']) 234 | else: 235 | raise NotImplementedError 236 | 237 | if config['scheduler'] == 'CosineAnnealingLR': 238 | scheduler = lr_scheduler.CosineAnnealingLR( 239 | optimizer, T_max=config['epochs'], eta_min=config['min_lr']) 240 | elif config['scheduler'] == 'ReduceLROnPlateau': 241 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], 242 | verbose=1, min_lr=config['min_lr']) 243 | elif config['scheduler'] == 'MultiStepLR': 244 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) 245 | elif config['scheduler'] == 'ConstantLR': 246 | scheduler = None 247 | else: 248 | raise NotImplementedError 249 | 250 | # Data loading code 251 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext'])) 252 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] 253 | 254 | train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41) 255 | 256 | train_transform = Compose([ 257 | RandomRotate90(), 258 | transforms.Flip(), 259 | Resize(config['input_h'], config['input_w']), 260 | transforms.Normalize(), 261 | ]) 262 | 263 | val_transform = Compose([ 264 | Resize(config['input_h'], config['input_w']), 265 | transforms.Normalize(), 266 | ]) 267 | 268 | train_dataset = Dataset( 269 | img_ids=train_img_ids, 270 | img_dir=os.path.join('inputs', config['dataset'], 'images'), 271 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'), 272 | img_ext=config['img_ext'], 273 | mask_ext=config['mask_ext'], 274 | num_classes=config['num_classes'], 275 | transform=train_transform) 276 | val_dataset = Dataset( 277 | img_ids=val_img_ids, 278 | img_dir=os.path.join('inputs', config['dataset'], 'images'), 279 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'), 280 | img_ext=config['img_ext'], 281 | mask_ext=config['mask_ext'], 282 | num_classes=config['num_classes'], 283 | transform=val_transform) 284 | 285 | train_loader = torch.utils.data.DataLoader( 286 | train_dataset, 287 | batch_size=config['batch_size'], 288 | shuffle=True, 289 | num_workers=config['num_workers'], 290 | drop_last=True) 291 | val_loader = torch.utils.data.DataLoader( 292 | val_dataset, 293 | batch_size=config['batch_size'], 294 | shuffle=False, 295 | num_workers=config['num_workers'], 296 | drop_last=False) 297 | 298 | log = OrderedDict([ 299 | ('epoch', []), 300 | ('lr', []), 301 | ('loss', []), 302 | ('iou', []), 303 | ('val_loss', []), 304 | ('val_iou', []), 305 | ('val_dice', []), 306 | ]) 307 | 308 | best_iou = 0 309 | trigger = 0 310 | for epoch in range(config['epochs']): 311 | print('Epoch [%d/%d]' % (epoch, config['epochs'])) 312 | 313 | # train for one epoch 314 | train_log = train(config, train_loader, model, criterion, optimizer) 315 | # evaluate on validation set 316 | val_log = validate(config, val_loader, model, criterion) 317 | 318 | if config['scheduler'] == 'CosineAnnealingLR': 319 | scheduler.step() 320 | elif config['scheduler'] == 'ReduceLROnPlateau': 321 | scheduler.step(val_log['loss']) 322 | 323 | print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f' 324 | % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'])) 325 | 326 | log['epoch'].append(epoch) 327 | log['lr'].append(config['lr']) 328 | log['loss'].append(train_log['loss']) 329 | log['iou'].append(train_log['iou']) 330 | log['val_loss'].append(val_log['loss']) 331 | log['val_iou'].append(val_log['iou']) 332 | log['val_dice'].append(val_log['dice']) 333 | 334 | pd.DataFrame(log).to_csv('models/%s/log.csv' % 335 | config['name'], index=False) 336 | 337 | trigger += 1 338 | 339 | if val_log['iou'] > best_iou: 340 | torch.save(model.state_dict(), 'models/%s/model.pth' % 341 | config['name']) 342 | best_iou = val_log['iou'] 343 | print("=> saved best model") 344 | trigger = 0 345 | 346 | # early stopping 347 | if config['early_stopping'] >= 0 and trigger >= config['early_stopping']: 348 | print("=> early stopping") 349 | break 350 | 351 | torch.cuda.empty_cache() 352 | 353 | 354 | if __name__ == '__main__': 355 | main() 356 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | 4 | class qkv_transform(nn.Conv1d): 5 | """Conv1d for qkv_transform""" 6 | 7 | def str2bool(v): 8 | if v.lower() in ['true', 1]: 9 | return True 10 | elif v.lower() in ['false', 0]: 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def count_params(model): 17 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 18 | 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from glob import glob 4 | 5 | import cv2 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import yaml 9 | from albumentations.augmentations import transforms 10 | from albumentations.core.composition import Compose 11 | from sklearn.model_selection import train_test_split 12 | from tqdm import tqdm 13 | 14 | import archs 15 | from dataset import Dataset 16 | from metrics import iou_score 17 | from utils import AverageMeter 18 | from albumentations import RandomRotate90,Resize 19 | import time 20 | from archs import UNext 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument('--name', default=None, 27 | help='model name') 28 | 29 | args = parser.parse_args() 30 | 31 | return args 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | 37 | with open('models/%s/config.yml' % args.name, 'r') as f: 38 | config = yaml.load(f, Loader=yaml.FullLoader) 39 | 40 | print('-'*20) 41 | for key in config.keys(): 42 | print('%s: %s' % (key, str(config[key]))) 43 | print('-'*20) 44 | 45 | cudnn.benchmark = True 46 | 47 | print("=> creating model %s" % config['arch']) 48 | model = archs.__dict__[config['arch']](config['num_classes'], 49 | config['input_channels'], 50 | config['deep_supervision']) 51 | 52 | model = model.cuda() 53 | 54 | # Data loading code 55 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext'])) 56 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] 57 | 58 | _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41) 59 | 60 | model.load_state_dict(torch.load('models/%s/model.pth' % 61 | config['name'])) 62 | model.eval() 63 | 64 | val_transform = Compose([ 65 | Resize(config['input_h'], config['input_w']), 66 | transforms.Normalize(), 67 | ]) 68 | 69 | val_dataset = Dataset( 70 | img_ids=val_img_ids, 71 | img_dir=os.path.join('inputs', config['dataset'], 'images'), 72 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'), 73 | img_ext=config['img_ext'], 74 | mask_ext=config['mask_ext'], 75 | num_classes=config['num_classes'], 76 | transform=val_transform) 77 | val_loader = torch.utils.data.DataLoader( 78 | val_dataset, 79 | batch_size=config['batch_size'], 80 | shuffle=False, 81 | num_workers=config['num_workers'], 82 | drop_last=False) 83 | 84 | iou_avg_meter = AverageMeter() 85 | dice_avg_meter = AverageMeter() 86 | gput = AverageMeter() 87 | cput = AverageMeter() 88 | 89 | count = 0 90 | for c in range(config['num_classes']): 91 | os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True) 92 | with torch.no_grad(): 93 | for input, target, meta in tqdm(val_loader, total=len(val_loader)): 94 | input = input.cuda() 95 | target = target.cuda() 96 | model = model.cuda() 97 | # compute output 98 | output = model(input) 99 | 100 | 101 | iou,dice = iou_score(output, target) 102 | iou_avg_meter.update(iou, input.size(0)) 103 | dice_avg_meter.update(dice, input.size(0)) 104 | 105 | output = torch.sigmoid(output).cpu().numpy() 106 | output[output>=0.5]=1 107 | output[output<0.5]=0 108 | 109 | for i in range(len(output)): 110 | for c in range(config['num_classes']): 111 | cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'), 112 | (output[i, c] * 255).astype('uint8')) 113 | 114 | print('IoU: %.4f' % iou_avg_meter.avg) 115 | print('Dice: %.4f' % dice_avg_meter.avg) 116 | 117 | torch.cuda.empty_cache() 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | --------------------------------------------------------------------------------