├── .gitignore ├── README.md ├── constants.py ├── core ├── models │ ├── deeplabv3_plus.py │ ├── mobilenet_v2_dilation.py │ ├── mobilenetv2.py │ ├── pspnet.py │ ├── resnet.py │ ├── unet.py │ ├── unet_paper.py │ └── unet_pytorch.py └── trainers │ └── trainer.py ├── dataloader ├── docunet.py ├── docunet_im2im.py └── docunet_inverted.py ├── environment.yml ├── main.py ├── parser_options.py ├── playground.py ├── readme_images ├── generating_deformed_images.PNG ├── output_examples.png └── overall_architecture.PNG ├── train.sh └── util ├── custom_transforms.py ├── general_functions.py ├── losses.py ├── lr_scheduler.py ├── ms_ssim.py ├── ssim.py └── summary.py /.gitignore: -------------------------------------------------------------------------------- 1 | results*/* 2 | images/* 3 | saved_models/* 4 | dataset/* 5 | .idea/* 6 | __pycache__/* 7 | **/__pycache__/* 8 | !.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document Image unwarping 2 | 3 |

4 | 5 |

6 | 7 | This repository contains an *unofficial* implementation of [DocUNet: Document Image Unwarping via a Stacked U-Net](http://openaccess.thecvf.com/content_cvpr_2018/html/Ma_DocUNet_Document_Image_CVPR_2018_paper.html). 8 | We extend this work by: 9 | * predicting the inverted vector fields directly, which saves computation time during inference 10 | * adding more networks that can be used: from UNet to Deeplabv3+ with different backbones 11 | * adding a second loss function (MS-SSIM / SSIM) to measure the similarity between unwarped and target image 12 | * achieving real-time inference speed (300ms) on cpu for Deeplabv3+ with MobileNetv2 as backbone 13 | 14 | ## Training dataset 15 | 16 | Unfortunately, I am not allowed to make public the dataset. However, I created a very small toy dataset to give you an idea of how the network input should look. 17 | You can find this [here](https://drive.google.com/file/d/16Ay3NVzFmsVe1saMOZam-9nHBE0xcmyA/view?usp=sharing). 18 | The idea is to create a 2D vector field to deform a flat input image. The deformed image is used as network input and the vector field is the network target. 19 | 20 |

21 | 22 |

23 | 24 | ## Training on your dataset 25 | 1. Check the [available parser options](parser_options.py). 26 | 2. Download the [toy dataset](https://drive.google.com/file/d/16Ay3NVzFmsVe1saMOZam-9nHBE0xcmyA/view?usp=sharing). 27 | 3. Set the path to your dataset in the [available parser options](parser_options.py). 28 | 4. Create the environment from the [conda file](environment.yml): `conda env create -f environment.yml` 29 | 5. Activate the conda environment: `conda activate unwarping_assignment` 30 | 6. Train the networks using the provided scripts: [1](main.py), [2](train.sh). The trained model is saved to the `save_dir` command line argument. 31 | 7. Run the [inference script](playground.py) on your set. The command line argument `inference_dir` should be used to provide the 32 | relative path to the folder which contains the images to be classified. 33 | 34 | ## Sample results 35 | 36 |

37 | 38 |

39 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Model constants 2 | DEEPLAB = 'deeplab' 3 | DEEPLAB_50 = 'deeplab_50' 4 | DEEPLAB_34 = 'deeplab_34' 5 | DEEPLAB_18 = 'deeplab_18' 6 | DEEPLAB_MOBILENET = 'deeplab_mn' 7 | DEEPLAB_MOBILENET_DILATION = 'deeplab_mnd' 8 | UNET = 'unet' 9 | UNET_PAPER = 'unet_paper' 10 | UNET_PYTORCH = 'unet_torch' 11 | PSPNET = 'pspnet' 12 | 13 | # Dataset constants 14 | DOCUNET = 'docunet' 15 | DOCUNET_INVERTED = 'docunet_inverted' 16 | DOCUNET_IM2IM = 'docunet_im2im' 17 | 18 | # Dataset locations 19 | HAZMAT_DATASET = 'hazmat_dataset' 20 | ADDRESS_DATASET = 'address_dataset' 21 | LABELS_DATASET = 'labels_dataset' 22 | 23 | # loss constants 24 | DOCUNET_LOSS = 'docunet_loss' 25 | MS_SSIM_LOSS = 'ms_ssim_loss' 26 | MS_SSIM_LOSS_V2 = 'ms_ssim_loss_v2' 27 | SSIM_LOSS = 'ssim_loss' 28 | SSIM_LOSS_V2 = 'ssim_loss_v2' 29 | SMOOTH_L1_LOSS = 'smoothl1_loss' 30 | L1_LOSS = 'l1_loss' 31 | MSE_LOSS = 'mse_loss' 32 | 33 | # Optimizers 34 | SGD = 'sgd' 35 | AMSGRAD = 'amsgrad' 36 | ADAM = 'adam' 37 | RMSPROP = 'rmsprop' 38 | ADABOUND = 'adabound' 39 | 40 | # Normalization layers 41 | INSTANCE_NORM = 'instance' 42 | BATCH_NORM = 'batch' 43 | SYNC_BATCH_NORM = 'syncbn' 44 | 45 | # Init types 46 | NORMAL_INIT = 'normal' 47 | KAIMING_INIT = 'kaiming' 48 | XAVIER_INIT = 'xavier' 49 | ORTHOGONAL_INIT = 'orthogonal' 50 | 51 | # Downsampling methods 52 | MAXPOOL = 'maxpool' 53 | STRIDECONV = 'strided' 54 | 55 | # Split constants 56 | TRAIN = 'train' 57 | VAL = 'val' 58 | TEST = 'test' 59 | TRAINVAL = 'trainval' 60 | VISUALIZATION = 'visualization' -------------------------------------------------------------------------------- /core/models/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted work from https://github.com/jfzhang95/pytorch-deeplab-xception 3 | ''' 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | from core.models.resnet import ResNet101, ResNet50, ResNet34, ResNet18 10 | from core.models.mobilenetv2 import MobileNet_v2 11 | from core.models.mobilenet_v2_dilation import MobileNet_v2_dilation 12 | from constants import * 13 | 14 | class _ASPPModule(nn.Module): 15 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer=nn.BatchNorm2d): 16 | super(_ASPPModule, self).__init__() 17 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 18 | stride=1, padding=padding, dilation=dilation, bias=False) 19 | self.bn = norm_layer(planes) 20 | self.relu = nn.ReLU() 21 | 22 | def forward(self, x): 23 | x = self.atrous_conv(x) 24 | x = self.bn(x) 25 | 26 | return self.relu(x) 27 | 28 | 29 | class ASPP(nn.Module): 30 | def __init__(self, output_stride, norm_layer=nn.BatchNorm2d, inplanes=2048): 31 | super(ASPP, self).__init__() 32 | 33 | if output_stride == 16: 34 | dilations = [1, 6, 12, 18] 35 | elif output_stride == 8: 36 | dilations = [1, 12, 24, 36] 37 | else: 38 | raise NotImplementedError 39 | 40 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer) 41 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer) 42 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer) 43 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer) 44 | 45 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 46 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 47 | norm_layer(256), 48 | nn.ReLU()) 49 | 50 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 51 | self.bn1 = norm_layer(256) 52 | self.relu = nn.ReLU() 53 | self.dropout = nn.Dropout2d(0.5) 54 | 55 | def forward(self, x): 56 | x1 = self.aspp1(x) 57 | x2 = self.aspp2(x) 58 | x3 = self.aspp3(x) 59 | x4 = self.aspp4(x) 60 | x5 = self.global_avg_pool(x) 61 | 62 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 63 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 64 | 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | 69 | return self.dropout(x) 70 | 71 | class Decoder(nn.Module): 72 | def __init__(self, num_classes, norm_layer=nn.BatchNorm2d, inplanes=256, aspp_outplanes=256): 73 | super(Decoder, self).__init__() 74 | 75 | self.conv1 = nn.Conv2d(inplanes, 48, 1, bias=False) 76 | self.bn1 = norm_layer(48) 77 | self.relu = nn.ReLU() 78 | 79 | inplanes = 48 + aspp_outplanes 80 | 81 | self.last_conv = nn.Sequential(nn.Conv2d(inplanes, 256, kernel_size=3, stride=1, padding=1, bias=False), 82 | norm_layer(256), 83 | nn.ReLU(), 84 | nn.Dropout2d(0.5), 85 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 86 | norm_layer(256), 87 | nn.ReLU(), 88 | nn.Dropout2d(0.1), 89 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 90 | 91 | def forward(self, x, low_level_feat): 92 | low_level_feat = self.conv1(low_level_feat) 93 | low_level_feat = self.bn1(low_level_feat) 94 | low_level_feat = self.relu(low_level_feat) 95 | 96 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 97 | x = torch.cat((x, low_level_feat), dim=1) 98 | x = self.last_conv(x) 99 | 100 | return x 101 | 102 | class DeepLabv3_plus(nn.Module): 103 | def __init__(self, args, num_classes=21, norm_layer=nn.BatchNorm2d, input_channels=3): 104 | super(DeepLabv3_plus, self).__init__() 105 | self.args = args 106 | 107 | if args.model == DEEPLAB: 108 | self.backbone = ResNet101(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels) 109 | self.aspp_inplanes = 2048 110 | self.decoder_inplanes = 256 111 | 112 | if self.args.refine_network: 113 | self.refine_backbone = ResNet101(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes) 114 | elif args.model == DEEPLAB_50: 115 | self.backbone = ResNet50(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained) 116 | self.aspp_inplanes = 2048 117 | self.decoder_inplanes = 256 118 | 119 | if self.args.refine_network: 120 | self.refine_backbone = ResNet50(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes) 121 | elif args.model == DEEPLAB_34: 122 | self.backbone = ResNet34(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained) 123 | self.aspp_inplanes = 2048 124 | self.decoder_inplanes = 256 125 | 126 | if self.args.refine_network: 127 | self.refine_backbone = ResNet34(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes) 128 | elif args.model == DEEPLAB_18: 129 | self.backbone = ResNet18(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained) 130 | self.aspp_inplanes = 2048 131 | self.decoder_inplanes = 256 132 | 133 | if self.args.refine_network: 134 | self.refine_backbone = ResNet18(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes) 135 | elif args.model == DEEPLAB_MOBILENET: 136 | self.backbone = MobileNet_v2(pretrained=args.pretrained, first_layer_input_channels=input_channels) 137 | self.aspp_inplanes = 320 138 | self.decoder_inplanes = 24 139 | 140 | if self.args.refine_network: 141 | self.refine_backbone = MobileNet_v2(pretrained=args.pretrained, first_layer_input_channels=input_channels + num_classes) 142 | 143 | elif args.model == DEEPLAB_MOBILENET_DILATION: 144 | self.backbone = MobileNet_v2_dilation(pretrained=args.pretrained, first_layer_input_channels=input_channels) 145 | self.aspp_inplanes = 320 146 | 147 | if self.args.refine_network: 148 | self.refine_backbone = MobileNet_v2_dilation(pretrained=args.pretrained, first_layer_input_channels=input_channels + num_classes) 149 | self.decoder_inplanes = 24 150 | else: 151 | raise NotImplementedError 152 | 153 | if self.args.use_aspp: 154 | self.aspp = ASPP(args.output_stride, norm_layer=norm_layer, inplanes=self.aspp_inplanes) 155 | 156 | aspp_outplanes = 256 if self.args.use_aspp else self.aspp_inplanes 157 | self.decoder = Decoder(num_classes, norm_layer=norm_layer, inplanes=self.decoder_inplanes, aspp_outplanes=aspp_outplanes) 158 | 159 | if self.args.learned_upsampling: 160 | self.learned_upsampling = nn.Sequential(nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1), 161 | nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)) 162 | 163 | if self.args.refine_network: 164 | if self.args.use_aspp: 165 | self.refine_aspp = ASPP(args.output_stride, norm_layer=norm_layer, inplanes=self.aspp_inplanes) 166 | 167 | self.refine_decoder = Decoder(num_classes, norm_layer=norm_layer, inplanes=self.decoder_inplanes, aspp_outplanes=aspp_outplanes) 168 | 169 | if self.args.learned_upsampling: 170 | self.refine_learned_upsampling = nn.Sequential(nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1), 171 | nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)) 172 | 173 | 174 | def forward(self, input): 175 | output, low_level_feat = self.backbone(input) 176 | 177 | if self.args.use_aspp: 178 | output = self.aspp(output) 179 | 180 | output = self.decoder(output, low_level_feat) 181 | 182 | if self.args.learned_upsampling: 183 | output = self.learned_upsampling(output) 184 | else: 185 | output = F.interpolate(output, size=input.size()[2:], mode='bilinear', align_corners=True) 186 | 187 | if self.args.refine_network: 188 | second_output, low_level_feat = self.refine_backbone(torch.cat((input, output), dim=1)) 189 | 190 | if self.args.use_aspp: 191 | second_output = self.refine_aspp(second_output) 192 | 193 | second_output = self.refine_decoder(second_output, low_level_feat) 194 | 195 | if self.args.learned_upsampling: 196 | second_output = self.refine_learned_upsampling(second_output) 197 | else: 198 | second_output = F.interpolate(second_output, size=input.size()[2:], mode='bilinear', align_corners=True) 199 | 200 | return output, second_output 201 | 202 | return output 203 | def get_train_parameters(self, lr): 204 | train_params = [{'params': self.parameters(), 'lr': lr}] 205 | 206 | return train_params -------------------------------------------------------------------------------- /core/models/mobilenet_v2_dilation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | import torch.nn.functional as F 4 | 5 | 6 | model_urls = { 7 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 8 | } 9 | 10 | 11 | def _make_divisible(v, divisor, min_value=None): 12 | if min_value is None: 13 | min_value = divisor 14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 15 | # Make sure that round down does not go down by more than 10%. 16 | if new_v < 0.9 * v: 17 | new_v += divisor 18 | return new_v 19 | 20 | 21 | class ConvBNReLU(nn.Sequential): 22 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 23 | #padding = (kernel_size - 1) // 2 24 | super(ConvBNReLU, self).__init__( 25 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 26 | nn.BatchNorm2d(out_planes), 27 | nn.ReLU6(inplace=True) 28 | ) 29 | 30 | def fixed_padding(kernel_size, dilation): 31 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 32 | pad_total = kernel_size_effective - 1 33 | pad_beg = pad_total // 2 34 | pad_end = pad_total - pad_beg 35 | return (pad_beg, pad_end, pad_beg, pad_end) 36 | 37 | class InvertedResidual(nn.Module): 38 | def __init__(self, inp, oup, stride, dilation, expand_ratio): 39 | super(InvertedResidual, self).__init__() 40 | self.stride = stride 41 | assert stride in [1, 2] 42 | 43 | hidden_dim = int(round(inp * expand_ratio)) 44 | self.use_res_connect = self.stride == 1 and inp == oup 45 | 46 | layers = [] 47 | if expand_ratio != 1: 48 | # pw 49 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 50 | 51 | layers.extend([ 52 | # dw 53 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), 54 | # pw-linear 55 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 56 | nn.BatchNorm2d(oup), 57 | ]) 58 | self.conv = nn.Sequential(*layers) 59 | 60 | self.input_padding = fixed_padding( 3, dilation ) 61 | 62 | def forward(self, x): 63 | x_pad = F.pad(x, self.input_padding) 64 | if self.use_res_connect: 65 | return x + self.conv(x_pad) 66 | else: 67 | return self.conv(x_pad) 68 | 69 | class MobileNetV2(nn.Module): 70 | def __init__(self, first_layer_input_channels=3, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None): 71 | super(MobileNetV2, self).__init__() 72 | 73 | if block is None: 74 | block = InvertedResidual 75 | 76 | input_channel = 32 77 | last_channel = 1280 78 | self.output_stride = output_stride 79 | current_stride = 1 80 | 81 | if inverted_residual_setting is None: 82 | inverted_residual_setting = [ 83 | # t, c, n, s 84 | [1, 16, 1, 1], 85 | [6, 24, 2, 2], 86 | [6, 32, 3, 2], 87 | [6, 64, 4, 2], 88 | [6, 96, 3, 1], 89 | [6, 160, 3, 2], 90 | [6, 320, 1, 1], 91 | ] 92 | 93 | # only check the first element, assuming user knows t,c,n,s are required 94 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 95 | raise ValueError("inverted_residual_setting should be non-empty " 96 | "or a 4-element list, got {}".format(inverted_residual_setting)) 97 | 98 | # building first layer 99 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 100 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 101 | features = [ConvBNReLU(first_layer_input_channels, input_channel, stride=2)] 102 | current_stride *= 2 103 | dilation=1 104 | previous_dilation = 1 105 | 106 | # building inverted residual blocks 107 | for t, c, n, s in inverted_residual_setting: 108 | output_channel = _make_divisible(c * width_mult, round_nearest) 109 | previous_dilation = dilation 110 | if current_stride == output_stride: 111 | stride = 1 112 | dilation *= s 113 | else: 114 | stride = s 115 | current_stride *= s 116 | output_channel = int(c * width_mult) 117 | 118 | for i in range(n): 119 | if i==0: 120 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) 121 | else: 122 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) 123 | input_channel = output_channel 124 | 125 | # building last several layers 126 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 127 | 128 | self.low_level_features = nn.Sequential(*features[0:4]) 129 | self.high_level_features = nn.Sequential(*features[4:-1]) 130 | 131 | def forward(self, x): 132 | x = self.low_level_features(x) 133 | low_level_feat = x 134 | x = self.high_level_features(x) 135 | 136 | return x, low_level_feat 137 | 138 | 139 | def MobileNet_v2_dilation(pretrained=False, **kwargs): 140 | model = MobileNetV2(**kwargs) 141 | if pretrained: 142 | _load_pretrained_model(model, model_urls['mobilenet_v2']) 143 | 144 | return model 145 | 146 | 147 | def _load_pretrained_model(model, url): 148 | pretrain_dict = load_state_dict_from_url(url) 149 | model_dict = {} 150 | state_dict = model.state_dict() 151 | for k, v in pretrain_dict.items(): 152 | if k in state_dict: 153 | model_dict[k] = v 154 | state_dict.update(model_dict) 155 | model.load_state_dict(state_dict) -------------------------------------------------------------------------------- /core/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | 4 | 5 | __all__ = ['MobileNetV2', 'MobileNet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | if min_value is None: 15 | min_value = divisor 16 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 17 | # Make sure that round down does not go down by more than 10%. 18 | if new_v < 0.9 * v: 19 | new_v += divisor 20 | return new_v 21 | 22 | 23 | class ConvBNReLU(nn.Sequential): 24 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 25 | padding = (kernel_size - 1) // 2 26 | super(ConvBNReLU, self).__init__( 27 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 28 | nn.BatchNorm2d(out_planes), 29 | nn.ReLU6(inplace=True) 30 | ) 31 | 32 | 33 | class InvertedResidual(nn.Module): 34 | def __init__(self, inp, oup, stride, expand_ratio): 35 | super(InvertedResidual, self).__init__() 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | hidden_dim = int(round(inp * expand_ratio)) 40 | self.use_res_connect = self.stride == 1 and inp == oup 41 | 42 | layers = [] 43 | if expand_ratio != 1: 44 | # pw 45 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 46 | layers.extend([ 47 | # dw 48 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 49 | # pw-linear 50 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 51 | nn.BatchNorm2d(oup), 52 | ]) 53 | self.conv = nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, first_layer_input_channels=3, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None): 64 | super(MobileNetV2, self).__init__() 65 | 66 | if block is None: 67 | block = InvertedResidual 68 | 69 | input_channel = 32 70 | last_channel = 1280 71 | 72 | if inverted_residual_setting is None: 73 | inverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | # only check the first element, assuming user knows t,c,n,s are required 85 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 86 | raise ValueError("inverted_residual_setting should be non-empty " 87 | "or a 4-element list, got {}".format(inverted_residual_setting)) 88 | 89 | # building first layer 90 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 91 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 92 | features = [ConvBNReLU(first_layer_input_channels, input_channel, stride=2)] 93 | 94 | # building inverted residual blocks 95 | for t, c, n, s in inverted_residual_setting: 96 | output_channel = _make_divisible(c * width_mult, round_nearest) 97 | for i in range(n): 98 | if i == 0: 99 | features.append(block(input_channel, output_channel, s, expand_ratio=t)) 100 | else: 101 | features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 102 | input_channel = output_channel 103 | 104 | # building last several layers 105 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 106 | 107 | self.low_level_features = nn.Sequential(*features[0:4]) 108 | self.high_level_features = nn.Sequential(*features[4:-1]) 109 | 110 | def forward(self, x): 111 | x = self.low_level_features(x) 112 | low_level_feat = x 113 | x = self.high_level_features(x) 114 | 115 | return x, low_level_feat 116 | 117 | def get_train_parameters(self, lr): 118 | train_params = [{'params': self.parameters(), 'lr': lr}] 119 | 120 | return train_params 121 | 122 | 123 | def MobileNet_v2(pretrained=False, **kwargs): 124 | model = MobileNetV2(**kwargs) 125 | if pretrained: 126 | _load_pretrained_model(model, model_urls['mobilenet_v2']) 127 | 128 | return model 129 | 130 | def _load_pretrained_model(model, url): 131 | pretrain_dict = load_state_dict_from_url(url) 132 | model_dict = {} 133 | state_dict = model.state_dict() 134 | for k, v in pretrain_dict.items(): 135 | if k in state_dict: 136 | model_dict[k] = v 137 | state_dict.update(model_dict) 138 | model.load_state_dict(state_dict) -------------------------------------------------------------------------------- /core/models/pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | RESNET_101 = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth' 8 | 9 | class _PyramidPoolingModule(nn.Module): 10 | def __init__(self, in_dim, reduction_dim, setting): 11 | super(_PyramidPoolingModule, self).__init__() 12 | self.features = [] 13 | for s in setting: 14 | self.features.append(nn.Sequential( 15 | nn.AdaptiveAvgPool2d(s), 16 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 17 | nn.BatchNorm2d(reduction_dim, momentum=.95), 18 | nn.ReLU(inplace=True) 19 | )) 20 | self.features = nn.ModuleList(self.features) 21 | 22 | def forward(self, x): 23 | x_size = x.size() 24 | out = [x] 25 | for f in self.features: 26 | out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True)) 27 | out = torch.cat(out, 1) 28 | return out 29 | 30 | 31 | class PSPNet(nn.Module): 32 | def __init__(self, num_classes, args=None): 33 | super(PSPNet, self).__init__() 34 | resnet = models.resnet101() 35 | 36 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 37 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 38 | 39 | for n, m in self.layer3.named_modules(): 40 | if 'conv2' in n: 41 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 42 | elif 'downsample.0' in n: 43 | m.stride = (1, 1) 44 | for n, m in self.layer4.named_modules(): 45 | if 'conv2' in n: 46 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 47 | elif 'downsample.0' in n: 48 | m.stride = (1, 1) 49 | 50 | self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 51 | self.final = nn.Sequential( 52 | nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 53 | nn.BatchNorm2d(512, momentum=.95), 54 | nn.ReLU(inplace=True), 55 | nn.Dropout(0.1), 56 | nn.Conv2d(512, num_classes, kernel_size=1) 57 | ) 58 | 59 | def forward(self, x): 60 | x_size = x.size() 61 | 62 | x = self.layer0(x) 63 | x = self.layer1(x) 64 | x = self.layer2(x) 65 | x = self.layer3(x) 66 | x = self.layer4(x) 67 | 68 | x = self.ppm(x) 69 | x = self.final(x) 70 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 71 | 72 | return x -------------------------------------------------------------------------------- /core/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code source: torchvision repository resnet code 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | RESNET_101 = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth' 9 | RESNET_50 = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 10 | RESNET_34 = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth' 11 | RESNET_18 = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | __constants__ = ['downsample'] 16 | 17 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm_layer=None): 18 | super(BasicBlock, self).__init__() 19 | if norm_layer is None: 20 | norm_layer = nn.BatchNorm2d 21 | if dilation > 1: 22 | dilation = 1 23 | 24 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 25 | self.conv1 = nn.Conv2d(inplanes, planes, stride=stride, kernel_size=3, padding=1, bias=False) 26 | self.bn1 = norm_layer(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = nn.Conv2d(planes, planes, stride=stride, kernel_size=3, padding=1, bias=False) 29 | self.bn2 = norm_layer(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | identity = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | identity = self.downsample(x) 45 | 46 | out += identity 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm_layer=nn.BatchNorm2d): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = norm_layer(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | dilation=dilation, padding=dilation, bias=False) 60 | self.bn2 = norm_layer(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = norm_layer(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | self.dilation = dilation 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | class ResNet(nn.Module): 91 | 92 | def __init__(self, block, layers, output_stride, norm_layer=nn.BatchNorm2d, input_channels=3): 93 | self.inplanes = 64 94 | 95 | super(ResNet, self).__init__() 96 | blocks = [1, 2, 4] 97 | if output_stride == 16: 98 | strides = [1, 2, 2, 1] 99 | dilations = [1, 1, 1, 2] 100 | elif output_stride == 8: 101 | strides = [1, 2, 1, 1] 102 | dilations = [1, 1, 2, 4] 103 | else: 104 | raise NotImplementedError 105 | 106 | # Modules 107 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 108 | self.bn1 = norm_layer(64) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | 112 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], norm_layer=norm_layer) 113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], norm_layer=norm_layer) 114 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], norm_layer=norm_layer) 115 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], norm_layer=norm_layer) 116 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], nn.BatchNorm2d=nn.BatchNorm2d) 117 | self._init_weight() 118 | 119 | def _init_weight(self): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 133 | norm_layer(planes * block.expansion), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, norm_layer=norm_layer)) 138 | self.inplanes = planes * block.expansion 139 | 140 | for i in range(1, blocks): 141 | layers.append(block(self.inplanes, planes, dilation=dilation, norm_layer=norm_layer)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 146 | downsample = None 147 | if stride != 1 or self.inplanes != planes * block.expansion: 148 | downsample = nn.Sequential( 149 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 150 | norm_layer(planes * block.expansion), 151 | ) 152 | 153 | layers = [] 154 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample, norm_layer=norm_layer)) 155 | self.inplanes = planes * block.expansion 156 | 157 | for i in range(1, len(blocks)): 158 | layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation, norm_layer=norm_layer)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | low_level_feat = x 170 | x = self.layer2(x) 171 | x = self.layer3(x) 172 | x = self.layer4(x) 173 | 174 | return x, low_level_feat 175 | 176 | 177 | def ResNet18(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3): 178 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, norm_layer=norm_layer, input_channels=input_channels) 179 | 180 | if pretrained: 181 | _load_pretrained_model(model, RESNET_18) 182 | 183 | return model 184 | 185 | 186 | def ResNet34(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3): 187 | model = ResNet(BasicBlock, [3, 4, 23, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels) 188 | 189 | if pretrained: 190 | _load_pretrained_model(model, RESNET_34) 191 | 192 | return model 193 | 194 | def ResNet101(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3): 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels) 196 | 197 | if pretrained: 198 | _load_pretrained_model(model, RESNET_101) 199 | 200 | return model 201 | 202 | 203 | def ResNet50(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3): 204 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels) 205 | 206 | if pretrained: 207 | _load_pretrained_model(model, RESNET_50) 208 | 209 | return model 210 | 211 | 212 | def _load_pretrained_model(model, url): 213 | pretrain_dict = model_zoo.load_url(url) 214 | model_dict = {} 215 | state_dict = model.state_dict() 216 | for k, v in pretrain_dict.items(): 217 | if k in state_dict: 218 | model_dict[k] = v 219 | state_dict.update(model_dict) 220 | model.load_state_dict(state_dict) -------------------------------------------------------------------------------- /core/models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lighter U-net implementation that achieves same performance as the one reported in the paper: https://arxiv.org/abs/1505.04597 3 | Main differences: 4 | a) U-net downblock has only 1 convolution instead of 2 5 | b) U-net upblock has only 1 convolution instead of 3 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from constants import * 11 | 12 | class UNetDownBlock(nn.Module): 13 | """ 14 | Constructs a UNet downsampling block 15 | 16 | Parameters: 17 | input_nc (int) -- the number of input channels 18 | output_nc (int) -- the number of output channels 19 | norm_layer (str) -- normalization layer 20 | down_type (str) -- if we should use strided convolution or maxpool for reducing the feature map 21 | outermost (bool) -- if this module is the outermost module 22 | innermost (bool) -- if this module is the innermost module 23 | user_dropout (bool) -- if use dropout layers. 24 | kernel_size (int) -- convolution kernel size 25 | bias (boolean) -- if convolution should use bias 26 | """ 27 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, down_type=STRIDECONV, outermost=False, innermost=False, dropout=0.2, kernel_size=4, bias=True): 28 | super(UNetDownBlock, self).__init__() 29 | self.innermost = innermost 30 | self.outermost = outermost 31 | self.use_maxpool = down_type == MAXPOOL 32 | 33 | stride = 1 if self.use_maxpool else 2 34 | kernel_size = 3 if self.use_maxpool else 4 35 | self.conv = nn.Conv2d(input_nc, output_nc, kernel_size=kernel_size, stride=stride, padding=1, bias=bias) 36 | self.relu = nn.LeakyReLU(0.2) 37 | self.maxpool = nn.MaxPool2d(2) 38 | self.norm = norm_layer(output_nc) 39 | self.dropout = nn.Dropout2d(dropout) if dropout else None 40 | 41 | def forward(self, x): 42 | if self.outermost: 43 | x = self.conv(x) 44 | x = self.norm(x) 45 | elif self.innermost: 46 | x = self.relu(x) 47 | if self.dropout: x = self.dropout(x) 48 | x = self.conv(x) 49 | else: 50 | x = self.relu(x) 51 | if self.dropout: x = self.dropout(x) 52 | x = self.conv(x) 53 | x = self.norm(x) 54 | 55 | return x 56 | 57 | class UNetUpBlock(nn.Module): 58 | """ 59 | Constructs a UNet upsampling block 60 | 61 | Parameters: 62 | input_nc (int) -- the number of input channels 63 | output_nc (int) -- the number of output channels 64 | norm_layer -- normalization layer 65 | outermost (bool) -- if this module is the outermost module 66 | innermost (bool) -- if this module is the innermost module 67 | user_dropout (bool) -- if use dropout layers. 68 | kernel_size (int) -- convolution kernel size 69 | """ 70 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, outermost=False, innermost=False, dropout=0.2, kernel_size=4, use_bias=True): 71 | super(UNetUpBlock, self).__init__() 72 | self.innermost = innermost 73 | self.outermost = outermost 74 | upconv_inner_nc = input_nc * 2 75 | 76 | if self.innermost: 77 | self.conv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1, bias=use_bias) 78 | elif self.outermost: 79 | self.conv = nn.ConvTranspose2d(upconv_inner_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1) 80 | else: 81 | self.conv = nn.ConvTranspose2d(upconv_inner_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1, bias=use_bias) 82 | 83 | self.norm = norm_layer(output_nc) 84 | self.relu = nn.ReLU() 85 | self.dropout = nn.Dropout2d(dropout) if dropout else None 86 | 87 | def forward(self, x): 88 | if self.outermost: 89 | x = self.relu(x) 90 | if self.dropout: x = self.dropout(x) 91 | x = self.conv(x) 92 | elif self.innermost: 93 | x = self.relu(x) 94 | if self.dropout: x = self.dropout(x) 95 | x = self.conv(x) 96 | x = self.norm(x) 97 | else: 98 | x = self.relu(x) 99 | if self.dropout: x = self.dropout(x) 100 | x = self.conv(x) 101 | x = self.norm(x) 102 | 103 | return x 104 | 105 | class UNet(nn.Module): 106 | """Create a Unet-based Fully Convolutional Network 107 | X -------------------identity---------------------- 108 | |-- downsampling -- |submodule| -- upsampling --| 109 | 110 | Parameters: 111 | num_classes (int) -- the number of channels in output images 112 | norm_layer -- normalization layer 113 | input_nc -- number of channels of input image 114 | 115 | Args: 116 | mode (str) -- process single frames or sequence of frames 117 | timesteps (int) -- 118 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 119 | image of size 128x128 will become of size 1x1 # at the bottleneck 120 | ngf (int) -- the number of filters in the last conv layer 121 | reconstruct (int [0,1])-- if we should reconstruct the next image or not 122 | sequence_model (str) -- the sequence model that for the sequence mode [] 123 | num_levels_tcn(int) -- number of levels of the TemporalConvNet 124 | """ 125 | 126 | def __init__(self, num_classes, args, norm_layer=nn.BatchNorm2d, input_nc=3): 127 | super(UNet, self).__init__() 128 | 129 | self.refine_network = args.refine_network 130 | self.num_downs = args.num_downs 131 | self.ngf = args.ngf 132 | 133 | self.encoder = self.build_encoder(self.num_downs, input_nc, self.ngf, norm_layer, down_type=args.down_type) 134 | self.decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer) 135 | 136 | if self.refine_network: 137 | self.refine_encoder = self.build_encoder(self.num_downs, input_nc + num_classes, self.ngf, norm_layer, down_type=args.down_type) 138 | self.refine_decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer) 139 | 140 | def build_encoder(self, num_downs, input_nc, ngf, norm_layer, down_type=STRIDECONV): 141 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetDownBlocks 142 | 143 | Parameters: 144 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 145 | image of size 128x128 will become of size 1x1 # at the bottleneck 146 | input_nc (int) -- the number of input channels 147 | ngf (int) -- the number of filters in the last conv layer 148 | norm_layer (str) -- normalization layer 149 | down_type (str) -- if we should use strided convolution or maxpool for reducing the feature map 150 | Returns: 151 | nn.Sequential consisting of $num_downs UnetDownBlocks 152 | """ 153 | layers = [] 154 | layers.append(UNetDownBlock(input_nc=input_nc, output_nc=ngf, norm_layer=norm_layer, down_type=down_type, outermost=True)) 155 | layers.append(UNetDownBlock(input_nc=ngf, output_nc=ngf*2, norm_layer=norm_layer, down_type=down_type)) 156 | layers.append(UNetDownBlock(input_nc=ngf*2, output_nc=ngf*4, norm_layer=norm_layer, down_type=down_type)) 157 | layers.append(UNetDownBlock(input_nc=ngf*4, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type)) 158 | 159 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 160 | layers.append(UNetDownBlock(input_nc=ngf*8, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type)) 161 | 162 | layers.append(UNetDownBlock(input_nc=ngf*8, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type, innermost=True)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def build_decoder(self, num_downs, num_classes, ngf, norm_layer): 167 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetUpBlocks 168 | 169 | Parameters: 170 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 171 | image of size 128x128 will become of size 1x1 # at the bottleneck 172 | num_classes (int) -- number of classes to classify 173 | output_nc (int) -- the number of output channels. outermost is ngf, innermost is ngf * 8 174 | norm_layer -- normalization layer 175 | 176 | Returns: 177 | nn.Sequential consisting of $num_downs UnetUpBlocks 178 | """ 179 | layers = [] 180 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, innermost=True)) 181 | 182 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 183 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer)) 184 | 185 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 4, norm_layer=norm_layer)) 186 | layers.append(UNetUpBlock(input_nc=ngf * 4, output_nc=ngf * 2, norm_layer=norm_layer)) 187 | layers.append(UNetUpBlock(input_nc=ngf*2, output_nc=ngf, norm_layer=norm_layer)) 188 | layers.append(UNetUpBlock(input_nc=ngf, output_nc=num_classes, norm_layer=norm_layer, outermost=True)) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def encoder_forward(self, x, use_refine_network=False): 193 | skip_connections = [] 194 | model = self.refine_encoder if use_refine_network else self.encoder 195 | 196 | for i, down in enumerate(model): 197 | x = down(x) 198 | if down.use_maxpool: 199 | x = down.maxpool(x) 200 | 201 | if not down.innermost: 202 | skip_connections.append(x) 203 | 204 | return x, skip_connections 205 | 206 | def decoder_forward(self, x, skip_connections, use_refine_network=False): 207 | model = self.refine_decoder if use_refine_network else self.decoder 208 | 209 | for i, up in enumerate(model): 210 | if not up.innermost: 211 | skip = skip_connections[-i] 212 | out = torch.cat([skip, out], 1) 213 | out = up(out) 214 | else: 215 | out = up(x) 216 | 217 | return out 218 | 219 | def forward(self, x): 220 | output, skip_connections = self.encoder_forward(x) 221 | output = self.decoder_forward(output, skip_connections) 222 | 223 | if self.refine_network: 224 | second_output, skip_connections = self.encoder_forward(torch.cat((x, output), dim=1), use_refine_network=True) 225 | second_output = self.decoder_forward(second_output, skip_connections, use_refine_network=True) 226 | 227 | return output, second_output 228 | 229 | return output 230 | 231 | def get_train_parameters(self, lr): 232 | params = [{'params': self.parameters(), 'lr': lr}] 233 | 234 | return params -------------------------------------------------------------------------------- /core/models/unet_paper.py: -------------------------------------------------------------------------------- 1 | """ 2 | U-net implementation from the reported paper: https://arxiv.org/abs/1505.04597 3 | Model follows the paper implementation, except for the use of padding in order to keep feature maps size the same. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class UnetConvBlock(nn.Module): 11 | """ 12 | Constructs a UNet downsampling block 13 | 14 | Parameters: 15 | input_nc (int) -- the number of input channels 16 | output_nc (int) -- the number of output channels 17 | norm_layer -- normalization layer 18 | outermost (bool) -- if this module is the outermost module 19 | innermost (bool) -- if this module is the innermost module 20 | user_dropout (bool) -- if use dropout layers. 21 | kernel_size (int) -- convolution kernel size 22 | """ 23 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, padding=0, innermost=False, dropout=0.2): 24 | super(UnetConvBlock, self).__init__() 25 | block = [] 26 | 27 | block.append(nn.Conv2d(input_nc, output_nc, kernel_size=3, padding=int(padding))) 28 | block.append(norm_layer(output_nc)) 29 | block.append(nn.ReLU()) 30 | 31 | block.append(nn.Conv2d(output_nc, output_nc, kernel_size=3, padding=int(padding))) 32 | block.append(norm_layer(output_nc)) 33 | block.append(nn.ReLU()) 34 | 35 | self.block = nn.Sequential(*block) 36 | self.innermost = innermost 37 | 38 | def forward(self, x): 39 | out = self.block(x) 40 | return out 41 | 42 | class UNetUpBlock(nn.Module): 43 | """ 44 | Constructs a UNet upsampling block 45 | 46 | Parameters: 47 | input_nc (int) -- the number of input channels 48 | output_nc (int) -- the number of output channels 49 | norm_layer -- normalization layer 50 | outermost (bool) -- if this module is the outermost module 51 | innermost (bool) -- if this module is the innermost module 52 | user_dropout (bool) -- if use dropout layers. 53 | kernel_size (int) -- convolution kernel size 54 | remove_skip (bool) -- if skip connections should be disabled or not 55 | """ 56 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, padding=1, remove_skip=False, outermost=False): 57 | super(UNetUpBlock, self).__init__() 58 | 59 | self.up = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=2, stride=2) 60 | self.conv_block = UnetConvBlock(output_nc * 2, output_nc, norm_layer, padding) 61 | self.outermost = outermost 62 | 63 | def forward(self, x, skip=None): 64 | out = self.up(x) 65 | 66 | if skip is not None: 67 | out = torch.cat([out, skip], 1) 68 | out = self.conv_block(out) 69 | 70 | return out 71 | 72 | class UNet_paper(nn.Module): 73 | """Create a Unet-based Fully Convolutional Network 74 | X -------------------identity---------------------- 75 | |-- downsampling -- |submodule| -- upsampling --| 76 | 77 | Parameters: 78 | num_classes (int) -- the number of channels in output images 79 | norm_layer -- normalization layer 80 | input_nc -- number of channels of input image 81 | 82 | Args: 83 | mode (str) -- process single frames or sequence of frames 84 | timesteps (int) -- 85 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 86 | image of size 128x128 will become of size 1x1 # at the bottleneck 87 | ngf (int) -- the number of filters in the last conv layer 88 | remove_skip (int [0,1])-- if skip connections should be disabled or not 89 | reconstruct (int [0,1])-- if we should reconstruct the next image or not 90 | sequence_model (str) -- the sequence model that for the sequence mode [] 91 | num_levels_tcn(int) -- number of levels of the TemporalConvNet 92 | """ 93 | 94 | def __init__(self, num_classes, args, norm_layer=nn.BatchNorm2d, input_nc=3): 95 | super(UNet_paper, self).__init__(args) 96 | 97 | self.num_downs = args.num_downs 98 | self.ngf = args.ngf 99 | 100 | self.encoder = self.build_encoder(self.num_downs, input_nc, self.ngf, norm_layer) 101 | self.decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer) 102 | self.decoder_last_conv = nn.Conv2d(self.ngf, num_classes, 1) 103 | 104 | def build_encoder(self, num_downs, input_nc, ngf, norm_layer): 105 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetDownBlocks 106 | 107 | Parameters: 108 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 109 | image of size 128x128 will become of size 1x1 # at the bottleneck 110 | input_nc (int) -- the number of input channels 111 | ngf (int) -- the number of filters in the last conv layer 112 | norm_layer -- normalization layer 113 | Returns: 114 | nn.Sequential consisting of $num_downs UnetDownBlocks 115 | """ 116 | layers = [] 117 | layers.append(UnetConvBlock(input_nc=input_nc, output_nc=ngf, norm_layer=norm_layer, padding=1)) 118 | layers.append(UnetConvBlock(input_nc=ngf, output_nc=ngf * 2, norm_layer=norm_layer, padding=1)) 119 | layers.append(UnetConvBlock(input_nc=ngf * 2, output_nc=ngf * 4, norm_layer=norm_layer, padding=1)) 120 | layers.append(UnetConvBlock(input_nc=ngf * 4, output_nc=ngf * 8, norm_layer=norm_layer, padding=1)) 121 | 122 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 123 | layers.append(UnetConvBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, padding=1)) 124 | 125 | layers.append(UnetConvBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, padding=1, innermost=True)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def build_decoder(self, num_downs, num_classes, ngf, norm_layer, remove_skip=0): 130 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetUpBlocks 131 | 132 | Parameters: 133 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 134 | image of size 128x128 will become of size 1x1 # at the bottleneck 135 | num_classes (int) -- number of classes to classify 136 | output_nc (int) -- the number of output channels. outermost is ngf, innermost is ngf * 8 137 | norm_layer -- normalization layer 138 | remove_skip (int) -- if skip connections should be disabled or not 139 | 140 | Returns: 141 | nn.Sequential consisting of $num_downs UnetUpBlocks 142 | """ 143 | layers = [] 144 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, remove_skip=remove_skip)) 145 | 146 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 147 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, remove_skip=remove_skip)) 148 | 149 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 4, norm_layer=norm_layer, remove_skip=remove_skip)) 150 | layers.append(UNetUpBlock(input_nc=ngf * 4, output_nc=ngf * 2, norm_layer=norm_layer, remove_skip=remove_skip)) 151 | layers.append(UNetUpBlock(input_nc=ngf*2, output_nc=ngf, norm_layer=norm_layer, remove_skip=remove_skip, outermost=True)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def encoder_forward(self, x): 156 | skip_connections = [] 157 | for i, down in enumerate(self.encoder): 158 | x = down(x) 159 | 160 | if not down.innermost: 161 | skip_connections.append(x) 162 | x = F.max_pool2d(x, 2) 163 | 164 | return x, skip_connections 165 | 166 | def decoder_forward(self, x, skip_connections): 167 | out = None 168 | for i, up in enumerate(self.decoder): 169 | skip = skip_connections.pop() 170 | if out is None: 171 | out = up(x, skip) 172 | else: 173 | out = up(out, skip) 174 | 175 | out = self.decoder_last_conv(out) 176 | return out 177 | 178 | def forward(self, x): 179 | x, skip_connections = self.encoder_forward(x) 180 | out = self.decoder_forward(x, skip_connections) 181 | 182 | return out -------------------------------------------------------------------------------- /core/models/unet_pytorch.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class UNet_torch(nn.Module): 7 | 8 | def __init__(self, num_classes=1, args=None, in_channels=3): 9 | super(UNet_torch, self).__init__() 10 | self.ngf = args.ngf 11 | 12 | self.encoder1 = UNet_torch._block(in_channels, self.ngf, name="enc1") 13 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 14 | 15 | self.encoder2 = UNet_torch._block(self.ngf, self.ngf * 2, name="enc2") 16 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 17 | 18 | self.encoder3 = UNet_torch._block(self.ngf * 2, self.ngf * 4, name="enc3") 19 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 20 | 21 | self.encoder4 = UNet_torch._block(self.ngf * 4, self.ngf * 8, name="enc4") 22 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 23 | 24 | self.bottleneck = UNet_torch._block(self.ngf * 8, self.ngf * 16, name="bottleneck") 25 | 26 | self.upconv4 = nn.ConvTranspose2d( 27 | self.ngf * 16, self.ngf * 8, kernel_size=2, stride=2 28 | ) 29 | self.decoder4 = UNet_torch._block((self.ngf * 8) * 2, self.ngf * 8, name="dec4") 30 | self.upconv3 = nn.ConvTranspose2d( 31 | self.ngf * 8, self.ngf * 4, kernel_size=2, stride=2 32 | ) 33 | self.decoder3 = UNet_torch._block((self.ngf * 4) * 2, self.ngf * 4, name="dec3") 34 | self.upconv2 = nn.ConvTranspose2d( 35 | self.ngf * 4, self.ngf * 2, kernel_size=2, stride=2 36 | ) 37 | self.decoder2 = UNet_torch._block((self.ngf * 2) * 2, self.ngf * 2, name="dec2") 38 | self.upconv1 = nn.ConvTranspose2d( 39 | self.ngf * 2, self.ngf, kernel_size=2, stride=2 40 | ) 41 | self.decoder1 = UNet_torch._block(self.ngf * 2, self.ngf, name="dec1") 42 | 43 | self.conv = nn.Conv2d( 44 | in_channels=self.ngf, out_channels=num_classes, kernel_size=1 45 | ) 46 | 47 | def forward(self, x): 48 | enc1 = self.encoder1(x) 49 | enc2 = self.encoder2(self.pool1(enc1)) 50 | enc3 = self.encoder3(self.pool2(enc2)) 51 | enc4 = self.encoder4(self.pool3(enc3)) 52 | bottleneck = self.bottleneck(self.pool4(enc4)) 53 | 54 | dec4 = self.upconv4(bottleneck) 55 | dec4 = torch.cat((dec4, enc4), dim=1) 56 | dec4 = self.decoder4(dec4) 57 | dec3 = self.upconv3(dec4) 58 | dec3 = torch.cat((dec3, enc3), dim=1) 59 | dec3 = self.decoder3(dec3) 60 | dec2 = self.upconv2(dec3) 61 | dec2 = torch.cat((dec2, enc2), dim=1) 62 | dec2 = self.decoder2(dec2) 63 | dec1 = self.upconv1(dec2) 64 | dec1 = torch.cat((dec1, enc1), dim=1) 65 | dec1 = self.decoder1(dec1) 66 | 67 | final = self.conv(dec1) 68 | return final 69 | 70 | @staticmethod 71 | def _block(in_channels, features, name): 72 | return nn.Sequential( 73 | OrderedDict( 74 | [ 75 | ( 76 | name + "conv1", 77 | nn.Conv2d( 78 | in_channels=in_channels, 79 | out_channels=features, 80 | kernel_size=3, 81 | padding=1, 82 | bias=False, 83 | ), 84 | ), 85 | (name + "norm1", nn.BatchNorm2d(num_features=features)), 86 | (name + "relu1", nn.ReLU(inplace=True)), 87 | ( 88 | name + "conv1", 89 | nn.Conv2d( 90 | in_channels=in_channels, 91 | out_channels=features, 92 | kernel_size=3, 93 | padding=1, 94 | bias=False, 95 | ), 96 | ), 97 | (name + "norm1", nn.BatchNorm2d(num_features=features)), 98 | (name + "relu1", nn.ReLU(inplace=True)), 99 | ] 100 | ) 101 | ) 102 | 103 | def get_train_parameters(self, lr): 104 | params = [{'params': self.parameters(), 'lr': lr}] 105 | 106 | return params -------------------------------------------------------------------------------- /core/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm import tqdm 3 | import torch 4 | import math 5 | import matplotlib.pyplot as plt 6 | import time 7 | 8 | from util.general_functions import get_model, get_optimizer, make_data_loader, get_loss_function, get_flat_images 9 | from util.lr_scheduler import LR_Scheduler 10 | from util.summary import TensorboardSummary 11 | from constants import * 12 | from util.ssim import MS_SSIM, SSIM 13 | 14 | 15 | class Trainer(object): 16 | 17 | def __init__(self, args): 18 | self.args = args 19 | self.best_loss = math.inf 20 | self.summary = TensorboardSummary(args) 21 | self.model = get_model(args) 22 | 23 | if args.inference: 24 | self.model = self.summary.load_network(self.model) 25 | 26 | if args.save_best_model: 27 | self.best_model = copy.deepcopy(self.model) 28 | 29 | self.optimizer = get_optimizer(self.model, args) 30 | self.ssim, self.ms_ssim = SSIM(), MS_SSIM() 31 | 32 | if args.trainval: 33 | self.train_loader, self.val_loader = make_data_loader(args, TRAINVAL), make_data_loader(args, TEST) 34 | else: 35 | self.train_loader, self.test_loader = make_data_loader(args, TRAIN), make_data_loader(args, TEST) 36 | 37 | self.criterion = get_loss_function(args.loss_type) 38 | self.scheduler = LR_Scheduler(args.lr_policy, args.lr, args.epochs, len(self.train_loader)) 39 | 40 | if args.second_loss: 41 | self.second_criterion = get_loss_function(MS_SSIM_LOSS) 42 | 43 | def run_epoch(self, epoch, split=TRAIN): 44 | total_loss = 0.0 45 | ssim_values, ms_ssim_values = [], [] 46 | 47 | if split == TRAIN: 48 | self.model.train() 49 | loader = self.train_loader 50 | elif split == VAL: 51 | self.model.eval() 52 | loader = self.val_loader 53 | else: 54 | self.model.eval() 55 | loader = self.test_loader 56 | 57 | bar = tqdm(loader) 58 | num_img = len(loader) 59 | 60 | for i, sample in enumerate(bar): 61 | with torch.autograd.set_detect_anomaly(True): 62 | image = sample[0] 63 | target = sample[1] 64 | 65 | if self.args.cuda: 66 | image, target = image.cuda(), target.cuda() 67 | 68 | if split == TRAIN: 69 | self.scheduler(self.optimizer, i, epoch, self.best_loss) 70 | self.optimizer.zero_grad() 71 | 72 | if self.args.refine_network: 73 | first_output, output = self.model(image) 74 | else: 75 | output = self.model(image) 76 | else: 77 | with torch.no_grad(): 78 | if self.args.refine_network: 79 | first_output, output = self.model(image) 80 | else: 81 | output = self.model(image) 82 | 83 | loss = self.criterion(output, target) 84 | if self.args.refine_network: 85 | refine_loss = self.criterion(first_output, target) 86 | loss += refine_loss 87 | 88 | if self.args.second_loss: 89 | flat_output_img, flat_target_img = get_flat_images(self.args.dataset, image, output, target) 90 | second_loss = self.second_criterion(flat_output_img, flat_target_img) 91 | loss += self.args.second_loss_rate * second_loss 92 | 93 | if self.args.refine_network: 94 | flat_first_output_img, flat_target_img = get_flat_images(self.args.dataset, image, first_output, target) 95 | third_loss = self.second_criterion(flat_first_output_img, flat_target_img) 96 | loss += third_loss 97 | 98 | if split == TRAIN: 99 | loss.backward() 100 | 101 | if self.args.clip > 0: 102 | if self.args.gpu_ids: 103 | torch.nn.utils.clip_grad_norm_(self.model.module().parameters(), self.args.clip) 104 | else: 105 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 106 | 107 | self.optimizer.step() 108 | 109 | if split == TEST: 110 | ssim_values.append(self.ssim.forward(output, target)) 111 | ms_ssim_values.append(self.ms_ssim.forward(output, target)) 112 | 113 | # Show 10 * 3 inference results each epoch 114 | if split != VISUALIZATION and i % (num_img // 10) == 0: 115 | self.summary.visualize_image(image, target, output, split=split) 116 | elif split == VISUALIZATION: 117 | self.summary.visualize_image(image, target, output, split=split) 118 | 119 | total_loss += loss.item() 120 | bar.set_description(split +' loss: %.3f' % (loss.item())) 121 | 122 | if split == TEST: 123 | ssim = sum(ssim_values) / len(ssim_values) 124 | ms_ssim = sum(ms_ssim_values) / len(ms_ssim_values) 125 | self.summary.add_scalar(split + '/ssim', ssim, epoch) 126 | self.summary.add_scalar(split + '/ms_ssim', ms_ssim, epoch) 127 | 128 | if total_loss < self.best_loss: 129 | self.best_loss = total_loss 130 | if self.args.save_best_model: 131 | self.best_model = copy.deepcopy(self.model) 132 | 133 | self.summary.add_scalar(split + '/total_loss_epoch', total_loss, epoch) 134 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 135 | 136 | def calculate_inference_speed(self, iterations): 137 | loader = self.test_loader 138 | bar = tqdm(loader) 139 | self.model.eval() 140 | times = [] 141 | 142 | for i, sample in enumerate(bar): 143 | image = sample[0] 144 | 145 | start = time.time() 146 | with torch.no_grad(): 147 | output = self.model(image) 148 | end = time.time() 149 | current_time = end - start 150 | print(current_time) 151 | times.append(current_time) 152 | 153 | if i>= iterations: 154 | break 155 | 156 | return sum(times) / len(times) 157 | 158 | def save_network(self): 159 | self.summary.save_network(self.model) -------------------------------------------------------------------------------- /dataloader/docunet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils import data 4 | import torchvision.transforms as standard_transforms 5 | import util.custom_transforms as custom_transforms 6 | from scipy import io 7 | 8 | class Docunet(data.Dataset): 9 | 10 | NUM_CLASSES = 2 11 | CLASSES = ["foreground", "background"] 12 | ROOT = '../../../datasets/' 13 | DEFORMED = 'deformed_labels' 14 | DEFORMED_EXT = '.jpg' 15 | VECTOR_FIELD = 'target_vf' 16 | VECTOR_FIELD_EXT = '.mat' 17 | 18 | 19 | def __init__(self, args, split="train"): 20 | self.args = args 21 | self.split = split 22 | self.dataset = self.make_dataset() 23 | 24 | if len(self.dataset) == 0: 25 | raise RuntimeError('Found 0 images, please check the dataset') 26 | 27 | self.transform = self.get_transforms() 28 | 29 | def __len__(self): 30 | return len(self.dataset) 31 | 32 | def __getitem__(self, index): 33 | image_path, label_path = self.dataset[index] 34 | image = Image.open(image_path) 35 | label = io.loadmat(label_path)['vector_field'] 36 | 37 | if self.transform is not None: 38 | image = self.transform(image) 39 | 40 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label) 41 | 42 | return image, label 43 | 44 | def make_dataset(self): 45 | current_dir = os.path.dirname(__file__) 46 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size))) 47 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.VECTOR_FIELD + '_' + 'x'.join(map(str, self.args.size))) 48 | 49 | images_name = os.listdir(images_path) 50 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)] 51 | items = [] 52 | 53 | for i in range(len(images_name)): 54 | image_name = images_name[i] 55 | label_name = image_name.replace(self.DEFORMED_EXT, self.VECTOR_FIELD_EXT) 56 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name))) 57 | 58 | return items 59 | 60 | def get_transforms(self): 61 | return None 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /dataloader/docunet_im2im.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils import data 4 | import torchvision.transforms as standard_transforms 5 | import util.custom_transforms as custom_transforms 6 | from scipy import io 7 | 8 | 9 | class DocunetIm2Im(data.Dataset): 10 | NUM_CLASSES = 2 11 | CLASSES = ["foreground", "background"] 12 | ROOT = '../../../datasets/' 13 | DEFORMED = 'deformed_labels' 14 | LABELS = 'cropped_labels' 15 | DEFORMED_EXT = '.jpg' 16 | LABEL_EXT = '.jpg' 17 | 18 | def __init__(self, args, split="train"): 19 | self.args = args 20 | self.split = split 21 | self.dataset = self.make_dataset() 22 | 23 | if len(self.dataset) == 0: 24 | raise RuntimeError('Found 0 images, please check the dataset') 25 | 26 | self.joint_transform = self.get_transforms() 27 | 28 | def __len__(self): 29 | return len(self.dataset) 30 | 31 | def __getitem__(self, index): 32 | image_path, label_path = self.dataset[index] 33 | image = Image.open(image_path) 34 | label = Image.open(label_path) 35 | 36 | if self.joint_transform is not None: 37 | image, label = self.joint_transform(image, label) 38 | 39 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label) 40 | 41 | return image, label 42 | 43 | def make_dataset(self): 44 | current_dir = os.path.dirname(__file__) 45 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size))) 46 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.LABELS) 47 | 48 | images_name = os.listdir(images_path) 49 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)] 50 | items = [] 51 | 52 | for i in range(len(images_name)): 53 | image_name = images_name[i] 54 | label_name = '_'.join(image_name.split('_')[:-1]) + self.LABEL_EXT 55 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name))) 56 | 57 | return items 58 | 59 | def get_transforms(self): 60 | if self.split == 'train': 61 | joint = custom_transforms.Compose([ 62 | custom_transforms.Resize(self.args.resize), 63 | custom_transforms.RandomHorizontallyFlip(), 64 | # custom_transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), 65 | custom_transforms.RandomGaussianBlur() 66 | ]) 67 | elif self.split == 'val' or self.split == 'test' or self.split == 'demoVideo': 68 | joint = custom_transforms.Compose([ 69 | custom_transforms.Resize(self.args.resize), 70 | ]) 71 | else: 72 | raise RuntimeError('Invalid dataset mode') 73 | 74 | return joint 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /dataloader/docunet_inverted.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils import data 4 | import torchvision.transforms as standard_transforms 5 | import util.custom_transforms as custom_transforms 6 | from scipy import io 7 | 8 | from constants import * 9 | 10 | class InvertedDocunet(data.Dataset): 11 | NUM_CLASSES = 2 12 | CLASSES = ["foreground", "background"] 13 | ROOT = '../../../datasets/' 14 | DEFORMED = 'deformed_labels' 15 | DEFORMED_EXT = '.jpg' 16 | VECTOR_FIELD = 'inverted_vf' 17 | VECTOR_FIELD_EXT = '.mat' 18 | 19 | def __init__(self, args, split=TRAIN): 20 | self.args = args 21 | self.split = split 22 | self.dataset = self.make_dataset() 23 | 24 | if len(self.dataset) == 0: 25 | raise RuntimeError('Found 0 images, please check the dataset') 26 | 27 | self.transform = self.get_transforms() 28 | 29 | def __len__(self): 30 | return len(self.dataset) 31 | 32 | def __getitem__(self, index): 33 | image_path, label_path = self.dataset[index] 34 | 35 | 36 | image = Image.open(image_path) 37 | label = io.loadmat(label_path)['inverted_vector_field'] 38 | 39 | if self.transform is not None: 40 | image = self.transform(image) 41 | 42 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label) 43 | 44 | label = label.float() 45 | return image, label 46 | 47 | def make_dataset(self): 48 | current_dir = os.path.dirname(__file__) 49 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size))) 50 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.VECTOR_FIELD + '_' + 'x'.join(map(str, self.args.size))) 51 | 52 | images_name = os.listdir(images_path) 53 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)] 54 | items = [] 55 | 56 | for i in range(len(images_name)): 57 | image_name = images_name[i] 58 | label_name = image_name.replace(self.DEFORMED_EXT, self.VECTOR_FIELD_EXT) 59 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name))) 60 | 61 | return items 62 | 63 | def get_transforms(self): 64 | return None 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: unwarping_assignment 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _tflow_select=2.3.0=mkl 8 | - absl-py=0.8.1=py37_0 9 | - astor=0.7.1=py_0 10 | - blas=1.0=mkl 11 | - ca-certificates=2019.11.28=hecc5488_0 12 | - certifi=2019.11.28=py37_0 13 | - cffi=1.13.2=py37h7a1dbc1_0 14 | - cudatoolkit=10.1.243=h74a9793_0 15 | - cycler=0.10.0=py_2 16 | - freetype=2.9.1=ha9979f8_1 17 | - gast=0.2.2=py_0 18 | - google-pasta=0.1.8=py_0 19 | - grpcio=1.23.0=py37h3db2c7e_0 20 | - h5py=2.10.0=nompi_py37h422b98e_100 21 | - hdf5=1.10.5=nompi_ha405e13_1104 22 | - icc_rt=2019.0.0=h0cc432a_1 23 | - icu=64.2=he025d50_1 24 | - intel-openmp=2019.4=245 25 | - jpeg=9c=hfa6e2cd_1001 26 | - keras-applications=1.0.8=py_1 27 | - keras-preprocessing=1.1.0=py_0 28 | - kiwisolver=1.1.0=py37he980bc4_0 29 | - libblas=3.8.0=14_mkl 30 | - libcblas=3.8.0=14_mkl 31 | - libclang=9.0.0=default_hf44288c_4 32 | - liblapack=3.8.0=14_mkl 33 | - liblapacke=3.8.0=14_mkl 34 | - libmklml=2019.0.5=0 35 | - libopencv=4.1.2=py37_2 36 | - libpng=1.6.37=h2a8f88b_0 37 | - libprotobuf=3.11.1=h1a1b453_0 38 | - libtiff=4.1.0=h56a325e_0 39 | - libwebp=1.0.2=hfa6e2cd_4 40 | - m2w64-gcc-libgfortran=5.3.0=6 41 | - m2w64-gcc-libs=5.3.0=7 42 | - m2w64-gcc-libs-core=5.3.0=7 43 | - m2w64-gmp=6.1.0=2 44 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 45 | - markdown=3.1.1=py_0 46 | - matplotlib=3.1.2=py37_1 47 | - matplotlib-base=3.1.2=py37h2981e6d_1 48 | - mkl=2019.4=245 49 | - mkl-service=2.3.0=py37hb782905_0 50 | - mkl_fft=1.0.15=py37h14836fe_0 51 | - mkl_random=1.1.0=py37h675688f_0 52 | - msys2-conda-epoch=20160418=1 53 | - ninja=1.9.0=py37h74a9793_0 54 | - numpy=1.17.4=py37h4320e6b_0 55 | - numpy-base=1.17.4=py37hc3f5095_0 56 | - olefile=0.46=py37_0 57 | - opencv=4.1.2=py37_2 58 | - openssl=1.1.1d=hfa6e2cd_0 59 | - opt_einsum=3.1.0=py_0 60 | - pillow=6.2.1=py37hdc69c19_0 61 | - pip=19.3.1=py37_0 62 | - protobuf=3.11.1=py37he025d50_0 63 | - py-opencv=4.1.2=py37h5ca1d4c_2 64 | - pycparser=2.19=py37_0 65 | - pyparsing=2.4.5=py_0 66 | - pyqt=5.12.3=py37h6538335_1 67 | - pyreadline=2.1=py37_1001 68 | - python=3.7.5=h8c8aaf0_0 69 | - python-dateutil=2.8.1=py_0 70 | - pytorch=1.3.1=py3.7_cuda101_cudnn7_0 71 | - qt=5.12.5=h7ef1ec2_0 72 | - scipy=1.3.2=py37h29ff71c_0 73 | - setuptools=42.0.2=py37_0 74 | - sip=4.19.8=py37h6538335_0 75 | - six=1.13.0=py37_0 76 | - sqlite=3.30.1=he774522_0 77 | - tensorboard=2.0.0=pyhb38c66f_1 78 | - tensorflow=2.0.0=mkl_py37he1bbcac_0 79 | - tensorflow-base=2.0.0=mkl_py37hd1d5974_0 80 | - tensorflow-estimator=2.0.0=pyh2649769_0 81 | - termcolor=1.1.0=py_2 82 | - tk=8.6.8=hfa6e2cd_0 83 | - torchvision=0.4.2=py37_cu101 84 | - tornado=6.0.3=py37hfa6e2cd_0 85 | - tqdm=4.40.0=py_0 86 | - vc=14.1=h0510ff6_4 87 | - vs2015_runtime=14.16.27012=hf0eaf9b_0 88 | - werkzeug=0.16.0=py_0 89 | - wheel=0.33.6=py37_0 90 | - wincertstore=0.2=py37_0 91 | - wrapt=1.11.2=py37hfa6e2cd_0 92 | - xz=5.2.4=h2fa13f4_4 93 | - zlib=1.2.11=h62dcd97_3 94 | - zstd=1.3.7=h508b16e_0 95 | - pip: 96 | - opencv-python==4.1.2.30 97 | - pyqt5-sip==4.19.18 98 | - pyqtwebengine==5.12.1 99 | prefix: C:\Users\r.sibechi\AppData\Local\Continuum\anaconda3\envs\primevision_torch 100 | 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from parser_options import ParserOptions 2 | from util.general_functions import print_training_info 3 | from constants import * 4 | from core.trainers.trainer import Trainer 5 | 6 | def main(): 7 | args = ParserOptions().parse() # get training options 8 | trainer = Trainer(args) 9 | 10 | print_training_info(args) 11 | 12 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 13 | trainer.run_epoch(epoch, split=TRAIN) 14 | 15 | if epoch % args.eval_interval == (args.eval_interval - 1): 16 | trainer.run_epoch(epoch, split=TEST) 17 | 18 | trainer.run_epoch(trainer.args.epochs, split=VISUALIZATION) 19 | trainer.summary.writer.add_scalar('test/best_result', trainer.best_loss, args.epochs) 20 | trainer.summary.writer.close() 21 | trainer.save_network() 22 | 23 | 24 | if __name__ == "__main__": 25 | main() -------------------------------------------------------------------------------- /parser_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import math 5 | 6 | from constants import * 7 | 8 | class ParserOptions(): 9 | """This class defines options that are used by the program""" 10 | 11 | def __init__(self): 12 | parser = argparse.ArgumentParser(description='PyTorch Semantic Video Segmentation training') 13 | 14 | # model specific 15 | parser.add_argument('--model', type=str, default=DEEPLAB_50, choices=[DEEPLAB, DEEPLAB_50, DEEPLAB_34, DEEPLAB_18, DEEPLAB_MOBILENET, DEEPLAB_MOBILENET_DILATION, UNET, UNET_PAPER, UNET_PYTORCH, PSPNET], help='model name (default:' + DEEPLAB + ')') 16 | parser.add_argument('--separable_conv', type=int, default=0, choices=[0,1], help='if we should convert normal convolutions to separable convolutions' ) 17 | parser.add_argument('--refine_network', type=int, default=0, choices=[0,1], help='if we should refine the first prediction with a second network ') 18 | parser.add_argument('--dataset', type=str, default=DOCUNET_INVERTED, choices=[DOCUNET, DOCUNET_IM2IM, DOCUNET_INVERTED], help='dataset (default:' + DOCUNET + ')') 19 | parser.add_argument('--dataset_dir', type=str, default=ADDRESS_DATASET, choices=[HAZMAT_DATASET, LABELS_DATASET, ADDRESS_DATASET], help='name of the dir in which the dataset is located') 20 | parser.add_argument('--loss_type', type=str, default=DOCUNET_LOSS, choices=[DOCUNET_LOSS, SSIM_LOSS, SSIM_LOSS_V2, MS_SSIM_LOSS, MS_SSIM_LOSS_V2, L1_LOSS, SMOOTH_L1_LOSS, MSE_LOSS], help='loss func type (default:' + DOCUNET_LOSS + ')') 21 | parser.add_argument('--second_loss', type=int, default=0, choices=[0,1], help='if we should use two losses') 22 | parser.add_argument('--second_loss_rate', type=float, default=10, help='used to tune the overall impact of the second loss') 23 | parser.add_argument('--norm_layer', type=str, default=BATCH_NORM, choices=[INSTANCE_NORM, BATCH_NORM, SYNC_BATCH_NORM]) 24 | parser.add_argument('--init_type', type=str, default=KAIMING_INIT, choices=[NORMAL_INIT, KAIMING_INIT, XAVIER_INIT, ORTHOGONAL_INIT]) 25 | parser.add_argument('--resize', type=str, default='64,64', help='image resize: h,w') 26 | parser.add_argument('--batch_size', type=int, default=2, metavar='N', help='input batch size for training (default: 2)') 27 | parser.add_argument('--optim', type=str, default=ADAM, choices=[SGD, ADAM, RMSPROP, AMSGRAD, ADABOUND]) 28 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: auto)') 29 | parser.add_argument('--lr_policy', type=str, default='poly', choices=['poly', 'step', 'cos', 'linear'], help='lr scheduler mode: (default: poly)') 30 | parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='M', help='w-decay (default: 5e-4)') 31 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)') 32 | parser.add_argument('--clip', type=float, default=0, help='gradient clip, 0 means no clip (default: 0)') 33 | 34 | # training specific 35 | parser.add_argument('--size', type=str, default='1024,1024', help='image size: h,w') 36 | parser.add_argument('--start_epoch', type=int, default=0, metavar='N', help='starting epoch') 37 | parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: auto)') 38 | parser.add_argument('--eval-interval', type=int, default=1, help='evaluation interval (default: 1)') 39 | parser.add_argument('--trainval', type=int, default=0, choices=[0,1], help='determines whether whe should use validation images as well for training') 40 | parser.add_argument('--inference', type=int, default=0, choices=[0,1], help='if we should run the model in inference mode') 41 | parser.add_argument('--debug', type=int, default=1) 42 | parser.add_argument('--results_root', type=str, default='..') 43 | parser.add_argument('--results_dir', type=str, default='results_final', help='models are saved here') 44 | parser.add_argument('--save_dir', type=str, default='saved_models') 45 | parser.add_argument('--save_best_model', type=int, default=0, choices=[0,1], help='keep track of best model') 46 | parser.add_argument('--pretrained_models_dir', type=str, default='pretrained_models', help='root dir of the pretrained models location') 47 | 48 | # deeplab specific 49 | parser.add_argument('--output_stride', type=int, default=16, help='network output stride (default: 16)') 50 | parser.add_argument('--pretrained', type=int, default=0, choices=[0,1], help='if we should use a pretrained model or not') 51 | parser.add_argument('--learned_upsampling', type=int, default=0, choices=[0,1], help='if we should use bilinear upsampling or learned upsampling') 52 | parser.add_argument('--use_aspp', type=int, default=1, choices=[0,1], help='if we should aspp in the deeplab head or not') 53 | 54 | # unet specific 55 | parser.add_argument('--num_downs', type=int, default=8, help='number of unet encoder-decoder blocks') 56 | parser.add_argument('--ngf', type=int, default=128, help='# of gen filters in the last conv layer') 57 | parser.add_argument('--down_type', type=str, default=MAXPOOL, choices=[STRIDECONV, MAXPOOL], help='method to reduce feature map size') 58 | parser.add_argument('--dropout', type=float, default=0.2) 59 | 60 | args = parser.parse_args() 61 | args.size = tuple([int(x) for x in args.size.split(',')]) 62 | args.resize = tuple([int(x) for x in args.resize.split(',')]) 63 | 64 | if args.debug: 65 | args.results_dir = 'results_dummy' 66 | 67 | args.num_downs = int(math.log(args.resize[0])/math.log(2)) 68 | args.cuda = torch.cuda.is_available() 69 | args.gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'] if ('CUDA_VISIBLE_DEVICES' in os.environ) else '' 70 | args.gpu_ids = list(range(len(args.gpu_ids.split(',')))) if (',' in args.gpu_ids and args.cuda) else None 71 | 72 | if args.gpu_ids and args.norm_layer == BATCH_NORM: 73 | args.norm_layer = SYNC_BATCH_NORM 74 | 75 | if args.dataset == DOCUNET or args.dataset == DOCUNET_INVERTED: 76 | args.size = args.resize 77 | 78 | if args.dataset == DOCUNET_IM2IM and args.loss_type == DOCUNET_LOSS: 79 | args.loss_type = MS_SSIM_LOSS 80 | args.double_loss = 0 81 | 82 | self.args = args 83 | 84 | def parse(self): 85 | return self.args -------------------------------------------------------------------------------- /playground.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from scipy import io 3 | import random 4 | import torch 5 | import torchvision.transforms as standard_transforms 6 | import matplotlib.pyplot as plt 7 | from tqdm import tqdm 8 | import time 9 | import os 10 | 11 | from core.trainers.trainer import Trainer 12 | from parser_options import ParserOptions 13 | from util.general_functions import apply_transformation_to_image_cv, apply_transformation_to_image, invert_vector_field, get_model, make_data_loader 14 | from constants import * 15 | 16 | def cv2_invert(invert=True): 17 | deformed_label = cv2.imread('../deformed_label.jpg', cv2.IMREAD_COLOR) 18 | 19 | if invert: 20 | vector_field = io.loadmat('../fm.mat')['vector_field'] 21 | else: 22 | vector_field = io.loadmat('../fm.mat')['inverted_vector_field'] 23 | 24 | flatten_label = apply_transformation_to_image_cv(deformed_label, vector_field, invert=invert) 25 | cv2.imwrite('../our_flatten.jpg', flatten_label) 26 | 27 | def pytorch_invert(invert=False): 28 | deformed_label = cv2.imread('../deformed_label.jpg', cv2.IMREAD_COLOR) 29 | deformed_label = standard_transforms.ToTensor()(deformed_label).unsqueeze(dim=0) 30 | 31 | if invert: 32 | vector_field = io.loadmat('../fm.mat')['vector_field'] 33 | vector_field = invert_vector_field(vector_field) 34 | else: 35 | vector_field = io.loadmat('../fm.mat')['inverted_vector_field'] 36 | 37 | vector_field = torch.Tensor(vector_field).unsqueeze(dim=0) 38 | vector_field = vector_field.permute(0, 3, 1, 2) 39 | flatten_label = apply_transformation_to_image(deformed_label, vector_field) 40 | plt.imsave('../our_flatten_2.jpg', flatten_label.permute(1,2,0).cpu().numpy()) 41 | 42 | def check_duplicates(source_folder_name, destination_folder_name): 43 | source_files = set(os.listdir(source_folder_name)) 44 | destination_files = set(os.listdir(destination_folder_name)) 45 | intersection = source_files.intersection(destination_files) 46 | intersection.remove('Thumbs.db') 47 | 48 | if len(intersection) == 0: 49 | print("OK") 50 | else: 51 | print("NOT OK") 52 | 53 | def network_predict(iterations=51, pretrained_model=''): 54 | if not pretrained_model: 55 | raise NotImplementedError() 56 | 57 | args = ParserOptions().parse() 58 | args.cuda = False 59 | args.batch_size = 1 60 | args.inference = 1 61 | args.pretrained_models_dir = pretrained_model 62 | args.num_downs = 8 63 | args.resize, args.size = (256,256), (256,256) 64 | args.model = DEEPLAB_50 65 | #args.refine_network = 1 66 | trainer = Trainer(args) 67 | mean_time = trainer.calculate_inference_speed(iterations) 68 | print('Mean time', mean_time) 69 | 70 | if __name__ == "__main__": 71 | model = 'saved_models/deeplab_50.pth' 72 | network_predict(pretrained_model=model) -------------------------------------------------------------------------------- /readme_images/generating_deformed_images.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/generating_deformed_images.PNG -------------------------------------------------------------------------------- /readme_images/output_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/output_examples.png -------------------------------------------------------------------------------- /readme_images/overall_architecture.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/overall_architecture.PNG -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | MODEL=$1 2 | DATASET=$2 3 | LOSS_TYPE=$3 4 | BATCH_SIZE=$4 5 | RESIZE=$5 6 | REFINE_NETWORK=${6:-0} 7 | SECOND_LOSS=${7:-0} 8 | 9 | python main.py --model $MODEL --dataset $DATASET --loss_type $LOSS_TYPE --batch_size $BATCH_SIZE \ 10 | --resize $RESIZE --epochs 100 --norm_layer batch --debug 0 \ 11 | --refine_network $REFINE_NETWORK --second_loss $SECOND_LOSS -------------------------------------------------------------------------------- /util/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import types 2 | import random 3 | import torch 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | from torchvision.transforms import functional as F 7 | import numpy as np 8 | 9 | class Lambda(object): 10 | """Apply a user-defined lambda as a transform. 11 | 12 | Args: 13 | lambd (function): Lambda/function to be used for transform. 14 | """ 15 | 16 | def __init__(self, lambd): 17 | assert isinstance(lambd, types.LambdaType) 18 | self.lambd = lambd 19 | 20 | def __call__(self, img): 21 | return self.lambd(img) 22 | 23 | def __repr__(self): 24 | return self.__class__.__name__ + '()' 25 | 26 | class Compose(object): 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | 30 | def __call__(self, image, label=None): 31 | #assert image.size == label.size 32 | 33 | for t in self.transforms: 34 | image, label = t(image, label) 35 | return image, label 36 | 37 | 38 | class RandomCrop(object): 39 | def __init__(self, size, padding=0): 40 | self.size = size 41 | self.padding = padding 42 | 43 | def __call__(self, image, label): 44 | if self.padding > 0: 45 | image = ImageOps.expand(image, border=self.padding, fill=0) 46 | label = ImageOps.expand(label, border=self.padding, fill=0) 47 | # @TODO RADU 48 | #assert image.size == label.size 49 | w, h = image.size 50 | tw, th = self.size 51 | 52 | if w == tw and h == th: 53 | return image, label 54 | if w < tw or h < th: 55 | return image.resize((tw, th), Image.BILINEAR), label.resize((tw, th), Image.BILINEAR) 56 | 57 | x1 = random.randint(0, w - tw) 58 | y1 = random.randint(0, h - th) 59 | 60 | cropped_images, cropped_labels = image.crop((x1, y1, x1 + tw, y1 + th)), label.crop((x1, y1, x1 + tw, y1 + th)) 61 | return cropped_images, cropped_labels 62 | 63 | class RandomHorizontallyFlip(object): 64 | def __call__(self, image, label): 65 | if random.random() < 0.5: 66 | return image.transpose(Image.FLIP_LEFT_RIGHT), label.transpose(Image.FLIP_LEFT_RIGHT) 67 | return image, label 68 | 69 | class Scale(object): 70 | def __init__(self, size, label_resize_type=Image.BILINEAR): 71 | self.size = size 72 | self.label_resize_type = label_resize_type 73 | 74 | def __call__(self, image, label): 75 | assert image.size == label.size 76 | w, h = image.size 77 | 78 | if (w >= h and w == self.size) or (h >= w and h == self.size): 79 | return image, label 80 | if w > h: 81 | ow = self.size 82 | oh = int(self.size * h / w) 83 | return image.resize((ow, oh), Image.BILINEAR), label.resize((ow, oh), self.label_resize_type) 84 | else: 85 | oh = self.size 86 | ow = int(self.size * w / h) 87 | return image.resize((ow, oh), Image.BILINEAR), label.resize((ow, oh), self.label_resize_type) 88 | 89 | class Resize(object): 90 | def __init__(self, size, label_resize_type=Image.BILINEAR): 91 | self.size = size 92 | self.label_resize_type = label_resize_type 93 | 94 | def __call__(self, image, label): 95 | #assert image.size == label.size 96 | w, h = image.size 97 | w_label, h_label = label.size 98 | 99 | if not ((w >= h and w == self.size[0]) or (h >= w and h == self.size[1])): 100 | image = image.resize(self.size, Image.BILINEAR) 101 | 102 | if not ((w_label >= h_label and w_label == self.size[0]) or (h_label >= w_label and h_label == self.size[1])): 103 | label = label.resize(self.size, self.label_resize_type) 104 | 105 | return image, label 106 | 107 | class RandomGaussianBlur(object): 108 | def __call__(self, image, label=None): 109 | if random.random() < 0.5: 110 | radius = random.random() 111 | image = image.filter(ImageFilter.GaussianBlur(radius=radius)) 112 | 113 | if label is not None: 114 | return image, label 115 | return image 116 | 117 | class Normalize(object): 118 | """Normalize a tensor image with mean and standard deviation. 119 | Args: 120 | mean (tuple): means for each channel. 121 | std (tuple): standard deviations for each channel. 122 | """ 123 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 124 | self.mean = mean 125 | self.std = std 126 | 127 | def __call__(self, img): 128 | img = np.array(img).astype(np.float32) 129 | img /= 255.0 130 | img -= self.mean 131 | img /= self.std 132 | 133 | return img 134 | 135 | class ColorJitter(object): 136 | """Randomly change the brightness, contrast and saturation of an image. 137 | 138 | Args: 139 | brightness (float): How much to jitter brightness. brightness_factor 140 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 141 | contrast (float): How much to jitter contrast. contrast_factor 142 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 143 | saturation (float): How much to jitter saturation. saturation_factor 144 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 145 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 146 | [-hue, hue]. Should be >=0 and <= 0.5. 147 | """ 148 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 149 | self.brightness = brightness 150 | self.contrast = contrast 151 | self.saturation = saturation 152 | self.hue = hue 153 | 154 | @staticmethod 155 | def get_params(brightness, contrast, saturation, hue): 156 | """Get a randomized transform to be applied on image. 157 | 158 | Arguments are same as that of __init__. 159 | 160 | Returns: 161 | Transform which randomly adjusts brightness, contrast and 162 | saturation in a random order. 163 | """ 164 | transforms = [] 165 | if brightness > 0: 166 | brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) 167 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 168 | 169 | if contrast > 0: 170 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) 171 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 172 | 173 | if saturation > 0: 174 | saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) 175 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 176 | 177 | if hue > 0: 178 | hue_factor = random.uniform(-hue, hue) 179 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 180 | 181 | random.shuffle(transforms) 182 | return transforms 183 | 184 | @staticmethod 185 | def forward_transforms(image, transforms): 186 | for transform in transforms: 187 | image = transform(image) 188 | 189 | return image 190 | 191 | def __call__(self, image, label): 192 | """ 193 | Args: 194 | images (PIL Image): Input image. 195 | 196 | Returns: 197 | PIL Image: Color jittered image. 198 | """ 199 | transforms = self.get_params(self.brightness, self.contrast, 200 | self.saturation, self.hue) 201 | 202 | return self.forward_transforms(image, transforms), label -------------------------------------------------------------------------------- /util/general_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import ConcatDataset 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import functools 8 | import cv2 9 | 10 | from core.models.deeplabv3_plus import DeepLabv3_plus 11 | from core.models.mobilenetv2 import MobileNet_v2 12 | from core.models.pspnet import PSPNet 13 | from core.models.unet import UNet 14 | from core.models.unet_paper import UNet_paper 15 | from core.models.unet_pytorch import UNet_torch 16 | 17 | from dataloader.docunet import Docunet 18 | from dataloader.docunet_inverted import InvertedDocunet 19 | from dataloader.docunet_im2im import DocunetIm2Im 20 | 21 | from util.losses import * 22 | from constants import * 23 | 24 | def make_data_loader(args, split=TRAIN): 25 | """ 26 | Builds the model based on the provided arguments 27 | 28 | Parameters: 29 | args (argparse) -- input arguments 30 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 31 | gain (float) -- scaling factor for normal, xavier and orthogonal. 32 | """ 33 | if args.dataset == DOCUNET: 34 | dataset = Docunet 35 | elif args.dataset == DOCUNET_INVERTED: 36 | dataset = InvertedDocunet 37 | elif args.dataset == DOCUNET_IM2IM: 38 | dataset = DocunetIm2Im 39 | else: 40 | raise NotImplementedError 41 | 42 | if split == TRAINVAL: 43 | train_set = dataset(args, split=TRAIN) 44 | val_set = dataset(args, split=VAL) 45 | trainval_set = ConcatDataset([train_set, val_set]) 46 | loader = DataLoader(trainval_set, batch_size=args.batch_size, num_workers=1, shuffle=True) 47 | else: 48 | set = dataset(args, split=split) 49 | 50 | if split == TRAIN: 51 | loader = DataLoader(set, batch_size=args.batch_size, num_workers=1, shuffle=True) 52 | else: 53 | loader = DataLoader(set, batch_size=args.batch_size, num_workers=1, shuffle=False) 54 | 55 | return loader 56 | 57 | def get_model(args): 58 | """ 59 | Builds the model based on the provided arguments and returns the initialized model 60 | 61 | Parameters: 62 | args (argparse) -- command line arguments 63 | """ 64 | 65 | norm_layer = get_norm_layer(args.norm_layer) 66 | num_classes = get_num_classes(args.dataset) 67 | 68 | if DEEPLAB in args.model: 69 | model = DeepLabv3_plus(args, num_classes=num_classes, norm_layer=norm_layer) 70 | 71 | if args.model != DEEPLAB_MOBILENET and args.separable_conv: 72 | convert_to_separable_conv(model) 73 | 74 | model = init_model(model, args.init_type) 75 | elif UNET in args.model: 76 | if args.model == UNET: 77 | model = UNet(num_classes=num_classes, args=args, norm_layer=norm_layer) 78 | elif args.model == UNET_PAPER: 79 | model = UNet_paper(num_classes=num_classes, args=args, norm_layer=norm_layer) 80 | elif args.model == UNET_PYTORCH: 81 | model = UNet_torch(num_classes=num_classes, args=args) 82 | 83 | if args.separable_conv: 84 | convert_to_separable_conv(model) 85 | 86 | model = init_model(model, args.init_type) 87 | elif PSPNET in args.model: 88 | model = PSPNet(num_classes=num_classes, args=args) 89 | else: 90 | raise NotImplementedError 91 | 92 | print("Built " + args.model) 93 | 94 | if args.cuda: 95 | model = model.cuda() 96 | 97 | return model 98 | 99 | def convert_to_separable_conv(module): 100 | class SeparableConvolution(nn.Module): 101 | """ Separable Convolution 102 | """ 103 | 104 | def __init__(self, in_channels, out_channels, kernel_size, 105 | stride=1, padding=0, dilation=1, bias=True): 106 | super(SeparableConvolution, self).__init__() 107 | self.body = nn.Sequential( 108 | # Separable Conv 109 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, 110 | dilation=dilation, bias=bias, groups=in_channels), 111 | # PointWise Conv 112 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 113 | ) 114 | 115 | def forward(self, x): 116 | return self.body(x) 117 | 118 | new_module = module 119 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: 120 | new_module = SeparableConvolution(module.in_channels, 121 | module.out_channels, 122 | module.kernel_size, 123 | module.stride, 124 | module.padding, 125 | module.dilation, 126 | module.bias is not None) 127 | 128 | for name, child in module.named_children(): 129 | new_module.add_module(name, convert_to_separable_conv(child)) 130 | return new_module 131 | 132 | def get_loss_function(mode): 133 | if mode == DOCUNET_LOSS: 134 | loss = DocunetLoss() 135 | elif mode == MS_SSIM_LOSS: 136 | loss = MS_SSIM_Loss() 137 | elif mode == MS_SSIM_LOSS_V2: 138 | loss = MS_SSIM_Loss_v2() 139 | elif mode == SSIM_LOSS: 140 | loss = SSIM_Loss() 141 | elif mode == L1_LOSS: 142 | loss = torch.nn.L1Loss() 143 | elif mode == SMOOTH_L1_LOSS: 144 | loss = torch.nn.SmoothL1Loss() 145 | elif mode == MSE_LOSS: 146 | loss = torch.nn.MSELoss() 147 | else: 148 | raise NotImplementedError 149 | 150 | return loss 151 | 152 | def get_num_classes(dataset): 153 | if dataset == DOCUNET or dataset == DOCUNET_INVERTED: 154 | num_classes = 2 155 | elif dataset == DOCUNET_IM2IM: 156 | num_classes = 3 157 | else: 158 | raise NotImplementedError 159 | 160 | return num_classes 161 | 162 | 163 | def get_optimizer(model, args): 164 | """ 165 | Builds the optimizer for the model based on the provided arguments and returns the optimizer 166 | 167 | Parameters: 168 | model -- the network to be optimized 169 | args -- command line arguments 170 | """ 171 | if args.gpu_ids: 172 | train_params = model.module.get_train_parameters(args.lr) 173 | else: 174 | train_params = model.get_train_parameters(args.lr) 175 | 176 | if args.optim == SGD: 177 | optimizer = optim.SGD(train_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) 178 | elif args.optim == ADAM: 179 | optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay) 180 | elif args.optim == AMSGRAD: 181 | optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay, amsgrad=True) 182 | else: 183 | raise NotImplementedError 184 | 185 | return optimizer 186 | 187 | 188 | def get_norm_layer(norm_type=INSTANCE_NORM): 189 | """Returns a normalization layer 190 | 191 | Parameters: 192 | norm_type (str) -- the name of the normalization layer: batch | instance | none 193 | 194 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 195 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 196 | """ 197 | 198 | if norm_type == BATCH_NORM: 199 | norm_layer = nn.BatchNorm2d 200 | elif norm_type == INSTANCE_NORM: 201 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 202 | elif norm_type == 'none': 203 | norm_layer = None 204 | else: 205 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 206 | return norm_layer 207 | 208 | def init_model(net, init_type=NORMAL_INIT, init_gain=0.02): 209 | """Initialize the network weights 210 | 211 | Parameters: 212 | net (network) -- the network to be initialized 213 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 214 | gain (float) -- scaling factor for normal, xavier and orthogonal. 215 | 216 | Return an initialized network. 217 | """ 218 | 219 | init_weights(net, init_type, init_gain=init_gain) 220 | return net 221 | 222 | def init_weights(net, init_type=NORMAL_INIT, init_gain=0.02): 223 | """Initialize network weights. 224 | 225 | Parameters: 226 | net (network) -- network to be initialized 227 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 228 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 229 | """ 230 | 231 | def init_func(m): # define the initialization function 232 | classname = m.__class__.__name__ 233 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 234 | if init_type == NORMAL_INIT: 235 | nn.init.normal_(m.weight.data, 0.0, init_gain) 236 | elif init_type == XAVIER_INIT: 237 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 238 | elif init_type == KAIMING_INIT: 239 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='leaky_relu') 240 | elif init_type == ORTHOGONAL_INIT: 241 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 242 | else: 243 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 244 | if hasattr(m, 'bias') and m.bias is not None: 245 | nn.init.constant_(m.bias.data, 0.0) 246 | elif hasattr(m, '_all_weights') and (classname.find('LSTM') != -1 or classname.find('GRU') != -1): 247 | for names in m._all_weights: 248 | for name in filter(lambda n: "weight" in n, names): 249 | weight = getattr(m, name) 250 | nn.init.xavier_normal_(weight.data, gain=init_gain) 251 | 252 | for name in filter(lambda n: "bias" in n, names): 253 | bias = getattr(m, name) 254 | nn.init.constant_(bias.data, 0.0) 255 | 256 | if classname.find('LSTM') != -1: 257 | n = bias.size(0) 258 | start, end = n // 4, n // 2 259 | nn.init.constant_(bias.data[start:end], 1.) 260 | elif classname.find('BatchNorm2d') != -1 or classname.find('SynchronizedBatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 261 | nn.init.normal_(m.weight.data, 1.0, init_gain) 262 | nn.init.constant_(m.bias.data, 0.0) 263 | 264 | print('Initialized network with %s' % init_type) 265 | net.apply(init_func) # apply the initialization function 266 | 267 | def set_requires_grad(net, requires_grad=False): 268 | """Set requies_grad=False for the network to avoid unnecessary computations 269 | Parameters: 270 | net (network) 271 | requires_grad (bool) -- whether the networks require gradients or not 272 | """ 273 | if net is not None: 274 | for param in net.parameters(): 275 | param.requires_grad = requires_grad 276 | 277 | def tensor2im(input_image, imtype=np.uint8, return_tensor=True): 278 | """"Converts a Tensor array into a numpy image array. 279 | 280 | Parameters: 281 | input_image (tensor) -- the input image tensor array 282 | imtype (type) -- the desired type of the converted numpy array 283 | """ 284 | if not isinstance(input_image, np.ndarray): 285 | if isinstance(input_image, torch.Tensor): # get the data from a variable 286 | image_tensor = input_image.data 287 | else: 288 | return input_image 289 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array 290 | if image_numpy.ndim == 3: 291 | image_numpy = (image_numpy - np.min(image_numpy))/(np.max(image_numpy)-np.min(image_numpy)) 292 | if image_numpy.shape[0] == 1: # grayscale to RGB 293 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 294 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 295 | else: # if it is a numpy array, do nothing 296 | image_numpy = input_image 297 | 298 | return torch.from_numpy(image_numpy.astype(imtype)) if return_tensor else np.transpose(image_numpy, (1,2,0)) 299 | 300 | def get_flat_images(dataset, images, outputs, targets): 301 | if dataset == DOCUNET: 302 | pass 303 | elif dataset == DOCUNET_INVERTED: 304 | outputs = apply_transformation_to_image(images, outputs) 305 | targets = apply_transformation_to_image(images, targets) 306 | else: 307 | pass 308 | 309 | return outputs, targets 310 | 311 | def apply_transformation_to_image(img, vector_field): 312 | vector_field = scale_vector_field_tensor(vector_field) 313 | vector_field = vector_field.permute(0, 2, 3, 1) 314 | flatten_image = nn.functional.grid_sample(img, vector_field, mode='bilinear', align_corners=True) 315 | 316 | return flatten_image.squeeze() 317 | 318 | def apply_transformation_to_image_cv(img, vector_field, invert=False): 319 | if invert: 320 | vector_field = invert_vector_field(vector_field) 321 | 322 | map_x = vector_field[:, :, 0] 323 | map_y = vector_field[:, :, 1] 324 | transformed_img = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR) 325 | 326 | return transformed_img 327 | 328 | def invert_vector_field(vector_field): 329 | vector_field_x = vector_field[:, :, 0] 330 | vector_field_y = vector_field[:, :, 1] 331 | 332 | assert(vector_field_x.shape == vector_field_y.shape) 333 | rows = vector_field_x.shape[0] 334 | cols = vector_field_x.shape[1] 335 | 336 | m_x = np.ones(vector_field_x.shape, dtype=vector_field_x.dtype) * -1 337 | m_y = np.ones(vector_field_y.shape, dtype=vector_field_y.dtype) * -1 338 | for i in range(rows): 339 | for j in range(cols): 340 | i_ = int(round(vector_field_y[i, j])) 341 | j_ = int(round(vector_field_x[i, j])) 342 | if 0 <= i_ < rows and 0 <= j_ < cols: 343 | m_x[i_, j_] = j 344 | m_y[i_, j_] = i 345 | return np.stack([m_x, m_y], axis=2) 346 | 347 | 348 | def scale_vector_field_tensor(vector_field): 349 | vector_field = torch.where(vector_field < 0, torch.tensor(3 * vector_field.shape[3], dtype=vector_field.dtype, device=vector_field.device), vector_field) 350 | vector_field = (vector_field / (vector_field.shape[3] / 2)) - 1 351 | 352 | return vector_field 353 | 354 | def print_training_info(args): 355 | print('Dataset', args.dataset) 356 | 357 | if 'unet' in args.model: 358 | print('Ngf', args.ngf) 359 | print('Num downs', args.num_downs) 360 | print('Down type', args.down_type) 361 | 362 | if 'deeplab' in args.model: 363 | print('Output stride', args.output_stride) 364 | print('Learned upsampling', args.learned_upsampling) 365 | print('Pretrained', args.pretrained) 366 | print('Use aspp', args.use_aspp) 367 | 368 | print('Refine network', args.refine_network) 369 | print('Separable conv', args.separable_conv) 370 | print('Optimizer', args.optim) 371 | print('Learning rate', args.lr) 372 | print('Second loss', args.second_loss) 373 | 374 | if args.clip > 0: 375 | print('Gradient clip', args.clip) 376 | 377 | print('Resize', args.resize) 378 | print('Batch size', args.batch_size) 379 | print('Norm layer', args.norm_layer) 380 | print('Using cuda', args.cuda) 381 | print('Using ' + args.loss_type + ' loss') 382 | print('Starting Epoch:', args.start_epoch) 383 | print('Total Epoches:', args.epochs) 384 | 385 | 386 | 387 | -------------------------------------------------------------------------------- /util/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from util.ssim import MS_SSIM, SSIM 6 | from util.ms_ssim import MS_SSIM_v2, SSIM_v2 7 | 8 | class DocunetLoss_v2(nn.Module): 9 | def __init__(self, r=0.1,reduction='mean'): 10 | super(DocunetLoss_v2, self).__init__() 11 | assert reduction in ['mean','sum'], " reduction must in ['mean','sum']" 12 | self.r = r 13 | self.reduction = reduction 14 | 15 | def forward(self, y, label): 16 | bs, n, h, w = y.size() 17 | d = y - label 18 | loss1 = [] 19 | for d_i in d: 20 | loss1.append(torch.abs(d_i).mean() - self.r * torch.abs(d_i.mean())) 21 | loss1 = torch.stack(loss1) 22 | # lossb1 = torch.max(y1, torch.zeros(y1.shape).to(y1.device)).mean() 23 | loss2 = F.mse_loss(y, label,reduction=self.reduction) 24 | 25 | if self.reduction == 'mean': 26 | loss1 = loss1.mean() 27 | elif self.reduction == 'sum': 28 | loss1= loss1.sum() 29 | return loss1 + loss2 30 | 31 | class DocunetLoss(nn.Module): 32 | def __init__(self, lamda=0.1, reduction='mean'): 33 | super(DocunetLoss, self).__init__() 34 | self.lamda = lamda 35 | self.reduction = reduction 36 | 37 | def forward(self, output, target): 38 | x = target[:, 0, :, :] 39 | y = target[:, 1, :, :] 40 | back_sign_x, back_sign_y = (x == -1).int(), (y == -1).int() 41 | # assert back_sign_x == back_sign_y 42 | 43 | back_sign = ((back_sign_x + back_sign_y) == 2).float() 44 | fore_sign = 1 - back_sign 45 | 46 | loss_term_1_x = torch.sum(torch.abs(output[:, 0, :, :] - x) * fore_sign) / torch.sum(fore_sign) 47 | loss_term_1_y = torch.sum(torch.abs(output[:, 1, :, :] - y) * fore_sign) / torch.sum(fore_sign) 48 | loss_term_1 = loss_term_1_x + loss_term_1_y 49 | 50 | loss_term_2_x = torch.abs(torch.sum((output[:, 0, :, :] - x) * fore_sign)) / torch.sum(fore_sign) 51 | loss_term_2_y = torch.abs(torch.sum((output[:, 1, :, :] - y) * fore_sign)) / torch.sum(fore_sign) 52 | loss_term_2 = loss_term_2_x + loss_term_2_y 53 | 54 | zeros_x = torch.zeros(x.size()).cuda() if torch.cuda.is_available() else torch.zeros(x.size()) 55 | zeros_y = torch.zeros(y.size()).cuda() if torch.cuda.is_available() else torch.zeros(y.size()) 56 | 57 | loss_term_3_x = torch.max(zeros_x, output[:, 0, :, :]) 58 | loss_term_3_y = torch.max(zeros_y, output[:, 1, :, :]) 59 | loss_term_3 = torch.sum((loss_term_3_x + loss_term_3_y) * back_sign) / torch.sum(back_sign) 60 | 61 | loss = loss_term_1 - self.lamda * loss_term_2 + loss_term_3 62 | 63 | return loss 64 | 65 | class MS_SSIM_Loss(MS_SSIM): 66 | def __init__(self, window_size=11, size_average=True, channel=3): 67 | super(MS_SSIM_Loss, self).__init__(window_size=window_size, size_average=size_average, channel=channel) 68 | 69 | def forward(self, img1, img2): 70 | ms_ssim = super(MS_SSIM_Loss, self).forward(img1, img2) 71 | return 100*( 1 - ms_ssim) 72 | 73 | class SSIM_Loss(SSIM): 74 | def __init__(self, window_size=11, size_average=True, val_range=None): 75 | super(SSIM_Loss, self).__init__(window_size=window_size, size_average=size_average, val_range=val_range) 76 | 77 | def forward(self, img1, img2): 78 | ssim = super(SSIM_Loss, self).forward(img1, img2) 79 | return 100*( 1 - ssim) 80 | 81 | class MS_SSIM_Loss_v2(MS_SSIM_v2): 82 | def __init__(self, window_size=11, size_average=True, channel=3): 83 | super(MS_SSIM_Loss_v2, self).__init__(win_size=window_size, size_average=size_average, channel=channel, data_range=255, nonnegative_ssim=True) 84 | 85 | def forward(self, img1, img2): 86 | ms_ssim = super(MS_SSIM_Loss_v2, self).forward(img1, img2) 87 | return 100*( 1 - ms_ssim) 88 | 89 | class SSIM_Loss_v2(SSIM_v2): 90 | def __init__(self, window_size=11, size_average=True): 91 | super(SSIM_Loss_v2, self).__init__(win_size=window_size, size_average=size_average, data_range=255, nonnegative_ssim=True) 92 | 93 | def forward(self, img1, img2): 94 | ssim = super(SSIM_Loss_v2, self).forward(img1, img2) 95 | return 100*( 1 - ssim) 96 | -------------------------------------------------------------------------------- /util/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | 43 | def __call__(self, optimizer, i, epoch, best_pred): 44 | T = epoch * self.iters_per_epoch + i 45 | if self.mode == 'cos': 46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 47 | elif self.mode == 'poly': 48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 49 | elif self.mode == 'step': 50 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 51 | else: 52 | raise NotImplemented 53 | # warm up lr schedule 54 | if self.warmup_iters > 0 and T < self.warmup_iters: 55 | lr = lr * 1.0 * T / self.warmup_iters 56 | if epoch > self.epoch: 57 | print('\n=>Epoches %i, learning rate = %.4f, \ 58 | previous best = %.4f' % (epoch, lr, best_pred)) 59 | self.epoch = epoch 60 | assert lr >= 0 61 | self._adjust_learning_rate(optimizer, lr) 62 | 63 | def _adjust_learning_rate(self, optimizer, lr): 64 | if len(optimizer.param_groups) == 1: 65 | optimizer.param_groups[0]['lr'] = lr 66 | else: 67 | # enlarge the lr at the head 68 | optimizer.param_groups[0]['lr'] = lr 69 | for i in range(1, len(optimizer.param_groups)): 70 | optimizer.param_groups[i]['lr'] = lr * 10 71 | -------------------------------------------------------------------------------- /util/ms_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def _fspecial_gauss_1d(size, sigma): 6 | r"""Create 1-D gauss kernel 7 | Args: 8 | size (int): the size of gauss kernel 9 | sigma (float): sigma of normal distribution 10 | 11 | Returns: 12 | torch.Tensor: 1D kernel 13 | """ 14 | coords = torch.arange(size).to(dtype=torch.float) 15 | coords -= size // 2 16 | 17 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 18 | g /= g.sum() 19 | 20 | return g.unsqueeze(0).unsqueeze(0) 21 | 22 | 23 | def gaussian_filter(input, win): 24 | r""" Blur input with 1-D kernel 25 | Args: 26 | input (torch.Tensor): a batch of tensors to be blured 27 | window (torch.Tensor): 1-D gauss kernel 28 | 29 | Returns: 30 | torch.Tensor: blured tensors 31 | """ 32 | 33 | N, C, H, W = input.shape 34 | out = F.conv2d(input, win, stride=1, padding=0, groups=C) 35 | out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C) 36 | return out 37 | 38 | 39 | def _ssim(X, Y, win, data_range=255, size_average=True, full=False, K=(0.01, 0.03), nonnegative_ssim=False): 40 | r""" Calculate ssim index for X and Y 41 | Args: 42 | X (torch.Tensor): images 43 | Y (torch.Tensor): images 44 | win (torch.Tensor): 1-D gauss kernel 45 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 46 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 47 | full (bool, optional): return sc or not 48 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results. 49 | 50 | Returns: 51 | torch.Tensor: ssim results 52 | """ 53 | K1, K2 = K 54 | batch, channel, height, width = X.shape 55 | compensation = 1.0 56 | 57 | C1 = (K1 * data_range) ** 2 58 | C2 = (K2 * data_range) ** 2 59 | 60 | win = win.to(X.device, dtype=X.dtype) 61 | 62 | mu1 = gaussian_filter(X, win) 63 | mu2 = gaussian_filter(Y, win) 64 | 65 | mu1_sq = mu1.pow(2) 66 | mu2_sq = mu2.pow(2) 67 | mu1_mu2 = mu1 * mu2 68 | 69 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 70 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 71 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 72 | 73 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 74 | if nonnegative_ssim: 75 | cs_map = F.relu(cs_map, inplace=True) 76 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 77 | 78 | if size_average: 79 | ssim_val = ssim_map.mean() 80 | cs = cs_map.mean() 81 | else: 82 | ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW 83 | cs = cs_map.mean(-1).mean(-1).mean(-1) 84 | 85 | if full: 86 | return ssim_val, cs 87 | else: 88 | return ssim_val 89 | 90 | 91 | def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, K=(0.01, 0.03), 92 | nonnegative_ssim=False): 93 | r""" interface of ssim 94 | Args: 95 | X (torch.Tensor): a batch of images, (N,C,H,W) 96 | Y (torch.Tensor): a batch of images, (N,C,H,W) 97 | win_size: (int, optional): the size of gauss kernel 98 | win_sigma: (float, optional): sigma of normal distribution 99 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 100 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 101 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 102 | full (bool, optional): return sc or not 103 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 104 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results. 105 | 106 | Returns: 107 | torch.Tensor: ssim results 108 | """ 109 | 110 | if len(X.shape) != 4: 111 | raise ValueError('Input images must be 4-d tensors.') 112 | 113 | if not X.type() == Y.type(): 114 | raise ValueError('Input images must have the same dtype.') 115 | 116 | if not X.shape == Y.shape: 117 | raise ValueError('Input images must have the same dimensions.') 118 | 119 | if not (win_size % 2 == 1): 120 | raise ValueError('Window size must be odd.') 121 | 122 | win_sigma = win_sigma 123 | if win is None: 124 | win = _fspecial_gauss_1d(win_size, win_sigma) 125 | win = win.repeat(X.shape[1], 1, 1, 1) 126 | else: 127 | win_size = win.shape[-1] 128 | 129 | ssim_val, cs = _ssim(X, Y, 130 | win=win, 131 | data_range=data_range, 132 | size_average=False, 133 | full=True, K=K, nonnegative_ssim=nonnegative_ssim) 134 | if size_average: 135 | ssim_val = ssim_val.mean() 136 | cs = cs.mean() 137 | 138 | if full: 139 | return ssim_val, cs 140 | else: 141 | return ssim_val 142 | 143 | 144 | def ms_ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, weights=None, 145 | K=(0.01, 0.03), nonnegative_ssim=False): 146 | r""" interface of ms-ssim 147 | Args: 148 | X (torch.Tensor): a batch of images, (N,C,H,W) 149 | Y (torch.Tensor): a batch of images, (N,C,H,W) 150 | win_size: (int, optional): the size of gauss kernel 151 | win_sigma: (float, optional): sigma of normal distribution 152 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 153 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 154 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 155 | full (bool, optional): return sc or not 156 | weights (list, optional): weights for different levels 157 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 158 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid NaN results. 159 | Returns: 160 | torch.Tensor: ms-ssim results 161 | """ 162 | if len(X.shape) != 4: 163 | raise ValueError('Input images must be 4-d tensors.') 164 | 165 | if not X.type() == Y.type(): 166 | raise ValueError('Input images must have the same dtype.') 167 | 168 | if not X.shape == Y.shape: 169 | raise ValueError('Input images must have the same dimensions.') 170 | 171 | if not (win_size % 2 == 1): 172 | raise ValueError('Window size must be odd.') 173 | 174 | if weights is None: 175 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype) 176 | 177 | win_sigma = win_sigma 178 | if win is None: 179 | win = _fspecial_gauss_1d(win_size, win_sigma) 180 | win = win.repeat(X.shape[1], 1, 1, 1) 181 | else: 182 | win_size = win.shape[-1] 183 | 184 | levels = weights.shape[0] 185 | mcs = [] 186 | for _ in range(levels): 187 | ssim_val, cs = _ssim(X, Y, 188 | win=win, 189 | data_range=data_range, 190 | size_average=False, 191 | full=True, K=K, nonnegative_ssim=nonnegative_ssim) 192 | mcs.append(cs) 193 | 194 | padding = (X.shape[2] % 2, X.shape[3] % 2) 195 | X = F.avg_pool2d(X, kernel_size=2, padding=padding) 196 | Y = F.avg_pool2d(Y, kernel_size=2, padding=padding) 197 | 198 | mcs = torch.stack(mcs, dim=0) # mcs, (level, batch) 199 | # weights, (level) 200 | msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1)) 201 | * (ssim_val ** weights[-1]), dim=0) # (batch, ) 202 | 203 | if size_average: 204 | msssim_val = msssim_val.mean() 205 | return msssim_val 206 | 207 | 208 | class SSIM_v2(torch.nn.Module): 209 | def __init__(self, win_size=11, win_sigma=1.5, data_range=1, size_average=True, channel=3, K=(0.01, 0.03), 210 | nonnegative_ssim=False): 211 | r""" class for ssim 212 | Args: 213 | win_size: (int, optional): the size of gauss kernel 214 | win_sigma: (float, optional): sigma of normal distribution 215 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 216 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 217 | channel (int, optional): input channels (default: 3) 218 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 219 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results. 220 | """ 221 | 222 | super(SSIM_v2, self).__init__() 223 | self.win = _fspecial_gauss_1d( 224 | win_size, win_sigma).repeat(channel, 1, 1, 1) 225 | self.size_average = size_average 226 | self.data_range = data_range 227 | self.K = K 228 | self.nonnegative_ssim = nonnegative_ssim 229 | 230 | def forward(self, X, Y): 231 | return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average, K=self.K, 232 | nonnegative_ssim=self.nonnegative_ssim) 233 | 234 | 235 | class MS_SSIM_v2(torch.nn.Module): 236 | def __init__(self, win_size=11, win_sigma=1.5, data_range=1, size_average=True, channel=3, weights=None, 237 | K=(0.01, 0.03), nonnegative_ssim=False): 238 | r""" class for ms-ssim 239 | Args: 240 | win_size: (int, optional): the size of gauss kernel 241 | win_sigma: (float, optional): sigma of normal distribution 242 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 243 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 244 | channel (int, optional): input channels (default: 3) 245 | weights (list, optional): weights for different levels 246 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 247 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid NaN results. 248 | """ 249 | 250 | super(MS_SSIM_v2, self).__init__() 251 | self.win = _fspecial_gauss_1d( 252 | win_size, win_sigma).repeat(channel, 1, 1, 1) 253 | self.size_average = size_average 254 | self.data_range = data_range 255 | self.weights = weights 256 | self.K = K 257 | self.nonnegative_ssim = nonnegative_ssim 258 | 259 | def forward(self, X, Y): 260 | return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, 261 | weights=self.weights, K=self.K, nonnegative_ssim=self.nonnegative_ssim) -------------------------------------------------------------------------------- /util/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | def gaussian(window_size, sigma): 7 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 8 | return gauss/gauss.sum() 9 | 10 | def create_window(window_size, channel=1): 11 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 12 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 13 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 14 | return window 15 | 16 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None, normalize=False): 17 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 18 | if val_range is None: 19 | if torch.max(img1) > 128: 20 | max_val = 255 21 | else: 22 | max_val = 1 23 | 24 | if torch.min(img1) < -0.5: 25 | min_val = -1 26 | else: 27 | min_val = 0 28 | L = max_val - min_val 29 | else: 30 | L = val_range 31 | 32 | padd = 0 33 | (_, channel, height, width) = img1.size() 34 | if window is None: 35 | real_size = min(window_size, height, width) 36 | window = create_window(real_size, channel=channel).to(img1.device) 37 | 38 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 39 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 40 | 41 | mu1_sq = mu1.pow(2) 42 | mu2_sq = mu2.pow(2) 43 | mu1_mu2 = mu1 * mu2 44 | 45 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 46 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 47 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 48 | 49 | C1 = (0.01 * L) ** 2 50 | C2 = (0.03 * L) ** 2 51 | 52 | v1 = 2.0 * sigma12 + C2 53 | v2 = sigma1_sq + sigma2_sq + C2 54 | cs = torch.mean(v1 / v2) # contrast sensitivity 55 | 56 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 57 | 58 | if normalize: 59 | ssim_map = F.relu(ssim_map, inplace=True) 60 | 61 | if size_average: 62 | ret = ssim_map.mean() 63 | else: 64 | ret = ssim_map.mean(1).mean(1).mean(1) 65 | 66 | if full: 67 | return ret, cs 68 | return ret 69 | 70 | 71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 72 | device = img1.device 73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 74 | levels = weights.size()[0] 75 | mssim = [] 76 | mcs = [] 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range, normalize=False) 79 | mssim.append(sim) 80 | mcs.append(cs) 81 | 82 | img1 = F.avg_pool2d(img1, (2, 2)) 83 | img2 = F.avg_pool2d(img2, (2, 2)) 84 | 85 | mssim = torch.stack(mssim) 86 | mcs = torch.stack(mcs) 87 | 88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 89 | if normalize: 90 | mssim = (mssim + 1) / 2 91 | mcs = (mcs + 1) / 2 92 | 93 | pow1 = mcs ** weights 94 | pow2 = mssim ** weights 95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 96 | output = torch.prod(pow1[:-1] * pow2[-1]) 97 | return output 98 | 99 | 100 | # Classes to re-use window 101 | class SSIM(torch.nn.Module): 102 | def __init__(self, window_size=11, size_average=True, val_range=None): 103 | super(SSIM, self).__init__() 104 | self.window_size = window_size 105 | self.size_average = size_average 106 | self.val_range = val_range 107 | 108 | # Assume 1 channel for SSIM 109 | self.channel = 1 110 | self.window = create_window(window_size) 111 | 112 | def forward(self, img1, img2): 113 | (_, channel, _, _) = img1.size() 114 | 115 | if channel == self.channel and self.window.dtype == img1.dtype: 116 | window = self.window 117 | else: 118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 119 | self.window = window 120 | self.channel = channel 121 | 122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 123 | 124 | class MS_SSIM(torch.nn.Module): 125 | def __init__(self, window_size=11, size_average=True, channel=3): 126 | super(MS_SSIM, self).__init__() 127 | self.window_size = window_size 128 | self.size_average = size_average 129 | self.channel = channel 130 | 131 | def forward(self, img1, img2): 132 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True) -------------------------------------------------------------------------------- /util/summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from torch.utils.tensorboard import SummaryWriter 5 | import glob 6 | 7 | from util.general_functions import tensor2im, get_flat_images 8 | 9 | 10 | from constants import * 11 | 12 | class TensorboardSummary(object): 13 | 14 | def __init__(self, args): 15 | self.args = args 16 | self.experiment_dir = self.generate_directory(args) 17 | self.writer = SummaryWriter(log_dir=os.path.join(self.experiment_dir)) 18 | 19 | self.train_step = 0 20 | self.test_step = 0 21 | self.visualization_step = 0 22 | 23 | def generate_directory(self, args): 24 | checkname = 'debug' if args.debug else '' 25 | checkname += args.model 26 | checkname += '_sc' if args.separable_conv else '' 27 | checkname += '-refined' if args.refine_network else '' 28 | 29 | if 'deeplab' in args.model: 30 | checkname += '-os_' + str(args.output_stride) 31 | checkname += '-ls_1' if args.learned_upsampling else '' 32 | checkname += '-pt_1' if args.pretrained else '' 33 | checkname += '-aspp_0' if not args.use_aspp else '' 34 | 35 | if 'unet' in args.model: 36 | checkname += '-downs_' + str(args.num_downs) + '-ngf_' + str(args.ngf) + '-type_' + str(args.down_type) 37 | 38 | checkname += '-loss_' + args.loss_type 39 | checkname += '-sloss_' if args.second_loss else '' 40 | 41 | if args.clip > 0: 42 | checkname += '-clipping_' + str(args.clip) 43 | 44 | if args.resize: 45 | checkname += '-' + ','.join([str(x) for x in list(args.resize)]) 46 | checkname += '-epochs_' + str(args.epochs) 47 | checkname += '-trainval' if args.trainval else '' 48 | 49 | current_dir = os.path.dirname(__file__) 50 | directory = os.path.join(current_dir, args.results_root, args.results_dir, args.dataset_dir, args.dataset, args.model, checkname) 51 | 52 | runs = sorted(glob.glob(os.path.join(directory, 'experiment_*'))) 53 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 54 | experiment_dir = os.path.join(directory, 'experiment_{}'.format(str(run_id))) 55 | 56 | if not os.path.exists(experiment_dir): 57 | os.makedirs(experiment_dir) 58 | 59 | return experiment_dir 60 | 61 | def add_scalar(self, tag, value, step): 62 | self.writer.add_scalar(tag, value, step) 63 | 64 | def visualize_image(self, images, targets, outputs, split="train"): 65 | step = self.get_step(split) 66 | 67 | outputs, targets = get_flat_images(self.args.dataset, images, outputs, targets) 68 | 69 | images = [tensor2im(image) for image in images] 70 | outputs = [tensor2im(output)[:, : int(self.args.resize[0] /2), : int(self.args.resize[1] / 2)] for output in outputs] 71 | targets = [tensor2im(target)[:, : int(self.args.resize[0] / 2), : int(self.args.resize[1] / 2)] for target in targets] 72 | 73 | grid_image = make_grid(images) 74 | self.writer.add_image(split + '/ZZ Image', grid_image, step) 75 | 76 | grid_image = make_grid(outputs) 77 | self.writer.add_image(split + '/Predicted label', grid_image, step) 78 | 79 | grid_image = make_grid(targets) 80 | self.writer.add_image(split + '/Groundtruth label', grid_image, step) 81 | 82 | return images, outputs, targets 83 | 84 | def save_network(self, model): 85 | path = self.experiment_dir[self.experiment_dir.find(self.args.results_dir):].replace(self.args.results_dir, self.args.save_dir) 86 | if not os.path.isdir(path): 87 | os.makedirs(path) 88 | 89 | torch.save(model.state_dict(), path + '/' + 'network.pth') 90 | 91 | def load_network(self, model): 92 | path = self.args.pretrained_models_dir 93 | state_dict = torch.load(path) if self.args.cuda else torch.load(path, map_location=torch.device('cpu')) 94 | model.load_state_dict(state_dict) 95 | 96 | return model 97 | 98 | def get_step(self, split): 99 | if split == TRAIN: 100 | self.train_step += 1 101 | return self.train_step 102 | elif split == TEST: 103 | self.test_step += 1 104 | return self.test_step 105 | elif split == VISUALIZATION: 106 | self.visualization_step += 1 107 | return self.visualization_step --------------------------------------------------------------------------------