├── LICENSE
├── README.md
├── __pycache__
├── archs.cpython-36.pyc
├── dataset.cpython-36.pyc
├── losses.cpython-36.pyc
├── metrics.cpython-36.pyc
└── utils.cpython-36.pyc
├── archs.py
├── config.py
├── dataset.py
├── environment.yml
├── imgs
├── readme.md
└── unext.png
├── losses.py
├── metrics.py
├── post_process.py
├── train.py
├── utils.py
└── val.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jeya Maria Jose
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNeXt
2 |
3 | Official Pytorch Code base for [UNeXt: MLP-based Rapid Medical Image Segmentation Network](https://arxiv.org/abs/2203.04967), MICCAI 2022
4 |
5 | [Paper](https://arxiv.org/abs/2203.04967) | [Project](https://jeya-maria-jose.github.io/UNext-web/)
6 |
7 | ## Introduction
8 |
9 | UNet and its latest extensions like TransUNet have been the leading medical image segmentation methods in recent years. However, these networks cannot be effectively adopted for rapid image segmentation in point-of-care applications as they are parameter-heavy, computationally complex and slow to use. To this end, we propose UNeXt which is a Convolutional multilayer perceptron (MLP) based network for image segmentation. We design UNeXt in an effective way with an early convolutional stage and a MLP stage in the latent stage. We propose a tokenized MLP block where we efficiently tokenize and project the convolutional features and use MLPs to model the representation. To further boost the performance, we propose shifting the channels of the inputs while feeding in to MLPs so as to focus on learning local dependencies. Using tokenized MLPs in latent space reduces the number of parameters and computational complexity while being able to result in a better representation to help segmentation. The network also consists of skip connections between various levels of encoder and decoder. We test UNeXt on multiple medical image segmentation datasets and show that we reduce the number of parameters by 72x, decrease the computational complexity by 68x, and improve the inference speed by 10x while also obtaining better segmentation performance over the state-of-the-art medical image segmentation architectures.
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Using the code:
17 |
18 | The code is stable while using Python 3.6.13, CUDA >=10.1
19 |
20 | - Clone this repository:
21 | ```bash
22 | git clone https://github.com/jeya-maria-jose/UNeXt-pytorch
23 | cd UNeXt-pytorch
24 | ```
25 |
26 | To install all the dependencies using conda:
27 |
28 | ```bash
29 | conda env create -f environment.yml
30 | conda activate unext
31 | ```
32 |
33 | If you prefer pip, install following versions:
34 |
35 | ```bash
36 | timm==0.3.2
37 | mmcv-full==1.2.7
38 | torch==1.7.1
39 | torchvision==0.8.2
40 | opencv-python==4.5.1.48
41 | ```
42 |
43 | ## Datasets
44 |
45 | 1) ISIC 2018 - [Link](https://challenge.isic-archive.com/data/)
46 | 2) BUSI - [Link](https://www.kaggle.com/aryashah2k/breast-ultrasound-images-dataset)
47 |
48 | ## Data Format
49 |
50 | Make sure to put the files as the following structure (e.g. the number of classes is 2):
51 |
52 | ```
53 | inputs
54 | └──
55 | ├── images
56 | | ├── 001.png
57 | │ ├── 002.png
58 | │ ├── 003.png
59 | │ ├── ...
60 | |
61 | └── masks
62 | ├── 0
63 | | ├── 001.png
64 | | ├── 002.png
65 | | ├── 003.png
66 | | ├── ...
67 | |
68 | └── 1
69 | ├── 001.png
70 | ├── 002.png
71 | ├── 003.png
72 | ├── ...
73 | ```
74 |
75 | For binary segmentation problems, just use folder 0.
76 |
77 | ## Training and Validation
78 |
79 | 1. Train the model.
80 | ```
81 | python train.py --dataset --arch UNext --name --img_ext .png --mask_ext .png --lr 0.0001 --epochs 500 --input_w 512 --input_h 512 --b 8
82 | ```
83 | 2. Evaluate.
84 | ```
85 | python val.py --name
86 | ```
87 |
88 | ### Acknowledgements:
89 |
90 | This code-base uses certain code-blocks and helper functions from [UNet++](https://github.com/4uiiurz1/pytorch-nested-unet), [Segformer](https://github.com/NVlabs/SegFormer), and [AS-MLP](https://github.com/svip-lab/AS-MLP). Naming credits to [Poojan](https://scholar.google.co.in/citations?user=9dhBHuAAAAAJ&hl=en).
91 |
92 | ### Citation:
93 | ```
94 | @article{valanarasu2022unext,
95 | title={UNeXt: MLP-based Rapid Medical Image Segmentation Network},
96 | author={Valanarasu, Jeya Maria Jose and Patel, Vishal M},
97 | journal={arXiv preprint arXiv:2203.04967},
98 | year={2022}
99 | }
100 | ```
101 |
--------------------------------------------------------------------------------
/__pycache__/archs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/archs.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/losses.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/losses.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/archs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch
4 | import torchvision
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 | from torchvision.utils import save_image
10 | import torch.nn.functional as F
11 | import os
12 | import matplotlib.pyplot as plt
13 | from utils import *
14 | __all__ = ['UNext']
15 |
16 | import timm
17 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18 | import types
19 | import math
20 | from abc import ABCMeta, abstractmethod
21 | from mmcv.cnn import ConvModule
22 | import pdb
23 |
24 |
25 |
26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
27 | """1x1 convolution"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)
29 |
30 |
31 | def shift(dim):
32 | x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
33 | x_cat = torch.cat(x_shift, 1)
34 | x_cat = torch.narrow(x_cat, 2, self.pad, H)
35 | x_cat = torch.narrow(x_cat, 3, self.pad, W)
36 | return x_cat
37 |
38 | class shiftmlp(nn.Module):
39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
40 | super().__init__()
41 | out_features = out_features or in_features
42 | hidden_features = hidden_features or in_features
43 | self.dim = in_features
44 | self.fc1 = nn.Linear(in_features, hidden_features)
45 | self.dwconv = DWConv(hidden_features)
46 | self.act = act_layer()
47 | self.fc2 = nn.Linear(hidden_features, out_features)
48 | self.drop = nn.Dropout(drop)
49 |
50 | self.shift_size = shift_size
51 | self.pad = shift_size // 2
52 |
53 |
54 | self.apply(self._init_weights)
55 |
56 | def _init_weights(self, m):
57 | if isinstance(m, nn.Linear):
58 | trunc_normal_(m.weight, std=.02)
59 | if isinstance(m, nn.Linear) and m.bias is not None:
60 | nn.init.constant_(m.bias, 0)
61 | elif isinstance(m, nn.LayerNorm):
62 | nn.init.constant_(m.bias, 0)
63 | nn.init.constant_(m.weight, 1.0)
64 | elif isinstance(m, nn.Conv2d):
65 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
66 | fan_out //= m.groups
67 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
68 | if m.bias is not None:
69 | m.bias.data.zero_()
70 |
71 | # def shift(x, dim):
72 | # x = F.pad(x, "constant", 0)
73 | # x = torch.chunk(x, shift_size, 1)
74 | # x = [ torch.roll(x_c, shift, dim) for x_s, shift in zip(x, range(-pad, pad+1))]
75 | # x = torch.cat(x, 1)
76 | # return x[:, :, pad:-pad, pad:-pad]
77 |
78 | def forward(self, x, H, W):
79 | # pdb.set_trace()
80 | B, N, C = x.shape
81 |
82 | xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
83 | xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
84 | xs = torch.chunk(xn, self.shift_size, 1)
85 | x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
86 | x_cat = torch.cat(x_shift, 1)
87 | x_cat = torch.narrow(x_cat, 2, self.pad, H)
88 | x_s = torch.narrow(x_cat, 3, self.pad, W)
89 |
90 |
91 | x_s = x_s.reshape(B,C,H*W).contiguous()
92 | x_shift_r = x_s.transpose(1,2)
93 |
94 |
95 | x = self.fc1(x_shift_r)
96 |
97 | x = self.dwconv(x, H, W)
98 | x = self.act(x)
99 | x = self.drop(x)
100 |
101 | xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
102 | xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
103 | xs = torch.chunk(xn, self.shift_size, 1)
104 | x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
105 | x_cat = torch.cat(x_shift, 1)
106 | x_cat = torch.narrow(x_cat, 2, self.pad, H)
107 | x_s = torch.narrow(x_cat, 3, self.pad, W)
108 | x_s = x_s.reshape(B,C,H*W).contiguous()
109 | x_shift_c = x_s.transpose(1,2)
110 |
111 | x = self.fc2(x_shift_c)
112 | x = self.drop(x)
113 | return x
114 |
115 |
116 |
117 | class shiftedBlock(nn.Module):
118 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
119 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
120 | super().__init__()
121 |
122 |
123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
124 | self.norm2 = norm_layer(dim)
125 | mlp_hidden_dim = int(dim * mlp_ratio)
126 | self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
127 | self.apply(self._init_weights)
128 |
129 | def _init_weights(self, m):
130 | if isinstance(m, nn.Linear):
131 | trunc_normal_(m.weight, std=.02)
132 | if isinstance(m, nn.Linear) and m.bias is not None:
133 | nn.init.constant_(m.bias, 0)
134 | elif isinstance(m, nn.LayerNorm):
135 | nn.init.constant_(m.bias, 0)
136 | nn.init.constant_(m.weight, 1.0)
137 | elif isinstance(m, nn.Conv2d):
138 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
139 | fan_out //= m.groups
140 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
141 | if m.bias is not None:
142 | m.bias.data.zero_()
143 |
144 | def forward(self, x, H, W):
145 |
146 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
147 | return x
148 |
149 |
150 | class DWConv(nn.Module):
151 | def __init__(self, dim=768):
152 | super(DWConv, self).__init__()
153 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
154 |
155 | def forward(self, x, H, W):
156 | B, N, C = x.shape
157 | x = x.transpose(1, 2).view(B, C, H, W)
158 | x = self.dwconv(x)
159 | x = x.flatten(2).transpose(1, 2)
160 |
161 | return x
162 |
163 | class OverlapPatchEmbed(nn.Module):
164 | """ Image to Patch Embedding
165 | """
166 |
167 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
168 | super().__init__()
169 | img_size = to_2tuple(img_size)
170 | patch_size = to_2tuple(patch_size)
171 |
172 | self.img_size = img_size
173 | self.patch_size = patch_size
174 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
175 | self.num_patches = self.H * self.W
176 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
177 | padding=(patch_size[0] // 2, patch_size[1] // 2))
178 | self.norm = nn.LayerNorm(embed_dim)
179 |
180 | self.apply(self._init_weights)
181 |
182 | def _init_weights(self, m):
183 | if isinstance(m, nn.Linear):
184 | trunc_normal_(m.weight, std=.02)
185 | if isinstance(m, nn.Linear) and m.bias is not None:
186 | nn.init.constant_(m.bias, 0)
187 | elif isinstance(m, nn.LayerNorm):
188 | nn.init.constant_(m.bias, 0)
189 | nn.init.constant_(m.weight, 1.0)
190 | elif isinstance(m, nn.Conv2d):
191 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
192 | fan_out //= m.groups
193 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
194 | if m.bias is not None:
195 | m.bias.data.zero_()
196 |
197 | def forward(self, x):
198 | x = self.proj(x)
199 | _, _, H, W = x.shape
200 | x = x.flatten(2).transpose(1, 2)
201 | x = self.norm(x)
202 |
203 | return x, H, W
204 |
205 |
206 | class UNext(nn.Module):
207 |
208 | ## Conv 3 + MLP 2 + shifted MLP
209 |
210 | def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3, embed_dims=[ 128, 160, 256],
211 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
212 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
213 | depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
214 | super().__init__()
215 |
216 | self.encoder1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
217 | self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
218 | self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1)
219 |
220 | self.ebn1 = nn.BatchNorm2d(16)
221 | self.ebn2 = nn.BatchNorm2d(32)
222 | self.ebn3 = nn.BatchNorm2d(128)
223 |
224 | self.norm3 = norm_layer(embed_dims[1])
225 | self.norm4 = norm_layer(embed_dims[2])
226 |
227 | self.dnorm3 = norm_layer(160)
228 | self.dnorm4 = norm_layer(128)
229 |
230 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
231 |
232 | self.block1 = nn.ModuleList([shiftedBlock(
233 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
234 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
235 | sr_ratio=sr_ratios[0])])
236 |
237 | self.block2 = nn.ModuleList([shiftedBlock(
238 | dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
240 | sr_ratio=sr_ratios[0])])
241 |
242 | self.dblock1 = nn.ModuleList([shiftedBlock(
243 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
244 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
245 | sr_ratio=sr_ratios[0])])
246 |
247 | self.dblock2 = nn.ModuleList([shiftedBlock(
248 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
249 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
250 | sr_ratio=sr_ratios[0])])
251 |
252 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
253 | embed_dim=embed_dims[1])
254 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
255 | embed_dim=embed_dims[2])
256 |
257 | self.decoder1 = nn.Conv2d(256, 160, 3, stride=1,padding=1)
258 | self.decoder2 = nn.Conv2d(160, 128, 3, stride=1, padding=1)
259 | self.decoder3 = nn.Conv2d(128, 32, 3, stride=1, padding=1)
260 | self.decoder4 = nn.Conv2d(32, 16, 3, stride=1, padding=1)
261 | self.decoder5 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
262 |
263 | self.dbn1 = nn.BatchNorm2d(160)
264 | self.dbn2 = nn.BatchNorm2d(128)
265 | self.dbn3 = nn.BatchNorm2d(32)
266 | self.dbn4 = nn.BatchNorm2d(16)
267 |
268 | self.final = nn.Conv2d(16, num_classes, kernel_size=1)
269 |
270 | self.soft = nn.Softmax(dim =1)
271 |
272 | def forward(self, x):
273 |
274 | B = x.shape[0]
275 | ### Encoder
276 | ### Conv Stage
277 |
278 | ### Stage 1
279 | out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
280 | t1 = out
281 | ### Stage 2
282 | out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
283 | t2 = out
284 | ### Stage 3
285 | out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
286 | t3 = out
287 |
288 | ### Tokenized MLP Stage
289 | ### Stage 4
290 |
291 | out,H,W = self.patch_embed3(out)
292 | for i, blk in enumerate(self.block1):
293 | out = blk(out, H, W)
294 | out = self.norm3(out)
295 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
296 | t4 = out
297 |
298 | ### Bottleneck
299 |
300 | out ,H,W= self.patch_embed4(out)
301 | for i, blk in enumerate(self.block2):
302 | out = blk(out, H, W)
303 | out = self.norm4(out)
304 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
305 |
306 | ### Stage 4
307 |
308 | out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
309 |
310 | out = torch.add(out,t4)
311 | _,_,H,W = out.shape
312 | out = out.flatten(2).transpose(1,2)
313 | for i, blk in enumerate(self.dblock1):
314 | out = blk(out, H, W)
315 |
316 | ### Stage 3
317 |
318 | out = self.dnorm3(out)
319 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
320 | out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
321 | out = torch.add(out,t3)
322 | _,_,H,W = out.shape
323 | out = out.flatten(2).transpose(1,2)
324 |
325 | for i, blk in enumerate(self.dblock2):
326 | out = blk(out, H, W)
327 |
328 | out = self.dnorm4(out)
329 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
330 |
331 | out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
332 | out = torch.add(out,t2)
333 | out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
334 | out = torch.add(out,t1)
335 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))
336 |
337 | return self.final(out)
338 |
339 |
340 | class UNext_S(nn.Module):
341 |
342 | ## Conv 3 + MLP 2 + shifted MLP w less parameters
343 |
344 | def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3, embed_dims=[32, 64, 128, 512],
345 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
346 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
347 | depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
348 | super().__init__()
349 |
350 | self.encoder1 = nn.Conv2d(3, 8, 3, stride=1, padding=1)
351 | self.encoder2 = nn.Conv2d(8, 16, 3, stride=1, padding=1)
352 | self.encoder3 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
353 |
354 | self.ebn1 = nn.BatchNorm2d(8)
355 | self.ebn2 = nn.BatchNorm2d(16)
356 | self.ebn3 = nn.BatchNorm2d(32)
357 |
358 | self.norm3 = norm_layer(embed_dims[1])
359 | self.norm4 = norm_layer(embed_dims[2])
360 |
361 | self.dnorm3 = norm_layer(64)
362 | self.dnorm4 = norm_layer(32)
363 |
364 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
365 |
366 | self.block1 = nn.ModuleList([shiftedBlock(
367 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
368 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
369 | sr_ratio=sr_ratios[0])])
370 |
371 | self.block2 = nn.ModuleList([shiftedBlock(
372 | dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
373 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
374 | sr_ratio=sr_ratios[0])])
375 |
376 | self.dblock1 = nn.ModuleList([shiftedBlock(
377 | dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
378 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
379 | sr_ratio=sr_ratios[0])])
380 |
381 | self.dblock2 = nn.ModuleList([shiftedBlock(
382 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
383 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
384 | sr_ratio=sr_ratios[0])])
385 |
386 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
387 | embed_dim=embed_dims[1])
388 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
389 | embed_dim=embed_dims[2])
390 |
391 | self.decoder1 = nn.Conv2d(128, 64, 3, stride=1,padding=1)
392 | self.decoder2 = nn.Conv2d(64, 32, 3, stride=1, padding=1)
393 | self.decoder3 = nn.Conv2d(32, 16, 3, stride=1, padding=1)
394 | self.decoder4 = nn.Conv2d(16, 8, 3, stride=1, padding=1)
395 | self.decoder5 = nn.Conv2d(8, 8, 3, stride=1, padding=1)
396 |
397 | self.dbn1 = nn.BatchNorm2d(64)
398 | self.dbn2 = nn.BatchNorm2d(32)
399 | self.dbn3 = nn.BatchNorm2d(16)
400 | self.dbn4 = nn.BatchNorm2d(8)
401 |
402 | self.final = nn.Conv2d(8, num_classes, kernel_size=1)
403 |
404 | self.soft = nn.Softmax(dim =1)
405 |
406 | def forward(self, x):
407 |
408 | B = x.shape[0]
409 | ### Encoder
410 | ### Conv Stage
411 |
412 | ### Stage 1
413 | out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
414 | t1 = out
415 | ### Stage 2
416 | out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
417 | t2 = out
418 | ### Stage 3
419 | out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
420 | t3 = out
421 |
422 | ### Tokenized MLP Stage
423 | ### Stage 4
424 |
425 | out,H,W = self.patch_embed3(out)
426 | for i, blk in enumerate(self.block1):
427 | out = blk(out, H, W)
428 | out = self.norm3(out)
429 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
430 | t4 = out
431 |
432 | ### Bottleneck
433 |
434 | out ,H,W= self.patch_embed4(out)
435 | for i, blk in enumerate(self.block2):
436 | out = blk(out, H, W)
437 | out = self.norm4(out)
438 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
439 |
440 | ### Stage 4
441 |
442 | out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
443 |
444 | out = torch.add(out,t4)
445 | _,_,H,W = out.shape
446 | out = out.flatten(2).transpose(1,2)
447 | for i, blk in enumerate(self.dblock1):
448 | out = blk(out, H, W)
449 |
450 | ### Stage 3
451 |
452 | out = self.dnorm3(out)
453 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
454 | out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
455 | out = torch.add(out,t3)
456 | _,_,H,W = out.shape
457 | out = out.flatten(2).transpose(1,2)
458 |
459 | for i, blk in enumerate(self.dblock2):
460 | out = blk(out, H, W)
461 |
462 | out = self.dnorm4(out)
463 | out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
464 |
465 | out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
466 | out = torch.add(out,t2)
467 | out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
468 | out = torch.add(out,t1)
469 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))
470 |
471 | return self.final(out)
472 |
473 |
474 | #EOF
475 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 1
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 256
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | # -----------------------------------------------------------------------------
42 | # Model settings
43 | # -----------------------------------------------------------------------------
44 | _C.MODEL = CN()
45 | # Model type
46 | _C.MODEL.TYPE = 'swin'
47 | # Model name
48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'
51 | _C.MODEL.RESUME = ''
52 | # Number of classes, overwritten in data preparation
53 | _C.MODEL.NUM_CLASSES = 1000
54 | # Dropout rate
55 | _C.MODEL.DROP_RATE = 0.0
56 | # Drop path rate
57 | _C.MODEL.DROP_PATH_RATE = 0.1
58 | # Label Smoothing
59 | _C.MODEL.LABEL_SMOOTHING = 0.1
60 |
61 | # Swin Transformer parameters
62 | _C.MODEL.SWIN = CN()
63 | _C.MODEL.SWIN.PATCH_SIZE = 4
64 | _C.MODEL.SWIN.IN_CHANS = 3
65 | _C.MODEL.SWIN.EMBED_DIM = 96
66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
69 | _C.MODEL.SWIN.WINDOW_SIZE = 4
70 | _C.MODEL.SWIN.MLP_RATIO = 4.
71 | _C.MODEL.SWIN.QKV_BIAS = True
72 | _C.MODEL.SWIN.QK_SCALE = False
73 | _C.MODEL.SWIN.APE = False
74 | _C.MODEL.SWIN.PATCH_NORM = True
75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
76 |
77 | # -----------------------------------------------------------------------------
78 | # Training settings
79 | # -----------------------------------------------------------------------------
80 | _C.TRAIN = CN()
81 | _C.TRAIN.START_EPOCH = 0
82 | _C.TRAIN.EPOCHS = 300
83 | _C.TRAIN.WARMUP_EPOCHS = 20
84 | _C.TRAIN.WEIGHT_DECAY = 0.05
85 | _C.TRAIN.BASE_LR = 5e-4
86 | _C.TRAIN.WARMUP_LR = 5e-7
87 | _C.TRAIN.MIN_LR = 5e-6
88 | # Clip gradient norm
89 | _C.TRAIN.CLIP_GRAD = 5.0
90 | # Auto resume from latest checkpoint
91 | _C.TRAIN.AUTO_RESUME = True
92 | # Gradient accumulation steps
93 | # could be overwritten by command line argument
94 | _C.TRAIN.ACCUMULATION_STEPS = 0
95 | # Whether to use gradient checkpointing to save memory
96 | # could be overwritten by command line argument
97 | _C.TRAIN.USE_CHECKPOINT = False
98 |
99 | # LR scheduler
100 | _C.TRAIN.LR_SCHEDULER = CN()
101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
102 | # Epoch interval to decay LR, used in StepLRScheduler
103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
104 | # LR decay rate, used in StepLRScheduler
105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
106 |
107 | # Optimizer
108 | _C.TRAIN.OPTIMIZER = CN()
109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
110 | # Optimizer Epsilon
111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
112 | # Optimizer Betas
113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
114 | # SGD momentum
115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
116 |
117 | # -----------------------------------------------------------------------------
118 | # Augmentation settings
119 | # -----------------------------------------------------------------------------
120 | _C.AUG = CN()
121 | # Color jitter factor
122 | _C.AUG.COLOR_JITTER = 0.4
123 | # Use AutoAugment policy. "v0" or "original"
124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
125 | # Random erase prob
126 | _C.AUG.REPROB = 0.25
127 | # Random erase mode
128 | _C.AUG.REMODE = 'pixel'
129 | # Random erase count
130 | _C.AUG.RECOUNT = 1
131 | # Mixup alpha, mixup enabled if > 0
132 | _C.AUG.MIXUP = 0.8
133 | # Cutmix alpha, cutmix enabled if > 0
134 | _C.AUG.CUTMIX = 1.0
135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
136 | _C.AUG.CUTMIX_MINMAX = False
137 | # Probability of performing mixup or cutmix when either/both is enabled
138 | _C.AUG.MIXUP_PROB = 1.0
139 | # Probability of switching to cutmix when both mixup and cutmix enabled
140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
142 | _C.AUG.MIXUP_MODE = 'batch'
143 |
144 | # -----------------------------------------------------------------------------
145 | # Testing settings
146 | # -----------------------------------------------------------------------------
147 | _C.TEST = CN()
148 | # Whether to use center crop when testing
149 | _C.TEST.CROP = True
150 |
151 | # -----------------------------------------------------------------------------
152 | # Misc
153 | # -----------------------------------------------------------------------------
154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
155 | # overwritten by command line argument
156 | _C.AMP_OPT_LEVEL = ''
157 | # Path to output folder, overwritten by command line argument
158 | _C.OUTPUT = ''
159 | # Tag of experiment, overwritten by command line argument
160 | _C.TAG = 'default'
161 | # Frequency to save checkpoint
162 | _C.SAVE_FREQ = 1
163 | # Frequency to logging info
164 | _C.PRINT_FREQ = 10
165 | # Fixed random seed
166 | _C.SEED = 0
167 | # Perform evaluation only, overwritten by command line argument
168 | _C.EVAL_MODE = False
169 | # Test throughput only, overwritten by command line argument
170 | _C.THROUGHPUT_MODE = False
171 | # local rank for DistributedDataParallel, given by command line argument
172 | _C.LOCAL_RANK = 0
173 |
174 |
175 | def _update_config_from_file(config, cfg_file):
176 | config.defrost()
177 | with open(cfg_file, 'r') as f:
178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
179 |
180 | for cfg in yaml_cfg.setdefault('BASE', ['']):
181 | if cfg:
182 | _update_config_from_file(
183 | config, os.path.join(os.path.dirname(cfg_file), cfg)
184 | )
185 | print('=> merge config from {}'.format(cfg_file))
186 | config.merge_from_file(cfg_file)
187 | config.freeze()
188 |
189 |
190 | def update_config(config, args):
191 | _update_config_from_file(config, args.cfg)
192 |
193 | config.defrost()
194 | if args.opts:
195 | config.merge_from_list(args.opts)
196 |
197 | # merge from specific arguments
198 | if args.batch_size:
199 | config.DATA.BATCH_SIZE = args.batch_size
200 | if args.zip:
201 | config.DATA.ZIP_MODE = True
202 | if args.cache_mode:
203 | config.DATA.CACHE_MODE = args.cache_mode
204 | if args.resume:
205 | config.MODEL.RESUME = args.resume
206 | if args.accumulation_steps:
207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
208 | if args.use_checkpoint:
209 | config.TRAIN.USE_CHECKPOINT = True
210 | if args.amp_opt_level:
211 | config.AMP_OPT_LEVEL = args.amp_opt_level
212 | if args.tag:
213 | config.TAG = args.tag
214 | if args.eval:
215 | config.EVAL_MODE = True
216 | if args.throughput:
217 | config.THROUGHPUT_MODE = True
218 |
219 | config.freeze()
220 |
221 |
222 | def get_config(args):
223 | """Get a yacs CfgNode object with default values."""
224 | # Return a clone so that the defaults will not be altered
225 | # This is for the "local variable" use pattern
226 | config = _C.clone()
227 | # update_config(config, args)
228 |
229 | return config
230 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 |
8 |
9 | class Dataset(torch.utils.data.Dataset):
10 | def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
11 | """
12 | Args:
13 | img_ids (list): Image ids.
14 | img_dir: Image file directory.
15 | mask_dir: Mask file directory.
16 | img_ext (str): Image file extension.
17 | mask_ext (str): Mask file extension.
18 | num_classes (int): Number of classes.
19 | transform (Compose, optional): Compose transforms of albumentations. Defaults to None.
20 |
21 | Note:
22 | Make sure to put the files as the following structure:
23 |
24 | ├── images
25 | | ├── 0a7e06.jpg
26 | │ ├── 0aab0a.jpg
27 | │ ├── 0b1761.jpg
28 | │ ├── ...
29 | |
30 | └── masks
31 | ├── 0
32 | | ├── 0a7e06.png
33 | | ├── 0aab0a.png
34 | | ├── 0b1761.png
35 | | ├── ...
36 | |
37 | ├── 1
38 | | ├── 0a7e06.png
39 | | ├── 0aab0a.png
40 | | ├── 0b1761.png
41 | | ├── ...
42 | ...
43 | """
44 | self.img_ids = img_ids
45 | self.img_dir = img_dir
46 | self.mask_dir = mask_dir
47 | self.img_ext = img_ext
48 | self.mask_ext = mask_ext
49 | self.num_classes = num_classes
50 | self.transform = transform
51 |
52 | def __len__(self):
53 | return len(self.img_ids)
54 |
55 | def __getitem__(self, idx):
56 | img_id = self.img_ids[idx]
57 |
58 | img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
59 |
60 | mask = []
61 | for i in range(self.num_classes):
62 | mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
63 | img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
64 | mask = np.dstack(mask)
65 |
66 | if self.transform is not None:
67 | augmented = self.transform(image=img, mask=mask)
68 | img = augmented['image']
69 | mask = augmented['mask']
70 |
71 | img = img.astype('float32') / 255
72 | img = img.transpose(2, 0, 1)
73 | mask = mask.astype('float32') / 255
74 | mask = mask.transpose(2, 0, 1)
75 |
76 | return img, mask, {'img_id': img_id}
77 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: unext
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=4.5=1_gnu
7 | - ca-certificates=2021.10.26=h06a4308_2
8 | - certifi=2021.5.30=py36h06a4308_0
9 | - ld_impl_linux-64=2.35.1=h7274673_9
10 | - libffi=3.3=he6710b0_2
11 | - libgcc-ng=9.3.0=h5101ec6_17
12 | - libgomp=9.3.0=h5101ec6_17
13 | - libstdcxx-ng=9.3.0=hd4cf53a_17
14 | - ncurses=6.3=h7f8727e_2
15 | - openssl=1.1.1l=h7f8727e_0
16 | - pip=21.2.2=py36h06a4308_0
17 | - python=3.6.13=h12debd9_1
18 | - readline=8.1=h27cfd23_0
19 | - setuptools=58.0.4=py36h06a4308_0
20 | - sqlite=3.36.0=hc218d9a_0
21 | - tk=8.6.11=h1ccaba5_0
22 | - wheel=0.37.0=pyhd3eb1b0_1
23 | - xz=5.2.5=h7b6447c_0
24 | - zlib=1.2.11=h7b6447c_3
25 | - pip:
26 | - addict==2.4.0
27 | - dataclasses==0.8
28 | - mmcv-full==1.2.7
29 | - numpy==1.19.5
30 | - opencv-python==4.5.1.48
31 | - perceptual==0.1
32 | - pillow==8.4.0
33 | - scikit-image==0.17.2
34 | - scipy==1.5.4
35 | - tifffile==2020.9.3
36 | - timm==0.3.2
37 | - torch==1.7.1
38 | - torchvision==0.8.2
39 | - typing-extensions==4.0.0
40 | - yapf==0.31.0
41 | prefix: /home/jeyamariajose/anaconda3/envs/transweather
42 |
43 |
--------------------------------------------------------------------------------
/imgs/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/imgs/unext.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/UNeXt-pytorch/6ad0855114a35afbf81decf5dc912cd8de70476a/imgs/unext.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | try:
6 | from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
7 | except ImportError:
8 | pass
9 |
10 | __all__ = ['BCEDiceLoss', 'LovaszHingeLoss']
11 |
12 |
13 | class BCEDiceLoss(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, input, target):
18 | bce = F.binary_cross_entropy_with_logits(input, target)
19 | smooth = 1e-5
20 | input = torch.sigmoid(input)
21 | num = target.size(0)
22 | input = input.view(num, -1)
23 | target = target.view(num, -1)
24 | intersection = (input * target)
25 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
26 | dice = 1 - dice.sum() / num
27 | return 0.5 * bce + dice
28 |
29 |
30 | class LovaszHingeLoss(nn.Module):
31 | def __init__(self):
32 | super().__init__()
33 |
34 | def forward(self, input, target):
35 | input = input.squeeze(1)
36 | target = target.squeeze(1)
37 | loss = lovasz_hinge(input, target, per_image=True)
38 |
39 | return loss
40 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def iou_score(output, target):
7 | smooth = 1e-5
8 |
9 | if torch.is_tensor(output):
10 | output = torch.sigmoid(output).data.cpu().numpy()
11 | if torch.is_tensor(target):
12 | target = target.data.cpu().numpy()
13 | output_ = output > 0.5
14 | target_ = target > 0.5
15 | intersection = (output_ & target_).sum()
16 | union = (output_ | target_).sum()
17 | iou = (intersection + smooth) / (union + smooth)
18 | dice = (2* iou) / (iou+1)
19 | return iou, dice
20 |
21 |
22 | def dice_coef(output, target):
23 | smooth = 1e-5
24 |
25 | output = torch.sigmoid(output).view(-1).data.cpu().numpy()
26 | target = target.view(-1).data.cpu().numpy()
27 | intersection = (output * target).sum()
28 |
29 | return (2. * intersection + smooth) / \
30 | (output.sum() + target.sum() + smooth)
31 |
--------------------------------------------------------------------------------
/post_process.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from glob import glob
4 |
5 | import cv2
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import yaml
9 | from albumentations.augmentations import transforms
10 | from albumentations.core.composition import Compose
11 | from sklearn.model_selection import train_test_split
12 | from tqdm import tqdm
13 |
14 | import archs
15 | from dataset import Dataset
16 | from metrics import iou_score
17 | from utils import AverageMeter
18 | from albumentations import RandomRotate90,Resize
19 | import time
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument('--name', default=None,
25 | help='model name')
26 |
27 | args = parser.parse_args()
28 |
29 | return args
30 |
31 |
32 | def main():
33 | args = parse_args()
34 |
35 | with open('models/%s/config.yml' % args.name, 'r') as f:
36 | config = yaml.load(f, Loader=yaml.FullLoader)
37 |
38 | print('-'*20)
39 | for key in config.keys():
40 | print('%s: %s' % (key, str(config[key])))
41 | print('-'*20)
42 |
43 | cudnn.benchmark = True
44 |
45 | # create model
46 | print("=> creating model %s" % config['arch'])
47 | model = archs.__dict__[config['arch']](config['num_classes'],
48 | config['input_channels'],
49 | config['deep_supervision'])
50 |
51 | model = model.cuda()
52 |
53 | # Data loading code
54 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
55 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
56 |
57 | _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
58 |
59 | model.load_state_dict(torch.load('models/%s/model.pth' %
60 | config['name']))
61 | model.eval()
62 |
63 | val_transform = Compose([
64 | Resize(config['input_h'], config['input_w']),
65 | transforms.Normalize(),
66 | ])
67 |
68 | val_dataset = Dataset(
69 | img_ids=val_img_ids,
70 | img_dir=os.path.join('inputs', config['dataset'], 'images'),
71 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
72 | img_ext=config['img_ext'],
73 | mask_ext=config['mask_ext'],
74 | num_classes=config['num_classes'],
75 | transform=val_transform)
76 | val_loader = torch.utils.data.DataLoader(
77 | val_dataset,
78 | batch_size=config['batch_size'],
79 | shuffle=False,
80 | num_workers=config['num_workers'],
81 | drop_last=False)
82 |
83 | iou_avg_meter = AverageMeter()
84 | dice_avg_meter = AverageMeter()
85 | gput = AverageMeter()
86 | cput = AverageMeter()
87 |
88 | count = 0
89 | for c in range(config['num_classes']):
90 | os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
91 | with torch.no_grad():
92 | for input, target, meta in tqdm(val_loader, total=len(val_loader)):
93 | input = input.cuda()
94 | target = target.cuda()
95 | model = model.cuda()
96 | # compute output
97 |
98 | if count<=5:
99 | start = time.time()
100 | if config['deep_supervision']:
101 | output = model(input)[-1]
102 | else:
103 | output = model(input)
104 | stop = time.time()
105 |
106 | gput.update(stop-start, input.size(0))
107 |
108 | start = time.time()
109 | model = model.cpu()
110 | input = input.cpu()
111 | output = model(input)
112 | stop = time.time()
113 |
114 | cput.update(stop-start, input.size(0))
115 | count=count+1
116 |
117 | iou,dice = iou_score(output, target)
118 | iou_avg_meter.update(iou, input.size(0))
119 | dice_avg_meter.update(dice, input.size(0))
120 |
121 | output = torch.sigmoid(output).cpu().numpy()
122 |
123 | for i in range(len(output)):
124 | for c in range(config['num_classes']):
125 | cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),
126 | (output[i, c] * 255).astype('uint8'))
127 |
128 | print('IoU: %.4f' % iou_avg_meter.avg)
129 | print('Dice: %.4f' % dice_avg_meter.avg)
130 |
131 | print('CPU: %.4f' %cput.avg)
132 | print('GPU: %.4f' %gput.avg)
133 |
134 | torch.cuda.empty_cache()
135 |
136 |
137 | if __name__ == '__main__':
138 | main()
139 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from collections import OrderedDict
4 | from glob import glob
5 |
6 | import pandas as pd
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import yaml
12 | from albumentations.augmentations import transforms
13 | from albumentations.core.composition import Compose, OneOf
14 | from sklearn.model_selection import train_test_split
15 | from torch.optim import lr_scheduler
16 | from tqdm import tqdm
17 | from albumentations import RandomRotate90,Resize
18 | import archs
19 | import losses
20 | from dataset import Dataset
21 | from metrics import iou_score
22 | from utils import AverageMeter, str2bool
23 | from archs import UNext
24 |
25 |
26 | ARCH_NAMES = archs.__all__
27 | LOSS_NAMES = losses.__all__
28 | LOSS_NAMES.append('BCEWithLogitsLoss')
29 |
30 |
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument('--name', default=None,
36 | help='model name: (default: arch+timestamp)')
37 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
38 | help='number of total epochs to run')
39 | parser.add_argument('-b', '--batch_size', default=16, type=int,
40 | metavar='N', help='mini-batch size (default: 16)')
41 |
42 | # model
43 | parser.add_argument('--arch', '-a', metavar='ARCH', default='UNext')
44 | parser.add_argument('--deep_supervision', default=False, type=str2bool)
45 | parser.add_argument('--input_channels', default=3, type=int,
46 | help='input channels')
47 | parser.add_argument('--num_classes', default=1, type=int,
48 | help='number of classes')
49 | parser.add_argument('--input_w', default=256, type=int,
50 | help='image width')
51 | parser.add_argument('--input_h', default=256, type=int,
52 | help='image height')
53 |
54 | # loss
55 | parser.add_argument('--loss', default='BCEDiceLoss',
56 | choices=LOSS_NAMES,
57 | help='loss: ' +
58 | ' | '.join(LOSS_NAMES) +
59 | ' (default: BCEDiceLoss)')
60 |
61 | # dataset
62 | parser.add_argument('--dataset', default='isic',
63 | help='dataset name')
64 | parser.add_argument('--img_ext', default='.png',
65 | help='image file extension')
66 | parser.add_argument('--mask_ext', default='.png',
67 | help='mask file extension')
68 |
69 | # optimizer
70 | parser.add_argument('--optimizer', default='Adam',
71 | choices=['Adam', 'SGD'],
72 | help='loss: ' +
73 | ' | '.join(['Adam', 'SGD']) +
74 | ' (default: Adam)')
75 | parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
76 | metavar='LR', help='initial learning rate')
77 | parser.add_argument('--momentum', default=0.9, type=float,
78 | help='momentum')
79 | parser.add_argument('--weight_decay', default=1e-4, type=float,
80 | help='weight decay')
81 | parser.add_argument('--nesterov', default=False, type=str2bool,
82 | help='nesterov')
83 |
84 | # scheduler
85 | parser.add_argument('--scheduler', default='CosineAnnealingLR',
86 | choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
87 | parser.add_argument('--min_lr', default=1e-5, type=float,
88 | help='minimum learning rate')
89 | parser.add_argument('--factor', default=0.1, type=float)
90 | parser.add_argument('--patience', default=2, type=int)
91 | parser.add_argument('--milestones', default='1,2', type=str)
92 | parser.add_argument('--gamma', default=2/3, type=float)
93 | parser.add_argument('--early_stopping', default=-1, type=int,
94 | metavar='N', help='early stopping (default: -1)')
95 | parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
96 |
97 | parser.add_argument('--num_workers', default=4, type=int)
98 |
99 | config = parser.parse_args()
100 |
101 | return config
102 |
103 | # args = parser.parse_args()
104 | def train(config, train_loader, model, criterion, optimizer):
105 | avg_meters = {'loss': AverageMeter(),
106 | 'iou': AverageMeter()}
107 |
108 | model.train()
109 |
110 | pbar = tqdm(total=len(train_loader))
111 | for input, target, _ in train_loader:
112 | input = input.cuda()
113 | target = target.cuda()
114 |
115 | # compute output
116 | if config['deep_supervision']:
117 | outputs = model(input)
118 | loss = 0
119 | for output in outputs:
120 | loss += criterion(output, target)
121 | loss /= len(outputs)
122 | iou,dice = iou_score(outputs[-1], target)
123 | else:
124 | output = model(input)
125 | loss = criterion(output, target)
126 | iou,dice = iou_score(output, target)
127 |
128 | # compute gradient and do optimizing step
129 | optimizer.zero_grad()
130 | loss.backward()
131 | optimizer.step()
132 |
133 | avg_meters['loss'].update(loss.item(), input.size(0))
134 | avg_meters['iou'].update(iou, input.size(0))
135 |
136 | postfix = OrderedDict([
137 | ('loss', avg_meters['loss'].avg),
138 | ('iou', avg_meters['iou'].avg),
139 | ])
140 | pbar.set_postfix(postfix)
141 | pbar.update(1)
142 | pbar.close()
143 |
144 | return OrderedDict([('loss', avg_meters['loss'].avg),
145 | ('iou', avg_meters['iou'].avg)])
146 |
147 |
148 | def validate(config, val_loader, model, criterion):
149 | avg_meters = {'loss': AverageMeter(),
150 | 'iou': AverageMeter(),
151 | 'dice': AverageMeter()}
152 |
153 | # switch to evaluate mode
154 | model.eval()
155 |
156 | with torch.no_grad():
157 | pbar = tqdm(total=len(val_loader))
158 | for input, target, _ in val_loader:
159 | input = input.cuda()
160 | target = target.cuda()
161 |
162 | # compute output
163 | if config['deep_supervision']:
164 | outputs = model(input)
165 | loss = 0
166 | for output in outputs:
167 | loss += criterion(output, target)
168 | loss /= len(outputs)
169 | iou,dice = iou_score(outputs[-1], target)
170 | else:
171 | output = model(input)
172 | loss = criterion(output, target)
173 | iou,dice = iou_score(output, target)
174 |
175 | avg_meters['loss'].update(loss.item(), input.size(0))
176 | avg_meters['iou'].update(iou, input.size(0))
177 | avg_meters['dice'].update(dice, input.size(0))
178 |
179 | postfix = OrderedDict([
180 | ('loss', avg_meters['loss'].avg),
181 | ('iou', avg_meters['iou'].avg),
182 | ('dice', avg_meters['dice'].avg)
183 | ])
184 | pbar.set_postfix(postfix)
185 | pbar.update(1)
186 | pbar.close()
187 |
188 | return OrderedDict([('loss', avg_meters['loss'].avg),
189 | ('iou', avg_meters['iou'].avg),
190 | ('dice', avg_meters['dice'].avg)])
191 |
192 |
193 | def main():
194 | config = vars(parse_args())
195 |
196 | if config['name'] is None:
197 | if config['deep_supervision']:
198 | config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
199 | else:
200 | config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
201 |
202 | os.makedirs('models/%s' % config['name'], exist_ok=True)
203 |
204 | print('-' * 20)
205 | for key in config:
206 | print('%s: %s' % (key, config[key]))
207 | print('-' * 20)
208 |
209 | with open('models/%s/config.yml' % config['name'], 'w') as f:
210 | yaml.dump(config, f)
211 |
212 | # define loss function (criterion)
213 | if config['loss'] == 'BCEWithLogitsLoss':
214 | criterion = nn.BCEWithLogitsLoss().cuda()
215 | else:
216 | criterion = losses.__dict__[config['loss']]().cuda()
217 |
218 | cudnn.benchmark = True
219 |
220 | # create model
221 | model = archs.__dict__[config['arch']](config['num_classes'],
222 | config['input_channels'],
223 | config['deep_supervision'])
224 |
225 | model = model.cuda()
226 |
227 | params = filter(lambda p: p.requires_grad, model.parameters())
228 | if config['optimizer'] == 'Adam':
229 | optimizer = optim.Adam(
230 | params, lr=config['lr'], weight_decay=config['weight_decay'])
231 | elif config['optimizer'] == 'SGD':
232 | optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
233 | nesterov=config['nesterov'], weight_decay=config['weight_decay'])
234 | else:
235 | raise NotImplementedError
236 |
237 | if config['scheduler'] == 'CosineAnnealingLR':
238 | scheduler = lr_scheduler.CosineAnnealingLR(
239 | optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
240 | elif config['scheduler'] == 'ReduceLROnPlateau':
241 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
242 | verbose=1, min_lr=config['min_lr'])
243 | elif config['scheduler'] == 'MultiStepLR':
244 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
245 | elif config['scheduler'] == 'ConstantLR':
246 | scheduler = None
247 | else:
248 | raise NotImplementedError
249 |
250 | # Data loading code
251 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
252 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
253 |
254 | train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
255 |
256 | train_transform = Compose([
257 | RandomRotate90(),
258 | transforms.Flip(),
259 | Resize(config['input_h'], config['input_w']),
260 | transforms.Normalize(),
261 | ])
262 |
263 | val_transform = Compose([
264 | Resize(config['input_h'], config['input_w']),
265 | transforms.Normalize(),
266 | ])
267 |
268 | train_dataset = Dataset(
269 | img_ids=train_img_ids,
270 | img_dir=os.path.join('inputs', config['dataset'], 'images'),
271 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
272 | img_ext=config['img_ext'],
273 | mask_ext=config['mask_ext'],
274 | num_classes=config['num_classes'],
275 | transform=train_transform)
276 | val_dataset = Dataset(
277 | img_ids=val_img_ids,
278 | img_dir=os.path.join('inputs', config['dataset'], 'images'),
279 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
280 | img_ext=config['img_ext'],
281 | mask_ext=config['mask_ext'],
282 | num_classes=config['num_classes'],
283 | transform=val_transform)
284 |
285 | train_loader = torch.utils.data.DataLoader(
286 | train_dataset,
287 | batch_size=config['batch_size'],
288 | shuffle=True,
289 | num_workers=config['num_workers'],
290 | drop_last=True)
291 | val_loader = torch.utils.data.DataLoader(
292 | val_dataset,
293 | batch_size=config['batch_size'],
294 | shuffle=False,
295 | num_workers=config['num_workers'],
296 | drop_last=False)
297 |
298 | log = OrderedDict([
299 | ('epoch', []),
300 | ('lr', []),
301 | ('loss', []),
302 | ('iou', []),
303 | ('val_loss', []),
304 | ('val_iou', []),
305 | ('val_dice', []),
306 | ])
307 |
308 | best_iou = 0
309 | trigger = 0
310 | for epoch in range(config['epochs']):
311 | print('Epoch [%d/%d]' % (epoch, config['epochs']))
312 |
313 | # train for one epoch
314 | train_log = train(config, train_loader, model, criterion, optimizer)
315 | # evaluate on validation set
316 | val_log = validate(config, val_loader, model, criterion)
317 |
318 | if config['scheduler'] == 'CosineAnnealingLR':
319 | scheduler.step()
320 | elif config['scheduler'] == 'ReduceLROnPlateau':
321 | scheduler.step(val_log['loss'])
322 |
323 | print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
324 | % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
325 |
326 | log['epoch'].append(epoch)
327 | log['lr'].append(config['lr'])
328 | log['loss'].append(train_log['loss'])
329 | log['iou'].append(train_log['iou'])
330 | log['val_loss'].append(val_log['loss'])
331 | log['val_iou'].append(val_log['iou'])
332 | log['val_dice'].append(val_log['dice'])
333 |
334 | pd.DataFrame(log).to_csv('models/%s/log.csv' %
335 | config['name'], index=False)
336 |
337 | trigger += 1
338 |
339 | if val_log['iou'] > best_iou:
340 | torch.save(model.state_dict(), 'models/%s/model.pth' %
341 | config['name'])
342 | best_iou = val_log['iou']
343 | print("=> saved best model")
344 | trigger = 0
345 |
346 | # early stopping
347 | if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
348 | print("=> early stopping")
349 | break
350 |
351 | torch.cuda.empty_cache()
352 |
353 |
354 | if __name__ == '__main__':
355 | main()
356 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.nn as nn
3 |
4 | class qkv_transform(nn.Conv1d):
5 | """Conv1d for qkv_transform"""
6 |
7 | def str2bool(v):
8 | if v.lower() in ['true', 1]:
9 | return True
10 | elif v.lower() in ['false', 0]:
11 | return False
12 | else:
13 | raise argparse.ArgumentTypeError('Boolean value expected.')
14 |
15 |
16 | def count_params(model):
17 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
18 |
19 |
20 | class AverageMeter(object):
21 | """Computes and stores the average and current value"""
22 |
23 | def __init__(self):
24 | self.reset()
25 |
26 | def reset(self):
27 | self.val = 0
28 | self.avg = 0
29 | self.sum = 0
30 | self.count = 0
31 |
32 | def update(self, val, n=1):
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from glob import glob
4 |
5 | import cv2
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import yaml
9 | from albumentations.augmentations import transforms
10 | from albumentations.core.composition import Compose
11 | from sklearn.model_selection import train_test_split
12 | from tqdm import tqdm
13 |
14 | import archs
15 | from dataset import Dataset
16 | from metrics import iou_score
17 | from utils import AverageMeter
18 | from albumentations import RandomRotate90,Resize
19 | import time
20 | from archs import UNext
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser()
25 |
26 | parser.add_argument('--name', default=None,
27 | help='model name')
28 |
29 | args = parser.parse_args()
30 |
31 | return args
32 |
33 |
34 | def main():
35 | args = parse_args()
36 |
37 | with open('models/%s/config.yml' % args.name, 'r') as f:
38 | config = yaml.load(f, Loader=yaml.FullLoader)
39 |
40 | print('-'*20)
41 | for key in config.keys():
42 | print('%s: %s' % (key, str(config[key])))
43 | print('-'*20)
44 |
45 | cudnn.benchmark = True
46 |
47 | print("=> creating model %s" % config['arch'])
48 | model = archs.__dict__[config['arch']](config['num_classes'],
49 | config['input_channels'],
50 | config['deep_supervision'])
51 |
52 | model = model.cuda()
53 |
54 | # Data loading code
55 | img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
56 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
57 |
58 | _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
59 |
60 | model.load_state_dict(torch.load('models/%s/model.pth' %
61 | config['name']))
62 | model.eval()
63 |
64 | val_transform = Compose([
65 | Resize(config['input_h'], config['input_w']),
66 | transforms.Normalize(),
67 | ])
68 |
69 | val_dataset = Dataset(
70 | img_ids=val_img_ids,
71 | img_dir=os.path.join('inputs', config['dataset'], 'images'),
72 | mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
73 | img_ext=config['img_ext'],
74 | mask_ext=config['mask_ext'],
75 | num_classes=config['num_classes'],
76 | transform=val_transform)
77 | val_loader = torch.utils.data.DataLoader(
78 | val_dataset,
79 | batch_size=config['batch_size'],
80 | shuffle=False,
81 | num_workers=config['num_workers'],
82 | drop_last=False)
83 |
84 | iou_avg_meter = AverageMeter()
85 | dice_avg_meter = AverageMeter()
86 | gput = AverageMeter()
87 | cput = AverageMeter()
88 |
89 | count = 0
90 | for c in range(config['num_classes']):
91 | os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
92 | with torch.no_grad():
93 | for input, target, meta in tqdm(val_loader, total=len(val_loader)):
94 | input = input.cuda()
95 | target = target.cuda()
96 | model = model.cuda()
97 | # compute output
98 | output = model(input)
99 |
100 |
101 | iou,dice = iou_score(output, target)
102 | iou_avg_meter.update(iou, input.size(0))
103 | dice_avg_meter.update(dice, input.size(0))
104 |
105 | output = torch.sigmoid(output).cpu().numpy()
106 | output[output>=0.5]=1
107 | output[output<0.5]=0
108 |
109 | for i in range(len(output)):
110 | for c in range(config['num_classes']):
111 | cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),
112 | (output[i, c] * 255).astype('uint8'))
113 |
114 | print('IoU: %.4f' % iou_avg_meter.avg)
115 | print('Dice: %.4f' % dice_avg_meter.avg)
116 |
117 | torch.cuda.empty_cache()
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------