├── .gitignore
├── README.md
├── constants.py
├── core
├── models
│ ├── deeplabv3_plus.py
│ ├── mobilenet_v2_dilation.py
│ ├── mobilenetv2.py
│ ├── pspnet.py
│ ├── resnet.py
│ ├── unet.py
│ ├── unet_paper.py
│ └── unet_pytorch.py
└── trainers
│ └── trainer.py
├── dataloader
├── docunet.py
├── docunet_im2im.py
└── docunet_inverted.py
├── environment.yml
├── main.py
├── parser_options.py
├── playground.py
├── readme_images
├── generating_deformed_images.PNG
├── output_examples.png
└── overall_architecture.PNG
├── train.sh
└── util
├── custom_transforms.py
├── general_functions.py
├── losses.py
├── lr_scheduler.py
├── ms_ssim.py
├── ssim.py
└── summary.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results*/*
2 | images/*
3 | saved_models/*
4 | dataset/*
5 | .idea/*
6 | __pycache__/*
7 | **/__pycache__/*
8 | !.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Document Image unwarping
2 |
3 |
4 |
5 |
6 |
7 | This repository contains an *unofficial* implementation of [DocUNet: Document Image Unwarping via a Stacked U-Net](http://openaccess.thecvf.com/content_cvpr_2018/html/Ma_DocUNet_Document_Image_CVPR_2018_paper.html).
8 | We extend this work by:
9 | * predicting the inverted vector fields directly, which saves computation time during inference
10 | * adding more networks that can be used: from UNet to Deeplabv3+ with different backbones
11 | * adding a second loss function (MS-SSIM / SSIM) to measure the similarity between unwarped and target image
12 | * achieving real-time inference speed (300ms) on cpu for Deeplabv3+ with MobileNetv2 as backbone
13 |
14 | ## Training dataset
15 |
16 | Unfortunately, I am not allowed to make public the dataset. However, I created a very small toy dataset to give you an idea of how the network input should look.
17 | You can find this [here](https://drive.google.com/file/d/16Ay3NVzFmsVe1saMOZam-9nHBE0xcmyA/view?usp=sharing).
18 | The idea is to create a 2D vector field to deform a flat input image. The deformed image is used as network input and the vector field is the network target.
19 |
20 |
21 |
22 |
23 |
24 | ## Training on your dataset
25 | 1. Check the [available parser options](parser_options.py).
26 | 2. Download the [toy dataset](https://drive.google.com/file/d/16Ay3NVzFmsVe1saMOZam-9nHBE0xcmyA/view?usp=sharing).
27 | 3. Set the path to your dataset in the [available parser options](parser_options.py).
28 | 4. Create the environment from the [conda file](environment.yml): `conda env create -f environment.yml`
29 | 5. Activate the conda environment: `conda activate unwarping_assignment`
30 | 6. Train the networks using the provided scripts: [1](main.py), [2](train.sh). The trained model is saved to the `save_dir` command line argument.
31 | 7. Run the [inference script](playground.py) on your set. The command line argument `inference_dir` should be used to provide the
32 | relative path to the folder which contains the images to be classified.
33 |
34 | ## Sample results
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | # Model constants
2 | DEEPLAB = 'deeplab'
3 | DEEPLAB_50 = 'deeplab_50'
4 | DEEPLAB_34 = 'deeplab_34'
5 | DEEPLAB_18 = 'deeplab_18'
6 | DEEPLAB_MOBILENET = 'deeplab_mn'
7 | DEEPLAB_MOBILENET_DILATION = 'deeplab_mnd'
8 | UNET = 'unet'
9 | UNET_PAPER = 'unet_paper'
10 | UNET_PYTORCH = 'unet_torch'
11 | PSPNET = 'pspnet'
12 |
13 | # Dataset constants
14 | DOCUNET = 'docunet'
15 | DOCUNET_INVERTED = 'docunet_inverted'
16 | DOCUNET_IM2IM = 'docunet_im2im'
17 |
18 | # Dataset locations
19 | HAZMAT_DATASET = 'hazmat_dataset'
20 | ADDRESS_DATASET = 'address_dataset'
21 | LABELS_DATASET = 'labels_dataset'
22 |
23 | # loss constants
24 | DOCUNET_LOSS = 'docunet_loss'
25 | MS_SSIM_LOSS = 'ms_ssim_loss'
26 | MS_SSIM_LOSS_V2 = 'ms_ssim_loss_v2'
27 | SSIM_LOSS = 'ssim_loss'
28 | SSIM_LOSS_V2 = 'ssim_loss_v2'
29 | SMOOTH_L1_LOSS = 'smoothl1_loss'
30 | L1_LOSS = 'l1_loss'
31 | MSE_LOSS = 'mse_loss'
32 |
33 | # Optimizers
34 | SGD = 'sgd'
35 | AMSGRAD = 'amsgrad'
36 | ADAM = 'adam'
37 | RMSPROP = 'rmsprop'
38 | ADABOUND = 'adabound'
39 |
40 | # Normalization layers
41 | INSTANCE_NORM = 'instance'
42 | BATCH_NORM = 'batch'
43 | SYNC_BATCH_NORM = 'syncbn'
44 |
45 | # Init types
46 | NORMAL_INIT = 'normal'
47 | KAIMING_INIT = 'kaiming'
48 | XAVIER_INIT = 'xavier'
49 | ORTHOGONAL_INIT = 'orthogonal'
50 |
51 | # Downsampling methods
52 | MAXPOOL = 'maxpool'
53 | STRIDECONV = 'strided'
54 |
55 | # Split constants
56 | TRAIN = 'train'
57 | VAL = 'val'
58 | TEST = 'test'
59 | TRAINVAL = 'trainval'
60 | VISUALIZATION = 'visualization'
--------------------------------------------------------------------------------
/core/models/deeplabv3_plus.py:
--------------------------------------------------------------------------------
1 | '''
2 | Adapted work from https://github.com/jfzhang95/pytorch-deeplab-xception
3 | '''
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 |
9 | from core.models.resnet import ResNet101, ResNet50, ResNet34, ResNet18
10 | from core.models.mobilenetv2 import MobileNet_v2
11 | from core.models.mobilenet_v2_dilation import MobileNet_v2_dilation
12 | from constants import *
13 |
14 | class _ASPPModule(nn.Module):
15 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer=nn.BatchNorm2d):
16 | super(_ASPPModule, self).__init__()
17 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
18 | stride=1, padding=padding, dilation=dilation, bias=False)
19 | self.bn = norm_layer(planes)
20 | self.relu = nn.ReLU()
21 |
22 | def forward(self, x):
23 | x = self.atrous_conv(x)
24 | x = self.bn(x)
25 |
26 | return self.relu(x)
27 |
28 |
29 | class ASPP(nn.Module):
30 | def __init__(self, output_stride, norm_layer=nn.BatchNorm2d, inplanes=2048):
31 | super(ASPP, self).__init__()
32 |
33 | if output_stride == 16:
34 | dilations = [1, 6, 12, 18]
35 | elif output_stride == 8:
36 | dilations = [1, 12, 24, 36]
37 | else:
38 | raise NotImplementedError
39 |
40 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer)
41 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer)
42 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer)
43 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer)
44 |
45 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
46 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
47 | norm_layer(256),
48 | nn.ReLU())
49 |
50 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
51 | self.bn1 = norm_layer(256)
52 | self.relu = nn.ReLU()
53 | self.dropout = nn.Dropout2d(0.5)
54 |
55 | def forward(self, x):
56 | x1 = self.aspp1(x)
57 | x2 = self.aspp2(x)
58 | x3 = self.aspp3(x)
59 | x4 = self.aspp4(x)
60 | x5 = self.global_avg_pool(x)
61 |
62 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
63 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
64 |
65 | x = self.conv1(x)
66 | x = self.bn1(x)
67 | x = self.relu(x)
68 |
69 | return self.dropout(x)
70 |
71 | class Decoder(nn.Module):
72 | def __init__(self, num_classes, norm_layer=nn.BatchNorm2d, inplanes=256, aspp_outplanes=256):
73 | super(Decoder, self).__init__()
74 |
75 | self.conv1 = nn.Conv2d(inplanes, 48, 1, bias=False)
76 | self.bn1 = norm_layer(48)
77 | self.relu = nn.ReLU()
78 |
79 | inplanes = 48 + aspp_outplanes
80 |
81 | self.last_conv = nn.Sequential(nn.Conv2d(inplanes, 256, kernel_size=3, stride=1, padding=1, bias=False),
82 | norm_layer(256),
83 | nn.ReLU(),
84 | nn.Dropout2d(0.5),
85 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
86 | norm_layer(256),
87 | nn.ReLU(),
88 | nn.Dropout2d(0.1),
89 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
90 |
91 | def forward(self, x, low_level_feat):
92 | low_level_feat = self.conv1(low_level_feat)
93 | low_level_feat = self.bn1(low_level_feat)
94 | low_level_feat = self.relu(low_level_feat)
95 |
96 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
97 | x = torch.cat((x, low_level_feat), dim=1)
98 | x = self.last_conv(x)
99 |
100 | return x
101 |
102 | class DeepLabv3_plus(nn.Module):
103 | def __init__(self, args, num_classes=21, norm_layer=nn.BatchNorm2d, input_channels=3):
104 | super(DeepLabv3_plus, self).__init__()
105 | self.args = args
106 |
107 | if args.model == DEEPLAB:
108 | self.backbone = ResNet101(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels)
109 | self.aspp_inplanes = 2048
110 | self.decoder_inplanes = 256
111 |
112 | if self.args.refine_network:
113 | self.refine_backbone = ResNet101(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes)
114 | elif args.model == DEEPLAB_50:
115 | self.backbone = ResNet50(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained)
116 | self.aspp_inplanes = 2048
117 | self.decoder_inplanes = 256
118 |
119 | if self.args.refine_network:
120 | self.refine_backbone = ResNet50(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes)
121 | elif args.model == DEEPLAB_34:
122 | self.backbone = ResNet34(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained)
123 | self.aspp_inplanes = 2048
124 | self.decoder_inplanes = 256
125 |
126 | if self.args.refine_network:
127 | self.refine_backbone = ResNet34(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes)
128 | elif args.model == DEEPLAB_18:
129 | self.backbone = ResNet18(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained)
130 | self.aspp_inplanes = 2048
131 | self.decoder_inplanes = 256
132 |
133 | if self.args.refine_network:
134 | self.refine_backbone = ResNet18(args.output_stride, norm_layer=norm_layer, pretrained=args.pretrained, input_channels=input_channels + num_classes)
135 | elif args.model == DEEPLAB_MOBILENET:
136 | self.backbone = MobileNet_v2(pretrained=args.pretrained, first_layer_input_channels=input_channels)
137 | self.aspp_inplanes = 320
138 | self.decoder_inplanes = 24
139 |
140 | if self.args.refine_network:
141 | self.refine_backbone = MobileNet_v2(pretrained=args.pretrained, first_layer_input_channels=input_channels + num_classes)
142 |
143 | elif args.model == DEEPLAB_MOBILENET_DILATION:
144 | self.backbone = MobileNet_v2_dilation(pretrained=args.pretrained, first_layer_input_channels=input_channels)
145 | self.aspp_inplanes = 320
146 |
147 | if self.args.refine_network:
148 | self.refine_backbone = MobileNet_v2_dilation(pretrained=args.pretrained, first_layer_input_channels=input_channels + num_classes)
149 | self.decoder_inplanes = 24
150 | else:
151 | raise NotImplementedError
152 |
153 | if self.args.use_aspp:
154 | self.aspp = ASPP(args.output_stride, norm_layer=norm_layer, inplanes=self.aspp_inplanes)
155 |
156 | aspp_outplanes = 256 if self.args.use_aspp else self.aspp_inplanes
157 | self.decoder = Decoder(num_classes, norm_layer=norm_layer, inplanes=self.decoder_inplanes, aspp_outplanes=aspp_outplanes)
158 |
159 | if self.args.learned_upsampling:
160 | self.learned_upsampling = nn.Sequential(nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1),
161 | nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1))
162 |
163 | if self.args.refine_network:
164 | if self.args.use_aspp:
165 | self.refine_aspp = ASPP(args.output_stride, norm_layer=norm_layer, inplanes=self.aspp_inplanes)
166 |
167 | self.refine_decoder = Decoder(num_classes, norm_layer=norm_layer, inplanes=self.decoder_inplanes, aspp_outplanes=aspp_outplanes)
168 |
169 | if self.args.learned_upsampling:
170 | self.refine_learned_upsampling = nn.Sequential(nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1),
171 | nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1))
172 |
173 |
174 | def forward(self, input):
175 | output, low_level_feat = self.backbone(input)
176 |
177 | if self.args.use_aspp:
178 | output = self.aspp(output)
179 |
180 | output = self.decoder(output, low_level_feat)
181 |
182 | if self.args.learned_upsampling:
183 | output = self.learned_upsampling(output)
184 | else:
185 | output = F.interpolate(output, size=input.size()[2:], mode='bilinear', align_corners=True)
186 |
187 | if self.args.refine_network:
188 | second_output, low_level_feat = self.refine_backbone(torch.cat((input, output), dim=1))
189 |
190 | if self.args.use_aspp:
191 | second_output = self.refine_aspp(second_output)
192 |
193 | second_output = self.refine_decoder(second_output, low_level_feat)
194 |
195 | if self.args.learned_upsampling:
196 | second_output = self.refine_learned_upsampling(second_output)
197 | else:
198 | second_output = F.interpolate(second_output, size=input.size()[2:], mode='bilinear', align_corners=True)
199 |
200 | return output, second_output
201 |
202 | return output
203 | def get_train_parameters(self, lr):
204 | train_params = [{'params': self.parameters(), 'lr': lr}]
205 |
206 | return train_params
--------------------------------------------------------------------------------
/core/models/mobilenet_v2_dilation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 | import torch.nn.functional as F
4 |
5 |
6 | model_urls = {
7 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
8 | }
9 |
10 |
11 | def _make_divisible(v, divisor, min_value=None):
12 | if min_value is None:
13 | min_value = divisor
14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
15 | # Make sure that round down does not go down by more than 10%.
16 | if new_v < 0.9 * v:
17 | new_v += divisor
18 | return new_v
19 |
20 |
21 | class ConvBNReLU(nn.Sequential):
22 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
23 | #padding = (kernel_size - 1) // 2
24 | super(ConvBNReLU, self).__init__(
25 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
26 | nn.BatchNorm2d(out_planes),
27 | nn.ReLU6(inplace=True)
28 | )
29 |
30 | def fixed_padding(kernel_size, dilation):
31 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
32 | pad_total = kernel_size_effective - 1
33 | pad_beg = pad_total // 2
34 | pad_end = pad_total - pad_beg
35 | return (pad_beg, pad_end, pad_beg, pad_end)
36 |
37 | class InvertedResidual(nn.Module):
38 | def __init__(self, inp, oup, stride, dilation, expand_ratio):
39 | super(InvertedResidual, self).__init__()
40 | self.stride = stride
41 | assert stride in [1, 2]
42 |
43 | hidden_dim = int(round(inp * expand_ratio))
44 | self.use_res_connect = self.stride == 1 and inp == oup
45 |
46 | layers = []
47 | if expand_ratio != 1:
48 | # pw
49 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
50 |
51 | layers.extend([
52 | # dw
53 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
54 | # pw-linear
55 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
56 | nn.BatchNorm2d(oup),
57 | ])
58 | self.conv = nn.Sequential(*layers)
59 |
60 | self.input_padding = fixed_padding( 3, dilation )
61 |
62 | def forward(self, x):
63 | x_pad = F.pad(x, self.input_padding)
64 | if self.use_res_connect:
65 | return x + self.conv(x_pad)
66 | else:
67 | return self.conv(x_pad)
68 |
69 | class MobileNetV2(nn.Module):
70 | def __init__(self, first_layer_input_channels=3, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None):
71 | super(MobileNetV2, self).__init__()
72 |
73 | if block is None:
74 | block = InvertedResidual
75 |
76 | input_channel = 32
77 | last_channel = 1280
78 | self.output_stride = output_stride
79 | current_stride = 1
80 |
81 | if inverted_residual_setting is None:
82 | inverted_residual_setting = [
83 | # t, c, n, s
84 | [1, 16, 1, 1],
85 | [6, 24, 2, 2],
86 | [6, 32, 3, 2],
87 | [6, 64, 4, 2],
88 | [6, 96, 3, 1],
89 | [6, 160, 3, 2],
90 | [6, 320, 1, 1],
91 | ]
92 |
93 | # only check the first element, assuming user knows t,c,n,s are required
94 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
95 | raise ValueError("inverted_residual_setting should be non-empty "
96 | "or a 4-element list, got {}".format(inverted_residual_setting))
97 |
98 | # building first layer
99 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
100 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
101 | features = [ConvBNReLU(first_layer_input_channels, input_channel, stride=2)]
102 | current_stride *= 2
103 | dilation=1
104 | previous_dilation = 1
105 |
106 | # building inverted residual blocks
107 | for t, c, n, s in inverted_residual_setting:
108 | output_channel = _make_divisible(c * width_mult, round_nearest)
109 | previous_dilation = dilation
110 | if current_stride == output_stride:
111 | stride = 1
112 | dilation *= s
113 | else:
114 | stride = s
115 | current_stride *= s
116 | output_channel = int(c * width_mult)
117 |
118 | for i in range(n):
119 | if i==0:
120 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
121 | else:
122 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
123 | input_channel = output_channel
124 |
125 | # building last several layers
126 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
127 |
128 | self.low_level_features = nn.Sequential(*features[0:4])
129 | self.high_level_features = nn.Sequential(*features[4:-1])
130 |
131 | def forward(self, x):
132 | x = self.low_level_features(x)
133 | low_level_feat = x
134 | x = self.high_level_features(x)
135 |
136 | return x, low_level_feat
137 |
138 |
139 | def MobileNet_v2_dilation(pretrained=False, **kwargs):
140 | model = MobileNetV2(**kwargs)
141 | if pretrained:
142 | _load_pretrained_model(model, model_urls['mobilenet_v2'])
143 |
144 | return model
145 |
146 |
147 | def _load_pretrained_model(model, url):
148 | pretrain_dict = load_state_dict_from_url(url)
149 | model_dict = {}
150 | state_dict = model.state_dict()
151 | for k, v in pretrain_dict.items():
152 | if k in state_dict:
153 | model_dict[k] = v
154 | state_dict.update(model_dict)
155 | model.load_state_dict(state_dict)
--------------------------------------------------------------------------------
/core/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 |
4 |
5 | __all__ = ['MobileNetV2', 'MobileNet_v2']
6 |
7 |
8 | model_urls = {
9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
10 | }
11 |
12 |
13 | def _make_divisible(v, divisor, min_value=None):
14 | if min_value is None:
15 | min_value = divisor
16 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
17 | # Make sure that round down does not go down by more than 10%.
18 | if new_v < 0.9 * v:
19 | new_v += divisor
20 | return new_v
21 |
22 |
23 | class ConvBNReLU(nn.Sequential):
24 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
25 | padding = (kernel_size - 1) // 2
26 | super(ConvBNReLU, self).__init__(
27 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
28 | nn.BatchNorm2d(out_planes),
29 | nn.ReLU6(inplace=True)
30 | )
31 |
32 |
33 | class InvertedResidual(nn.Module):
34 | def __init__(self, inp, oup, stride, expand_ratio):
35 | super(InvertedResidual, self).__init__()
36 | self.stride = stride
37 | assert stride in [1, 2]
38 |
39 | hidden_dim = int(round(inp * expand_ratio))
40 | self.use_res_connect = self.stride == 1 and inp == oup
41 |
42 | layers = []
43 | if expand_ratio != 1:
44 | # pw
45 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
46 | layers.extend([
47 | # dw
48 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
49 | # pw-linear
50 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
51 | nn.BatchNorm2d(oup),
52 | ])
53 | self.conv = nn.Sequential(*layers)
54 |
55 | def forward(self, x):
56 | if self.use_res_connect:
57 | return x + self.conv(x)
58 | else:
59 | return self.conv(x)
60 |
61 |
62 | class MobileNetV2(nn.Module):
63 | def __init__(self, first_layer_input_channels=3, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None):
64 | super(MobileNetV2, self).__init__()
65 |
66 | if block is None:
67 | block = InvertedResidual
68 |
69 | input_channel = 32
70 | last_channel = 1280
71 |
72 | if inverted_residual_setting is None:
73 | inverted_residual_setting = [
74 | # t, c, n, s
75 | [1, 16, 1, 1],
76 | [6, 24, 2, 2],
77 | [6, 32, 3, 2],
78 | [6, 64, 4, 2],
79 | [6, 96, 3, 1],
80 | [6, 160, 3, 2],
81 | [6, 320, 1, 1],
82 | ]
83 |
84 | # only check the first element, assuming user knows t,c,n,s are required
85 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
86 | raise ValueError("inverted_residual_setting should be non-empty "
87 | "or a 4-element list, got {}".format(inverted_residual_setting))
88 |
89 | # building first layer
90 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
91 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
92 | features = [ConvBNReLU(first_layer_input_channels, input_channel, stride=2)]
93 |
94 | # building inverted residual blocks
95 | for t, c, n, s in inverted_residual_setting:
96 | output_channel = _make_divisible(c * width_mult, round_nearest)
97 | for i in range(n):
98 | if i == 0:
99 | features.append(block(input_channel, output_channel, s, expand_ratio=t))
100 | else:
101 | features.append(block(input_channel, output_channel, 1, expand_ratio=t))
102 | input_channel = output_channel
103 |
104 | # building last several layers
105 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
106 |
107 | self.low_level_features = nn.Sequential(*features[0:4])
108 | self.high_level_features = nn.Sequential(*features[4:-1])
109 |
110 | def forward(self, x):
111 | x = self.low_level_features(x)
112 | low_level_feat = x
113 | x = self.high_level_features(x)
114 |
115 | return x, low_level_feat
116 |
117 | def get_train_parameters(self, lr):
118 | train_params = [{'params': self.parameters(), 'lr': lr}]
119 |
120 | return train_params
121 |
122 |
123 | def MobileNet_v2(pretrained=False, **kwargs):
124 | model = MobileNetV2(**kwargs)
125 | if pretrained:
126 | _load_pretrained_model(model, model_urls['mobilenet_v2'])
127 |
128 | return model
129 |
130 | def _load_pretrained_model(model, url):
131 | pretrain_dict = load_state_dict_from_url(url)
132 | model_dict = {}
133 | state_dict = model.state_dict()
134 | for k, v in pretrain_dict.items():
135 | if k in state_dict:
136 | model_dict[k] = v
137 | state_dict.update(model_dict)
138 | model.load_state_dict(state_dict)
--------------------------------------------------------------------------------
/core/models/pspnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision import models
5 | import torch.utils.model_zoo as model_zoo
6 |
7 | RESNET_101 = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
8 |
9 | class _PyramidPoolingModule(nn.Module):
10 | def __init__(self, in_dim, reduction_dim, setting):
11 | super(_PyramidPoolingModule, self).__init__()
12 | self.features = []
13 | for s in setting:
14 | self.features.append(nn.Sequential(
15 | nn.AdaptiveAvgPool2d(s),
16 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
17 | nn.BatchNorm2d(reduction_dim, momentum=.95),
18 | nn.ReLU(inplace=True)
19 | ))
20 | self.features = nn.ModuleList(self.features)
21 |
22 | def forward(self, x):
23 | x_size = x.size()
24 | out = [x]
25 | for f in self.features:
26 | out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
27 | out = torch.cat(out, 1)
28 | return out
29 |
30 |
31 | class PSPNet(nn.Module):
32 | def __init__(self, num_classes, args=None):
33 | super(PSPNet, self).__init__()
34 | resnet = models.resnet101()
35 |
36 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
37 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
38 |
39 | for n, m in self.layer3.named_modules():
40 | if 'conv2' in n:
41 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
42 | elif 'downsample.0' in n:
43 | m.stride = (1, 1)
44 | for n, m in self.layer4.named_modules():
45 | if 'conv2' in n:
46 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
47 | elif 'downsample.0' in n:
48 | m.stride = (1, 1)
49 |
50 | self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
51 | self.final = nn.Sequential(
52 | nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
53 | nn.BatchNorm2d(512, momentum=.95),
54 | nn.ReLU(inplace=True),
55 | nn.Dropout(0.1),
56 | nn.Conv2d(512, num_classes, kernel_size=1)
57 | )
58 |
59 | def forward(self, x):
60 | x_size = x.size()
61 |
62 | x = self.layer0(x)
63 | x = self.layer1(x)
64 | x = self.layer2(x)
65 | x = self.layer3(x)
66 | x = self.layer4(x)
67 |
68 | x = self.ppm(x)
69 | x = self.final(x)
70 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True)
71 |
72 | return x
--------------------------------------------------------------------------------
/core/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code source: torchvision repository resnet code
3 | """
4 |
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | RESNET_101 = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
9 | RESNET_50 = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
10 | RESNET_34 = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
11 | RESNET_18 = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 | __constants__ = ['downsample']
16 |
17 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm_layer=None):
18 | super(BasicBlock, self).__init__()
19 | if norm_layer is None:
20 | norm_layer = nn.BatchNorm2d
21 | if dilation > 1:
22 | dilation = 1
23 |
24 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
25 | self.conv1 = nn.Conv2d(inplanes, planes, stride=stride, kernel_size=3, padding=1, bias=False)
26 | self.bn1 = norm_layer(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = nn.Conv2d(planes, planes, stride=stride, kernel_size=3, padding=1, bias=False)
29 | self.bn2 = norm_layer(planes)
30 | self.downsample = downsample
31 | self.stride = stride
32 |
33 | def forward(self, x):
34 | identity = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | identity = self.downsample(x)
45 |
46 | out += identity
47 | out = self.relu(out)
48 |
49 | return out
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm_layer=nn.BatchNorm2d):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn1 = norm_layer(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | dilation=dilation, padding=dilation, bias=False)
60 | self.bn2 = norm_layer(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.bn3 = norm_layer(planes * 4)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 | self.dilation = dilation
67 |
68 | def forward(self, x):
69 | residual = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv3(out)
80 | out = self.bn3(out)
81 |
82 | if self.downsample is not None:
83 | residual = self.downsample(x)
84 |
85 | out += residual
86 | out = self.relu(out)
87 |
88 | return out
89 |
90 | class ResNet(nn.Module):
91 |
92 | def __init__(self, block, layers, output_stride, norm_layer=nn.BatchNorm2d, input_channels=3):
93 | self.inplanes = 64
94 |
95 | super(ResNet, self).__init__()
96 | blocks = [1, 2, 4]
97 | if output_stride == 16:
98 | strides = [1, 2, 2, 1]
99 | dilations = [1, 1, 1, 2]
100 | elif output_stride == 8:
101 | strides = [1, 2, 1, 1]
102 | dilations = [1, 1, 2, 4]
103 | else:
104 | raise NotImplementedError
105 |
106 | # Modules
107 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
108 | self.bn1 = norm_layer(64)
109 | self.relu = nn.ReLU(inplace=True)
110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
111 |
112 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], norm_layer=norm_layer)
113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], norm_layer=norm_layer)
114 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], norm_layer=norm_layer)
115 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], norm_layer=norm_layer)
116 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], nn.BatchNorm2d=nn.BatchNorm2d)
117 | self._init_weight()
118 |
119 | def _init_weight(self):
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
123 | elif isinstance(m, nn.BatchNorm2d):
124 | nn.init.constant_(m.weight, 1)
125 | nn.init.constant_(m.bias, 0)
126 |
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
133 | norm_layer(planes * block.expansion),
134 | )
135 |
136 | layers = []
137 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, norm_layer=norm_layer))
138 | self.inplanes = planes * block.expansion
139 |
140 | for i in range(1, blocks):
141 | layers.append(block(self.inplanes, planes, dilation=dilation, norm_layer=norm_layer))
142 |
143 | return nn.Sequential(*layers)
144 |
145 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
146 | downsample = None
147 | if stride != 1 or self.inplanes != planes * block.expansion:
148 | downsample = nn.Sequential(
149 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
150 | norm_layer(planes * block.expansion),
151 | )
152 |
153 | layers = []
154 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample, norm_layer=norm_layer))
155 | self.inplanes = planes * block.expansion
156 |
157 | for i in range(1, len(blocks)):
158 | layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation, norm_layer=norm_layer))
159 |
160 | return nn.Sequential(*layers)
161 |
162 | def forward(self, x):
163 | x = self.conv1(x)
164 | x = self.bn1(x)
165 | x = self.relu(x)
166 | x = self.maxpool(x)
167 |
168 | x = self.layer1(x)
169 | low_level_feat = x
170 | x = self.layer2(x)
171 | x = self.layer3(x)
172 | x = self.layer4(x)
173 |
174 | return x, low_level_feat
175 |
176 |
177 | def ResNet18(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3):
178 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, norm_layer=norm_layer, input_channels=input_channels)
179 |
180 | if pretrained:
181 | _load_pretrained_model(model, RESNET_18)
182 |
183 | return model
184 |
185 |
186 | def ResNet34(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3):
187 | model = ResNet(BasicBlock, [3, 4, 23, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels)
188 |
189 | if pretrained:
190 | _load_pretrained_model(model, RESNET_34)
191 |
192 | return model
193 |
194 | def ResNet101(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3):
195 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels)
196 |
197 | if pretrained:
198 | _load_pretrained_model(model, RESNET_101)
199 |
200 | return model
201 |
202 |
203 | def ResNet50(output_stride, norm_layer=nn.BatchNorm2d, pretrained=True, input_channels=3):
204 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, norm_layer=norm_layer, input_channels=input_channels)
205 |
206 | if pretrained:
207 | _load_pretrained_model(model, RESNET_50)
208 |
209 | return model
210 |
211 |
212 | def _load_pretrained_model(model, url):
213 | pretrain_dict = model_zoo.load_url(url)
214 | model_dict = {}
215 | state_dict = model.state_dict()
216 | for k, v in pretrain_dict.items():
217 | if k in state_dict:
218 | model_dict[k] = v
219 | state_dict.update(model_dict)
220 | model.load_state_dict(state_dict)
--------------------------------------------------------------------------------
/core/models/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | Lighter U-net implementation that achieves same performance as the one reported in the paper: https://arxiv.org/abs/1505.04597
3 | Main differences:
4 | a) U-net downblock has only 1 convolution instead of 2
5 | b) U-net upblock has only 1 convolution instead of 3
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | from constants import *
11 |
12 | class UNetDownBlock(nn.Module):
13 | """
14 | Constructs a UNet downsampling block
15 |
16 | Parameters:
17 | input_nc (int) -- the number of input channels
18 | output_nc (int) -- the number of output channels
19 | norm_layer (str) -- normalization layer
20 | down_type (str) -- if we should use strided convolution or maxpool for reducing the feature map
21 | outermost (bool) -- if this module is the outermost module
22 | innermost (bool) -- if this module is the innermost module
23 | user_dropout (bool) -- if use dropout layers.
24 | kernel_size (int) -- convolution kernel size
25 | bias (boolean) -- if convolution should use bias
26 | """
27 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, down_type=STRIDECONV, outermost=False, innermost=False, dropout=0.2, kernel_size=4, bias=True):
28 | super(UNetDownBlock, self).__init__()
29 | self.innermost = innermost
30 | self.outermost = outermost
31 | self.use_maxpool = down_type == MAXPOOL
32 |
33 | stride = 1 if self.use_maxpool else 2
34 | kernel_size = 3 if self.use_maxpool else 4
35 | self.conv = nn.Conv2d(input_nc, output_nc, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)
36 | self.relu = nn.LeakyReLU(0.2)
37 | self.maxpool = nn.MaxPool2d(2)
38 | self.norm = norm_layer(output_nc)
39 | self.dropout = nn.Dropout2d(dropout) if dropout else None
40 |
41 | def forward(self, x):
42 | if self.outermost:
43 | x = self.conv(x)
44 | x = self.norm(x)
45 | elif self.innermost:
46 | x = self.relu(x)
47 | if self.dropout: x = self.dropout(x)
48 | x = self.conv(x)
49 | else:
50 | x = self.relu(x)
51 | if self.dropout: x = self.dropout(x)
52 | x = self.conv(x)
53 | x = self.norm(x)
54 |
55 | return x
56 |
57 | class UNetUpBlock(nn.Module):
58 | """
59 | Constructs a UNet upsampling block
60 |
61 | Parameters:
62 | input_nc (int) -- the number of input channels
63 | output_nc (int) -- the number of output channels
64 | norm_layer -- normalization layer
65 | outermost (bool) -- if this module is the outermost module
66 | innermost (bool) -- if this module is the innermost module
67 | user_dropout (bool) -- if use dropout layers.
68 | kernel_size (int) -- convolution kernel size
69 | """
70 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, outermost=False, innermost=False, dropout=0.2, kernel_size=4, use_bias=True):
71 | super(UNetUpBlock, self).__init__()
72 | self.innermost = innermost
73 | self.outermost = outermost
74 | upconv_inner_nc = input_nc * 2
75 |
76 | if self.innermost:
77 | self.conv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1, bias=use_bias)
78 | elif self.outermost:
79 | self.conv = nn.ConvTranspose2d(upconv_inner_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1)
80 | else:
81 | self.conv = nn.ConvTranspose2d(upconv_inner_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1, bias=use_bias)
82 |
83 | self.norm = norm_layer(output_nc)
84 | self.relu = nn.ReLU()
85 | self.dropout = nn.Dropout2d(dropout) if dropout else None
86 |
87 | def forward(self, x):
88 | if self.outermost:
89 | x = self.relu(x)
90 | if self.dropout: x = self.dropout(x)
91 | x = self.conv(x)
92 | elif self.innermost:
93 | x = self.relu(x)
94 | if self.dropout: x = self.dropout(x)
95 | x = self.conv(x)
96 | x = self.norm(x)
97 | else:
98 | x = self.relu(x)
99 | if self.dropout: x = self.dropout(x)
100 | x = self.conv(x)
101 | x = self.norm(x)
102 |
103 | return x
104 |
105 | class UNet(nn.Module):
106 | """Create a Unet-based Fully Convolutional Network
107 | X -------------------identity----------------------
108 | |-- downsampling -- |submodule| -- upsampling --|
109 |
110 | Parameters:
111 | num_classes (int) -- the number of channels in output images
112 | norm_layer -- normalization layer
113 | input_nc -- number of channels of input image
114 |
115 | Args:
116 | mode (str) -- process single frames or sequence of frames
117 | timesteps (int) --
118 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
119 | image of size 128x128 will become of size 1x1 # at the bottleneck
120 | ngf (int) -- the number of filters in the last conv layer
121 | reconstruct (int [0,1])-- if we should reconstruct the next image or not
122 | sequence_model (str) -- the sequence model that for the sequence mode []
123 | num_levels_tcn(int) -- number of levels of the TemporalConvNet
124 | """
125 |
126 | def __init__(self, num_classes, args, norm_layer=nn.BatchNorm2d, input_nc=3):
127 | super(UNet, self).__init__()
128 |
129 | self.refine_network = args.refine_network
130 | self.num_downs = args.num_downs
131 | self.ngf = args.ngf
132 |
133 | self.encoder = self.build_encoder(self.num_downs, input_nc, self.ngf, norm_layer, down_type=args.down_type)
134 | self.decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer)
135 |
136 | if self.refine_network:
137 | self.refine_encoder = self.build_encoder(self.num_downs, input_nc + num_classes, self.ngf, norm_layer, down_type=args.down_type)
138 | self.refine_decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer)
139 |
140 | def build_encoder(self, num_downs, input_nc, ngf, norm_layer, down_type=STRIDECONV):
141 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetDownBlocks
142 |
143 | Parameters:
144 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
145 | image of size 128x128 will become of size 1x1 # at the bottleneck
146 | input_nc (int) -- the number of input channels
147 | ngf (int) -- the number of filters in the last conv layer
148 | norm_layer (str) -- normalization layer
149 | down_type (str) -- if we should use strided convolution or maxpool for reducing the feature map
150 | Returns:
151 | nn.Sequential consisting of $num_downs UnetDownBlocks
152 | """
153 | layers = []
154 | layers.append(UNetDownBlock(input_nc=input_nc, output_nc=ngf, norm_layer=norm_layer, down_type=down_type, outermost=True))
155 | layers.append(UNetDownBlock(input_nc=ngf, output_nc=ngf*2, norm_layer=norm_layer, down_type=down_type))
156 | layers.append(UNetDownBlock(input_nc=ngf*2, output_nc=ngf*4, norm_layer=norm_layer, down_type=down_type))
157 | layers.append(UNetDownBlock(input_nc=ngf*4, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type))
158 |
159 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
160 | layers.append(UNetDownBlock(input_nc=ngf*8, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type))
161 |
162 | layers.append(UNetDownBlock(input_nc=ngf*8, output_nc=ngf*8, norm_layer=norm_layer, down_type=down_type, innermost=True))
163 |
164 | return nn.Sequential(*layers)
165 |
166 | def build_decoder(self, num_downs, num_classes, ngf, norm_layer):
167 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetUpBlocks
168 |
169 | Parameters:
170 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
171 | image of size 128x128 will become of size 1x1 # at the bottleneck
172 | num_classes (int) -- number of classes to classify
173 | output_nc (int) -- the number of output channels. outermost is ngf, innermost is ngf * 8
174 | norm_layer -- normalization layer
175 |
176 | Returns:
177 | nn.Sequential consisting of $num_downs UnetUpBlocks
178 | """
179 | layers = []
180 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, innermost=True))
181 |
182 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
183 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer))
184 |
185 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 4, norm_layer=norm_layer))
186 | layers.append(UNetUpBlock(input_nc=ngf * 4, output_nc=ngf * 2, norm_layer=norm_layer))
187 | layers.append(UNetUpBlock(input_nc=ngf*2, output_nc=ngf, norm_layer=norm_layer))
188 | layers.append(UNetUpBlock(input_nc=ngf, output_nc=num_classes, norm_layer=norm_layer, outermost=True))
189 |
190 | return nn.Sequential(*layers)
191 |
192 | def encoder_forward(self, x, use_refine_network=False):
193 | skip_connections = []
194 | model = self.refine_encoder if use_refine_network else self.encoder
195 |
196 | for i, down in enumerate(model):
197 | x = down(x)
198 | if down.use_maxpool:
199 | x = down.maxpool(x)
200 |
201 | if not down.innermost:
202 | skip_connections.append(x)
203 |
204 | return x, skip_connections
205 |
206 | def decoder_forward(self, x, skip_connections, use_refine_network=False):
207 | model = self.refine_decoder if use_refine_network else self.decoder
208 |
209 | for i, up in enumerate(model):
210 | if not up.innermost:
211 | skip = skip_connections[-i]
212 | out = torch.cat([skip, out], 1)
213 | out = up(out)
214 | else:
215 | out = up(x)
216 |
217 | return out
218 |
219 | def forward(self, x):
220 | output, skip_connections = self.encoder_forward(x)
221 | output = self.decoder_forward(output, skip_connections)
222 |
223 | if self.refine_network:
224 | second_output, skip_connections = self.encoder_forward(torch.cat((x, output), dim=1), use_refine_network=True)
225 | second_output = self.decoder_forward(second_output, skip_connections, use_refine_network=True)
226 |
227 | return output, second_output
228 |
229 | return output
230 |
231 | def get_train_parameters(self, lr):
232 | params = [{'params': self.parameters(), 'lr': lr}]
233 |
234 | return params
--------------------------------------------------------------------------------
/core/models/unet_paper.py:
--------------------------------------------------------------------------------
1 | """
2 | U-net implementation from the reported paper: https://arxiv.org/abs/1505.04597
3 | Model follows the paper implementation, except for the use of padding in order to keep feature maps size the same.
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class UnetConvBlock(nn.Module):
11 | """
12 | Constructs a UNet downsampling block
13 |
14 | Parameters:
15 | input_nc (int) -- the number of input channels
16 | output_nc (int) -- the number of output channels
17 | norm_layer -- normalization layer
18 | outermost (bool) -- if this module is the outermost module
19 | innermost (bool) -- if this module is the innermost module
20 | user_dropout (bool) -- if use dropout layers.
21 | kernel_size (int) -- convolution kernel size
22 | """
23 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, padding=0, innermost=False, dropout=0.2):
24 | super(UnetConvBlock, self).__init__()
25 | block = []
26 |
27 | block.append(nn.Conv2d(input_nc, output_nc, kernel_size=3, padding=int(padding)))
28 | block.append(norm_layer(output_nc))
29 | block.append(nn.ReLU())
30 |
31 | block.append(nn.Conv2d(output_nc, output_nc, kernel_size=3, padding=int(padding)))
32 | block.append(norm_layer(output_nc))
33 | block.append(nn.ReLU())
34 |
35 | self.block = nn.Sequential(*block)
36 | self.innermost = innermost
37 |
38 | def forward(self, x):
39 | out = self.block(x)
40 | return out
41 |
42 | class UNetUpBlock(nn.Module):
43 | """
44 | Constructs a UNet upsampling block
45 |
46 | Parameters:
47 | input_nc (int) -- the number of input channels
48 | output_nc (int) -- the number of output channels
49 | norm_layer -- normalization layer
50 | outermost (bool) -- if this module is the outermost module
51 | innermost (bool) -- if this module is the innermost module
52 | user_dropout (bool) -- if use dropout layers.
53 | kernel_size (int) -- convolution kernel size
54 | remove_skip (bool) -- if skip connections should be disabled or not
55 | """
56 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, padding=1, remove_skip=False, outermost=False):
57 | super(UNetUpBlock, self).__init__()
58 |
59 | self.up = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=2, stride=2)
60 | self.conv_block = UnetConvBlock(output_nc * 2, output_nc, norm_layer, padding)
61 | self.outermost = outermost
62 |
63 | def forward(self, x, skip=None):
64 | out = self.up(x)
65 |
66 | if skip is not None:
67 | out = torch.cat([out, skip], 1)
68 | out = self.conv_block(out)
69 |
70 | return out
71 |
72 | class UNet_paper(nn.Module):
73 | """Create a Unet-based Fully Convolutional Network
74 | X -------------------identity----------------------
75 | |-- downsampling -- |submodule| -- upsampling --|
76 |
77 | Parameters:
78 | num_classes (int) -- the number of channels in output images
79 | norm_layer -- normalization layer
80 | input_nc -- number of channels of input image
81 |
82 | Args:
83 | mode (str) -- process single frames or sequence of frames
84 | timesteps (int) --
85 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
86 | image of size 128x128 will become of size 1x1 # at the bottleneck
87 | ngf (int) -- the number of filters in the last conv layer
88 | remove_skip (int [0,1])-- if skip connections should be disabled or not
89 | reconstruct (int [0,1])-- if we should reconstruct the next image or not
90 | sequence_model (str) -- the sequence model that for the sequence mode []
91 | num_levels_tcn(int) -- number of levels of the TemporalConvNet
92 | """
93 |
94 | def __init__(self, num_classes, args, norm_layer=nn.BatchNorm2d, input_nc=3):
95 | super(UNet_paper, self).__init__(args)
96 |
97 | self.num_downs = args.num_downs
98 | self.ngf = args.ngf
99 |
100 | self.encoder = self.build_encoder(self.num_downs, input_nc, self.ngf, norm_layer)
101 | self.decoder = self.build_decoder(self.num_downs, num_classes, self.ngf, norm_layer)
102 | self.decoder_last_conv = nn.Conv2d(self.ngf, num_classes, 1)
103 |
104 | def build_encoder(self, num_downs, input_nc, ngf, norm_layer):
105 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetDownBlocks
106 |
107 | Parameters:
108 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
109 | image of size 128x128 will become of size 1x1 # at the bottleneck
110 | input_nc (int) -- the number of input channels
111 | ngf (int) -- the number of filters in the last conv layer
112 | norm_layer -- normalization layer
113 | Returns:
114 | nn.Sequential consisting of $num_downs UnetDownBlocks
115 | """
116 | layers = []
117 | layers.append(UnetConvBlock(input_nc=input_nc, output_nc=ngf, norm_layer=norm_layer, padding=1))
118 | layers.append(UnetConvBlock(input_nc=ngf, output_nc=ngf * 2, norm_layer=norm_layer, padding=1))
119 | layers.append(UnetConvBlock(input_nc=ngf * 2, output_nc=ngf * 4, norm_layer=norm_layer, padding=1))
120 | layers.append(UnetConvBlock(input_nc=ngf * 4, output_nc=ngf * 8, norm_layer=norm_layer, padding=1))
121 |
122 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
123 | layers.append(UnetConvBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, padding=1))
124 |
125 | layers.append(UnetConvBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, padding=1, innermost=True))
126 |
127 | return nn.Sequential(*layers)
128 |
129 | def build_decoder(self, num_downs, num_classes, ngf, norm_layer, remove_skip=0):
130 | """Constructs a UNet downsampling encoder, consisting of $num_downs UNetUpBlocks
131 |
132 | Parameters:
133 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
134 | image of size 128x128 will become of size 1x1 # at the bottleneck
135 | num_classes (int) -- number of classes to classify
136 | output_nc (int) -- the number of output channels. outermost is ngf, innermost is ngf * 8
137 | norm_layer -- normalization layer
138 | remove_skip (int) -- if skip connections should be disabled or not
139 |
140 | Returns:
141 | nn.Sequential consisting of $num_downs UnetUpBlocks
142 | """
143 | layers = []
144 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, remove_skip=remove_skip))
145 |
146 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
147 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 8, norm_layer=norm_layer, remove_skip=remove_skip))
148 |
149 | layers.append(UNetUpBlock(input_nc=ngf * 8, output_nc=ngf * 4, norm_layer=norm_layer, remove_skip=remove_skip))
150 | layers.append(UNetUpBlock(input_nc=ngf * 4, output_nc=ngf * 2, norm_layer=norm_layer, remove_skip=remove_skip))
151 | layers.append(UNetUpBlock(input_nc=ngf*2, output_nc=ngf, norm_layer=norm_layer, remove_skip=remove_skip, outermost=True))
152 |
153 | return nn.Sequential(*layers)
154 |
155 | def encoder_forward(self, x):
156 | skip_connections = []
157 | for i, down in enumerate(self.encoder):
158 | x = down(x)
159 |
160 | if not down.innermost:
161 | skip_connections.append(x)
162 | x = F.max_pool2d(x, 2)
163 |
164 | return x, skip_connections
165 |
166 | def decoder_forward(self, x, skip_connections):
167 | out = None
168 | for i, up in enumerate(self.decoder):
169 | skip = skip_connections.pop()
170 | if out is None:
171 | out = up(x, skip)
172 | else:
173 | out = up(out, skip)
174 |
175 | out = self.decoder_last_conv(out)
176 | return out
177 |
178 | def forward(self, x):
179 | x, skip_connections = self.encoder_forward(x)
180 | out = self.decoder_forward(x, skip_connections)
181 |
182 | return out
--------------------------------------------------------------------------------
/core/models/unet_pytorch.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class UNet_torch(nn.Module):
7 |
8 | def __init__(self, num_classes=1, args=None, in_channels=3):
9 | super(UNet_torch, self).__init__()
10 | self.ngf = args.ngf
11 |
12 | self.encoder1 = UNet_torch._block(in_channels, self.ngf, name="enc1")
13 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
14 |
15 | self.encoder2 = UNet_torch._block(self.ngf, self.ngf * 2, name="enc2")
16 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
17 |
18 | self.encoder3 = UNet_torch._block(self.ngf * 2, self.ngf * 4, name="enc3")
19 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
20 |
21 | self.encoder4 = UNet_torch._block(self.ngf * 4, self.ngf * 8, name="enc4")
22 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
23 |
24 | self.bottleneck = UNet_torch._block(self.ngf * 8, self.ngf * 16, name="bottleneck")
25 |
26 | self.upconv4 = nn.ConvTranspose2d(
27 | self.ngf * 16, self.ngf * 8, kernel_size=2, stride=2
28 | )
29 | self.decoder4 = UNet_torch._block((self.ngf * 8) * 2, self.ngf * 8, name="dec4")
30 | self.upconv3 = nn.ConvTranspose2d(
31 | self.ngf * 8, self.ngf * 4, kernel_size=2, stride=2
32 | )
33 | self.decoder3 = UNet_torch._block((self.ngf * 4) * 2, self.ngf * 4, name="dec3")
34 | self.upconv2 = nn.ConvTranspose2d(
35 | self.ngf * 4, self.ngf * 2, kernel_size=2, stride=2
36 | )
37 | self.decoder2 = UNet_torch._block((self.ngf * 2) * 2, self.ngf * 2, name="dec2")
38 | self.upconv1 = nn.ConvTranspose2d(
39 | self.ngf * 2, self.ngf, kernel_size=2, stride=2
40 | )
41 | self.decoder1 = UNet_torch._block(self.ngf * 2, self.ngf, name="dec1")
42 |
43 | self.conv = nn.Conv2d(
44 | in_channels=self.ngf, out_channels=num_classes, kernel_size=1
45 | )
46 |
47 | def forward(self, x):
48 | enc1 = self.encoder1(x)
49 | enc2 = self.encoder2(self.pool1(enc1))
50 | enc3 = self.encoder3(self.pool2(enc2))
51 | enc4 = self.encoder4(self.pool3(enc3))
52 | bottleneck = self.bottleneck(self.pool4(enc4))
53 |
54 | dec4 = self.upconv4(bottleneck)
55 | dec4 = torch.cat((dec4, enc4), dim=1)
56 | dec4 = self.decoder4(dec4)
57 | dec3 = self.upconv3(dec4)
58 | dec3 = torch.cat((dec3, enc3), dim=1)
59 | dec3 = self.decoder3(dec3)
60 | dec2 = self.upconv2(dec3)
61 | dec2 = torch.cat((dec2, enc2), dim=1)
62 | dec2 = self.decoder2(dec2)
63 | dec1 = self.upconv1(dec2)
64 | dec1 = torch.cat((dec1, enc1), dim=1)
65 | dec1 = self.decoder1(dec1)
66 |
67 | final = self.conv(dec1)
68 | return final
69 |
70 | @staticmethod
71 | def _block(in_channels, features, name):
72 | return nn.Sequential(
73 | OrderedDict(
74 | [
75 | (
76 | name + "conv1",
77 | nn.Conv2d(
78 | in_channels=in_channels,
79 | out_channels=features,
80 | kernel_size=3,
81 | padding=1,
82 | bias=False,
83 | ),
84 | ),
85 | (name + "norm1", nn.BatchNorm2d(num_features=features)),
86 | (name + "relu1", nn.ReLU(inplace=True)),
87 | (
88 | name + "conv1",
89 | nn.Conv2d(
90 | in_channels=in_channels,
91 | out_channels=features,
92 | kernel_size=3,
93 | padding=1,
94 | bias=False,
95 | ),
96 | ),
97 | (name + "norm1", nn.BatchNorm2d(num_features=features)),
98 | (name + "relu1", nn.ReLU(inplace=True)),
99 | ]
100 | )
101 | )
102 |
103 | def get_train_parameters(self, lr):
104 | params = [{'params': self.parameters(), 'lr': lr}]
105 |
106 | return params
--------------------------------------------------------------------------------
/core/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from tqdm import tqdm
3 | import torch
4 | import math
5 | import matplotlib.pyplot as plt
6 | import time
7 |
8 | from util.general_functions import get_model, get_optimizer, make_data_loader, get_loss_function, get_flat_images
9 | from util.lr_scheduler import LR_Scheduler
10 | from util.summary import TensorboardSummary
11 | from constants import *
12 | from util.ssim import MS_SSIM, SSIM
13 |
14 |
15 | class Trainer(object):
16 |
17 | def __init__(self, args):
18 | self.args = args
19 | self.best_loss = math.inf
20 | self.summary = TensorboardSummary(args)
21 | self.model = get_model(args)
22 |
23 | if args.inference:
24 | self.model = self.summary.load_network(self.model)
25 |
26 | if args.save_best_model:
27 | self.best_model = copy.deepcopy(self.model)
28 |
29 | self.optimizer = get_optimizer(self.model, args)
30 | self.ssim, self.ms_ssim = SSIM(), MS_SSIM()
31 |
32 | if args.trainval:
33 | self.train_loader, self.val_loader = make_data_loader(args, TRAINVAL), make_data_loader(args, TEST)
34 | else:
35 | self.train_loader, self.test_loader = make_data_loader(args, TRAIN), make_data_loader(args, TEST)
36 |
37 | self.criterion = get_loss_function(args.loss_type)
38 | self.scheduler = LR_Scheduler(args.lr_policy, args.lr, args.epochs, len(self.train_loader))
39 |
40 | if args.second_loss:
41 | self.second_criterion = get_loss_function(MS_SSIM_LOSS)
42 |
43 | def run_epoch(self, epoch, split=TRAIN):
44 | total_loss = 0.0
45 | ssim_values, ms_ssim_values = [], []
46 |
47 | if split == TRAIN:
48 | self.model.train()
49 | loader = self.train_loader
50 | elif split == VAL:
51 | self.model.eval()
52 | loader = self.val_loader
53 | else:
54 | self.model.eval()
55 | loader = self.test_loader
56 |
57 | bar = tqdm(loader)
58 | num_img = len(loader)
59 |
60 | for i, sample in enumerate(bar):
61 | with torch.autograd.set_detect_anomaly(True):
62 | image = sample[0]
63 | target = sample[1]
64 |
65 | if self.args.cuda:
66 | image, target = image.cuda(), target.cuda()
67 |
68 | if split == TRAIN:
69 | self.scheduler(self.optimizer, i, epoch, self.best_loss)
70 | self.optimizer.zero_grad()
71 |
72 | if self.args.refine_network:
73 | first_output, output = self.model(image)
74 | else:
75 | output = self.model(image)
76 | else:
77 | with torch.no_grad():
78 | if self.args.refine_network:
79 | first_output, output = self.model(image)
80 | else:
81 | output = self.model(image)
82 |
83 | loss = self.criterion(output, target)
84 | if self.args.refine_network:
85 | refine_loss = self.criterion(first_output, target)
86 | loss += refine_loss
87 |
88 | if self.args.second_loss:
89 | flat_output_img, flat_target_img = get_flat_images(self.args.dataset, image, output, target)
90 | second_loss = self.second_criterion(flat_output_img, flat_target_img)
91 | loss += self.args.second_loss_rate * second_loss
92 |
93 | if self.args.refine_network:
94 | flat_first_output_img, flat_target_img = get_flat_images(self.args.dataset, image, first_output, target)
95 | third_loss = self.second_criterion(flat_first_output_img, flat_target_img)
96 | loss += third_loss
97 |
98 | if split == TRAIN:
99 | loss.backward()
100 |
101 | if self.args.clip > 0:
102 | if self.args.gpu_ids:
103 | torch.nn.utils.clip_grad_norm_(self.model.module().parameters(), self.args.clip)
104 | else:
105 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
106 |
107 | self.optimizer.step()
108 |
109 | if split == TEST:
110 | ssim_values.append(self.ssim.forward(output, target))
111 | ms_ssim_values.append(self.ms_ssim.forward(output, target))
112 |
113 | # Show 10 * 3 inference results each epoch
114 | if split != VISUALIZATION and i % (num_img // 10) == 0:
115 | self.summary.visualize_image(image, target, output, split=split)
116 | elif split == VISUALIZATION:
117 | self.summary.visualize_image(image, target, output, split=split)
118 |
119 | total_loss += loss.item()
120 | bar.set_description(split +' loss: %.3f' % (loss.item()))
121 |
122 | if split == TEST:
123 | ssim = sum(ssim_values) / len(ssim_values)
124 | ms_ssim = sum(ms_ssim_values) / len(ms_ssim_values)
125 | self.summary.add_scalar(split + '/ssim', ssim, epoch)
126 | self.summary.add_scalar(split + '/ms_ssim', ms_ssim, epoch)
127 |
128 | if total_loss < self.best_loss:
129 | self.best_loss = total_loss
130 | if self.args.save_best_model:
131 | self.best_model = copy.deepcopy(self.model)
132 |
133 | self.summary.add_scalar(split + '/total_loss_epoch', total_loss, epoch)
134 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
135 |
136 | def calculate_inference_speed(self, iterations):
137 | loader = self.test_loader
138 | bar = tqdm(loader)
139 | self.model.eval()
140 | times = []
141 |
142 | for i, sample in enumerate(bar):
143 | image = sample[0]
144 |
145 | start = time.time()
146 | with torch.no_grad():
147 | output = self.model(image)
148 | end = time.time()
149 | current_time = end - start
150 | print(current_time)
151 | times.append(current_time)
152 |
153 | if i>= iterations:
154 | break
155 |
156 | return sum(times) / len(times)
157 |
158 | def save_network(self):
159 | self.summary.save_network(self.model)
--------------------------------------------------------------------------------
/dataloader/docunet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils import data
4 | import torchvision.transforms as standard_transforms
5 | import util.custom_transforms as custom_transforms
6 | from scipy import io
7 |
8 | class Docunet(data.Dataset):
9 |
10 | NUM_CLASSES = 2
11 | CLASSES = ["foreground", "background"]
12 | ROOT = '../../../datasets/'
13 | DEFORMED = 'deformed_labels'
14 | DEFORMED_EXT = '.jpg'
15 | VECTOR_FIELD = 'target_vf'
16 | VECTOR_FIELD_EXT = '.mat'
17 |
18 |
19 | def __init__(self, args, split="train"):
20 | self.args = args
21 | self.split = split
22 | self.dataset = self.make_dataset()
23 |
24 | if len(self.dataset) == 0:
25 | raise RuntimeError('Found 0 images, please check the dataset')
26 |
27 | self.transform = self.get_transforms()
28 |
29 | def __len__(self):
30 | return len(self.dataset)
31 |
32 | def __getitem__(self, index):
33 | image_path, label_path = self.dataset[index]
34 | image = Image.open(image_path)
35 | label = io.loadmat(label_path)['vector_field']
36 |
37 | if self.transform is not None:
38 | image = self.transform(image)
39 |
40 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label)
41 |
42 | return image, label
43 |
44 | def make_dataset(self):
45 | current_dir = os.path.dirname(__file__)
46 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size)))
47 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.VECTOR_FIELD + '_' + 'x'.join(map(str, self.args.size)))
48 |
49 | images_name = os.listdir(images_path)
50 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)]
51 | items = []
52 |
53 | for i in range(len(images_name)):
54 | image_name = images_name[i]
55 | label_name = image_name.replace(self.DEFORMED_EXT, self.VECTOR_FIELD_EXT)
56 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name)))
57 |
58 | return items
59 |
60 | def get_transforms(self):
61 | return None
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/dataloader/docunet_im2im.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils import data
4 | import torchvision.transforms as standard_transforms
5 | import util.custom_transforms as custom_transforms
6 | from scipy import io
7 |
8 |
9 | class DocunetIm2Im(data.Dataset):
10 | NUM_CLASSES = 2
11 | CLASSES = ["foreground", "background"]
12 | ROOT = '../../../datasets/'
13 | DEFORMED = 'deformed_labels'
14 | LABELS = 'cropped_labels'
15 | DEFORMED_EXT = '.jpg'
16 | LABEL_EXT = '.jpg'
17 |
18 | def __init__(self, args, split="train"):
19 | self.args = args
20 | self.split = split
21 | self.dataset = self.make_dataset()
22 |
23 | if len(self.dataset) == 0:
24 | raise RuntimeError('Found 0 images, please check the dataset')
25 |
26 | self.joint_transform = self.get_transforms()
27 |
28 | def __len__(self):
29 | return len(self.dataset)
30 |
31 | def __getitem__(self, index):
32 | image_path, label_path = self.dataset[index]
33 | image = Image.open(image_path)
34 | label = Image.open(label_path)
35 |
36 | if self.joint_transform is not None:
37 | image, label = self.joint_transform(image, label)
38 |
39 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label)
40 |
41 | return image, label
42 |
43 | def make_dataset(self):
44 | current_dir = os.path.dirname(__file__)
45 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size)))
46 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.LABELS)
47 |
48 | images_name = os.listdir(images_path)
49 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)]
50 | items = []
51 |
52 | for i in range(len(images_name)):
53 | image_name = images_name[i]
54 | label_name = '_'.join(image_name.split('_')[:-1]) + self.LABEL_EXT
55 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name)))
56 |
57 | return items
58 |
59 | def get_transforms(self):
60 | if self.split == 'train':
61 | joint = custom_transforms.Compose([
62 | custom_transforms.Resize(self.args.resize),
63 | custom_transforms.RandomHorizontallyFlip(),
64 | # custom_transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
65 | custom_transforms.RandomGaussianBlur()
66 | ])
67 | elif self.split == 'val' or self.split == 'test' or self.split == 'demoVideo':
68 | joint = custom_transforms.Compose([
69 | custom_transforms.Resize(self.args.resize),
70 | ])
71 | else:
72 | raise RuntimeError('Invalid dataset mode')
73 |
74 | return joint
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/dataloader/docunet_inverted.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils import data
4 | import torchvision.transforms as standard_transforms
5 | import util.custom_transforms as custom_transforms
6 | from scipy import io
7 |
8 | from constants import *
9 |
10 | class InvertedDocunet(data.Dataset):
11 | NUM_CLASSES = 2
12 | CLASSES = ["foreground", "background"]
13 | ROOT = '../../../datasets/'
14 | DEFORMED = 'deformed_labels'
15 | DEFORMED_EXT = '.jpg'
16 | VECTOR_FIELD = 'inverted_vf'
17 | VECTOR_FIELD_EXT = '.mat'
18 |
19 | def __init__(self, args, split=TRAIN):
20 | self.args = args
21 | self.split = split
22 | self.dataset = self.make_dataset()
23 |
24 | if len(self.dataset) == 0:
25 | raise RuntimeError('Found 0 images, please check the dataset')
26 |
27 | self.transform = self.get_transforms()
28 |
29 | def __len__(self):
30 | return len(self.dataset)
31 |
32 | def __getitem__(self, index):
33 | image_path, label_path = self.dataset[index]
34 |
35 |
36 | image = Image.open(image_path)
37 | label = io.loadmat(label_path)['inverted_vector_field']
38 |
39 | if self.transform is not None:
40 | image = self.transform(image)
41 |
42 | image, label = standard_transforms.ToTensor()(image), standard_transforms.ToTensor()(label)
43 |
44 | label = label.float()
45 | return image, label
46 |
47 | def make_dataset(self):
48 | current_dir = os.path.dirname(__file__)
49 | images_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.DEFORMED + '_' + 'x'.join(map(str, self.args.size)))
50 | labels_path = os.path.join(current_dir, self.ROOT, self.args.dataset_dir, self.split, self.VECTOR_FIELD + '_' + 'x'.join(map(str, self.args.size)))
51 |
52 | images_name = os.listdir(images_path)
53 | images_name = [image_name for image_name in images_name if image_name.endswith(self.DEFORMED_EXT)]
54 | items = []
55 |
56 | for i in range(len(images_name)):
57 | image_name = images_name[i]
58 | label_name = image_name.replace(self.DEFORMED_EXT, self.VECTOR_FIELD_EXT)
59 | items.append((os.path.join(images_path, image_name), os.path.join(labels_path, label_name)))
60 |
61 | return items
62 |
63 | def get_transforms(self):
64 | return None
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: unwarping_assignment
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _tflow_select=2.3.0=mkl
8 | - absl-py=0.8.1=py37_0
9 | - astor=0.7.1=py_0
10 | - blas=1.0=mkl
11 | - ca-certificates=2019.11.28=hecc5488_0
12 | - certifi=2019.11.28=py37_0
13 | - cffi=1.13.2=py37h7a1dbc1_0
14 | - cudatoolkit=10.1.243=h74a9793_0
15 | - cycler=0.10.0=py_2
16 | - freetype=2.9.1=ha9979f8_1
17 | - gast=0.2.2=py_0
18 | - google-pasta=0.1.8=py_0
19 | - grpcio=1.23.0=py37h3db2c7e_0
20 | - h5py=2.10.0=nompi_py37h422b98e_100
21 | - hdf5=1.10.5=nompi_ha405e13_1104
22 | - icc_rt=2019.0.0=h0cc432a_1
23 | - icu=64.2=he025d50_1
24 | - intel-openmp=2019.4=245
25 | - jpeg=9c=hfa6e2cd_1001
26 | - keras-applications=1.0.8=py_1
27 | - keras-preprocessing=1.1.0=py_0
28 | - kiwisolver=1.1.0=py37he980bc4_0
29 | - libblas=3.8.0=14_mkl
30 | - libcblas=3.8.0=14_mkl
31 | - libclang=9.0.0=default_hf44288c_4
32 | - liblapack=3.8.0=14_mkl
33 | - liblapacke=3.8.0=14_mkl
34 | - libmklml=2019.0.5=0
35 | - libopencv=4.1.2=py37_2
36 | - libpng=1.6.37=h2a8f88b_0
37 | - libprotobuf=3.11.1=h1a1b453_0
38 | - libtiff=4.1.0=h56a325e_0
39 | - libwebp=1.0.2=hfa6e2cd_4
40 | - m2w64-gcc-libgfortran=5.3.0=6
41 | - m2w64-gcc-libs=5.3.0=7
42 | - m2w64-gcc-libs-core=5.3.0=7
43 | - m2w64-gmp=6.1.0=2
44 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
45 | - markdown=3.1.1=py_0
46 | - matplotlib=3.1.2=py37_1
47 | - matplotlib-base=3.1.2=py37h2981e6d_1
48 | - mkl=2019.4=245
49 | - mkl-service=2.3.0=py37hb782905_0
50 | - mkl_fft=1.0.15=py37h14836fe_0
51 | - mkl_random=1.1.0=py37h675688f_0
52 | - msys2-conda-epoch=20160418=1
53 | - ninja=1.9.0=py37h74a9793_0
54 | - numpy=1.17.4=py37h4320e6b_0
55 | - numpy-base=1.17.4=py37hc3f5095_0
56 | - olefile=0.46=py37_0
57 | - opencv=4.1.2=py37_2
58 | - openssl=1.1.1d=hfa6e2cd_0
59 | - opt_einsum=3.1.0=py_0
60 | - pillow=6.2.1=py37hdc69c19_0
61 | - pip=19.3.1=py37_0
62 | - protobuf=3.11.1=py37he025d50_0
63 | - py-opencv=4.1.2=py37h5ca1d4c_2
64 | - pycparser=2.19=py37_0
65 | - pyparsing=2.4.5=py_0
66 | - pyqt=5.12.3=py37h6538335_1
67 | - pyreadline=2.1=py37_1001
68 | - python=3.7.5=h8c8aaf0_0
69 | - python-dateutil=2.8.1=py_0
70 | - pytorch=1.3.1=py3.7_cuda101_cudnn7_0
71 | - qt=5.12.5=h7ef1ec2_0
72 | - scipy=1.3.2=py37h29ff71c_0
73 | - setuptools=42.0.2=py37_0
74 | - sip=4.19.8=py37h6538335_0
75 | - six=1.13.0=py37_0
76 | - sqlite=3.30.1=he774522_0
77 | - tensorboard=2.0.0=pyhb38c66f_1
78 | - tensorflow=2.0.0=mkl_py37he1bbcac_0
79 | - tensorflow-base=2.0.0=mkl_py37hd1d5974_0
80 | - tensorflow-estimator=2.0.0=pyh2649769_0
81 | - termcolor=1.1.0=py_2
82 | - tk=8.6.8=hfa6e2cd_0
83 | - torchvision=0.4.2=py37_cu101
84 | - tornado=6.0.3=py37hfa6e2cd_0
85 | - tqdm=4.40.0=py_0
86 | - vc=14.1=h0510ff6_4
87 | - vs2015_runtime=14.16.27012=hf0eaf9b_0
88 | - werkzeug=0.16.0=py_0
89 | - wheel=0.33.6=py37_0
90 | - wincertstore=0.2=py37_0
91 | - wrapt=1.11.2=py37hfa6e2cd_0
92 | - xz=5.2.4=h2fa13f4_4
93 | - zlib=1.2.11=h62dcd97_3
94 | - zstd=1.3.7=h508b16e_0
95 | - pip:
96 | - opencv-python==4.1.2.30
97 | - pyqt5-sip==4.19.18
98 | - pyqtwebengine==5.12.1
99 | prefix: C:\Users\r.sibechi\AppData\Local\Continuum\anaconda3\envs\primevision_torch
100 |
101 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from parser_options import ParserOptions
2 | from util.general_functions import print_training_info
3 | from constants import *
4 | from core.trainers.trainer import Trainer
5 |
6 | def main():
7 | args = ParserOptions().parse() # get training options
8 | trainer = Trainer(args)
9 |
10 | print_training_info(args)
11 |
12 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
13 | trainer.run_epoch(epoch, split=TRAIN)
14 |
15 | if epoch % args.eval_interval == (args.eval_interval - 1):
16 | trainer.run_epoch(epoch, split=TEST)
17 |
18 | trainer.run_epoch(trainer.args.epochs, split=VISUALIZATION)
19 | trainer.summary.writer.add_scalar('test/best_result', trainer.best_loss, args.epochs)
20 | trainer.summary.writer.close()
21 | trainer.save_network()
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
--------------------------------------------------------------------------------
/parser_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import math
5 |
6 | from constants import *
7 |
8 | class ParserOptions():
9 | """This class defines options that are used by the program"""
10 |
11 | def __init__(self):
12 | parser = argparse.ArgumentParser(description='PyTorch Semantic Video Segmentation training')
13 |
14 | # model specific
15 | parser.add_argument('--model', type=str, default=DEEPLAB_50, choices=[DEEPLAB, DEEPLAB_50, DEEPLAB_34, DEEPLAB_18, DEEPLAB_MOBILENET, DEEPLAB_MOBILENET_DILATION, UNET, UNET_PAPER, UNET_PYTORCH, PSPNET], help='model name (default:' + DEEPLAB + ')')
16 | parser.add_argument('--separable_conv', type=int, default=0, choices=[0,1], help='if we should convert normal convolutions to separable convolutions' )
17 | parser.add_argument('--refine_network', type=int, default=0, choices=[0,1], help='if we should refine the first prediction with a second network ')
18 | parser.add_argument('--dataset', type=str, default=DOCUNET_INVERTED, choices=[DOCUNET, DOCUNET_IM2IM, DOCUNET_INVERTED], help='dataset (default:' + DOCUNET + ')')
19 | parser.add_argument('--dataset_dir', type=str, default=ADDRESS_DATASET, choices=[HAZMAT_DATASET, LABELS_DATASET, ADDRESS_DATASET], help='name of the dir in which the dataset is located')
20 | parser.add_argument('--loss_type', type=str, default=DOCUNET_LOSS, choices=[DOCUNET_LOSS, SSIM_LOSS, SSIM_LOSS_V2, MS_SSIM_LOSS, MS_SSIM_LOSS_V2, L1_LOSS, SMOOTH_L1_LOSS, MSE_LOSS], help='loss func type (default:' + DOCUNET_LOSS + ')')
21 | parser.add_argument('--second_loss', type=int, default=0, choices=[0,1], help='if we should use two losses')
22 | parser.add_argument('--second_loss_rate', type=float, default=10, help='used to tune the overall impact of the second loss')
23 | parser.add_argument('--norm_layer', type=str, default=BATCH_NORM, choices=[INSTANCE_NORM, BATCH_NORM, SYNC_BATCH_NORM])
24 | parser.add_argument('--init_type', type=str, default=KAIMING_INIT, choices=[NORMAL_INIT, KAIMING_INIT, XAVIER_INIT, ORTHOGONAL_INIT])
25 | parser.add_argument('--resize', type=str, default='64,64', help='image resize: h,w')
26 | parser.add_argument('--batch_size', type=int, default=2, metavar='N', help='input batch size for training (default: 2)')
27 | parser.add_argument('--optim', type=str, default=ADAM, choices=[SGD, ADAM, RMSPROP, AMSGRAD, ADABOUND])
28 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: auto)')
29 | parser.add_argument('--lr_policy', type=str, default='poly', choices=['poly', 'step', 'cos', 'linear'], help='lr scheduler mode: (default: poly)')
30 | parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='M', help='w-decay (default: 5e-4)')
31 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)')
32 | parser.add_argument('--clip', type=float, default=0, help='gradient clip, 0 means no clip (default: 0)')
33 |
34 | # training specific
35 | parser.add_argument('--size', type=str, default='1024,1024', help='image size: h,w')
36 | parser.add_argument('--start_epoch', type=int, default=0, metavar='N', help='starting epoch')
37 | parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: auto)')
38 | parser.add_argument('--eval-interval', type=int, default=1, help='evaluation interval (default: 1)')
39 | parser.add_argument('--trainval', type=int, default=0, choices=[0,1], help='determines whether whe should use validation images as well for training')
40 | parser.add_argument('--inference', type=int, default=0, choices=[0,1], help='if we should run the model in inference mode')
41 | parser.add_argument('--debug', type=int, default=1)
42 | parser.add_argument('--results_root', type=str, default='..')
43 | parser.add_argument('--results_dir', type=str, default='results_final', help='models are saved here')
44 | parser.add_argument('--save_dir', type=str, default='saved_models')
45 | parser.add_argument('--save_best_model', type=int, default=0, choices=[0,1], help='keep track of best model')
46 | parser.add_argument('--pretrained_models_dir', type=str, default='pretrained_models', help='root dir of the pretrained models location')
47 |
48 | # deeplab specific
49 | parser.add_argument('--output_stride', type=int, default=16, help='network output stride (default: 16)')
50 | parser.add_argument('--pretrained', type=int, default=0, choices=[0,1], help='if we should use a pretrained model or not')
51 | parser.add_argument('--learned_upsampling', type=int, default=0, choices=[0,1], help='if we should use bilinear upsampling or learned upsampling')
52 | parser.add_argument('--use_aspp', type=int, default=1, choices=[0,1], help='if we should aspp in the deeplab head or not')
53 |
54 | # unet specific
55 | parser.add_argument('--num_downs', type=int, default=8, help='number of unet encoder-decoder blocks')
56 | parser.add_argument('--ngf', type=int, default=128, help='# of gen filters in the last conv layer')
57 | parser.add_argument('--down_type', type=str, default=MAXPOOL, choices=[STRIDECONV, MAXPOOL], help='method to reduce feature map size')
58 | parser.add_argument('--dropout', type=float, default=0.2)
59 |
60 | args = parser.parse_args()
61 | args.size = tuple([int(x) for x in args.size.split(',')])
62 | args.resize = tuple([int(x) for x in args.resize.split(',')])
63 |
64 | if args.debug:
65 | args.results_dir = 'results_dummy'
66 |
67 | args.num_downs = int(math.log(args.resize[0])/math.log(2))
68 | args.cuda = torch.cuda.is_available()
69 | args.gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'] if ('CUDA_VISIBLE_DEVICES' in os.environ) else ''
70 | args.gpu_ids = list(range(len(args.gpu_ids.split(',')))) if (',' in args.gpu_ids and args.cuda) else None
71 |
72 | if args.gpu_ids and args.norm_layer == BATCH_NORM:
73 | args.norm_layer = SYNC_BATCH_NORM
74 |
75 | if args.dataset == DOCUNET or args.dataset == DOCUNET_INVERTED:
76 | args.size = args.resize
77 |
78 | if args.dataset == DOCUNET_IM2IM and args.loss_type == DOCUNET_LOSS:
79 | args.loss_type = MS_SSIM_LOSS
80 | args.double_loss = 0
81 |
82 | self.args = args
83 |
84 | def parse(self):
85 | return self.args
--------------------------------------------------------------------------------
/playground.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from scipy import io
3 | import random
4 | import torch
5 | import torchvision.transforms as standard_transforms
6 | import matplotlib.pyplot as plt
7 | from tqdm import tqdm
8 | import time
9 | import os
10 |
11 | from core.trainers.trainer import Trainer
12 | from parser_options import ParserOptions
13 | from util.general_functions import apply_transformation_to_image_cv, apply_transformation_to_image, invert_vector_field, get_model, make_data_loader
14 | from constants import *
15 |
16 | def cv2_invert(invert=True):
17 | deformed_label = cv2.imread('../deformed_label.jpg', cv2.IMREAD_COLOR)
18 |
19 | if invert:
20 | vector_field = io.loadmat('../fm.mat')['vector_field']
21 | else:
22 | vector_field = io.loadmat('../fm.mat')['inverted_vector_field']
23 |
24 | flatten_label = apply_transformation_to_image_cv(deformed_label, vector_field, invert=invert)
25 | cv2.imwrite('../our_flatten.jpg', flatten_label)
26 |
27 | def pytorch_invert(invert=False):
28 | deformed_label = cv2.imread('../deformed_label.jpg', cv2.IMREAD_COLOR)
29 | deformed_label = standard_transforms.ToTensor()(deformed_label).unsqueeze(dim=0)
30 |
31 | if invert:
32 | vector_field = io.loadmat('../fm.mat')['vector_field']
33 | vector_field = invert_vector_field(vector_field)
34 | else:
35 | vector_field = io.loadmat('../fm.mat')['inverted_vector_field']
36 |
37 | vector_field = torch.Tensor(vector_field).unsqueeze(dim=0)
38 | vector_field = vector_field.permute(0, 3, 1, 2)
39 | flatten_label = apply_transformation_to_image(deformed_label, vector_field)
40 | plt.imsave('../our_flatten_2.jpg', flatten_label.permute(1,2,0).cpu().numpy())
41 |
42 | def check_duplicates(source_folder_name, destination_folder_name):
43 | source_files = set(os.listdir(source_folder_name))
44 | destination_files = set(os.listdir(destination_folder_name))
45 | intersection = source_files.intersection(destination_files)
46 | intersection.remove('Thumbs.db')
47 |
48 | if len(intersection) == 0:
49 | print("OK")
50 | else:
51 | print("NOT OK")
52 |
53 | def network_predict(iterations=51, pretrained_model=''):
54 | if not pretrained_model:
55 | raise NotImplementedError()
56 |
57 | args = ParserOptions().parse()
58 | args.cuda = False
59 | args.batch_size = 1
60 | args.inference = 1
61 | args.pretrained_models_dir = pretrained_model
62 | args.num_downs = 8
63 | args.resize, args.size = (256,256), (256,256)
64 | args.model = DEEPLAB_50
65 | #args.refine_network = 1
66 | trainer = Trainer(args)
67 | mean_time = trainer.calculate_inference_speed(iterations)
68 | print('Mean time', mean_time)
69 |
70 | if __name__ == "__main__":
71 | model = 'saved_models/deeplab_50.pth'
72 | network_predict(pretrained_model=model)
--------------------------------------------------------------------------------
/readme_images/generating_deformed_images.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/generating_deformed_images.PNG
--------------------------------------------------------------------------------
/readme_images/output_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/output_examples.png
--------------------------------------------------------------------------------
/readme_images/overall_architecture.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mhashas/Document-Image-Unwarping-pytorch/92b29172b981d132f7b31e767505524f8cc7af7a/readme_images/overall_architecture.PNG
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | MODEL=$1
2 | DATASET=$2
3 | LOSS_TYPE=$3
4 | BATCH_SIZE=$4
5 | RESIZE=$5
6 | REFINE_NETWORK=${6:-0}
7 | SECOND_LOSS=${7:-0}
8 |
9 | python main.py --model $MODEL --dataset $DATASET --loss_type $LOSS_TYPE --batch_size $BATCH_SIZE \
10 | --resize $RESIZE --epochs 100 --norm_layer batch --debug 0 \
11 | --refine_network $REFINE_NETWORK --second_loss $SECOND_LOSS
--------------------------------------------------------------------------------
/util/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import types
2 | import random
3 | import torch
4 |
5 | from PIL import Image, ImageOps, ImageFilter
6 | from torchvision.transforms import functional as F
7 | import numpy as np
8 |
9 | class Lambda(object):
10 | """Apply a user-defined lambda as a transform.
11 |
12 | Args:
13 | lambd (function): Lambda/function to be used for transform.
14 | """
15 |
16 | def __init__(self, lambd):
17 | assert isinstance(lambd, types.LambdaType)
18 | self.lambd = lambd
19 |
20 | def __call__(self, img):
21 | return self.lambd(img)
22 |
23 | def __repr__(self):
24 | return self.__class__.__name__ + '()'
25 |
26 | class Compose(object):
27 | def __init__(self, transforms):
28 | self.transforms = transforms
29 |
30 | def __call__(self, image, label=None):
31 | #assert image.size == label.size
32 |
33 | for t in self.transforms:
34 | image, label = t(image, label)
35 | return image, label
36 |
37 |
38 | class RandomCrop(object):
39 | def __init__(self, size, padding=0):
40 | self.size = size
41 | self.padding = padding
42 |
43 | def __call__(self, image, label):
44 | if self.padding > 0:
45 | image = ImageOps.expand(image, border=self.padding, fill=0)
46 | label = ImageOps.expand(label, border=self.padding, fill=0)
47 | # @TODO RADU
48 | #assert image.size == label.size
49 | w, h = image.size
50 | tw, th = self.size
51 |
52 | if w == tw and h == th:
53 | return image, label
54 | if w < tw or h < th:
55 | return image.resize((tw, th), Image.BILINEAR), label.resize((tw, th), Image.BILINEAR)
56 |
57 | x1 = random.randint(0, w - tw)
58 | y1 = random.randint(0, h - th)
59 |
60 | cropped_images, cropped_labels = image.crop((x1, y1, x1 + tw, y1 + th)), label.crop((x1, y1, x1 + tw, y1 + th))
61 | return cropped_images, cropped_labels
62 |
63 | class RandomHorizontallyFlip(object):
64 | def __call__(self, image, label):
65 | if random.random() < 0.5:
66 | return image.transpose(Image.FLIP_LEFT_RIGHT), label.transpose(Image.FLIP_LEFT_RIGHT)
67 | return image, label
68 |
69 | class Scale(object):
70 | def __init__(self, size, label_resize_type=Image.BILINEAR):
71 | self.size = size
72 | self.label_resize_type = label_resize_type
73 |
74 | def __call__(self, image, label):
75 | assert image.size == label.size
76 | w, h = image.size
77 |
78 | if (w >= h and w == self.size) or (h >= w and h == self.size):
79 | return image, label
80 | if w > h:
81 | ow = self.size
82 | oh = int(self.size * h / w)
83 | return image.resize((ow, oh), Image.BILINEAR), label.resize((ow, oh), self.label_resize_type)
84 | else:
85 | oh = self.size
86 | ow = int(self.size * w / h)
87 | return image.resize((ow, oh), Image.BILINEAR), label.resize((ow, oh), self.label_resize_type)
88 |
89 | class Resize(object):
90 | def __init__(self, size, label_resize_type=Image.BILINEAR):
91 | self.size = size
92 | self.label_resize_type = label_resize_type
93 |
94 | def __call__(self, image, label):
95 | #assert image.size == label.size
96 | w, h = image.size
97 | w_label, h_label = label.size
98 |
99 | if not ((w >= h and w == self.size[0]) or (h >= w and h == self.size[1])):
100 | image = image.resize(self.size, Image.BILINEAR)
101 |
102 | if not ((w_label >= h_label and w_label == self.size[0]) or (h_label >= w_label and h_label == self.size[1])):
103 | label = label.resize(self.size, self.label_resize_type)
104 |
105 | return image, label
106 |
107 | class RandomGaussianBlur(object):
108 | def __call__(self, image, label=None):
109 | if random.random() < 0.5:
110 | radius = random.random()
111 | image = image.filter(ImageFilter.GaussianBlur(radius=radius))
112 |
113 | if label is not None:
114 | return image, label
115 | return image
116 |
117 | class Normalize(object):
118 | """Normalize a tensor image with mean and standard deviation.
119 | Args:
120 | mean (tuple): means for each channel.
121 | std (tuple): standard deviations for each channel.
122 | """
123 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
124 | self.mean = mean
125 | self.std = std
126 |
127 | def __call__(self, img):
128 | img = np.array(img).astype(np.float32)
129 | img /= 255.0
130 | img -= self.mean
131 | img /= self.std
132 |
133 | return img
134 |
135 | class ColorJitter(object):
136 | """Randomly change the brightness, contrast and saturation of an image.
137 |
138 | Args:
139 | brightness (float): How much to jitter brightness. brightness_factor
140 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
141 | contrast (float): How much to jitter contrast. contrast_factor
142 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
143 | saturation (float): How much to jitter saturation. saturation_factor
144 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
145 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
146 | [-hue, hue]. Should be >=0 and <= 0.5.
147 | """
148 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
149 | self.brightness = brightness
150 | self.contrast = contrast
151 | self.saturation = saturation
152 | self.hue = hue
153 |
154 | @staticmethod
155 | def get_params(brightness, contrast, saturation, hue):
156 | """Get a randomized transform to be applied on image.
157 |
158 | Arguments are same as that of __init__.
159 |
160 | Returns:
161 | Transform which randomly adjusts brightness, contrast and
162 | saturation in a random order.
163 | """
164 | transforms = []
165 | if brightness > 0:
166 | brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
167 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
168 |
169 | if contrast > 0:
170 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
171 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
172 |
173 | if saturation > 0:
174 | saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
175 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
176 |
177 | if hue > 0:
178 | hue_factor = random.uniform(-hue, hue)
179 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
180 |
181 | random.shuffle(transforms)
182 | return transforms
183 |
184 | @staticmethod
185 | def forward_transforms(image, transforms):
186 | for transform in transforms:
187 | image = transform(image)
188 |
189 | return image
190 |
191 | def __call__(self, image, label):
192 | """
193 | Args:
194 | images (PIL Image): Input image.
195 |
196 | Returns:
197 | PIL Image: Color jittered image.
198 | """
199 | transforms = self.get_params(self.brightness, self.contrast,
200 | self.saturation, self.hue)
201 |
202 | return self.forward_transforms(image, transforms), label
--------------------------------------------------------------------------------
/util/general_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import ConcatDataset
5 | from torch.utils.data import DataLoader
6 | import numpy as np
7 | import functools
8 | import cv2
9 |
10 | from core.models.deeplabv3_plus import DeepLabv3_plus
11 | from core.models.mobilenetv2 import MobileNet_v2
12 | from core.models.pspnet import PSPNet
13 | from core.models.unet import UNet
14 | from core.models.unet_paper import UNet_paper
15 | from core.models.unet_pytorch import UNet_torch
16 |
17 | from dataloader.docunet import Docunet
18 | from dataloader.docunet_inverted import InvertedDocunet
19 | from dataloader.docunet_im2im import DocunetIm2Im
20 |
21 | from util.losses import *
22 | from constants import *
23 |
24 | def make_data_loader(args, split=TRAIN):
25 | """
26 | Builds the model based on the provided arguments
27 |
28 | Parameters:
29 | args (argparse) -- input arguments
30 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
31 | gain (float) -- scaling factor for normal, xavier and orthogonal.
32 | """
33 | if args.dataset == DOCUNET:
34 | dataset = Docunet
35 | elif args.dataset == DOCUNET_INVERTED:
36 | dataset = InvertedDocunet
37 | elif args.dataset == DOCUNET_IM2IM:
38 | dataset = DocunetIm2Im
39 | else:
40 | raise NotImplementedError
41 |
42 | if split == TRAINVAL:
43 | train_set = dataset(args, split=TRAIN)
44 | val_set = dataset(args, split=VAL)
45 | trainval_set = ConcatDataset([train_set, val_set])
46 | loader = DataLoader(trainval_set, batch_size=args.batch_size, num_workers=1, shuffle=True)
47 | else:
48 | set = dataset(args, split=split)
49 |
50 | if split == TRAIN:
51 | loader = DataLoader(set, batch_size=args.batch_size, num_workers=1, shuffle=True)
52 | else:
53 | loader = DataLoader(set, batch_size=args.batch_size, num_workers=1, shuffle=False)
54 |
55 | return loader
56 |
57 | def get_model(args):
58 | """
59 | Builds the model based on the provided arguments and returns the initialized model
60 |
61 | Parameters:
62 | args (argparse) -- command line arguments
63 | """
64 |
65 | norm_layer = get_norm_layer(args.norm_layer)
66 | num_classes = get_num_classes(args.dataset)
67 |
68 | if DEEPLAB in args.model:
69 | model = DeepLabv3_plus(args, num_classes=num_classes, norm_layer=norm_layer)
70 |
71 | if args.model != DEEPLAB_MOBILENET and args.separable_conv:
72 | convert_to_separable_conv(model)
73 |
74 | model = init_model(model, args.init_type)
75 | elif UNET in args.model:
76 | if args.model == UNET:
77 | model = UNet(num_classes=num_classes, args=args, norm_layer=norm_layer)
78 | elif args.model == UNET_PAPER:
79 | model = UNet_paper(num_classes=num_classes, args=args, norm_layer=norm_layer)
80 | elif args.model == UNET_PYTORCH:
81 | model = UNet_torch(num_classes=num_classes, args=args)
82 |
83 | if args.separable_conv:
84 | convert_to_separable_conv(model)
85 |
86 | model = init_model(model, args.init_type)
87 | elif PSPNET in args.model:
88 | model = PSPNet(num_classes=num_classes, args=args)
89 | else:
90 | raise NotImplementedError
91 |
92 | print("Built " + args.model)
93 |
94 | if args.cuda:
95 | model = model.cuda()
96 |
97 | return model
98 |
99 | def convert_to_separable_conv(module):
100 | class SeparableConvolution(nn.Module):
101 | """ Separable Convolution
102 | """
103 |
104 | def __init__(self, in_channels, out_channels, kernel_size,
105 | stride=1, padding=0, dilation=1, bias=True):
106 | super(SeparableConvolution, self).__init__()
107 | self.body = nn.Sequential(
108 | # Separable Conv
109 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
110 | dilation=dilation, bias=bias, groups=in_channels),
111 | # PointWise Conv
112 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
113 | )
114 |
115 | def forward(self, x):
116 | return self.body(x)
117 |
118 | new_module = module
119 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
120 | new_module = SeparableConvolution(module.in_channels,
121 | module.out_channels,
122 | module.kernel_size,
123 | module.stride,
124 | module.padding,
125 | module.dilation,
126 | module.bias is not None)
127 |
128 | for name, child in module.named_children():
129 | new_module.add_module(name, convert_to_separable_conv(child))
130 | return new_module
131 |
132 | def get_loss_function(mode):
133 | if mode == DOCUNET_LOSS:
134 | loss = DocunetLoss()
135 | elif mode == MS_SSIM_LOSS:
136 | loss = MS_SSIM_Loss()
137 | elif mode == MS_SSIM_LOSS_V2:
138 | loss = MS_SSIM_Loss_v2()
139 | elif mode == SSIM_LOSS:
140 | loss = SSIM_Loss()
141 | elif mode == L1_LOSS:
142 | loss = torch.nn.L1Loss()
143 | elif mode == SMOOTH_L1_LOSS:
144 | loss = torch.nn.SmoothL1Loss()
145 | elif mode == MSE_LOSS:
146 | loss = torch.nn.MSELoss()
147 | else:
148 | raise NotImplementedError
149 |
150 | return loss
151 |
152 | def get_num_classes(dataset):
153 | if dataset == DOCUNET or dataset == DOCUNET_INVERTED:
154 | num_classes = 2
155 | elif dataset == DOCUNET_IM2IM:
156 | num_classes = 3
157 | else:
158 | raise NotImplementedError
159 |
160 | return num_classes
161 |
162 |
163 | def get_optimizer(model, args):
164 | """
165 | Builds the optimizer for the model based on the provided arguments and returns the optimizer
166 |
167 | Parameters:
168 | model -- the network to be optimized
169 | args -- command line arguments
170 | """
171 | if args.gpu_ids:
172 | train_params = model.module.get_train_parameters(args.lr)
173 | else:
174 | train_params = model.get_train_parameters(args.lr)
175 |
176 | if args.optim == SGD:
177 | optimizer = optim.SGD(train_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True)
178 | elif args.optim == ADAM:
179 | optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay)
180 | elif args.optim == AMSGRAD:
181 | optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay, amsgrad=True)
182 | else:
183 | raise NotImplementedError
184 |
185 | return optimizer
186 |
187 |
188 | def get_norm_layer(norm_type=INSTANCE_NORM):
189 | """Returns a normalization layer
190 |
191 | Parameters:
192 | norm_type (str) -- the name of the normalization layer: batch | instance | none
193 |
194 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
195 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
196 | """
197 |
198 | if norm_type == BATCH_NORM:
199 | norm_layer = nn.BatchNorm2d
200 | elif norm_type == INSTANCE_NORM:
201 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
202 | elif norm_type == 'none':
203 | norm_layer = None
204 | else:
205 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
206 | return norm_layer
207 |
208 | def init_model(net, init_type=NORMAL_INIT, init_gain=0.02):
209 | """Initialize the network weights
210 |
211 | Parameters:
212 | net (network) -- the network to be initialized
213 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
214 | gain (float) -- scaling factor for normal, xavier and orthogonal.
215 |
216 | Return an initialized network.
217 | """
218 |
219 | init_weights(net, init_type, init_gain=init_gain)
220 | return net
221 |
222 | def init_weights(net, init_type=NORMAL_INIT, init_gain=0.02):
223 | """Initialize network weights.
224 |
225 | Parameters:
226 | net (network) -- network to be initialized
227 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
228 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
229 | """
230 |
231 | def init_func(m): # define the initialization function
232 | classname = m.__class__.__name__
233 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
234 | if init_type == NORMAL_INIT:
235 | nn.init.normal_(m.weight.data, 0.0, init_gain)
236 | elif init_type == XAVIER_INIT:
237 | nn.init.xavier_normal_(m.weight.data, gain=init_gain)
238 | elif init_type == KAIMING_INIT:
239 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='leaky_relu')
240 | elif init_type == ORTHOGONAL_INIT:
241 | nn.init.orthogonal_(m.weight.data, gain=init_gain)
242 | else:
243 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
244 | if hasattr(m, 'bias') and m.bias is not None:
245 | nn.init.constant_(m.bias.data, 0.0)
246 | elif hasattr(m, '_all_weights') and (classname.find('LSTM') != -1 or classname.find('GRU') != -1):
247 | for names in m._all_weights:
248 | for name in filter(lambda n: "weight" in n, names):
249 | weight = getattr(m, name)
250 | nn.init.xavier_normal_(weight.data, gain=init_gain)
251 |
252 | for name in filter(lambda n: "bias" in n, names):
253 | bias = getattr(m, name)
254 | nn.init.constant_(bias.data, 0.0)
255 |
256 | if classname.find('LSTM') != -1:
257 | n = bias.size(0)
258 | start, end = n // 4, n // 2
259 | nn.init.constant_(bias.data[start:end], 1.)
260 | elif classname.find('BatchNorm2d') != -1 or classname.find('SynchronizedBatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
261 | nn.init.normal_(m.weight.data, 1.0, init_gain)
262 | nn.init.constant_(m.bias.data, 0.0)
263 |
264 | print('Initialized network with %s' % init_type)
265 | net.apply(init_func) # apply the initialization function
266 |
267 | def set_requires_grad(net, requires_grad=False):
268 | """Set requies_grad=False for the network to avoid unnecessary computations
269 | Parameters:
270 | net (network)
271 | requires_grad (bool) -- whether the networks require gradients or not
272 | """
273 | if net is not None:
274 | for param in net.parameters():
275 | param.requires_grad = requires_grad
276 |
277 | def tensor2im(input_image, imtype=np.uint8, return_tensor=True):
278 | """"Converts a Tensor array into a numpy image array.
279 |
280 | Parameters:
281 | input_image (tensor) -- the input image tensor array
282 | imtype (type) -- the desired type of the converted numpy array
283 | """
284 | if not isinstance(input_image, np.ndarray):
285 | if isinstance(input_image, torch.Tensor): # get the data from a variable
286 | image_tensor = input_image.data
287 | else:
288 | return input_image
289 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array
290 | if image_numpy.ndim == 3:
291 | image_numpy = (image_numpy - np.min(image_numpy))/(np.max(image_numpy)-np.min(image_numpy))
292 | if image_numpy.shape[0] == 1: # grayscale to RGB
293 | image_numpy = np.tile(image_numpy, (3, 1, 1))
294 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
295 | else: # if it is a numpy array, do nothing
296 | image_numpy = input_image
297 |
298 | return torch.from_numpy(image_numpy.astype(imtype)) if return_tensor else np.transpose(image_numpy, (1,2,0))
299 |
300 | def get_flat_images(dataset, images, outputs, targets):
301 | if dataset == DOCUNET:
302 | pass
303 | elif dataset == DOCUNET_INVERTED:
304 | outputs = apply_transformation_to_image(images, outputs)
305 | targets = apply_transformation_to_image(images, targets)
306 | else:
307 | pass
308 |
309 | return outputs, targets
310 |
311 | def apply_transformation_to_image(img, vector_field):
312 | vector_field = scale_vector_field_tensor(vector_field)
313 | vector_field = vector_field.permute(0, 2, 3, 1)
314 | flatten_image = nn.functional.grid_sample(img, vector_field, mode='bilinear', align_corners=True)
315 |
316 | return flatten_image.squeeze()
317 |
318 | def apply_transformation_to_image_cv(img, vector_field, invert=False):
319 | if invert:
320 | vector_field = invert_vector_field(vector_field)
321 |
322 | map_x = vector_field[:, :, 0]
323 | map_y = vector_field[:, :, 1]
324 | transformed_img = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR)
325 |
326 | return transformed_img
327 |
328 | def invert_vector_field(vector_field):
329 | vector_field_x = vector_field[:, :, 0]
330 | vector_field_y = vector_field[:, :, 1]
331 |
332 | assert(vector_field_x.shape == vector_field_y.shape)
333 | rows = vector_field_x.shape[0]
334 | cols = vector_field_x.shape[1]
335 |
336 | m_x = np.ones(vector_field_x.shape, dtype=vector_field_x.dtype) * -1
337 | m_y = np.ones(vector_field_y.shape, dtype=vector_field_y.dtype) * -1
338 | for i in range(rows):
339 | for j in range(cols):
340 | i_ = int(round(vector_field_y[i, j]))
341 | j_ = int(round(vector_field_x[i, j]))
342 | if 0 <= i_ < rows and 0 <= j_ < cols:
343 | m_x[i_, j_] = j
344 | m_y[i_, j_] = i
345 | return np.stack([m_x, m_y], axis=2)
346 |
347 |
348 | def scale_vector_field_tensor(vector_field):
349 | vector_field = torch.where(vector_field < 0, torch.tensor(3 * vector_field.shape[3], dtype=vector_field.dtype, device=vector_field.device), vector_field)
350 | vector_field = (vector_field / (vector_field.shape[3] / 2)) - 1
351 |
352 | return vector_field
353 |
354 | def print_training_info(args):
355 | print('Dataset', args.dataset)
356 |
357 | if 'unet' in args.model:
358 | print('Ngf', args.ngf)
359 | print('Num downs', args.num_downs)
360 | print('Down type', args.down_type)
361 |
362 | if 'deeplab' in args.model:
363 | print('Output stride', args.output_stride)
364 | print('Learned upsampling', args.learned_upsampling)
365 | print('Pretrained', args.pretrained)
366 | print('Use aspp', args.use_aspp)
367 |
368 | print('Refine network', args.refine_network)
369 | print('Separable conv', args.separable_conv)
370 | print('Optimizer', args.optim)
371 | print('Learning rate', args.lr)
372 | print('Second loss', args.second_loss)
373 |
374 | if args.clip > 0:
375 | print('Gradient clip', args.clip)
376 |
377 | print('Resize', args.resize)
378 | print('Batch size', args.batch_size)
379 | print('Norm layer', args.norm_layer)
380 | print('Using cuda', args.cuda)
381 | print('Using ' + args.loss_type + ' loss')
382 | print('Starting Epoch:', args.start_epoch)
383 | print('Total Epoches:', args.epochs)
384 |
385 |
386 |
387 |
--------------------------------------------------------------------------------
/util/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from util.ssim import MS_SSIM, SSIM
6 | from util.ms_ssim import MS_SSIM_v2, SSIM_v2
7 |
8 | class DocunetLoss_v2(nn.Module):
9 | def __init__(self, r=0.1,reduction='mean'):
10 | super(DocunetLoss_v2, self).__init__()
11 | assert reduction in ['mean','sum'], " reduction must in ['mean','sum']"
12 | self.r = r
13 | self.reduction = reduction
14 |
15 | def forward(self, y, label):
16 | bs, n, h, w = y.size()
17 | d = y - label
18 | loss1 = []
19 | for d_i in d:
20 | loss1.append(torch.abs(d_i).mean() - self.r * torch.abs(d_i.mean()))
21 | loss1 = torch.stack(loss1)
22 | # lossb1 = torch.max(y1, torch.zeros(y1.shape).to(y1.device)).mean()
23 | loss2 = F.mse_loss(y, label,reduction=self.reduction)
24 |
25 | if self.reduction == 'mean':
26 | loss1 = loss1.mean()
27 | elif self.reduction == 'sum':
28 | loss1= loss1.sum()
29 | return loss1 + loss2
30 |
31 | class DocunetLoss(nn.Module):
32 | def __init__(self, lamda=0.1, reduction='mean'):
33 | super(DocunetLoss, self).__init__()
34 | self.lamda = lamda
35 | self.reduction = reduction
36 |
37 | def forward(self, output, target):
38 | x = target[:, 0, :, :]
39 | y = target[:, 1, :, :]
40 | back_sign_x, back_sign_y = (x == -1).int(), (y == -1).int()
41 | # assert back_sign_x == back_sign_y
42 |
43 | back_sign = ((back_sign_x + back_sign_y) == 2).float()
44 | fore_sign = 1 - back_sign
45 |
46 | loss_term_1_x = torch.sum(torch.abs(output[:, 0, :, :] - x) * fore_sign) / torch.sum(fore_sign)
47 | loss_term_1_y = torch.sum(torch.abs(output[:, 1, :, :] - y) * fore_sign) / torch.sum(fore_sign)
48 | loss_term_1 = loss_term_1_x + loss_term_1_y
49 |
50 | loss_term_2_x = torch.abs(torch.sum((output[:, 0, :, :] - x) * fore_sign)) / torch.sum(fore_sign)
51 | loss_term_2_y = torch.abs(torch.sum((output[:, 1, :, :] - y) * fore_sign)) / torch.sum(fore_sign)
52 | loss_term_2 = loss_term_2_x + loss_term_2_y
53 |
54 | zeros_x = torch.zeros(x.size()).cuda() if torch.cuda.is_available() else torch.zeros(x.size())
55 | zeros_y = torch.zeros(y.size()).cuda() if torch.cuda.is_available() else torch.zeros(y.size())
56 |
57 | loss_term_3_x = torch.max(zeros_x, output[:, 0, :, :])
58 | loss_term_3_y = torch.max(zeros_y, output[:, 1, :, :])
59 | loss_term_3 = torch.sum((loss_term_3_x + loss_term_3_y) * back_sign) / torch.sum(back_sign)
60 |
61 | loss = loss_term_1 - self.lamda * loss_term_2 + loss_term_3
62 |
63 | return loss
64 |
65 | class MS_SSIM_Loss(MS_SSIM):
66 | def __init__(self, window_size=11, size_average=True, channel=3):
67 | super(MS_SSIM_Loss, self).__init__(window_size=window_size, size_average=size_average, channel=channel)
68 |
69 | def forward(self, img1, img2):
70 | ms_ssim = super(MS_SSIM_Loss, self).forward(img1, img2)
71 | return 100*( 1 - ms_ssim)
72 |
73 | class SSIM_Loss(SSIM):
74 | def __init__(self, window_size=11, size_average=True, val_range=None):
75 | super(SSIM_Loss, self).__init__(window_size=window_size, size_average=size_average, val_range=val_range)
76 |
77 | def forward(self, img1, img2):
78 | ssim = super(SSIM_Loss, self).forward(img1, img2)
79 | return 100*( 1 - ssim)
80 |
81 | class MS_SSIM_Loss_v2(MS_SSIM_v2):
82 | def __init__(self, window_size=11, size_average=True, channel=3):
83 | super(MS_SSIM_Loss_v2, self).__init__(win_size=window_size, size_average=size_average, channel=channel, data_range=255, nonnegative_ssim=True)
84 |
85 | def forward(self, img1, img2):
86 | ms_ssim = super(MS_SSIM_Loss_v2, self).forward(img1, img2)
87 | return 100*( 1 - ms_ssim)
88 |
89 | class SSIM_Loss_v2(SSIM_v2):
90 | def __init__(self, window_size=11, size_average=True):
91 | super(SSIM_Loss_v2, self).__init__(win_size=window_size, size_average=size_average, data_range=255, nonnegative_ssim=True)
92 |
93 | def forward(self, img1, img2):
94 | ssim = super(SSIM_Loss_v2, self).forward(img1, img2)
95 | return 100*( 1 - ssim)
96 |
--------------------------------------------------------------------------------
/util/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | import math
12 |
13 | class LR_Scheduler(object):
14 | """Learning Rate Scheduler
15 |
16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}``
17 |
18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))``
19 |
20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9``
21 |
22 | Args:
23 | args:
24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`),
25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
26 | :attr:`args.lr_step`
27 |
28 | iters_per_epoch: number of iterations per epoch
29 | """
30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
31 | lr_step=0, warmup_epochs=0):
32 | self.mode = mode
33 | print('Using {} LR Scheduler!'.format(self.mode))
34 | self.lr = base_lr
35 | if mode == 'step':
36 | assert lr_step
37 | self.lr_step = lr_step
38 | self.iters_per_epoch = iters_per_epoch
39 | self.N = num_epochs * iters_per_epoch
40 | self.epoch = -1
41 | self.warmup_iters = warmup_epochs * iters_per_epoch
42 |
43 | def __call__(self, optimizer, i, epoch, best_pred):
44 | T = epoch * self.iters_per_epoch + i
45 | if self.mode == 'cos':
46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi))
47 | elif self.mode == 'poly':
48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9)
49 | elif self.mode == 'step':
50 | lr = self.lr * (0.1 ** (epoch // self.lr_step))
51 | else:
52 | raise NotImplemented
53 | # warm up lr schedule
54 | if self.warmup_iters > 0 and T < self.warmup_iters:
55 | lr = lr * 1.0 * T / self.warmup_iters
56 | if epoch > self.epoch:
57 | print('\n=>Epoches %i, learning rate = %.4f, \
58 | previous best = %.4f' % (epoch, lr, best_pred))
59 | self.epoch = epoch
60 | assert lr >= 0
61 | self._adjust_learning_rate(optimizer, lr)
62 |
63 | def _adjust_learning_rate(self, optimizer, lr):
64 | if len(optimizer.param_groups) == 1:
65 | optimizer.param_groups[0]['lr'] = lr
66 | else:
67 | # enlarge the lr at the head
68 | optimizer.param_groups[0]['lr'] = lr
69 | for i in range(1, len(optimizer.param_groups)):
70 | optimizer.param_groups[i]['lr'] = lr * 10
71 |
--------------------------------------------------------------------------------
/util/ms_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def _fspecial_gauss_1d(size, sigma):
6 | r"""Create 1-D gauss kernel
7 | Args:
8 | size (int): the size of gauss kernel
9 | sigma (float): sigma of normal distribution
10 |
11 | Returns:
12 | torch.Tensor: 1D kernel
13 | """
14 | coords = torch.arange(size).to(dtype=torch.float)
15 | coords -= size // 2
16 |
17 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
18 | g /= g.sum()
19 |
20 | return g.unsqueeze(0).unsqueeze(0)
21 |
22 |
23 | def gaussian_filter(input, win):
24 | r""" Blur input with 1-D kernel
25 | Args:
26 | input (torch.Tensor): a batch of tensors to be blured
27 | window (torch.Tensor): 1-D gauss kernel
28 |
29 | Returns:
30 | torch.Tensor: blured tensors
31 | """
32 |
33 | N, C, H, W = input.shape
34 | out = F.conv2d(input, win, stride=1, padding=0, groups=C)
35 | out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
36 | return out
37 |
38 |
39 | def _ssim(X, Y, win, data_range=255, size_average=True, full=False, K=(0.01, 0.03), nonnegative_ssim=False):
40 | r""" Calculate ssim index for X and Y
41 | Args:
42 | X (torch.Tensor): images
43 | Y (torch.Tensor): images
44 | win (torch.Tensor): 1-D gauss kernel
45 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
46 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
47 | full (bool, optional): return sc or not
48 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results.
49 |
50 | Returns:
51 | torch.Tensor: ssim results
52 | """
53 | K1, K2 = K
54 | batch, channel, height, width = X.shape
55 | compensation = 1.0
56 |
57 | C1 = (K1 * data_range) ** 2
58 | C2 = (K2 * data_range) ** 2
59 |
60 | win = win.to(X.device, dtype=X.dtype)
61 |
62 | mu1 = gaussian_filter(X, win)
63 | mu2 = gaussian_filter(Y, win)
64 |
65 | mu1_sq = mu1.pow(2)
66 | mu2_sq = mu2.pow(2)
67 | mu1_mu2 = mu1 * mu2
68 |
69 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
70 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
71 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
72 |
73 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
74 | if nonnegative_ssim:
75 | cs_map = F.relu(cs_map, inplace=True)
76 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
77 |
78 | if size_average:
79 | ssim_val = ssim_map.mean()
80 | cs = cs_map.mean()
81 | else:
82 | ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW
83 | cs = cs_map.mean(-1).mean(-1).mean(-1)
84 |
85 | if full:
86 | return ssim_val, cs
87 | else:
88 | return ssim_val
89 |
90 |
91 | def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, K=(0.01, 0.03),
92 | nonnegative_ssim=False):
93 | r""" interface of ssim
94 | Args:
95 | X (torch.Tensor): a batch of images, (N,C,H,W)
96 | Y (torch.Tensor): a batch of images, (N,C,H,W)
97 | win_size: (int, optional): the size of gauss kernel
98 | win_sigma: (float, optional): sigma of normal distribution
99 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
100 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
101 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
102 | full (bool, optional): return sc or not
103 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
104 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results.
105 |
106 | Returns:
107 | torch.Tensor: ssim results
108 | """
109 |
110 | if len(X.shape) != 4:
111 | raise ValueError('Input images must be 4-d tensors.')
112 |
113 | if not X.type() == Y.type():
114 | raise ValueError('Input images must have the same dtype.')
115 |
116 | if not X.shape == Y.shape:
117 | raise ValueError('Input images must have the same dimensions.')
118 |
119 | if not (win_size % 2 == 1):
120 | raise ValueError('Window size must be odd.')
121 |
122 | win_sigma = win_sigma
123 | if win is None:
124 | win = _fspecial_gauss_1d(win_size, win_sigma)
125 | win = win.repeat(X.shape[1], 1, 1, 1)
126 | else:
127 | win_size = win.shape[-1]
128 |
129 | ssim_val, cs = _ssim(X, Y,
130 | win=win,
131 | data_range=data_range,
132 | size_average=False,
133 | full=True, K=K, nonnegative_ssim=nonnegative_ssim)
134 | if size_average:
135 | ssim_val = ssim_val.mean()
136 | cs = cs.mean()
137 |
138 | if full:
139 | return ssim_val, cs
140 | else:
141 | return ssim_val
142 |
143 |
144 | def ms_ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, weights=None,
145 | K=(0.01, 0.03), nonnegative_ssim=False):
146 | r""" interface of ms-ssim
147 | Args:
148 | X (torch.Tensor): a batch of images, (N,C,H,W)
149 | Y (torch.Tensor): a batch of images, (N,C,H,W)
150 | win_size: (int, optional): the size of gauss kernel
151 | win_sigma: (float, optional): sigma of normal distribution
152 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
153 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
154 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
155 | full (bool, optional): return sc or not
156 | weights (list, optional): weights for different levels
157 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
158 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid NaN results.
159 | Returns:
160 | torch.Tensor: ms-ssim results
161 | """
162 | if len(X.shape) != 4:
163 | raise ValueError('Input images must be 4-d tensors.')
164 |
165 | if not X.type() == Y.type():
166 | raise ValueError('Input images must have the same dtype.')
167 |
168 | if not X.shape == Y.shape:
169 | raise ValueError('Input images must have the same dimensions.')
170 |
171 | if not (win_size % 2 == 1):
172 | raise ValueError('Window size must be odd.')
173 |
174 | if weights is None:
175 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype)
176 |
177 | win_sigma = win_sigma
178 | if win is None:
179 | win = _fspecial_gauss_1d(win_size, win_sigma)
180 | win = win.repeat(X.shape[1], 1, 1, 1)
181 | else:
182 | win_size = win.shape[-1]
183 |
184 | levels = weights.shape[0]
185 | mcs = []
186 | for _ in range(levels):
187 | ssim_val, cs = _ssim(X, Y,
188 | win=win,
189 | data_range=data_range,
190 | size_average=False,
191 | full=True, K=K, nonnegative_ssim=nonnegative_ssim)
192 | mcs.append(cs)
193 |
194 | padding = (X.shape[2] % 2, X.shape[3] % 2)
195 | X = F.avg_pool2d(X, kernel_size=2, padding=padding)
196 | Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)
197 |
198 | mcs = torch.stack(mcs, dim=0) # mcs, (level, batch)
199 | # weights, (level)
200 | msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
201 | * (ssim_val ** weights[-1]), dim=0) # (batch, )
202 |
203 | if size_average:
204 | msssim_val = msssim_val.mean()
205 | return msssim_val
206 |
207 |
208 | class SSIM_v2(torch.nn.Module):
209 | def __init__(self, win_size=11, win_sigma=1.5, data_range=1, size_average=True, channel=3, K=(0.01, 0.03),
210 | nonnegative_ssim=False):
211 | r""" class for ssim
212 | Args:
213 | win_size: (int, optional): the size of gauss kernel
214 | win_sigma: (float, optional): sigma of normal distribution
215 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
216 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
217 | channel (int, optional): input channels (default: 3)
218 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
219 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid negative results.
220 | """
221 |
222 | super(SSIM_v2, self).__init__()
223 | self.win = _fspecial_gauss_1d(
224 | win_size, win_sigma).repeat(channel, 1, 1, 1)
225 | self.size_average = size_average
226 | self.data_range = data_range
227 | self.K = K
228 | self.nonnegative_ssim = nonnegative_ssim
229 |
230 | def forward(self, X, Y):
231 | return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average, K=self.K,
232 | nonnegative_ssim=self.nonnegative_ssim)
233 |
234 |
235 | class MS_SSIM_v2(torch.nn.Module):
236 | def __init__(self, win_size=11, win_sigma=1.5, data_range=1, size_average=True, channel=3, weights=None,
237 | K=(0.01, 0.03), nonnegative_ssim=False):
238 | r""" class for ms-ssim
239 | Args:
240 | win_size: (int, optional): the size of gauss kernel
241 | win_sigma: (float, optional): sigma of normal distribution
242 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
243 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
244 | channel (int, optional): input channels (default: 3)
245 | weights (list, optional): weights for different levels
246 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
247 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative to avoid NaN results.
248 | """
249 |
250 | super(MS_SSIM_v2, self).__init__()
251 | self.win = _fspecial_gauss_1d(
252 | win_size, win_sigma).repeat(channel, 1, 1, 1)
253 | self.size_average = size_average
254 | self.data_range = data_range
255 | self.weights = weights
256 | self.K = K
257 | self.nonnegative_ssim = nonnegative_ssim
258 |
259 | def forward(self, X, Y):
260 | return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range,
261 | weights=self.weights, K=self.K, nonnegative_ssim=self.nonnegative_ssim)
--------------------------------------------------------------------------------
/util/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from math import exp
4 | import numpy as np
5 |
6 | def gaussian(window_size, sigma):
7 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
8 | return gauss/gauss.sum()
9 |
10 | def create_window(window_size, channel=1):
11 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
12 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
13 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
14 | return window
15 |
16 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None, normalize=False):
17 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
18 | if val_range is None:
19 | if torch.max(img1) > 128:
20 | max_val = 255
21 | else:
22 | max_val = 1
23 |
24 | if torch.min(img1) < -0.5:
25 | min_val = -1
26 | else:
27 | min_val = 0
28 | L = max_val - min_val
29 | else:
30 | L = val_range
31 |
32 | padd = 0
33 | (_, channel, height, width) = img1.size()
34 | if window is None:
35 | real_size = min(window_size, height, width)
36 | window = create_window(real_size, channel=channel).to(img1.device)
37 |
38 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
39 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
40 |
41 | mu1_sq = mu1.pow(2)
42 | mu2_sq = mu2.pow(2)
43 | mu1_mu2 = mu1 * mu2
44 |
45 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
46 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
47 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
48 |
49 | C1 = (0.01 * L) ** 2
50 | C2 = (0.03 * L) ** 2
51 |
52 | v1 = 2.0 * sigma12 + C2
53 | v2 = sigma1_sq + sigma2_sq + C2
54 | cs = torch.mean(v1 / v2) # contrast sensitivity
55 |
56 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
57 |
58 | if normalize:
59 | ssim_map = F.relu(ssim_map, inplace=True)
60 |
61 | if size_average:
62 | ret = ssim_map.mean()
63 | else:
64 | ret = ssim_map.mean(1).mean(1).mean(1)
65 |
66 | if full:
67 | return ret, cs
68 | return ret
69 |
70 |
71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
72 | device = img1.device
73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
74 | levels = weights.size()[0]
75 | mssim = []
76 | mcs = []
77 | for _ in range(levels):
78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range, normalize=False)
79 | mssim.append(sim)
80 | mcs.append(cs)
81 |
82 | img1 = F.avg_pool2d(img1, (2, 2))
83 | img2 = F.avg_pool2d(img2, (2, 2))
84 |
85 | mssim = torch.stack(mssim)
86 | mcs = torch.stack(mcs)
87 |
88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
89 | if normalize:
90 | mssim = (mssim + 1) / 2
91 | mcs = (mcs + 1) / 2
92 |
93 | pow1 = mcs ** weights
94 | pow2 = mssim ** weights
95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
96 | output = torch.prod(pow1[:-1] * pow2[-1])
97 | return output
98 |
99 |
100 | # Classes to re-use window
101 | class SSIM(torch.nn.Module):
102 | def __init__(self, window_size=11, size_average=True, val_range=None):
103 | super(SSIM, self).__init__()
104 | self.window_size = window_size
105 | self.size_average = size_average
106 | self.val_range = val_range
107 |
108 | # Assume 1 channel for SSIM
109 | self.channel = 1
110 | self.window = create_window(window_size)
111 |
112 | def forward(self, img1, img2):
113 | (_, channel, _, _) = img1.size()
114 |
115 | if channel == self.channel and self.window.dtype == img1.dtype:
116 | window = self.window
117 | else:
118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
119 | self.window = window
120 | self.channel = channel
121 |
122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
123 |
124 | class MS_SSIM(torch.nn.Module):
125 | def __init__(self, window_size=11, size_average=True, channel=3):
126 | super(MS_SSIM, self).__init__()
127 | self.window_size = window_size
128 | self.size_average = size_average
129 | self.channel = channel
130 |
131 | def forward(self, img1, img2):
132 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)
--------------------------------------------------------------------------------
/util/summary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.utils import make_grid
4 | from torch.utils.tensorboard import SummaryWriter
5 | import glob
6 |
7 | from util.general_functions import tensor2im, get_flat_images
8 |
9 |
10 | from constants import *
11 |
12 | class TensorboardSummary(object):
13 |
14 | def __init__(self, args):
15 | self.args = args
16 | self.experiment_dir = self.generate_directory(args)
17 | self.writer = SummaryWriter(log_dir=os.path.join(self.experiment_dir))
18 |
19 | self.train_step = 0
20 | self.test_step = 0
21 | self.visualization_step = 0
22 |
23 | def generate_directory(self, args):
24 | checkname = 'debug' if args.debug else ''
25 | checkname += args.model
26 | checkname += '_sc' if args.separable_conv else ''
27 | checkname += '-refined' if args.refine_network else ''
28 |
29 | if 'deeplab' in args.model:
30 | checkname += '-os_' + str(args.output_stride)
31 | checkname += '-ls_1' if args.learned_upsampling else ''
32 | checkname += '-pt_1' if args.pretrained else ''
33 | checkname += '-aspp_0' if not args.use_aspp else ''
34 |
35 | if 'unet' in args.model:
36 | checkname += '-downs_' + str(args.num_downs) + '-ngf_' + str(args.ngf) + '-type_' + str(args.down_type)
37 |
38 | checkname += '-loss_' + args.loss_type
39 | checkname += '-sloss_' if args.second_loss else ''
40 |
41 | if args.clip > 0:
42 | checkname += '-clipping_' + str(args.clip)
43 |
44 | if args.resize:
45 | checkname += '-' + ','.join([str(x) for x in list(args.resize)])
46 | checkname += '-epochs_' + str(args.epochs)
47 | checkname += '-trainval' if args.trainval else ''
48 |
49 | current_dir = os.path.dirname(__file__)
50 | directory = os.path.join(current_dir, args.results_root, args.results_dir, args.dataset_dir, args.dataset, args.model, checkname)
51 |
52 | runs = sorted(glob.glob(os.path.join(directory, 'experiment_*')))
53 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
54 | experiment_dir = os.path.join(directory, 'experiment_{}'.format(str(run_id)))
55 |
56 | if not os.path.exists(experiment_dir):
57 | os.makedirs(experiment_dir)
58 |
59 | return experiment_dir
60 |
61 | def add_scalar(self, tag, value, step):
62 | self.writer.add_scalar(tag, value, step)
63 |
64 | def visualize_image(self, images, targets, outputs, split="train"):
65 | step = self.get_step(split)
66 |
67 | outputs, targets = get_flat_images(self.args.dataset, images, outputs, targets)
68 |
69 | images = [tensor2im(image) for image in images]
70 | outputs = [tensor2im(output)[:, : int(self.args.resize[0] /2), : int(self.args.resize[1] / 2)] for output in outputs]
71 | targets = [tensor2im(target)[:, : int(self.args.resize[0] / 2), : int(self.args.resize[1] / 2)] for target in targets]
72 |
73 | grid_image = make_grid(images)
74 | self.writer.add_image(split + '/ZZ Image', grid_image, step)
75 |
76 | grid_image = make_grid(outputs)
77 | self.writer.add_image(split + '/Predicted label', grid_image, step)
78 |
79 | grid_image = make_grid(targets)
80 | self.writer.add_image(split + '/Groundtruth label', grid_image, step)
81 |
82 | return images, outputs, targets
83 |
84 | def save_network(self, model):
85 | path = self.experiment_dir[self.experiment_dir.find(self.args.results_dir):].replace(self.args.results_dir, self.args.save_dir)
86 | if not os.path.isdir(path):
87 | os.makedirs(path)
88 |
89 | torch.save(model.state_dict(), path + '/' + 'network.pth')
90 |
91 | def load_network(self, model):
92 | path = self.args.pretrained_models_dir
93 | state_dict = torch.load(path) if self.args.cuda else torch.load(path, map_location=torch.device('cpu'))
94 | model.load_state_dict(state_dict)
95 |
96 | return model
97 |
98 | def get_step(self, split):
99 | if split == TRAIN:
100 | self.train_step += 1
101 | return self.train_step
102 | elif split == TEST:
103 | self.test_step += 1
104 | return self.test_step
105 | elif split == VISUALIZATION:
106 | self.visualization_step += 1
107 | return self.visualization_step
--------------------------------------------------------------------------------