├── loss ├── __init__.py ├── bceLoss.py ├── diceLoss.py └── msssimLoss.py ├── utils ├── __init__.py ├── data_vis.py ├── eval.py └── dataset.py ├── .gitignore ├── data ├── test │ ├── 00001.png │ ├── 00002.png │ └── 00003.png └── train │ ├── imgs │ ├── 00001.png │ ├── 00002.png │ └── 00003.png │ └── masks │ ├── 00001_matte.png │ ├── 00002_matte.png │ └── 00003_matte.png ├── unet ├── __init__.py ├── unet_model.py ├── init_weights.py ├── unet_parts.py ├── UNet.py ├── layers.py ├── UNet2Plus.py └── UNet3Plus.py ├── requirements.txt ├── README.md ├── predict.py └── train.py /loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__ 3 | ckpts/ 4 | log/ 5 | LR.png 6 | runs/ 7 | -------------------------------------------------------------------------------- /data/test/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/test/00001.png -------------------------------------------------------------------------------- /data/test/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/test/00002.png -------------------------------------------------------------------------------- /data/test/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/test/00003.png -------------------------------------------------------------------------------- /data/train/imgs/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/imgs/00001.png -------------------------------------------------------------------------------- /data/train/imgs/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/imgs/00002.png -------------------------------------------------------------------------------- /data/train/imgs/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/imgs/00003.png -------------------------------------------------------------------------------- /data/train/masks/00001_matte.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/masks/00001_matte.png -------------------------------------------------------------------------------- /data/train/masks/00002_matte.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/masks/00002_matte.png -------------------------------------------------------------------------------- /data/train/masks/00003_matte.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avBuffer/UNet3plus_pth/HEAD/data/train/masks/00003_matte.png -------------------------------------------------------------------------------- /unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet 2 | from .UNet2Plus import UNet2Plus 3 | from .UNet3Plus import UNet3Plus, UNet3Plus_DeepSup, UNet3Plus_DeepSup_CGM 4 | -------------------------------------------------------------------------------- /loss/bceLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def BCE_loss(pred,label): 6 | bce_loss = nn.BCELoss(size_average=True) 7 | bce_out = bce_loss(pred, label) 8 | return bce_out 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future==0.18.2 2 | MarkupSafe==1.1.1 3 | matplotlib==3.1.3 4 | numpy==1.16.0 5 | Pillow==6.2.2 6 | protobuf==3.11.3 7 | pyparsing==2.4.6 8 | python-dateutil==2.8.1 9 | six==1.14.0 10 | tensorboard==1.14.0 11 | torch>=1.1.0 12 | torchvision>=0.3.0 13 | tqdm==4.42.1 14 | -------------------------------------------------------------------------------- /utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_img_and_mask(img, mask): 5 | classes = mask.shape[2] if len(mask.shape) > 2 else 1 6 | fig, ax = plt.subplots(1, classes + 1) 7 | ax[0].set_title('Input image') 8 | ax[0].imshow(img) 9 | 10 | if classes > 1: 11 | for i in range(classes): 12 | ax[i+1].set_title(f'Output mask (class {i + 1})') 13 | ax[i+1].imshow(mask[:, :, i]) 14 | else: 15 | ax[1].set_title(f'Output mask') 16 | ax[1].imshow(mask) 17 | 18 | plt.xticks([]) 19 | plt.yticks([]) 20 | plt.show() 21 | -------------------------------------------------------------------------------- /unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | import torch.nn.functional as F 3 | from .unet_parts import * 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, n_channels, n_classes, bilinear=True): 8 | super(UNet, self).__init__() 9 | self.n_channels = n_channels 10 | self.n_classes = n_classes 11 | self.bilinear = bilinear 12 | 13 | self.inc = DoubleConv(n_channels, 64) 14 | self.down1 = Down(64, 128) 15 | self.down2 = Down(128, 256) 16 | self.down3 = Down(256, 512) 17 | self.down4 = Down(512, 512) 18 | 19 | self.up1 = Up(1024, 256, bilinear) 20 | self.up2 = Up(512, 128, bilinear) 21 | self.up3 = Up(256, 64, bilinear) 22 | self.up4 = Up(128, 64, bilinear) 23 | self.outc = OutConv(64, n_classes) 24 | 25 | 26 | def forward(self, x): 27 | x1 = self.inc(x) 28 | x2 = self.down1(x1) 29 | x3 = self.down2(x2) 30 | x4 = self.down3(x3) 31 | x5 = self.down4(x4) 32 | x = self.up1(x5, x4) 33 | x = self.up2(x, x3) 34 | x = self.up3(x, x2) 35 | x = self.up4(x, x1) 36 | 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | from loss.diceLoss import dice_coeff 5 | from loss.bceLoss import BCE_loss 6 | 7 | 8 | def eval_net(net, loader, device, n_val): 9 | """Evaluation without the densecrf with the dice coefficient""" 10 | net.eval() 11 | tot = 0 12 | 13 | with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: 14 | for batch in loader: 15 | imgs = batch['image'] 16 | true_masks = batch['mask'] 17 | 18 | imgs = imgs.to(device=device, dtype=torch.float32) 19 | mask_type = torch.float32 if net.n_classes == 1 else torch.long 20 | true_masks = true_masks.to(device=device, dtype=mask_type) 21 | mask_pred = net(imgs) 22 | 23 | for true_mask, pred in zip(true_masks, mask_pred): 24 | pred = (pred > 0.5).float() 25 | if net.n_classes > 1: 26 | tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() 27 | else: 28 | tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() 29 | 30 | pbar.update(imgs.shape[0]) 31 | return tot / n_val 32 | -------------------------------------------------------------------------------- /loss/diceLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class DiceCoeff(Function): 6 | """Dice coeff for individual examples""" 7 | def forward(self, input, target): 8 | self.save_for_backward(input, target) 9 | eps = 0.0001 10 | self.inter = torch.dot(input.view(-1), target.view(-1)) 11 | self.union = torch.sum(input) + torch.sum(target) + eps 12 | t = (2 * self.inter.float() + eps) / self.union.float() 13 | return t 14 | 15 | 16 | # This function has only a single output, so it gets only one gradient 17 | def backward(self, grad_output): 18 | input, target = self.saved_variables 19 | grad_input = grad_target = None 20 | 21 | if self.needs_input_grad[0]: 22 | grad_input = grad_output * 2 * (target * self.union - self.inter) / (self.union * self.union) 23 | 24 | if self.needs_input_grad[1]: 25 | grad_target = None 26 | return grad_input, grad_target 27 | 28 | 29 | def dice_coeff(input, target): 30 | """Dice coeff for batches""" 31 | if input.is_cuda: 32 | s = torch.FloatTensor(1).cuda().zero_() 33 | else: 34 | s = torch.FloatTensor(1).zero_() 35 | 36 | for i, c in enumerate(zip(input, target)): 37 | s = s + DiceCoeff().forward(c[0], c[1]) 38 | return s / (i + 1) 39 | -------------------------------------------------------------------------------- /unet/init_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | 6 | def weights_init_normal(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | init.normal_(m.weight.data, 0.0, 0.02) 10 | elif classname.find('Linear') != -1: 11 | init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | init.normal_(m.weight.data, 1.0, 0.02) 14 | init.constant_(m.bias.data, 0.0) 15 | 16 | 17 | def weights_init_xavier(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Conv') != -1: 20 | init.xavier_normal_(m.weight.data, gain=1) 21 | elif classname.find('Linear') != -1: 22 | init.xavier_normal_(m.weight.data, gain=1) 23 | elif classname.find('BatchNorm') != -1: 24 | init.normal_(m.weight.data, 1.0, 0.02) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_kaiming(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1: 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | elif classname.find('Linear') != -1: 33 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 34 | elif classname.find('BatchNorm') != -1: 35 | init.normal_(m.weight.data, 1.0, 0.02) 36 | init.constant_(m.bias.data, 0.0) 37 | 38 | 39 | def weights_init_orthogonal(m): 40 | classname = m.__class__.__name__ 41 | if classname.find('Conv') != -1: 42 | init.orthogonal_(m.weight.data, gain=1) 43 | elif classname.find('Linear') != -1: 44 | init.orthogonal_(m.weight.data, gain=1) 45 | elif classname.find('BatchNorm') != -1: 46 | init.normal_(m.weight.data, 1.0, 0.02) 47 | init.constant_(m.bias.data, 0.0) 48 | 49 | 50 | def init_weights(net, init_type='normal'): 51 | if init_type == 'normal': 52 | net.apply(weights_init_normal) 53 | elif init_type == 'xavier': 54 | net.apply(weights_init_xavier) 55 | elif init_type == 'kaiming': 56 | net.apply(weights_init_kaiming) 57 | elif init_type == 'orthogonal': 58 | net.apply(weights_init_orthogonal) 59 | else: 60 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 61 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import torch 5 | import numpy as np 6 | 7 | from os.path import splitext 8 | from os import listdir 9 | from glob import glob 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | 13 | 14 | class BasicDataset(Dataset): 15 | def __init__(self, unet_type, imgs_dir, masks_dir, scale=1): 16 | self.unet_type = unet_type 17 | self.imgs_dir = imgs_dir 18 | self.masks_dir = masks_dir 19 | self.scale = scale 20 | 21 | assert 0 < scale <= 1, 'Scale must be between 0 and 1' 22 | self.ids = [splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith('.')] 23 | logging.info(f'Creating dataset with {len(self.ids)} examples') 24 | 25 | 26 | def __len__(self): 27 | return len(self.ids) 28 | 29 | 30 | @classmethod 31 | def preprocess(cls, unet_type, pil_img, scale): 32 | w, h = pil_img.size 33 | newW, newH = int(scale * w), int(scale * h) 34 | assert newW > 0 and newH > 0, 'Scale is too small' 35 | 36 | if unet_type != 'v3': 37 | pil_img = pil_img.resize((newW, newH)) 38 | else: 39 | new_size = int(scale * 640) 40 | pil_img = pil_img.resize((new_size, new_size)) 41 | 42 | img_nd = np.array(pil_img) 43 | if len(img_nd.shape) == 2: 44 | img_nd = np.expand_dims(img_nd, axis=2) 45 | 46 | # HWC to CHW 47 | img_trans = img_nd.transpose((2, 0, 1)) 48 | if img_trans.max() > 1: 49 | img_trans = img_trans / 255 50 | return img_trans 51 | 52 | 53 | def __getitem__(self, i): 54 | idx = self.ids[i] 55 | mask_file = glob(self.masks_dir + idx + '*') 56 | img_file = glob(self.imgs_dir + idx + '*') 57 | 58 | assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' 59 | assert len(img_file) == 1, f'Either no image or multiple images found for the ID {idx}: {img_file}' 60 | mask = Image.open(mask_file[0]) 61 | img = Image.open(img_file[0]) 62 | 63 | assert img.size == mask.size, f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' 64 | img = self.preprocess(self.unet_type, img, self.scale) 65 | mask = self.preprocess(self.unet_type, mask, self.scale) 66 | 67 | return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} 68 | -------------------------------------------------------------------------------- /unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DoubleConv(nn.Module): 8 | """(convolution => [BN] => ReLU) * 2""" 9 | def __init__(self, in_channels, out_channels): 10 | super().__init__() 11 | self.double_conv = nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), 14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) 16 | 17 | 18 | def forward(self, x): 19 | return self.double_conv(x) 20 | 21 | 22 | class Down(nn.Module): 23 | """Downscaling with maxpool then double conv""" 24 | def __init__(self, in_channels, out_channels): 25 | super().__init__() 26 | self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) 27 | 28 | 29 | def forward(self, x): 30 | return self.maxpool_conv(x) 31 | 32 | 33 | class Up(nn.Module): 34 | """Upscaling then double conv""" 35 | def __init__(self, in_channels, out_channels, bilinear=True): 36 | super().__init__() 37 | # if bilinear, use the normal convolutions to reduce the number of channels 38 | if bilinear: 39 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 40 | else: 41 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 42 | self.conv = DoubleConv(in_channels, out_channels) 43 | 44 | 45 | def forward(self, x1, x2): 46 | x1 = self.up(x1) 47 | # input is CHW 48 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 49 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 50 | 51 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 52 | # if you have padding issues, see 53 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 54 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 55 | x = torch.cat([x2, x1], dim=1) 56 | return self.conv(x) 57 | 58 | 59 | class OutConv(nn.Module): 60 | def __init__(self, in_channels, out_channels): 61 | super(OutConv, self).__init__() 62 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 63 | 64 | 65 | def forward(self, x): 66 | return self.conv(x) 67 | -------------------------------------------------------------------------------- /unet/UNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .layers import unetConv2, unetUp, unetUp_origin 7 | from .init_weights import init_weights 8 | from torchvision import models 9 | 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4, 13 | is_deconv=True, is_batchnorm=True): 14 | super(UNet, self).__init__() 15 | self.n_channels = n_channels 16 | self.n_classes = n_classes 17 | self.bilinear = bilinear 18 | self.feature_scale = feature_scale 19 | self.is_deconv = is_deconv 20 | self.is_batchnorm = is_batchnorm 21 | filters = [64, 128, 256, 512, 1024] 22 | 23 | # downsampling 24 | self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm) 25 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 26 | 27 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 28 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 29 | 30 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 31 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 32 | 33 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 34 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 35 | 36 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 37 | 38 | # upsampling 39 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 40 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 41 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 42 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 43 | self.outconv1 = nn.Conv2d(filters[0], 1, 3, padding=1) 44 | 45 | # initialise weights 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | init_weights(m, init_type='kaiming') 49 | elif isinstance(m, nn.BatchNorm2d): 50 | init_weights(m, init_type='kaiming') 51 | 52 | 53 | def dotProduct(self,seg,cls): 54 | B, N, H, W = seg.size() 55 | seg = seg.view(B, N, H * W) 56 | final = torch.einsum("ijk,ij->ijk", [seg, cls]) 57 | final = final.view(B, N, H, W) 58 | return final 59 | 60 | 61 | def forward(self, inputs): 62 | conv1 = self.conv1(inputs) # 16*512*1024 63 | maxpool1 = self.maxpool1(conv1) # 16*256*512 64 | 65 | conv2 = self.conv2(maxpool1) # 32*256*512 66 | maxpool2 = self.maxpool2(conv2) # 32*128*256 67 | 68 | conv3 = self.conv3(maxpool2) # 64*128*256 69 | maxpool3 = self.maxpool3(conv3) # 64*64*128 70 | 71 | conv4 = self.conv4(maxpool3) # 128*64*128 72 | maxpool4 = self.maxpool4(conv4) # 128*32*64 73 | 74 | center = self.center(maxpool4) # 256*32*64 75 | 76 | up4 = self.up_concat4(center, conv4) # 128*64*128 77 | up3 = self.up_concat3(up4, conv3) # 64*128*256 78 | up2 = self.up_concat2(up3, conv2) # 32*256*512 79 | up1 = self.up_concat1(up2, conv1) # 16*512*1024 80 | 81 | d1 = self.outconv1(up1) # 256 82 | return F.sigmoid(d1) 83 | -------------------------------------------------------------------------------- /unet/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .init_weights import init_weights 5 | 6 | 7 | class unetConv2(nn.Module): 8 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 9 | super(unetConv2, self).__init__() 10 | self.n = n 11 | self.ks = ks 12 | self.stride = stride 13 | self.padding = padding 14 | s = stride 15 | p = padding 16 | 17 | if is_batchnorm: 18 | for i in range(1, n + 1): 19 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 20 | nn.BatchNorm2d(out_size), nn.ReLU(inplace=True),) 21 | setattr(self, 'conv%d' % i, conv) 22 | in_size = out_size 23 | else: 24 | for i in range(1, n + 1): 25 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) 26 | setattr(self, 'conv%d' % i, conv) 27 | in_size = out_size 28 | 29 | # initialise the blocks 30 | for m in self.children(): 31 | init_weights(m, init_type='kaiming') 32 | 33 | 34 | def forward(self, inputs): 35 | x = inputs 36 | for i in range(1, self.n + 1): 37 | conv = getattr(self, 'conv%d' % i) 38 | x = conv(x) 39 | return x 40 | 41 | 42 | class unetUp(nn.Module): 43 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 44 | super(unetUp, self).__init__() 45 | self.conv = unetConv2(out_size * 2, out_size, False) 46 | 47 | if is_deconv: 48 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) 49 | else: 50 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 51 | 52 | # initialise the blocks 53 | for m in self.children(): 54 | if m.__class__.__name__.find('unetConv2') != -1: continue 55 | init_weights(m, init_type='kaiming') 56 | 57 | 58 | def forward(self, inputs0, *input): 59 | outputs0 = self.up(inputs0) 60 | for i in range(len(input)): 61 | outputs0 = torch.cat([outputs0, input[i]], 1) 62 | return self.conv(outputs0) 63 | 64 | 65 | class unetUp_origin(nn.Module): 66 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 67 | super(unetUp_origin, self).__init__() 68 | # self.conv = unetConv2(out_size*2, out_size, False) 69 | if is_deconv: 70 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False) 71 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) 72 | else: 73 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False) 74 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 75 | 76 | # initialise the blocks 77 | for m in self.children(): 78 | if m.__class__.__name__.find('unetConv2') != -1: continue 79 | init_weights(m, init_type='kaiming') 80 | 81 | def forward(self, inputs0, *input): 82 | # print(self.n_concat) 83 | # print(input) 84 | outputs0 = self.up(inputs0) 85 | for i in range(len(input)): 86 | outputs0 = torch.cat([outputs0, input[i]], 1) 87 | return self.conv(outputs0) 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet3plus_pth 2 | UNet3+/UNet++/UNet, used in Deep Automatic Portrait Matting in Pytorth 3 | 4 | ## Dependencies 5 | 6 | - Python 3.6 7 | - PyTorch >= 1.1.0 8 | - Torchvision >= 0.3.0 9 | - future 0.18.2 10 | - matplotlib 3.1.3 11 | - numpy 1.16.0 12 | - Pillow 6.2.0 13 | - protobuf 3.11.3 14 | - tensorboard 1.14.0 15 | - tqdm==4.42.1 16 | 17 | ## Data 18 | This model was trained from scratch with 18000 images (data augmentation by 2000images) 19 | Training dataset was from [Deep Automatic Portrait Matting](http://www.cse.cuhk.edu.hk/leojia/projects/automatting/index.html). 20 | Your can download in baidu cloud [http://pan.baidu.com/s/1dE14537](http://pan.baidu.com/s/1dE14537). Password: ndg8 21 | **For academic communication only, if there is a quote, please inform the original author!** 22 | 23 | We augment the number of images by perturbing them withrotation and scaling. Four rotation angles{−45◦,−22◦,22◦,45◦}and four scales{0.6,0.8,1.2,1.5}are used. We also apply four different Gamma transforms toincrease color variation. The Gamma values are{0.5,0.8,1.2,1.5}. After thesetransforms, we have 18K training images. 24 | 25 | ## Run locally 26 | **Note : Use Python 3** 27 | 28 | ### Training 29 | 30 | ```shell script 31 | > python train.py -h 32 | usage: train.py [-h] [-g G] [-u U] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL] 33 | 34 | Train the UNet on images and target masks 35 | 36 | optional arguments: 37 | -h, --help show this help message and exit 38 | -g G, --gpu_id Number of gpu 39 | -u U, --unet\_type UNet type is unet/unet2/unet3 40 | -e E, --epochs E Number of epochs (default: 5) 41 | -b [B], --batch-size [B] 42 | Batch size (default: 1) 43 | -l [LR], --learning-rate [LR] 44 | Learning rate (default: 0.1) 45 | -f LOAD, --load LOAD Load model from a .pth file (default: False) 46 | -s SCALE, --scale SCALE 47 | Downscaling factor of the images (default: 0.5) 48 | -v VAL, --validation VAL 49 | Percent of the data that is used as validation (0-100) 50 | (default: 10.0) 51 | 52 | ``` 53 | By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. 54 | 55 | The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. 56 | 57 | ### Notes on memory 58 | ```bash 59 | $ python train.py -g 0 -u v3 -e 200 -b 1 -l 0.1 -s 0.5 -v 15.0 60 | ``` 61 | 62 | ### Prediction 63 | 64 | You can easily test the output masks on your images via the CLI. 65 | 66 | To predict a single image and save it: 67 | 68 | ```bash 69 | $ python predict.py -i image.jpg -o output.jpg 70 | ``` 71 | 72 | To predict a multiple images and show them without saving them: 73 | 74 | ```bash 75 | $ python predict.py -i image1.jpg image2.jpg --viz --no-save 76 | ``` 77 | 78 | ```shell script 79 | 80 | > python predict.py -h 81 | usage: predict.py [-h] [--gpu_id 0] [--unet\_type unet/unet2/unet3] [--model FILE] --input INPUT [INPUT ...] [--output INPUT [INPUT ...]] [--viz] [--no-save] [--mask-threshold MASK_THRESHOLD] [--scale SCALE] 82 | 83 | Predict masks from input images 84 | 85 | optional arguments: 86 | -h, --help show this help message and exit 87 | -g G, --gpu_id Number of gpu 88 | --unet\_type, -u U UNet type is unet/unet2/unet3 89 | --model FILE, -m FILE 90 | Specify the file in which the model is stored 91 | (default: MODEL.pth) 92 | --input INPUT [INPUT ...], -i INPUT [INPUT ...] 93 | filenames of input images (default: None) 94 | --output INPUT [INPUT ...], -o INPUT [INPUT ...] 95 | Filenames of ouput images (default: None) 96 | --viz, -v Visualize the images as they are processed (default: 97 | False) 98 | --no-save, -n Do not save the output masks (default: False) 99 | --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD 100 | Minimum probability value to consider a mask pixel 101 | white (default: 0.5) 102 | --scale SCALE, -s SCALE 103 | Scale factor for the input images (default: 0.5) 104 | ``` 105 | 106 | ## Reference 107 | 108 | [[2015] U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI)](https://arxiv.org/pdf/1505.04597.pdf) 109 | 110 | [[2018] UNet++: A Nested U-Net Architecture for Medical Image Segmentation (MICCAI)](https://arxiv.org/pdf/1807.10165.pdf) 111 | 112 | [[2020] UNET 3+: A Full-Scale Connected UNet for Medical Image Segmentation (ICASSP 2020)](https://arxiv.org/pdf/2004.08790.pdf) -------------------------------------------------------------------------------- /unet/UNet2Plus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .layers import unetConv2, unetUp_origin 7 | from .init_weights import init_weights 8 | from torchvision import models 9 | 10 | 11 | class UNet2Plus(nn.Module): 12 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4, 13 | is_deconv=True, is_batchnorm=True, is_ds=True): 14 | super(UNet_2Plus, self).__init__() 15 | self.n_channels = n_channels 16 | self.n_classes = n_classes 17 | self.bilinear = bilinear 18 | self.feature_scale = feature_scale 19 | self.is_deconv = is_deconv 20 | self.is_batchnorm = is_batchnorm 21 | self.is_ds = is_ds 22 | filters = [64, 128, 256, 512, 1024] 23 | 24 | # downsampling 25 | self.conv00 = unetConv2(self.n_channels, filters[0], self.is_batchnorm) 26 | self.maxpool0 = nn.MaxPool2d(kernel_size=2) 27 | self.conv10 = unetConv2(filters[0], filters[1], self.is_batchnorm) 28 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 29 | self.conv20 = unetConv2(filters[1], filters[2], self.is_batchnorm) 30 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 31 | self.conv30 = unetConv2(filters[2], filters[3], self.is_batchnorm) 32 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 33 | self.conv40 = unetConv2(filters[3], filters[4], self.is_batchnorm) 34 | 35 | # upsampling 36 | self.up_concat01 = unetUp_origin(filters[1], filters[0], self.is_deconv) 37 | self.up_concat11 = unetUp_origin(filters[2], filters[1], self.is_deconv) 38 | self.up_concat21 = unetUp_origin(filters[3], filters[2], self.is_deconv) 39 | self.up_concat31 = unetUp_origin(filters[4], filters[3], self.is_deconv) 40 | 41 | self.up_concat02 = unetUp_origin(filters[1], filters[0], self.is_deconv, 3) 42 | self.up_concat12 = unetUp_origin(filters[2], filters[1], self.is_deconv, 3) 43 | self.up_concat22 = unetUp_origin(filters[3], filters[2], self.is_deconv, 3) 44 | 45 | self.up_concat03 = unetUp_origin(filters[1], filters[0], self.is_deconv, 4) 46 | self.up_concat13 = unetUp_origin(filters[2], filters[1], self.is_deconv, 4) 47 | 48 | self.up_concat04 = unetUp_origin(filters[1], filters[0], self.is_deconv, 5) 49 | 50 | # final conv (without any concat) 51 | self.final_1 = nn.Conv2d(filters[0], n_classes, 1) 52 | self.final_2 = nn.Conv2d(filters[0], n_classes, 1) 53 | self.final_3 = nn.Conv2d(filters[0], n_classes, 1) 54 | self.final_4 = nn.Conv2d(filters[0], n_classes, 1) 55 | 56 | # initialise weights 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | init_weights(m, init_type='kaiming') 60 | elif isinstance(m, nn.BatchNorm2d): 61 | init_weights(m, init_type='kaiming') 62 | 63 | 64 | def forward(self, inputs): 65 | # column : 0 66 | X_00 = self.conv00(inputs) 67 | maxpool0 = self.maxpool0(X_00) 68 | X_10 = self.conv10(maxpool0) 69 | maxpool1 = self.maxpool1(X_10) 70 | X_20 = self.conv20(maxpool1) 71 | maxpool2 = self.maxpool2(X_20) 72 | X_30 = self.conv30(maxpool2) 73 | maxpool3 = self.maxpool3(X_30) 74 | X_40 = self.conv40(maxpool3) 75 | 76 | # column : 1 77 | X_01 = self.up_concat01(X_10, X_00) 78 | X_11 = self.up_concat11(X_20, X_10) 79 | X_21 = self.up_concat21(X_30, X_20) 80 | X_31 = self.up_concat31(X_40, X_30) 81 | 82 | # column : 2 83 | X_02 = self.up_concat02(X_11, X_00, X_01) 84 | X_12 = self.up_concat12(X_21, X_10, X_11) 85 | X_22 = self.up_concat22(X_31, X_20, X_21) 86 | 87 | # column : 3 88 | X_03 = self.up_concat03(X_12, X_00, X_01, X_02) 89 | X_13 = self.up_concat13(X_22, X_10, X_11, X_12) 90 | 91 | # column : 4 92 | X_04 = self.up_concat04(X_13, X_00, X_01, X_02, X_03) 93 | 94 | # final layer 95 | final_1 = self.final_1(X_01) 96 | final_2 = self.final_2(X_02) 97 | final_3 = self.final_3(X_03) 98 | final_4 = self.final_4(X_04) 99 | final = (final_1 + final_2 + final_3 + final_4) / 4 100 | 101 | if self.is_ds: 102 | return F.sigmoid(final) 103 | else: 104 | return F.sigmoid(final_4) 105 | 106 | 107 | if __name__ == '__main__': 108 | model = UNet2Plus() 109 | print('# generator parameters:', 1.0 * sum(param.numel() for param in model.parameters()) / 1000000) 110 | 111 | params = list(model.named_parameters()) 112 | for i in range(len(params)): 113 | name, param = params[i] 114 | print('name:', name, ' param.shape:', param.shape) 115 | -------------------------------------------------------------------------------- /loss/msssimLoss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from math import exp 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 21 | if val_range is None: 22 | if torch.max(img1) > 128: 23 | max_val = 255 24 | else: 25 | max_val = 1 26 | 27 | if torch.min(img1) < -0.5: 28 | min_val = -1 29 | else: 30 | min_val = 0 31 | L = max_val - min_val 32 | else: 33 | L = val_range 34 | 35 | padd = 0 36 | _, channel, height, width = img1.size() 37 | if window is None: 38 | real_size = min(window_size, height, width) 39 | window = create_window(real_size, channel=channel).to(img1.device) 40 | 41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 43 | 44 | mu1_sq = mu1.pow(2) 45 | mu2_sq = mu2.pow(2) 46 | mu1_mu2 = mu1 * mu2 47 | 48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 51 | 52 | C1 = (0.01 * L) ** 2 53 | C2 = (0.03 * L) ** 2 54 | 55 | v1 = 2.0 * sigma12 + C2 56 | v2 = sigma1_sq + sigma2_sq + C2 57 | cs = torch.mean(v1 / v2) # contrast sensitivity 58 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 59 | 60 | if size_average: 61 | ret = ssim_map.mean() 62 | else: 63 | ret = ssim_map.mean(1).mean(1).mean(1) 64 | 65 | if full: 66 | return ret, cs 67 | return ret 68 | 69 | 70 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 71 | device = img1.device 72 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 73 | levels = weights.size()[0] 74 | mssim = [] 75 | mcs = [] 76 | 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 79 | mssim.append(sim) 80 | mcs.append(cs) 81 | img1 = F.avg_pool2d(img1, (2, 2)) 82 | img2 = F.avg_pool2d(img2, (2, 2)) 83 | 84 | mssim = torch.stack(mssim) 85 | mcs = torch.stack(mcs) 86 | 87 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 88 | if normalize: 89 | mssim = (mssim + 1) / 2 90 | mcs = (mcs + 1) / 2 91 | 92 | pow1 = mcs ** weights 93 | pow2 = mssim ** weights 94 | 95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 96 | output = torch.prod(pow1[:-1] * pow2[-1]) 97 | return output 98 | 99 | 100 | # Classes to re-use window 101 | class SSIM(torch.nn.Module): 102 | def __init__(self, window_size=11, size_average=True, val_range=None): 103 | super(SSIM, self).__init__() 104 | self.window_size = window_size 105 | self.size_average = size_average 106 | self.val_range = val_range 107 | # Assume 1 channel for SSIM 108 | self.channel = 1 109 | self.window = create_window(window_size) 110 | 111 | 112 | def forward(self, img1, img2): 113 | _, channel, _, _ = img1.size() 114 | 115 | if channel == self.channel and self.window.dtype == img1.dtype: 116 | window = self.window 117 | else: 118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 119 | self.window = window 120 | self.channel = channel 121 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 122 | 123 | 124 | class MSSSIM(torch.nn.Module): 125 | def __init__(self, window_size=11, size_average=True, channel=3): 126 | super(MSSSIM, self).__init__() 127 | self.window_size = window_size 128 | self.size_average = size_average 129 | self.channel = channel 130 | 131 | 132 | def forward(self, img1, img2): 133 | # TODO: store window between calls if possible 134 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True) 135 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchsummary as summary 8 | 9 | from PIL import Image 10 | from torchvision import transforms 11 | from unet import UNet 12 | from unet import UNet2Plus 13 | from unet import UNet3Plus, UNet3Plus_DeepSup, UNet3Plus_DeepSup_CGM 14 | from utils.data_vis import plot_img_and_mask 15 | from utils.dataset import BasicDataset 16 | 17 | 18 | def predict_img(unet_type, net, full_img, device, scale_factor=1, out_threshold=0.5): 19 | net.eval() 20 | img = torch.from_numpy(BasicDataset.preprocess(unet_type, full_img, scale_factor)) 21 | img = img.unsqueeze(0) 22 | img = img.to(device=device, dtype=torch.float32) 23 | 24 | with torch.no_grad(): 25 | output = net(img) 26 | if net.n_classes > 1: 27 | probs = F.softmax(output, dim=1) 28 | else: 29 | probs = torch.sigmoid(output) 30 | 31 | probs = probs.squeeze(0) 32 | tf = transforms.Compose([transforms.ToPILImage(), transforms.Resize(full_img.size[1]), 33 | transforms.ToTensor()]) 34 | probs = tf(probs.cpu()) 35 | full_mask = probs.squeeze().cpu().numpy() 36 | return full_mask > out_threshold 37 | 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description='Predict masks from input images', 41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | parser.add_argument('--gpu_id', '-g', metavar='G', type=int, default=0, help='Number of gpu') 43 | parser.add_argument('--unet_type', '-u', metavar='U', default='v3', help='UNet type is v1/v2/v3 (unet unet++ unet3+)') 44 | parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', help='Specify the file in which the model is stored') 45 | 46 | parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='filenames of input images', required=True) 47 | parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of ouput images') 48 | 49 | parser.add_argument('--viz', '-v', action='store_true', help='Visualize the images as they are processed', default=False) 50 | parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks', default=False) 51 | 52 | parser.add_argument('--mask-threshold', '-t', type=float, help='Minimum probability value to consider a mask pixel white', default=0.5) 53 | parser.add_argument('--scale', '-s', type=float, help='Scale factor for the input images', default=0.5) 54 | return parser.parse_args() 55 | 56 | 57 | def get_output_filenames(args): 58 | in_files = args.input 59 | out_files = [] 60 | 61 | if not args.output: 62 | for f in in_files: 63 | pathsplit = os.path.splitext(f) 64 | out_files.append('{}_OUT{}'.format(pathsplit[0], pathsplit[1])) 65 | elif len(in_files) != len(args.output): 66 | logging.error('Input files and output files are not of the same length') 67 | raise SystemExit() 68 | else: 69 | out_files = args.output 70 | return out_files 71 | 72 | 73 | def mask_to_image(mask): 74 | return Image.fromarray((mask * 255).astype(np.uint8)) 75 | 76 | 77 | if __name__ == '__main__': 78 | args = get_args() 79 | gpu_id = args.gpu_id 80 | unet_type = args.unet_type 81 | in_files = args.input 82 | out_files = get_output_filenames(args) 83 | 84 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 85 | if unet_type == 'v2': 86 | net = UNet2Plus(n_channels=3, n_classes=1) 87 | elif unet_type == 'v3': 88 | net = UNet3Plus(n_channels=3, n_classes=1) 89 | #net = UNet3Plus_DeepSup(n_channels=3, n_classes=1) 90 | #net = UNet3Plus_DeepSup_CGM(n_channels=3, n_classes=1) 91 | else: 92 | net = UNet(n_channels=3, n_classes=1) 93 | 94 | logging.info('Loading model {}'.format(args.model)) 95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 96 | 97 | logging.info(f'Using device {device}') 98 | net.to(device=device) 99 | net.load_state_dict(torch.load(args.model, map_location=device)) 100 | logging.info('Model loaded !') 101 | 102 | for i, fn in enumerate(in_files): 103 | logging.info('\nPredicting image {} ...'.format(fn)) 104 | img = Image.open(fn) 105 | mask = predict_img(unet_type=unet_type, net=net, full_img=img, scale_factor=args.scale, 106 | out_threshold=args.mask_threshold, device=device) 107 | 108 | if not args.no_save: 109 | out_fn = out_files[i] 110 | result = mask_to_image(mask) 111 | result.save(out_files[i]) 112 | logging.info('Mask saved to {}'.format(out_files[i])) 113 | 114 | if args.viz: 115 | logging.info('Visualizing results for image {}, close to continue ...'.format(fn)) 116 | plot_img_and_mask(img, mask) 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | import math 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | from torch import optim 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DataLoader, random_split 13 | from apex import amp 14 | from tqdm import tqdm 15 | from unet import UNet 16 | from unet import UNet2Plus 17 | from unet import UNet3Plus, UNet3Plus_DeepSup, UNet3Plus_DeepSup_CGM 18 | from utils.dataset import BasicDataset 19 | from utils.eval import eval_net 20 | 21 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 22 | 23 | dir_img = 'D:/datasets/Portraits/train/imgs/' 24 | dir_mask = 'D:/datasets/Portraits/train/masks/' 25 | dir_checkpoint = 'ckpts/' 26 | 27 | 28 | def train_net(unet_type, net, device, epochs=5, batch_size=1, lr=0.1, val_percent=0.1, save_cp=True, img_scale=0.5): 29 | dataset = BasicDataset(unet_type, dir_img, dir_mask, img_scale) 30 | n_val = int(len(dataset) * val_percent) 31 | n_train = len(dataset) - n_val 32 | 33 | train, val = random_split(dataset, [n_train, n_val]) 34 | train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) 35 | val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) 36 | 37 | writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') 38 | global_step = 0 39 | 40 | logging.info(f'''Starting training: 41 | UNet type: {unet_type} 42 | Epochs: {epochs} 43 | Batch size: {batch_size} 44 | Learning rate: {lr} 45 | Dataset size: {len(dataset)} 46 | Training size: {n_train} 47 | Validation size: {n_val} 48 | Checkpoints: {save_cp} 49 | Device: {device.type} 50 | Images scaling: {img_scale}''') 51 | 52 | optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) 53 | model, optimizer = amp.initialize(net, optimizer, opt_level="O1") 54 | 55 | # Scheduler https://arxiv.org/pdf/1812.01187.pdf 56 | lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.95 + 0.05 #cosine 57 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 58 | scheduler.last_epoch = global_step 59 | 60 | if net.n_classes > 1: 61 | criterion = nn.CrossEntropyLoss() 62 | else: 63 | criterion = nn.BCEWithLogitsLoss() 64 | 65 | lrs = [] 66 | best_loss = 10000 67 | for epoch in range(epochs): 68 | cur_lr = optimizer.param_groups[0]['lr'] 69 | print('\nEpoch=', (epoch + 1), ' lr=', cur_lr) 70 | net.train() 71 | epoch_loss = 0 72 | 73 | with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 74 | for batch in train_loader: 75 | imgs = batch['image'] 76 | true_masks = batch['mask'] 77 | 78 | assert imgs.shape[1] == net.n_channels, f'Network has been defined with {net.n_channels} input channels, ' \ 79 | f'but loaded images have {imgs.shape[1]} channels. Please check that the images are loaded correctly.' 80 | 81 | imgs = imgs.to(device=device, dtype=torch.float32) 82 | mask_type = torch.float32 if net.n_classes == 1 else torch.long 83 | true_masks = true_masks.to(device=device, dtype=mask_type) 84 | 85 | masks_pred = net(imgs) 86 | loss = criterion(masks_pred, true_masks) 87 | epoch_loss += loss.item() 88 | writer.add_scalar('Loss/train', loss.item(), global_step) 89 | pbar.set_postfix(**{'loss (batch)': loss.item()}) 90 | 91 | optimizer.zero_grad() 92 | #loss.backward() 93 | with amp.scale_loss(loss, optimizer) as scaled_loss: 94 | scaled_loss.backward() 95 | optimizer.step() 96 | 97 | pbar.update(imgs.shape[0]) 98 | global_step += 1 99 | dataset_len = len(dataset) 100 | 101 | a1 = dataset_len // 10 if dataset_len // 10 > 0 else 1 102 | a2 = dataset_len / 10 if dataset_len / 10 > 0 else 1 103 | b1 = global_step % a1 104 | b2 = global_step % a2 105 | 106 | if global_step % (len(dataset) // (10 * batch_size)) == 0: 107 | val_score = eval_net(net, val_loader, device, n_val) 108 | if net.n_classes > 1: 109 | logging.info('Validation cross entropy: {}'.format(val_score)) 110 | writer.add_scalar('Loss/test', val_score, global_step) 111 | else: 112 | logging.info('Validation Dice Coeff: {}'.format(val_score)) 113 | writer.add_scalar('Dice/test', val_score, global_step) 114 | 115 | writer.add_images('images', imgs, global_step) 116 | if net.n_classes == 1: 117 | writer.add_images('masks/true', true_masks, global_step) 118 | writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) 119 | 120 | # update scheduler 121 | scheduler.step() 122 | lrs.append(cur_lr) 123 | 124 | if save_cp: 125 | try: 126 | os.mkdir(dir_checkpoint) 127 | logging.info('Created checkpoint directory') 128 | except OSError: 129 | pass 130 | 131 | if loss.item() < best_loss: 132 | torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}_loss_{str(loss.item())}.pth') 133 | best_loss = loss.item() 134 | logging.info(f'Checkpoint {epoch + 1} saved ! loss (batch) = ' + str(loss.item())) 135 | 136 | # plot lr scheduler 137 | plt.plot(lrs , '.-', label='LambdaLR') 138 | plt.xlabel('epoch') 139 | plt.ylabel('LR') 140 | plt.tight_layout() 141 | plt.savefig('LR.png', dpi=300) 142 | 143 | writer.close() 144 | 145 | 146 | def get_args(): 147 | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', 148 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 149 | parser.add_argument('-g', '--gpu_id', dest='gpu_id', metavar='G', type=int, default=0, help='Number of gpu') 150 | parser.add_argument('-u', '--unet_type', dest='unet_type', metavar='U', type=str, default='v3', help='UNet type is v1/v2/v3 (unet unet++ unet3+)') 151 | 152 | parser.add_argument('-e', '--epochs', metavar='E', type=int, default=10000, help='Number of epochs', dest='epochs') 153 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=2, help='Batch size', dest='batchsize') 154 | parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1, help='Learning rate', dest='lr') 155 | 156 | parser.add_argument('-f', '--load', dest='load', type=str, default=False, help='Load model from a .pth file') 157 | parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, help='Downscaling factor of the images') 158 | parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0, help='Percent of the data that is used as validation (0-100)') 159 | return parser.parse_args() 160 | 161 | 162 | if __name__ == '__main__': 163 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') 164 | args = get_args() 165 | gpu_id = args.gpu_id 166 | unet_type = args.unet_type 167 | 168 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 169 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 170 | logging.info(f'Using device {device}') 171 | 172 | # Change here to adapt to your data 173 | # n_channels=3 for RGB images 174 | # n_classes is the number of probabilities you want to get per pixel 175 | # - For 1 class and background, use n_classes=1 176 | # - For 2 classes, use n_classes=1 177 | # - For N > 2 classes, use n_classes=N 178 | if unet_type == 'v2': 179 | net = UNet2Plus(n_channels=3, n_classes=1) 180 | elif unet_type == 'v3': 181 | net = UNet3Plus(n_channels=3, n_classes=1) 182 | #net = UNet3Plus_DeepSup(n_channels=3, n_classes=1) 183 | #net = UNet3Plus_DeepSup_CGM(n_channels=3, n_classes=1) 184 | else: 185 | net = UNet(n_channels=3, n_classes=1) 186 | 187 | logging.info(f'Network:\n' 188 | f'\t{net.n_channels} input channels\n' 189 | f'\t{net.n_classes} output channels (classes)\n') 190 | #f'\t{'Bilinear' if net.bilinear else 'Dilated conv'} upscaling') 191 | 192 | if args.load: 193 | net.load_state_dict(torch.load(args.load, map_location=device)) 194 | logging.info(f'Model loaded from {args.load}') 195 | 196 | net.to(device=device) 197 | # faster convolutions, but more memory 198 | # cudnn.benchmark = True 199 | try: 200 | train_net(unet_type=unet_type, net=net, epochs=args.epochs, batch_size=args.batchsize, 201 | lr=args.lr, device=device, img_scale=args.scale, val_percent=args.val / 100) 202 | except KeyboardInterrupt: 203 | torch.save(net.state_dict(), 'INTERRUPTED.pth') 204 | logging.info('Saved interrupt') 205 | try: 206 | sys.exit(0) 207 | except SystemExit: 208 | os.exit(0) 209 | -------------------------------------------------------------------------------- /unet/UNet3Plus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .layers import unetConv2 7 | from .init_weights import init_weights 8 | 9 | 10 | class UNet3Plus(nn.Module): 11 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4, 12 | is_deconv=True, is_batchnorm=True): 13 | super(UNet3Plus, self).__init__() 14 | self.n_channels = n_channels 15 | self.n_classes = n_classes 16 | self.bilinear = bilinear 17 | self.feature_scale = feature_scale 18 | self.is_deconv = is_deconv 19 | self.is_batchnorm = is_batchnorm 20 | filters = [64, 128, 256, 512, 1024] 21 | 22 | ## -------------Encoder-------------- 23 | self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm) 24 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 25 | 26 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 27 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 28 | 29 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 30 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 31 | 32 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 33 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 34 | 35 | self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) 36 | 37 | ## -------------Decoder-------------- 38 | self.CatChannels = filters[0] 39 | self.CatBlocks = 5 40 | self.UpChannels = self.CatChannels * self.CatBlocks 41 | 42 | '''stage 4d''' 43 | # h1->320*320, hd4->40*40, Pooling 8 times 44 | self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) 45 | self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 46 | self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 47 | self.h1_PT_hd4_relu = nn.ReLU(inplace=True) 48 | 49 | # h2->160*160, hd4->40*40, Pooling 4 times 50 | self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) 51 | self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 52 | self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 53 | self.h2_PT_hd4_relu = nn.ReLU(inplace=True) 54 | 55 | # h3->80*80, hd4->40*40, Pooling 2 times 56 | self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) 57 | self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 58 | self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 59 | self.h3_PT_hd4_relu = nn.ReLU(inplace=True) 60 | 61 | # h4->40*40, hd4->40*40, Concatenation 62 | self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) 63 | self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) 64 | self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) 65 | 66 | # hd5->20*20, hd4->40*40, Upsample 2 times 67 | self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 68 | self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 69 | self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 70 | self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) 71 | 72 | # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) 73 | self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 74 | self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) 75 | self.relu4d_1 = nn.ReLU(inplace=True) 76 | 77 | '''stage 3d''' 78 | # h1->320*320, hd3->80*80, Pooling 4 times 79 | self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) 80 | self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 81 | self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 82 | self.h1_PT_hd3_relu = nn.ReLU(inplace=True) 83 | 84 | # h2->160*160, hd3->80*80, Pooling 2 times 85 | self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) 86 | self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 87 | self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 88 | self.h2_PT_hd3_relu = nn.ReLU(inplace=True) 89 | 90 | # h3->80*80, hd3->80*80, Concatenation 91 | self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 92 | self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) 93 | self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) 94 | 95 | # hd4->40*40, hd4->80*80, Upsample 2 times 96 | self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 97 | self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 98 | self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 99 | self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) 100 | 101 | # hd5->20*20, hd4->80*80, Upsample 4 times 102 | self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 103 | self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 104 | self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 105 | self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) 106 | 107 | # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) 108 | self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 109 | self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) 110 | self.relu3d_1 = nn.ReLU(inplace=True) 111 | 112 | '''stage 2d ''' 113 | # h1->320*320, hd2->160*160, Pooling 2 times 114 | self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) 115 | self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 116 | self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 117 | self.h1_PT_hd2_relu = nn.ReLU(inplace=True) 118 | 119 | # h2->160*160, hd2->160*160, Concatenation 120 | self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 121 | self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) 122 | self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) 123 | 124 | # hd3->80*80, hd2->160*160, Upsample 2 times 125 | self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 126 | self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 127 | self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 128 | self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) 129 | 130 | # hd4->40*40, hd2->160*160, Upsample 4 times 131 | self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 132 | self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 133 | self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 134 | self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) 135 | 136 | # hd5->20*20, hd2->160*160, Upsample 8 times 137 | self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 138 | self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 139 | self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 140 | self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) 141 | 142 | # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) 143 | self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 144 | self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) 145 | self.relu2d_1 = nn.ReLU(inplace=True) 146 | 147 | '''stage 1d''' 148 | # h1->320*320, hd1->320*320, Concatenation 149 | self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 150 | self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) 151 | self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) 152 | 153 | # hd2->160*160, hd1->320*320, Upsample 2 times 154 | self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 155 | self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 156 | self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 157 | self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) 158 | 159 | # hd3->80*80, hd1->320*320, Upsample 4 times 160 | self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 161 | self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 162 | self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 163 | self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) 164 | 165 | # hd4->40*40, hd1->320*320, Upsample 8 times 166 | self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 167 | self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 168 | self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 169 | self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) 170 | 171 | # hd5->20*20, hd1->320*320, Upsample 16 times 172 | self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14 173 | self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 174 | self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 175 | self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) 176 | 177 | # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) 178 | self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 179 | self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) 180 | self.relu1d_1 = nn.ReLU(inplace=True) 181 | 182 | # output 183 | self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 184 | 185 | # initialise weights 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | init_weights(m, init_type='kaiming') 189 | elif isinstance(m, nn.BatchNorm2d): 190 | init_weights(m, init_type='kaiming') 191 | 192 | 193 | def forward(self, inputs): 194 | ## -------------Encoder------------- 195 | h1 = self.conv1(inputs) # h1->320*320*64 196 | 197 | h2 = self.maxpool1(h1) 198 | h2 = self.conv2(h2) # h2->160*160*128 199 | 200 | h3 = self.maxpool2(h2) 201 | h3 = self.conv3(h3) # h3->80*80*256 202 | 203 | h4 = self.maxpool3(h3) 204 | h4 = self.conv4(h4) # h4->40*40*512 205 | 206 | h5 = self.maxpool4(h4) 207 | hd5 = self.conv5(h5) # h5->20*20*1024 208 | 209 | ## -------------Decoder------------- 210 | h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) 211 | h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) 212 | h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) 213 | h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) 214 | hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) 215 | hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels 216 | 217 | h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) 218 | h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) 219 | h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) 220 | hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) 221 | hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) 222 | hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels 223 | 224 | h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) 225 | h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) 226 | hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) 227 | hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) 228 | hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) 229 | hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels 230 | 231 | h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) 232 | hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) 233 | hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) 234 | hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) 235 | hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) 236 | hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels 237 | 238 | d1 = self.outconv1(hd1) # d1->320*320*n_classes 239 | return F.sigmoid(d1) 240 | 241 | 242 | #UNet 3+ with deep supervision 243 | class UNet3Plus_DeepSup(nn.Module): 244 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4, 245 | is_deconv=True, is_batchnorm=True): 246 | super(UNet3Plus_DeepSup, self).__init__() 247 | self.n_channels = n_channels 248 | self.n_classes = n_classes 249 | self.bilinear = bilinear 250 | self.feature_scale = feature_scale 251 | self.is_deconv = is_deconv 252 | self.is_batchnorm = is_batchnorm 253 | filters = [64, 128, 256, 512, 1024] 254 | 255 | ## -------------Encoder-------------- 256 | self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm) 257 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 258 | 259 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 260 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 261 | 262 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 263 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 264 | 265 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 266 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 267 | 268 | self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) 269 | 270 | ## -------------Decoder-------------- 271 | self.CatChannels = filters[0] 272 | self.CatBlocks = 5 273 | self.UpChannels = self.CatChannels * self.CatBlocks 274 | 275 | '''stage 4d''' 276 | # h1->320*320, hd4->40*40, Pooling 8 times 277 | self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) 278 | self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 279 | self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 280 | self.h1_PT_hd4_relu = nn.ReLU(inplace=True) 281 | 282 | # h2->160*160, hd4->40*40, Pooling 4 times 283 | self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) 284 | self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 285 | self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 286 | self.h2_PT_hd4_relu = nn.ReLU(inplace=True) 287 | 288 | # h3->80*80, hd4->40*40, Pooling 2 times 289 | self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) 290 | self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 291 | self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 292 | self.h3_PT_hd4_relu = nn.ReLU(inplace=True) 293 | 294 | # h4->40*40, hd4->40*40, Concatenation 295 | self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) 296 | self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) 297 | self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) 298 | 299 | # hd5->20*20, hd4->40*40, Upsample 2 times 300 | self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 301 | self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 302 | self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 303 | self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) 304 | 305 | # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) 306 | self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 307 | self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) 308 | self.relu4d_1 = nn.ReLU(inplace=True) 309 | 310 | '''stage 3d''' 311 | # h1->320*320, hd3->80*80, Pooling 4 times 312 | self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) 313 | self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 314 | self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 315 | self.h1_PT_hd3_relu = nn.ReLU(inplace=True) 316 | 317 | # h2->160*160, hd3->80*80, Pooling 2 times 318 | self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) 319 | self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 320 | self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 321 | self.h2_PT_hd3_relu = nn.ReLU(inplace=True) 322 | 323 | # h3->80*80, hd3->80*80, Concatenation 324 | self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 325 | self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) 326 | self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) 327 | 328 | # hd4->40*40, hd4->80*80, Upsample 2 times 329 | self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 330 | self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 331 | self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 332 | self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) 333 | 334 | # hd5->20*20, hd4->80*80, Upsample 4 times 335 | self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 336 | self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 337 | self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 338 | self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) 339 | 340 | # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) 341 | self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 342 | self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) 343 | self.relu3d_1 = nn.ReLU(inplace=True) 344 | 345 | '''stage 2d ''' 346 | # h1->320*320, hd2->160*160, Pooling 2 times 347 | self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) 348 | self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 349 | self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 350 | self.h1_PT_hd2_relu = nn.ReLU(inplace=True) 351 | 352 | # h2->160*160, hd2->160*160, Concatenation 353 | self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 354 | self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) 355 | self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) 356 | 357 | # hd3->80*80, hd2->160*160, Upsample 2 times 358 | self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 359 | self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 360 | self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 361 | self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) 362 | 363 | # hd4->40*40, hd2->160*160, Upsample 4 times 364 | self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 365 | self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 366 | self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 367 | self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) 368 | 369 | # hd5->20*20, hd2->160*160, Upsample 8 times 370 | self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 371 | self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 372 | self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 373 | self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) 374 | 375 | # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) 376 | self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 377 | self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) 378 | self.relu2d_1 = nn.ReLU(inplace=True) 379 | 380 | '''stage 1d''' 381 | # h1->320*320, hd1->320*320, Concatenation 382 | self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 383 | self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) 384 | self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) 385 | 386 | # hd2->160*160, hd1->320*320, Upsample 2 times 387 | self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 388 | self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 389 | self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 390 | self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) 391 | 392 | # hd3->80*80, hd1->320*320, Upsample 4 times 393 | self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 394 | self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 395 | self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 396 | self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) 397 | 398 | # hd4->40*40, hd1->320*320, Upsample 8 times 399 | self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 400 | self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 401 | self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 402 | self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) 403 | 404 | # hd5->20*20, hd1->320*320, Upsample 16 times 405 | self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14 406 | self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 407 | self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 408 | self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) 409 | 410 | # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) 411 | self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 412 | self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) 413 | self.relu1d_1 = nn.ReLU(inplace=True) 414 | 415 | # -------------Bilinear Upsampling-------------- 416 | self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear')### 417 | self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear') 418 | self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear') 419 | self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear') 420 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 421 | 422 | # DeepSup 423 | self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 424 | self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 425 | self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 426 | self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 427 | self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1) 428 | 429 | # initialise weights 430 | for m in self.modules(): 431 | if isinstance(m, nn.Conv2d): 432 | init_weights(m, init_type='kaiming') 433 | elif isinstance(m, nn.BatchNorm2d): 434 | init_weights(m, init_type='kaiming') 435 | 436 | 437 | def forward(self, inputs): 438 | ## -------------Encoder------------- 439 | h1 = self.conv1(inputs) # h1->320*320*64 440 | 441 | h2 = self.maxpool1(h1) 442 | h2 = self.conv2(h2) # h2->160*160*128 443 | 444 | h3 = self.maxpool2(h2) 445 | h3 = self.conv3(h3) # h3->80*80*256 446 | 447 | h4 = self.maxpool3(h3) 448 | h4 = self.conv4(h4) # h4->40*40*512 449 | 450 | h5 = self.maxpool4(h4) 451 | hd5 = self.conv5(h5) # h5->20*20*1024 452 | 453 | ## -------------Decoder------------- 454 | h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) 455 | h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) 456 | h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) 457 | h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) 458 | hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) 459 | hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels 460 | 461 | h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) 462 | h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) 463 | h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) 464 | hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) 465 | hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) 466 | hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels 467 | 468 | h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) 469 | h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) 470 | hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) 471 | hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) 472 | hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) 473 | hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels 474 | 475 | h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) 476 | hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) 477 | hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) 478 | hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) 479 | hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) 480 | hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels 481 | 482 | d5 = self.outconv5(hd5) 483 | d5 = self.upscore5(d5) # 16->256 484 | 485 | d4 = self.outconv4(hd4) 486 | d4 = self.upscore4(d4) # 32->256 487 | 488 | d3 = self.outconv3(hd3) 489 | d3 = self.upscore3(d3) # 64->256 490 | 491 | d2 = self.outconv2(hd2) 492 | d2 = self.upscore2(d2) # 128->256 493 | 494 | d1 = self.outconv1(hd1) # 256 495 | return F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5) 496 | 497 | 498 | #UNet 3+ with deep supervision and class-guided module 499 | class UNet3Plus_DeepSup_CGM(nn.Module): 500 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4, 501 | is_deconv=True, is_batchnorm=True): 502 | super(UNet3Plus_DeepSup_CGM, self).__init__() 503 | self.n_channels = n_channels 504 | self.n_classes = n_classes 505 | self.bilinear = bilinear 506 | self.feature_scale = feature_scale 507 | self.is_deconv = is_deconv 508 | self.is_batchnorm = is_batchnorm 509 | filters = [64, 128, 256, 512, 1024] 510 | 511 | ## -------------Encoder-------------- 512 | self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm) 513 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 514 | 515 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 516 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 517 | 518 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 519 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 520 | 521 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 522 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 523 | 524 | self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) 525 | 526 | ## -------------Decoder-------------- 527 | self.CatChannels = filters[0] 528 | self.CatBlocks = 5 529 | self.UpChannels = self.CatChannels * self.CatBlocks 530 | 531 | '''stage 4d''' 532 | # h1->320*320, hd4->40*40, Pooling 8 times 533 | self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) 534 | self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 535 | self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 536 | self.h1_PT_hd4_relu = nn.ReLU(inplace=True) 537 | 538 | # h2->160*160, hd4->40*40, Pooling 4 times 539 | self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) 540 | self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 541 | self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 542 | self.h2_PT_hd4_relu = nn.ReLU(inplace=True) 543 | 544 | # h3->80*80, hd4->40*40, Pooling 2 times 545 | self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) 546 | self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 547 | self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 548 | self.h3_PT_hd4_relu = nn.ReLU(inplace=True) 549 | 550 | # h4->40*40, hd4->40*40, Concatenation 551 | self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) 552 | self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) 553 | self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) 554 | 555 | # hd5->20*20, hd4->40*40, Upsample 2 times 556 | self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 557 | self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 558 | self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) 559 | self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) 560 | 561 | # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) 562 | self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 563 | self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) 564 | self.relu4d_1 = nn.ReLU(inplace=True) 565 | 566 | '''stage 3d''' 567 | # h1->320*320, hd3->80*80, Pooling 4 times 568 | self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) 569 | self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 570 | self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 571 | self.h1_PT_hd3_relu = nn.ReLU(inplace=True) 572 | 573 | # h2->160*160, hd3->80*80, Pooling 2 times 574 | self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) 575 | self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 576 | self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 577 | self.h2_PT_hd3_relu = nn.ReLU(inplace=True) 578 | 579 | # h3->80*80, hd3->80*80, Concatenation 580 | self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) 581 | self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) 582 | self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) 583 | 584 | # hd4->40*40, hd4->80*80, Upsample 2 times 585 | self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 586 | self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 587 | self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 588 | self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) 589 | 590 | # hd5->20*20, hd4->80*80, Upsample 4 times 591 | self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 592 | self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 593 | self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) 594 | self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) 595 | 596 | # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) 597 | self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 598 | self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) 599 | self.relu3d_1 = nn.ReLU(inplace=True) 600 | 601 | '''stage 2d ''' 602 | # h1->320*320, hd2->160*160, Pooling 2 times 603 | self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) 604 | self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 605 | self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 606 | self.h1_PT_hd2_relu = nn.ReLU(inplace=True) 607 | 608 | # h2->160*160, hd2->160*160, Concatenation 609 | self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) 610 | self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) 611 | self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) 612 | 613 | # hd3->80*80, hd2->160*160, Upsample 2 times 614 | self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 615 | self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 616 | self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 617 | self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) 618 | 619 | # hd4->40*40, hd2->160*160, Upsample 4 times 620 | self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 621 | self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 622 | self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 623 | self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) 624 | 625 | # hd5->20*20, hd2->160*160, Upsample 8 times 626 | self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 627 | self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 628 | self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) 629 | self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) 630 | 631 | # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) 632 | self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 633 | self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) 634 | self.relu2d_1 = nn.ReLU(inplace=True) 635 | 636 | '''stage 1d''' 637 | # h1->320*320, hd1->320*320, Concatenation 638 | self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) 639 | self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) 640 | self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) 641 | 642 | # hd2->160*160, hd1->320*320, Upsample 2 times 643 | self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 644 | self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 645 | self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 646 | self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) 647 | 648 | # hd3->80*80, hd1->320*320, Upsample 4 times 649 | self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 650 | self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 651 | self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 652 | self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) 653 | 654 | # hd4->40*40, hd1->320*320, Upsample 8 times 655 | self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 656 | self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) 657 | self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 658 | self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) 659 | 660 | # hd5->20*20, hd1->320*320, Upsample 16 times 661 | self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14 662 | self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) 663 | self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) 664 | self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) 665 | 666 | # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) 667 | self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 668 | self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) 669 | self.relu1d_1 = nn.ReLU(inplace=True) 670 | 671 | # -------------Bilinear Upsampling-------------- 672 | self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')### 673 | self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear') 674 | self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear') 675 | self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear') 676 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 677 | 678 | # DeepSup 679 | self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 680 | self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 681 | self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 682 | self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) 683 | self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1) 684 | 685 | self.cls = nn.Sequential( 686 | nn.Dropout(p=0.5), 687 | nn.Conv2d(filters[4], 2, 1), 688 | nn.AdaptiveMaxPool2d(1), 689 | nn.Sigmoid()) 690 | 691 | # initialise weights 692 | for m in self.modules(): 693 | if isinstance(m, nn.Conv2d): 694 | init_weights(m, init_type='kaiming') 695 | elif isinstance(m, nn.BatchNorm2d): 696 | init_weights(m, init_type='kaiming') 697 | 698 | 699 | def dotProduct(self,seg,cls): 700 | B, N, H, W = seg.size() 701 | seg = seg.view(B, N, H * W) 702 | final = torch.einsum("ijk,ij->ijk", [seg, cls]) 703 | final = final.view(B, N, H, W) 704 | return final 705 | 706 | 707 | def forward(self, inputs): 708 | ## -------------Encoder------------- 709 | h1 = self.conv1(inputs) # h1->320*320*64 710 | 711 | h2 = self.maxpool1(h1) 712 | h2 = self.conv2(h2) # h2->160*160*128 713 | 714 | h3 = self.maxpool2(h2) 715 | h3 = self.conv3(h3) # h3->80*80*256 716 | 717 | h4 = self.maxpool3(h3) 718 | h4 = self.conv4(h4) # h4->40*40*512 719 | 720 | h5 = self.maxpool4(h4) 721 | hd5 = self.conv5(h5) # h5->20*20*1024 722 | 723 | # -------------Classification------------- 724 | cls_branch = self.cls(hd5).squeeze(3).squeeze(2) # (B,N,1,1)->(B,N) 725 | cls_branch_max = cls_branch.argmax(dim=1) 726 | cls_branch_max = cls_branch_max[:, np.newaxis].float() 727 | 728 | ## -------------Decoder------------- 729 | h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) 730 | h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) 731 | h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) 732 | h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) 733 | hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) 734 | hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels 735 | 736 | h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) 737 | h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) 738 | h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) 739 | hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) 740 | hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) 741 | hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels 742 | 743 | h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) 744 | h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) 745 | hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) 746 | hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) 747 | hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) 748 | hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels 749 | 750 | h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) 751 | hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) 752 | hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) 753 | hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) 754 | hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) 755 | hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels 756 | 757 | d5 = self.outconv5(hd5) 758 | d5 = self.upscore5(d5) # 16->256 759 | 760 | d4 = self.outconv4(hd4) 761 | d4 = self.upscore4(d4) # 32->256 762 | 763 | d3 = self.outconv3(hd3) 764 | d3 = self.upscore3(d3) # 64->256 765 | 766 | d2 = self.outconv2(hd2) 767 | d2 = self.upscore2(d2) # 128->256 768 | 769 | d1 = self.outconv1(hd1) # 256 770 | 771 | d1 = self.dotProduct(d1, cls_branch_max) 772 | d2 = self.dotProduct(d2, cls_branch_max) 773 | d3 = self.dotProduct(d3, cls_branch_max) 774 | d4 = self.dotProduct(d4, cls_branch_max) 775 | d5 = self.dotProduct(d5, cls_branch_max) 776 | return F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5) 777 | --------------------------------------------------------------------------------