├── README.md ├── LICENSE ├── .gitignore └── convnext_model.py /README.md: -------------------------------------------------------------------------------- 1 | # convnext_unet 2 | unet base on convnext aritecture 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 kornellewy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /convnext_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: 3 | https://arxiv.org/pdf/2201.03545.pdf 4 | https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from zmq import device 11 | 12 | 13 | class LayerNorm(nn.Module): 14 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 15 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 16 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 17 | with shape (batch_size, channels, height, width). 18 | """ 19 | 20 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 21 | super().__init__() 22 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 23 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 24 | self.eps = eps 25 | self.data_format = data_format 26 | if self.data_format not in ["channels_last", "channels_first"]: 27 | raise NotImplementedError 28 | self.normalized_shape = (normalized_shape,) 29 | 30 | def forward(self, x): 31 | if self.data_format == "channels_last": 32 | return F.layer_norm( 33 | x, self.normalized_shape, self.weight, self.bias, self.eps 34 | ) 35 | elif self.data_format == "channels_first": 36 | u = x.mean(1, keepdim=True) 37 | s = (x - u).pow(2).mean(1, keepdim=True) 38 | x = (x - u) / torch.sqrt(s + self.eps) 39 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 40 | return x 41 | 42 | 43 | class ConvNectBlock(nn.Module): 44 | r"""ConvNeXt ConvNectBlock. There are two equivalent implementations: 45 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 46 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 47 | We use (2) as we find it slightly faster in PyTorch 48 | 49 | Args: 50 | dim (int): Number of input channels. 51 | drop_path (float): Stochastic depth rate. Default: 0.0 52 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 53 | """ 54 | 55 | def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): 56 | super().__init__() 57 | self.dwconv = nn.Conv2d( 58 | dim, dim, kernel_size=7, padding=3, groups=dim 59 | ) # depthwise conv 60 | self.norm = LayerNorm(dim, eps=1e-6) 61 | self.pwconv1 = nn.Linear( 62 | dim, 4 * dim 63 | ) # pointwise/1x1 convs, implemented with linear layers 64 | self.act = nn.GELU() 65 | self.pwconv2 = nn.Linear(4 * dim, dim) 66 | self.gamma = ( 67 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 68 | if layer_scale_init_value > 0 69 | else None 70 | ) 71 | self.drop_path = nn.Identity() 72 | 73 | def forward(self, x): 74 | input = x 75 | x = self.dwconv(x) 76 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 77 | x = self.norm(x) 78 | x = self.pwconv1(x) 79 | x = self.act(x) 80 | x = self.pwconv2(x) 81 | if self.gamma is not None: 82 | x = self.gamma * x 83 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 84 | x = input + self.drop_path(x) 85 | return x 86 | 87 | 88 | class UpConvNext(nn.Module): 89 | def __init__(self, ch_in, ch_out): 90 | super().__init__() 91 | self.upscale_factor = 2 92 | self.up = nn.Upsample(scale_factor=2) 93 | self.conv = nn.Conv2d( 94 | ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True 95 | ) 96 | self.norm = LayerNorm(ch_out, eps=1e-6) 97 | self.act = nn.GELU() 98 | 99 | def forward(self, x): 100 | x = self.up(x) 101 | x = self.conv(x) 102 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 103 | x = self.norm(x) 104 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 105 | x = self.act(x) 106 | return x 107 | 108 | 109 | class UpConvNext2(nn.Module): 110 | def __init__(self, ch_in, ch_out): 111 | super().__init__() 112 | self.upscale_factor = 2 113 | self.pixel = nn.PixelShuffle(upscale_factor=self.upscale_factor) 114 | self.up = nn.ConvTranspose2d( 115 | ch_in // self.upscale_factor ** 2, 116 | ch_out, 117 | kernel_size=3, 118 | stride=1, 119 | padding=1, 120 | ) 121 | self.norm = LayerNorm(ch_out, eps=1e-6) 122 | self.act = nn.GELU() 123 | 124 | def forward(self, x): 125 | x = self.pixel(x) 126 | x = self.up(x) 127 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 128 | x = self.norm(x) 129 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 130 | x = self.act(x) 131 | return x 132 | 133 | 134 | class ConvNext_block(nn.Module): 135 | def __init__(self, ch_in, ch_out): 136 | super().__init__() 137 | self.conv_next = nn.Sequential(ConvNectBlock(dim=ch_out)) 138 | self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) 139 | 140 | def forward(self, x): 141 | x = self.Conv_1x1(x) 142 | return self.conv_next(x) 143 | 144 | 145 | class U_ConvNext(nn.Module): 146 | """ 147 | Version 1 using jast ConvNext_block. No bach normalization and no relu. 148 | """ 149 | 150 | def __init__(self, img_ch=3, output_ch=1, channels=24): 151 | super().__init__() 152 | self.Maxpool = nn.AvgPool2d(kernel_size=2, stride=2) 153 | self.dropout = nn.Dropout(0.5) 154 | 155 | self.Conv1 = ConvNext_block(ch_in=img_ch, ch_out=channels) 156 | self.Conv2 = ConvNext_block(ch_in=channels, ch_out=channels * 2) 157 | self.Conv3 = ConvNext_block(ch_in=channels * 2, ch_out=channels * 4) 158 | self.Conv4 = ConvNext_block(ch_in=channels * 4, ch_out=channels * 8) 159 | self.Conv5 = ConvNext_block(ch_in=channels * 8, ch_out=channels * 16) 160 | 161 | self.Up5 = UpConvNext2(ch_in=channels * 16, ch_out=channels * 8) 162 | self.Up_conv5 = ConvNext_block(ch_in=channels * 16, ch_out=channels * 8) 163 | 164 | self.Up4 = UpConvNext2(ch_in=channels * 8, ch_out=channels * 4) 165 | self.Up_conv4 = ConvNext_block(ch_in=channels * 8, ch_out=channels * 4) 166 | 167 | self.Up3 = UpConvNext2(ch_in=channels * 4, ch_out=channels * 2) 168 | self.Up_conv3 = ConvNext_block(ch_in=channels * 4, ch_out=channels * 2) 169 | 170 | self.Up2 = UpConvNext2(ch_in=channels * 2, ch_out=channels) 171 | self.Up_conv2 = ConvNext_block(ch_in=channels * 2, ch_out=channels) 172 | 173 | self.Conv_1x1 = nn.Conv2d( 174 | channels, output_ch, kernel_size=1, stride=1, padding=0 175 | ) 176 | self.last_activation = nn.Hardtanh() 177 | 178 | def forward(self, x): 179 | # encoding path 180 | x1 = self.Conv1(x) 181 | 182 | x2 = self.Maxpool(x1) 183 | x2 = self.Conv2(x2) 184 | x2 = self.dropout(x2) 185 | 186 | x3 = self.Maxpool(x2) 187 | x3 = self.Conv3(x3) 188 | x3 = self.dropout(x3) 189 | 190 | x4 = self.Maxpool(x3) 191 | x4 = self.Conv4(x4) 192 | x4 = self.dropout(x4) 193 | 194 | x5 = self.Maxpool(x4) 195 | x5 = self.Conv5(x5) 196 | x5 = self.dropout(x5) 197 | 198 | # decoding + concat path 199 | d5 = self.Up5(x5) 200 | d5 = torch.cat((x4, d5), dim=1) 201 | 202 | d5 = self.Up_conv5(d5) 203 | 204 | d4 = self.Up4(d5) 205 | d4 = torch.cat((x3, d4), dim=1) 206 | d4 = self.Up_conv4(d4) 207 | 208 | d3 = self.Up3(d4) 209 | d3 = torch.cat((x2, d3), dim=1) 210 | d3 = self.Up_conv3(d3) 211 | 212 | d2 = self.Up2(d3) 213 | d2 = torch.cat((x1, d2), dim=1) 214 | d2 = self.Up_conv2(d2) 215 | 216 | d1 = self.Conv_1x1(d2) 217 | d1 = self.last_activation(d1) 218 | 219 | return d1 220 | 221 | 222 | class ConvNextDiscriminator(nn.Module): 223 | def __init__(self, in_channels=3, channels=4): 224 | super().__init__() 225 | 226 | self.Maxpool = nn.AvgPool2d(kernel_size=2, stride=2) 227 | 228 | self.block1 = ConvNext_block(ch_in=in_channels, ch_out=channels * 2) 229 | self.block2 = ConvNext_block(ch_in=channels * 2, ch_out=channels * 4) 230 | self.block3 = ConvNext_block(ch_in=channels * 4, ch_out=channels * 8) 231 | 232 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) 233 | self.last_conv = nn.Conv2d(channels * 8, 1, 4, padding=1, bias=False) 234 | 235 | def forward(self, img_A, img_B): 236 | # Concatenate image and condition image by channels to produce input 237 | x = torch.cat((img_A, img_B), 1) 238 | 239 | x = self.block1(x) 240 | x = self.Maxpool(x) 241 | 242 | x = self.block2(x) 243 | x = self.Maxpool(x) 244 | 245 | x = self.block3(x) 246 | x = self.Maxpool(x) 247 | 248 | x = self.zero_pad(x) 249 | 250 | x = self.last_conv(x) 251 | x = self.Maxpool(x) 252 | return x 253 | 254 | 255 | class GeneratorConvNext001(nn.Module): 256 | def __init__(self, img_shape=(3, 128, 128), blocks=9, img_ch=3, output_ch=3): 257 | super().__init__() 258 | # Initial convolution block 259 | model = [ 260 | nn.Conv2d(img_ch, 64, 7, stride=1, padding=3, bias=False), 261 | LayerNorm(64, eps=1e-6, data_format="channels_first"), 262 | ] 263 | 264 | # Downsampling 265 | curr_dim = 64 266 | for _ in range(2): 267 | model += [ 268 | LayerNorm(curr_dim, eps=1e-6, data_format="channels_first"), 269 | nn.Conv2d(curr_dim, curr_dim * 2, 4, stride=2, padding=1, bias=False), 270 | nn.Dropout(0.5), 271 | ] 272 | curr_dim *= 2 273 | 274 | for _ in range(blocks): 275 | model += [ConvNectBlock(curr_dim)] 276 | 277 | # Upsampling 278 | for _ in range(2): 279 | model += [ 280 | nn.ConvTranspose2d( 281 | curr_dim, curr_dim // 2, 4, stride=2, padding=1, bias=False 282 | ), 283 | LayerNorm(curr_dim // 2, eps=1e-6, data_format="channels_first"), 284 | nn.GELU(), 285 | ] 286 | curr_dim = curr_dim // 2 287 | 288 | # Output layer 289 | model += [nn.Conv2d(curr_dim, output_ch, 7, stride=1, padding=3), nn.Tanh()] 290 | 291 | self.model = nn.Sequential(*model) 292 | 293 | def forward(self, x): 294 | return self.model(x) 295 | 296 | 297 | class Attention_block2(nn.Module): 298 | def __init__(self, F_g, F_l, F_int): 299 | super().__init__() 300 | self.W_g = nn.Sequential( 301 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 302 | nn.BatchNorm2d(F_int), 303 | ) 304 | 305 | self.W_x = nn.Sequential( 306 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 307 | nn.BatchNorm2d(F_int), 308 | ) 309 | 310 | self.psi = nn.Sequential( 311 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 312 | nn.BatchNorm2d(1), 313 | nn.Sigmoid(), 314 | ) 315 | 316 | self.relu = nn.GELU() 317 | 318 | def forward(self, g, x): 319 | g1 = self.W_g(g) 320 | x1 = self.W_x(x) 321 | psi = self.relu(g1 + x1) 322 | psi = self.psi(psi) 323 | 324 | return x * psi 325 | 326 | 327 | class AttU_ConvNext(nn.Module): 328 | def __init__(self, img_ch=3, output_ch=1, channels=24): 329 | super().__init__() 330 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 331 | self.dropout = nn.Dropout(0.5) 332 | 333 | self.Conv1 = ConvNext_block(ch_in=img_ch, ch_out=channels) 334 | self.Conv2 = ConvNext_block(ch_in=channels, ch_out=channels * 2) 335 | self.Conv3 = ConvNext_block(ch_in=channels * 2, ch_out=channels * 4) 336 | self.Conv4 = ConvNext_block(ch_in=channels * 4, ch_out=channels * 8) 337 | self.Conv5 = ConvNext_block(ch_in=channels * 8, ch_out=channels * 16) 338 | 339 | self.Up5 = UpConvNext2(ch_in=channels * 16, ch_out=channels * 8) 340 | self.Att5 = Attention_block2( 341 | F_g=channels * 8, F_l=channels * 8, F_int=channels * 4 342 | ) 343 | self.Up_conv5 = ConvNext_block(ch_in=channels * 16, ch_out=channels * 8) 344 | 345 | self.Up4 = UpConvNext2(ch_in=channels * 8, ch_out=channels * 4) 346 | self.Att4 = Attention_block2( 347 | F_g=channels * 4, F_l=channels * 4, F_int=channels * 2 348 | ) 349 | self.Up_conv4 = ConvNext_block(ch_in=channels * 8, ch_out=channels * 4) 350 | 351 | self.Up3 = UpConvNext2(ch_in=channels * 4, ch_out=channels * 2) 352 | self.Att3 = Attention_block2(F_g=channels * 2, F_l=channels * 2, F_int=channels) 353 | self.Up_conv3 = ConvNext_block(ch_in=channels * 4, ch_out=channels * 2) 354 | 355 | self.Up2 = UpConvNext2(ch_in=channels * 2, ch_out=channels) 356 | self.Att2 = Attention_block2(F_g=channels, F_l=channels, F_int=channels // 2) 357 | self.Up_conv2 = ConvNext_block(ch_in=channels * 2, ch_out=channels) 358 | 359 | self.Conv_1x1 = nn.Conv2d( 360 | channels, output_ch, kernel_size=1, stride=1, padding=0 361 | ) 362 | self.last_activation = nn.Hardtanh() 363 | 364 | def forward(self, x): 365 | # encoding path 366 | x1 = self.Conv1(x) 367 | 368 | x2 = self.Maxpool(x1) 369 | x2 = self.Conv2(x2) 370 | x2 = self.dropout(x2) 371 | 372 | x3 = self.Maxpool(x2) 373 | x3 = self.Conv3(x3) 374 | x3 = self.dropout(x3) 375 | 376 | x4 = self.Maxpool(x3) 377 | x4 = self.Conv4(x4) 378 | x4 = self.dropout(x4) 379 | 380 | x5 = self.Maxpool(x4) 381 | x5 = self.Conv5(x5) 382 | x5 = self.dropout(x5) 383 | 384 | # decoding + concat path 385 | d5 = self.Up5(x5) 386 | x4 = self.Att5(g=d5, x=x4) 387 | d5 = torch.cat((x4, d5), dim=1) 388 | d5 = self.Up_conv5(d5) 389 | 390 | d4 = self.Up4(d5) 391 | x3 = self.Att4(g=d4, x=x3) 392 | d4 = torch.cat((x3, d4), dim=1) 393 | d4 = self.Up_conv4(d4) 394 | 395 | d3 = self.Up3(d4) 396 | x2 = self.Att3(g=d3, x=x2) 397 | d3 = torch.cat((x2, d3), dim=1) 398 | d3 = self.Up_conv3(d3) 399 | 400 | d2 = self.Up2(d3) 401 | x1 = self.Att2(g=d2, x=x1) 402 | d2 = torch.cat((x1, d2), dim=1) 403 | d2 = self.Up_conv2(d2) 404 | 405 | d1 = self.Conv_1x1(d2) 406 | d1 = self.last_activation(d1) 407 | 408 | return d1 409 | 410 | 411 | class Recurrent_block2(nn.Module): 412 | def __init__(self, ch_out, t=2): 413 | super().__init__() 414 | self.t = t 415 | self.ch_out = ch_out 416 | self.conv_next = ConvNectBlock(dim=ch_out) 417 | 418 | def forward(self, x): 419 | for i in range(self.t): 420 | if i == 0: 421 | x1 = self.conv_next(x) 422 | x1 = self.conv_next(x + x1) 423 | return x1 424 | 425 | 426 | class RRCNN_block2(nn.Module): 427 | def __init__(self, ch_in, ch_out, t=2): 428 | super().__init__() 429 | self.RCNN = nn.Sequential( 430 | Recurrent_block2(ch_out, t=t), Recurrent_block2(ch_out, t=t) 431 | ) 432 | self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) 433 | 434 | def forward(self, x): 435 | x = self.Conv_1x1(x) 436 | x1 = self.RCNN(x) 437 | return x + x1 438 | 439 | 440 | class R2U_ConvNext(nn.Module): 441 | def __init__(self, img_ch=3, output_ch=1, channels=8, t=2): 442 | super().__init__() 443 | 444 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 445 | self.dropout = nn.Dropout(0.5) 446 | 447 | self.RRCNN1 = RRCNN_block2(ch_in=img_ch, ch_out=channels, t=t) 448 | 449 | self.RRCNN2 = RRCNN_block2(ch_in=channels, ch_out=channels * 2, t=t) 450 | 451 | self.RRCNN3 = RRCNN_block2(ch_in=channels * 2, ch_out=channels * 4, t=t) 452 | 453 | self.RRCNN4 = RRCNN_block2(ch_in=channels * 4, ch_out=channels * 8, t=t) 454 | 455 | self.RRCNN5 = RRCNN_block2(ch_in=channels * 8, ch_out=channels * 16, t=t) 456 | 457 | self.Up5 = UpConvNext2(ch_in=channels * 16, ch_out=channels * 8) 458 | self.Up_RRCNN5 = RRCNN_block2(ch_in=channels * 16, ch_out=channels * 8, t=t) 459 | 460 | self.Up4 = UpConvNext2(ch_in=channels * 8, ch_out=channels * 4) 461 | self.Up_RRCNN4 = RRCNN_block2(ch_in=channels * 8, ch_out=channels * 4, t=t) 462 | 463 | self.Up3 = UpConvNext2(ch_in=channels * 4, ch_out=channels * 2) 464 | self.Up_RRCNN3 = RRCNN_block2(ch_in=channels * 4, ch_out=channels * 2, t=t) 465 | 466 | self.Up2 = UpConvNext2(ch_in=channels * 2, ch_out=channels) 467 | self.Up_RRCNN2 = RRCNN_block2(ch_in=channels * 2, ch_out=channels, t=t) 468 | 469 | self.Conv_1x1 = nn.Conv2d( 470 | channels, output_ch, kernel_size=1, stride=1, padding=0 471 | ) 472 | self.last_activation = nn.Hardtanh() 473 | 474 | def forward(self, x): 475 | # encoding path 476 | x1 = self.RRCNN1(x) 477 | 478 | x2 = self.Maxpool(x1) 479 | x2 = self.RRCNN2(x2) 480 | 481 | x3 = self.Maxpool(x2) 482 | x3 = self.RRCNN3(x3) 483 | 484 | x4 = self.Maxpool(x3) 485 | x4 = self.RRCNN4(x4) 486 | 487 | x5 = self.Maxpool(x4) 488 | x5 = self.RRCNN5(x5) 489 | 490 | # decoding + concat path 491 | d5 = self.Up5(x5) 492 | d5 = torch.cat((x4, d5), dim=1) 493 | d5 = self.Up_RRCNN5(d5) 494 | 495 | d4 = self.Up4(d5) 496 | d4 = torch.cat((x3, d4), dim=1) 497 | d4 = self.Up_RRCNN4(d4) 498 | 499 | d3 = self.Up3(d4) 500 | d3 = torch.cat((x2, d3), dim=1) 501 | d3 = self.Up_RRCNN3(d3) 502 | 503 | d2 = self.Up2(d3) 504 | d2 = torch.cat((x1, d2), dim=1) 505 | d2 = self.Up_RRCNN2(d2) 506 | 507 | d1 = self.Conv_1x1(d2) 508 | d1 = self.last_activation(d1) 509 | 510 | return d1 511 | 512 | 513 | class R2AttU_ConvNext(nn.Module): 514 | def __init__(self, img_ch=3, output_ch=1, channels=8, t=2): 515 | super().__init__() 516 | 517 | self.Maxpool = nn.AvgPool2d(kernel_size=2, stride=2) 518 | self.Upsample = nn.Upsample(scale_factor=2) 519 | 520 | self.RRCNN1 = RRCNN_block2(ch_in=img_ch, ch_out=channels, t=t) 521 | 522 | self.RRCNN2 = RRCNN_block2(ch_in=channels, ch_out=channels * 2, t=t) 523 | 524 | self.RRCNN3 = RRCNN_block2(ch_in=channels * 2, ch_out=channels * 4, t=t) 525 | 526 | self.RRCNN4 = RRCNN_block2(ch_in=channels * 4, ch_out=channels * 8, t=t) 527 | 528 | self.RRCNN5 = RRCNN_block2(ch_in=channels * 8, ch_out=channels * 16, t=t) 529 | 530 | self.Up5 = UpConvNext2(ch_in=channels * 16, ch_out=channels * 8) 531 | self.Att5 = Attention_block2( 532 | F_g=channels * 8, F_l=channels * 8, F_int=channels * 4 533 | ) 534 | self.Up_RRCNN5 = RRCNN_block2(ch_in=channels * 16, ch_out=channels * 8, t=t) 535 | 536 | self.Up4 = UpConvNext2(ch_in=channels * 8, ch_out=channels * 4) 537 | self.Att4 = Attention_block2( 538 | F_g=channels * 4, F_l=channels * 4, F_int=channels * 2 539 | ) 540 | self.Up_RRCNN4 = RRCNN_block2(ch_in=channels * 8, ch_out=channels * 4, t=t) 541 | 542 | self.Up3 = UpConvNext2(ch_in=channels * 4, ch_out=channels * 2) 543 | self.Att3 = Attention_block2(F_g=channels * 2, F_l=channels * 2, F_int=channels) 544 | self.Up_RRCNN3 = RRCNN_block2(ch_in=channels * 4, ch_out=channels * 2, t=t) 545 | 546 | self.Up2 = UpConvNext2(ch_in=channels * 2, ch_out=channels) 547 | self.Att2 = Attention_block2(F_g=channels, F_l=channels, F_int=channels // 2) 548 | self.Up_RRCNN2 = RRCNN_block2(ch_in=channels * 2, ch_out=channels, t=t) 549 | 550 | self.Conv_1x1 = nn.Conv2d( 551 | channels, output_ch, kernel_size=1, stride=1, padding=0 552 | ) 553 | self.last_activation = nn.Hardtanh() 554 | 555 | def forward(self, x): 556 | # encoding path 557 | x1 = self.RRCNN1(x) 558 | 559 | x2 = self.Maxpool(x1) 560 | x2 = self.RRCNN2(x2) 561 | 562 | x3 = self.Maxpool(x2) 563 | x3 = self.RRCNN3(x3) 564 | 565 | x4 = self.Maxpool(x3) 566 | x4 = self.RRCNN4(x4) 567 | 568 | x5 = self.Maxpool(x4) 569 | x5 = self.RRCNN5(x5) 570 | 571 | # decoding + concat path 572 | d5 = self.Up5(x5) 573 | x4 = self.Att5(g=d5, x=x4) 574 | d5 = torch.cat((x4, d5), dim=1) 575 | d5 = self.Up_RRCNN5(d5) 576 | 577 | d4 = self.Up4(d5) 578 | x3 = self.Att4(g=d4, x=x3) 579 | d4 = torch.cat((x3, d4), dim=1) 580 | d4 = self.Up_RRCNN4(d4) 581 | 582 | d3 = self.Up3(d4) 583 | x2 = self.Att3(g=d3, x=x2) 584 | d3 = torch.cat((x2, d3), dim=1) 585 | d3 = self.Up_RRCNN3(d3) 586 | 587 | d2 = self.Up2(d3) 588 | x1 = self.Att2(g=d2, x=x1) 589 | d2 = torch.cat((x1, d2), dim=1) 590 | d2 = self.Up_RRCNN2(d2) 591 | 592 | d1 = self.Conv_1x1(d2) 593 | d1 = self.last_activation(d1) 594 | 595 | return d1 596 | 597 | 598 | if __name__ == "__main__": 599 | device = torch.device("cuda") 600 | input_tensor = torch.rand(1, 3, 512, 512, device=device) 601 | model = R2AttU_ConvNext() 602 | model.to(device) 603 | output_tensor = model.forward(input_tensor) 604 | print(output_tensor) 605 | --------------------------------------------------------------------------------