├── README.md ├── LICENSE └── pytorch_prototyping.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch modules with sane defaults for model prototyping 2 | This is a number of pytorch modules (all based on prior work of the ML community) with sane default parameters that I find useful in model prototyping. 3 | 4 | I'll continue to update this repository. 5 | 6 | Contains: 7 | 8 | * 2d U-Net with different options for how the feature maps are upsampled (to prevent checkerboard artifacts.) 9 | * 3d U-Net 10 | * 2d downsampling network 11 | * 2d upsampling network with different options for how the feature maps are upsampled (to prevent checkerboard artifacts.) 12 | * 2d conv layer that pads to keep the spatial dimensions of the feature map constant, with reflection padding. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pytorch_prototyping.py: -------------------------------------------------------------------------------- 1 | '''A number of custom pytorch modules with sane defaults that I find useful for model prototyping.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import torchvision.utils 6 | 7 | import numpy as np 8 | 9 | import math 10 | import numbers 11 | 12 | class FCLayer(nn.Module): 13 | def __init__(self, in_features, out_features): 14 | super().__init__() 15 | self.net = nn.Sequential( 16 | nn.Linear(in_features, out_features), 17 | nn.LayerNorm([out_features]), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, input): 22 | return self.net(input) 23 | 24 | 25 | # From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02 26 | class LayerNormConv2d(nn.Module): 27 | 28 | def __init__(self, num_features, eps=1e-5, affine=True): 29 | super().__init__() 30 | self.num_features = num_features 31 | self.affine = affine 32 | self.eps = eps 33 | 34 | if self.affine: 35 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 36 | self.beta = nn.Parameter(torch.zeros(num_features)) 37 | 38 | def forward(self, x): 39 | shape = [-1] + [1] * (x.dim() - 1) 40 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 41 | std = x.view(x.size(0), -1).std(1).view(*shape) 42 | 43 | y = (x - mean) / (std + self.eps) 44 | if self.affine: 45 | shape = [1, -1] + [1] * (x.dim() - 2) 46 | y = self.gamma.view(*shape) * y + self.beta.view(*shape) 47 | return y 48 | 49 | 50 | class FCBlock(nn.Module): 51 | def __init__(self, 52 | hidden_ch, 53 | num_hidden_layers, 54 | in_features, 55 | out_features, 56 | outermost_linear=False): 57 | super().__init__() 58 | 59 | self.net = [] 60 | self.net.append(FCLayer(in_features=in_features, out_features=hidden_ch)) 61 | 62 | for i in range(num_hidden_layers): 63 | self.net.append(FCLayer(in_features=hidden_ch, out_features=hidden_ch)) 64 | 65 | if outermost_linear: 66 | self.net.append(nn.Linear(in_features=hidden_ch, out_features=out_features)) 67 | else: 68 | self.net.append(FCLayer(in_features=hidden_ch, out_features=out_features)) 69 | 70 | self.net = nn.Sequential(*self.net) 71 | self.net.apply(self.init_weights) 72 | 73 | def __getitem__(self,item): 74 | return self.net[item] 75 | 76 | def init_weights(self, m): 77 | if type(m) == nn.Linear: 78 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 79 | 80 | def forward(self, input): 81 | return self.net(input) 82 | 83 | 84 | class DownBlock3D(nn.Module): 85 | '''A 3D convolutional downsampling block. 86 | ''' 87 | 88 | def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): 89 | super().__init__() 90 | 91 | self.net = [ 92 | nn.ReplicationPad3d(1), 93 | nn.Conv3d(in_channels, 94 | out_channels, 95 | kernel_size=4, 96 | padding=0, 97 | stride=2, 98 | bias=False if norm is not None else True), 99 | ] 100 | 101 | if norm is not None: 102 | self.net += [norm(out_channels, affine=True)] 103 | 104 | self.net += [nn.LeakyReLU(0.2, True)] 105 | self.net = nn.Sequential(*self.net) 106 | 107 | def forward(self, x): 108 | return self.net(x) 109 | 110 | 111 | class UpBlock3D(nn.Module): 112 | '''A 3D convolutional upsampling block. 113 | ''' 114 | 115 | def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): 116 | super().__init__() 117 | 118 | self.net = [ 119 | nn.ConvTranspose3d(in_channels, 120 | out_channels, 121 | kernel_size=4, 122 | stride=2, 123 | padding=1, 124 | bias=False if norm is not None else True), 125 | ] 126 | 127 | if norm is not None: 128 | self.net += [norm(out_channels, affine=True)] 129 | 130 | self.net += [nn.ReLU(True)] 131 | self.net = nn.Sequential(*self.net) 132 | 133 | def forward(self, x, skipped=None): 134 | if skipped is not None: 135 | input = torch.cat([skipped, x], dim=1) 136 | else: 137 | input = x 138 | return self.net(input) 139 | 140 | 141 | class Conv3dSame(torch.nn.Module): 142 | '''3D convolution that pads to keep spatial dimensions equal. 143 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 144 | ''' 145 | 146 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReplicationPad3d): 147 | ''' 148 | :param in_channels: Number of input channels 149 | :param out_channels: Number of output channels 150 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 151 | :param bias: Whether or not to use bias. 152 | :param padding_layer: Which padding to use. Default is reflection padding. 153 | ''' 154 | super().__init__() 155 | ka = kernel_size // 2 156 | kb = ka - 1 if kernel_size % 2 == 0 else ka 157 | self.net = nn.Sequential( 158 | padding_layer((ka, kb, ka, kb, ka, kb)), 159 | nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 160 | ) 161 | 162 | def forward(self, x): 163 | return self.net(x) 164 | 165 | 166 | class Conv2dSame(torch.nn.Module): 167 | '''2D convolution that pads to keep spatial dimensions equal. 168 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 169 | ''' 170 | 171 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d): 172 | ''' 173 | :param in_channels: Number of input channels 174 | :param out_channels: Number of output channels 175 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 176 | :param bias: Whether or not to use bias. 177 | :param padding_layer: Which padding to use. Default is reflection padding. 178 | ''' 179 | super().__init__() 180 | ka = kernel_size // 2 181 | kb = ka - 1 if kernel_size % 2 == 0 else ka 182 | self.net = nn.Sequential( 183 | padding_layer((ka, kb, ka, kb)), 184 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 185 | ) 186 | 187 | self.weight = self.net[1].weight 188 | self.bias = self.net[1].bias 189 | 190 | def forward(self, x): 191 | return self.net(x) 192 | 193 | 194 | class UpBlock(nn.Module): 195 | '''A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with 196 | reasonable defaults. (LeakyReLU, kernel size multiple of stride) 197 | ''' 198 | 199 | def __init__(self, 200 | in_channels, 201 | out_channels, 202 | post_conv=True, 203 | use_dropout=False, 204 | dropout_prob=0.1, 205 | norm=nn.BatchNorm2d, 206 | upsampling_mode='transpose'): 207 | ''' 208 | :param in_channels: Number of input channels 209 | :param out_channels: Number of output channels 210 | :param post_conv: Whether to have another convolutional layer after the upsampling layer. 211 | :param use_dropout: bool. Whether to use dropout or not. 212 | :param dropout_prob: Float. The dropout probability (if use_dropout is True) 213 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 214 | :param upsampling_mode: Which upsampling mode: 215 | transpose: Upsampling with stride-2, kernel size 4 transpose convolutions. 216 | bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer. 217 | nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer. 218 | shuffle: Feature map is upsampled with pixel shuffling, then a conv layer. 219 | ''' 220 | super().__init__() 221 | 222 | net = list() 223 | 224 | if upsampling_mode == 'transpose': 225 | net += [nn.ConvTranspose2d(in_channels, 226 | out_channels, 227 | kernel_size=4, 228 | stride=2, 229 | padding=1, 230 | bias=True if norm is None else False)] 231 | elif upsampling_mode == 'bilinear': 232 | net += [nn.UpsamplingBilinear2d(scale_factor=2)] 233 | net += [ 234 | Conv2dSame(in_channels, out_channels, kernel_size=3, bias=True if norm is None else False)] 235 | elif upsampling_mode == 'nearest': 236 | net += [nn.UpsamplingNearest2d(scale_factor=2)] 237 | net += [ 238 | Conv2dSame(in_channels, out_channels, kernel_size=3, bias=True if norm is None else False)] 239 | elif upsampling_mode == 'shuffle': 240 | net += [nn.PixelShuffle(upscale_factor=2)] 241 | net += [ 242 | Conv2dSame(in_channels // 4, out_channels, kernel_size=3, 243 | bias=True if norm is None else False)] 244 | else: 245 | raise ValueError("Unknown upsampling mode!") 246 | 247 | if norm is not None: 248 | net += [norm(out_channels, affine=True)] 249 | 250 | net += [nn.ReLU(True)] 251 | 252 | if use_dropout: 253 | net += [nn.Dropout2d(dropout_prob, False)] 254 | 255 | if post_conv: 256 | net += [Conv2dSame(out_channels, 257 | out_channels, 258 | kernel_size=3, 259 | bias=True if norm is None else False)] 260 | 261 | if norm is not None: 262 | net += [norm(out_channels, affine=True)] 263 | 264 | net += [nn.ReLU(True)] 265 | 266 | if use_dropout: 267 | net += [nn.Dropout2d(0.1, False)] 268 | 269 | self.net = nn.Sequential(*net) 270 | 271 | def forward(self, x, skipped=None): 272 | if skipped is not None: 273 | input = torch.cat([skipped, x], dim=1) 274 | else: 275 | input = x 276 | return self.net(input) 277 | 278 | 279 | class DownBlock(nn.Module): 280 | '''A 2D-conv downsampling block following best practices / with reasonable defaults 281 | (LeakyReLU, kernel size multiple of stride) 282 | ''' 283 | 284 | def __init__(self, 285 | in_channels, 286 | out_channels, 287 | prep_conv=True, 288 | middle_channels=None, 289 | use_dropout=False, 290 | dropout_prob=0.1, 291 | norm=nn.BatchNorm2d): 292 | ''' 293 | :param in_channels: Number of input channels 294 | :param out_channels: Number of output channels 295 | :param prep_conv: Whether to have another convolutional layer before the downsampling layer. 296 | :param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling 297 | convs. 298 | :param use_dropout: bool. Whether to use dropout or not. 299 | :param dropout_prob: Float. The dropout probability (if use_dropout is True) 300 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 301 | ''' 302 | super().__init__() 303 | 304 | if middle_channels is None: 305 | middle_channels = in_channels 306 | 307 | net = list() 308 | 309 | if prep_conv: 310 | net += [nn.ReflectionPad2d(1), 311 | nn.Conv2d(in_channels, 312 | middle_channels, 313 | kernel_size=3, 314 | padding=0, 315 | stride=1, 316 | bias=True if norm is None else False)] 317 | 318 | if norm is not None: 319 | net += [norm(middle_channels, affine=True)] 320 | 321 | net += [nn.LeakyReLU(0.2, True)] 322 | 323 | if use_dropout: 324 | net += [nn.Dropout2d(dropout_prob, False)] 325 | 326 | net += [nn.ReflectionPad2d(1), 327 | nn.Conv2d(middle_channels, 328 | out_channels, 329 | kernel_size=4, 330 | padding=0, 331 | stride=2, 332 | bias=True if norm is None else False)] 333 | 334 | if norm is not None: 335 | net += [norm(out_channels, affine=True)] 336 | 337 | net += [nn.LeakyReLU(0.2, True)] 338 | 339 | if use_dropout: 340 | net += [nn.Dropout2d(dropout_prob, False)] 341 | 342 | self.net = nn.Sequential(*net) 343 | 344 | def forward(self, x): 345 | return self.net(x) 346 | 347 | 348 | class Unet3d(nn.Module): 349 | '''A 3d-Unet implementation with sane defaults. 350 | ''' 351 | 352 | def __init__(self, 353 | in_channels, 354 | out_channels, 355 | nf0, 356 | num_down, 357 | max_channels, 358 | norm=nn.BatchNorm3d, 359 | outermost_linear=False): 360 | ''' 361 | :param in_channels: Number of input channels 362 | :param out_channels: Number of output channels 363 | :param nf0: Number of features at highest level of U-Net 364 | :param num_down: Number of downsampling stages. 365 | :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) 366 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 367 | :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. 368 | ''' 369 | super().__init__() 370 | 371 | assert (num_down > 0), "Need at least one downsampling layer in UNet3d." 372 | 373 | # Define the in block 374 | self.in_layer = [Conv3dSame(in_channels, nf0, kernel_size=3, bias=False)] 375 | 376 | if norm is not None: 377 | self.in_layer += [norm(nf0, affine=True)] 378 | 379 | self.in_layer += [nn.LeakyReLU(0.2, True)] 380 | self.in_layer = nn.Sequential(*self.in_layer) 381 | 382 | # Define the center UNet block. The feature map has height and width 1 --> no batchnorm. 383 | self.unet_block = UnetSkipConnectionBlock3d(int(min(2 ** (num_down - 1) * nf0, max_channels)), 384 | int(min(2 ** (num_down - 1) * nf0, max_channels)), 385 | norm=None) 386 | for i in list(range(0, num_down - 1))[::-1]: 387 | self.unet_block = UnetSkipConnectionBlock3d(int(min(2 ** i * nf0, max_channels)), 388 | int(min(2 ** (i + 1) * nf0, max_channels)), 389 | submodule=self.unet_block, 390 | norm=norm) 391 | 392 | # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer 393 | # automatically receives the output of the in_layer and the output of the last unet layer. 394 | self.out_layer = [Conv3dSame(2 * nf0, 395 | out_channels, 396 | kernel_size=3, 397 | bias=outermost_linear)] 398 | 399 | if not outermost_linear: 400 | if norm is not None: 401 | self.out_layer += [norm(out_channels, affine=True)] 402 | self.out_layer += [nn.ReLU(True)] 403 | self.out_layer = nn.Sequential(*self.out_layer) 404 | 405 | def forward(self, x): 406 | in_layer = self.in_layer(x) 407 | unet = self.unet_block(in_layer) 408 | out_layer = self.out_layer(unet) 409 | return out_layer 410 | 411 | 412 | class UnetSkipConnectionBlock3d(nn.Module): 413 | '''Helper class for building a 3D unet. 414 | ''' 415 | 416 | def __init__(self, 417 | outer_nc, 418 | inner_nc, 419 | norm=nn.BatchNorm3d, 420 | submodule=None): 421 | super().__init__() 422 | 423 | if submodule is None: 424 | model = [DownBlock3D(outer_nc, inner_nc, norm=norm), 425 | UpBlock3D(inner_nc, outer_nc, norm=norm)] 426 | else: 427 | model = [DownBlock3D(outer_nc, inner_nc, norm=norm), 428 | submodule, 429 | UpBlock3D(2 * inner_nc, outer_nc, norm=norm)] 430 | 431 | self.model = nn.Sequential(*model) 432 | 433 | def forward(self, x): 434 | forward_passed = self.model(x) 435 | return torch.cat([x, forward_passed], 1) 436 | 437 | 438 | class UnetSkipConnectionBlock(nn.Module): 439 | '''Helper class for building a 2D unet. 440 | ''' 441 | 442 | def __init__(self, 443 | outer_nc, 444 | inner_nc, 445 | upsampling_mode, 446 | norm=nn.BatchNorm2d, 447 | submodule=None, 448 | use_dropout=False, 449 | dropout_prob=0.1): 450 | super().__init__() 451 | 452 | if submodule is None: 453 | model = [DownBlock(outer_nc, inner_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm), 454 | UpBlock(inner_nc, outer_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm, 455 | upsampling_mode=upsampling_mode)] 456 | else: 457 | model = [DownBlock(outer_nc, inner_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm), 458 | submodule, 459 | UpBlock(2 * inner_nc, outer_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm, 460 | upsampling_mode=upsampling_mode)] 461 | 462 | self.model = nn.Sequential(*model) 463 | 464 | def forward(self, x): 465 | forward_passed = self.model(x) 466 | return torch.cat([x, forward_passed], 1) 467 | 468 | 469 | class Unet(nn.Module): 470 | '''A 2d-Unet implementation with sane defaults. 471 | ''' 472 | 473 | def __init__(self, 474 | in_channels, 475 | out_channels, 476 | nf0, 477 | num_down, 478 | max_channels, 479 | use_dropout, 480 | upsampling_mode='transpose', 481 | dropout_prob=0.1, 482 | norm=nn.BatchNorm2d, 483 | outermost_linear=False): 484 | ''' 485 | :param in_channels: Number of input channels 486 | :param out_channels: Number of output channels 487 | :param nf0: Number of features at highest level of U-Net 488 | :param num_down: Number of downsampling stages. 489 | :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) 490 | :param use_dropout: Whether to use dropout or no. 491 | :param dropout_prob: Dropout probability if use_dropout=True. 492 | :param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation. 493 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 494 | :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. 495 | ''' 496 | super().__init__() 497 | 498 | assert (num_down > 0), "Need at least one downsampling layer in UNet." 499 | 500 | # Define the in block 501 | self.in_layer = [Conv2dSame(in_channels, nf0, kernel_size=3, bias=True if norm is None else False)] 502 | if norm is not None: 503 | self.in_layer += [norm(nf0, affine=True)] 504 | self.in_layer += [nn.LeakyReLU(0.2, True)] 505 | 506 | if use_dropout: 507 | self.in_layer += [nn.Dropout2d(dropout_prob)] 508 | self.in_layer = nn.Sequential(*self.in_layer) 509 | 510 | # Define the center UNet block 511 | self.unet_block = UnetSkipConnectionBlock(min(2 ** (num_down-1) * nf0, max_channels), 512 | min(2 ** (num_down-1) * nf0, max_channels), 513 | use_dropout=use_dropout, 514 | dropout_prob=dropout_prob, 515 | norm=None, # Innermost has no norm (spatial dimension 1) 516 | upsampling_mode=upsampling_mode) 517 | 518 | for i in list(range(0, num_down - 1))[::-1]: 519 | self.unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels), 520 | min(2 ** (i + 1) * nf0, max_channels), 521 | use_dropout=use_dropout, 522 | dropout_prob=dropout_prob, 523 | submodule=self.unet_block, 524 | norm=norm, 525 | upsampling_mode=upsampling_mode) 526 | 527 | # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer 528 | # automatically receives the output of the in_layer and the output of the last unet layer. 529 | self.out_layer = [Conv2dSame(2 * nf0, 530 | out_channels, 531 | kernel_size=3, 532 | bias=outermost_linear or (norm is None))] 533 | 534 | if not outermost_linear: 535 | if norm is not None: 536 | self.out_layer += [norm(out_channels, affine=True)] 537 | self.out_layer += [nn.ReLU(True)] 538 | 539 | if use_dropout: 540 | self.out_layer += [nn.Dropout2d(dropout_prob)] 541 | self.out_layer = nn.Sequential(*self.out_layer) 542 | 543 | self.out_layer_weight = self.out_layer[0].weight 544 | 545 | def forward(self, x): 546 | in_layer = self.in_layer(x) 547 | unet = self.unet_block(in_layer) 548 | out_layer = self.out_layer(unet) 549 | return out_layer 550 | 551 | 552 | class Identity(nn.Module): 553 | '''Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list.''' 554 | 555 | def __init__(self): 556 | super().__init__() 557 | 558 | def forward(self, input): 559 | return input 560 | 561 | 562 | class DownsamplingNet(nn.Module): 563 | '''A subnetwork that downsamples a 2D feature map with strided convolutions. 564 | ''' 565 | 566 | def __init__(self, 567 | per_layer_out_ch, 568 | in_channels, 569 | use_dropout, 570 | dropout_prob=0.1, 571 | last_layer_one=False, 572 | norm=nn.BatchNorm2d): 573 | ''' 574 | :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of 575 | list defines number of downsampling steps (each step dowsamples by factor of 2.) 576 | :param in_channels: Number of input channels. 577 | :param use_dropout: Whether or not to use dropout. 578 | :param dropout_prob: Dropout probability. 579 | :param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case, 580 | the last layer will not have batchnorm, else, it will. 581 | :param norm: Which norm to use. Defaults to BatchNorm. 582 | ''' 583 | super().__init__() 584 | 585 | if not len(per_layer_out_ch): 586 | self.downs = Identity() 587 | else: 588 | self.downs = list() 589 | self.downs.append(DownBlock(in_channels, per_layer_out_ch[0], use_dropout=use_dropout, 590 | dropout_prob=dropout_prob, middle_channels=per_layer_out_ch[0], norm=norm)) 591 | for i in range(0, len(per_layer_out_ch) - 1): 592 | if last_layer_one and (i == len(per_layer_out_ch) - 2): 593 | norm = None 594 | self.downs.append(DownBlock(per_layer_out_ch[i], 595 | per_layer_out_ch[i + 1], 596 | dropout_prob=dropout_prob, 597 | use_dropout=use_dropout, 598 | norm=norm)) 599 | self.downs = nn.Sequential(*self.downs) 600 | 601 | def forward(self, input): 602 | return self.downs(input) 603 | 604 | 605 | class UpsamplingNet(nn.Module): 606 | '''A subnetwork that upsamples a 2D feature map with a variety of upsampling options. 607 | ''' 608 | 609 | def __init__(self, 610 | per_layer_out_ch, 611 | in_channels, 612 | upsampling_mode, 613 | use_dropout, 614 | dropout_prob=0.1, 615 | first_layer_one=False, 616 | norm=nn.BatchNorm2d): 617 | ''' 618 | :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of 619 | list defines number of upsampling steps (each step upsamples by factor of 2.) 620 | :param in_channels: Number of input channels. 621 | :param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock" 622 | :param use_dropout: Whether or not to use dropout. 623 | :param dropout_prob: Dropout probability. 624 | :param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case, 625 | the first layer will not have a norm, else, it will. 626 | :param norm: Which norm to use. Defaults to BatchNorm. 627 | ''' 628 | super().__init__() 629 | 630 | if not len(per_layer_out_ch): 631 | self.ups = Identity() 632 | else: 633 | self.ups = list() 634 | self.ups.append(UpBlock(in_channels, 635 | per_layer_out_ch[0], 636 | use_dropout=use_dropout, 637 | dropout_prob=dropout_prob, 638 | norm=None if first_layer_one else norm, 639 | upsampling_mode=upsampling_mode)) 640 | for i in range(0, len(per_layer_out_ch) - 1): 641 | self.ups.append( 642 | UpBlock(per_layer_out_ch[i], 643 | per_layer_out_ch[i + 1], 644 | use_dropout=use_dropout, 645 | dropout_prob=dropout_prob, 646 | norm=norm, 647 | upsampling_mode=upsampling_mode)) 648 | self.ups = nn.Sequential(*self.ups) 649 | 650 | def forward(self, input): 651 | return self.ups(input) 652 | --------------------------------------------------------------------------------