├── .gitignore
├── CoreML_convert.py
├── README.md
├── model.py
├── pics
├── VID_edited.gif
├── VID_orig.gif
├── ex_2_orig.png
├── ex_2_transformed.png
├── ex_3_edited_mask.png
├── ex_3_orig_mask.png
├── example_1.png
├── girl_ex_blured.png
├── girl_ex_orig.png
├── mobilenetV2_loss.png
├── mobilenetV2_metric.png
├── resnet101_loss.png
└── resnet101_metric.png
├── predict.py
├── requirements.sh
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /test/
2 | /models/
3 | .vscode
4 |
5 |
--------------------------------------------------------------------------------
/CoreML_convert.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 |
4 | import onnx
5 | import torch
6 | from onnx import onnx_pb
7 | from onnx_coreml import convert
8 |
9 | from model import *
10 |
11 | #https://github.com/akirasosa/mobile-semantic-segmentation/blob/master/coreml_converter.py
12 |
13 | # python3 CoreML_convert.py --tmp_onnx ./models/tmp.onnx --weights_path ./models/mobilenetV2_model/mobilenetV2_model_checkpoint_metric.pth
14 |
15 | def init_unet(state_dict):
16 | model = UnetMobilenetV2(pretrained=False, num_classes=1, num_filters=32, Dropout=.2)
17 | model.load_state_dict(state_dict["state_dict"])
18 | return model
19 |
20 | parser = argparse.ArgumentParser(description='crnn_ctc_loss')
21 | parser.add_argument('--tmp_onnx', type=str, required=True)
22 | parser.add_argument('--weights_path', type=str, required=True)
23 | parser.add_argument('--img_H', type=int, default= 320)
24 | parser.add_argument('--img_W', type=int, default= 256)
25 | args = parser.parse_args()
26 | globals().update(vars(args))
27 |
28 | coreml_path = re.sub('\.pth$', '.mlmodel', weights_path)
29 |
30 | #convert and save ONNX
31 | model = init_unet(torch.load(weights_path, map_location=lambda storage, loc: storage))
32 | torch.onnx.export(model,
33 | torch.randn(1, 3, img_H, img_W),
34 | tmp_onnx)
35 |
36 | # Convert ONNX to CoreML model
37 | model_file = open(tmp_onnx, 'rb')
38 | model_proto = onnx_pb.ModelProto()
39 | model_proto.ParseFromString(model_file.read())
40 | # 595 is the identifier of output.
41 | coreml_model = convert(model_proto,
42 | image_input_names=['0'],
43 | image_output_names=['595'])
44 | coreml_model.save(coreml_path)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PicsArtHack-binary-segmentation
2 |
3 | The goal of the hackathon was to build some image processing algorithm which can be helpful for [PicsArt](https://picsart.com/?hl=en) applications.
4 | Here I publish results of the first stage: segmenting people on selfies.
5 | PicsArt gave us labeled dataset. [Dice](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) coef. was used as evaluation metric.
6 | I noticed that a lot of images has been labeled by another segmentation model due to a lot of artifacts around the masks borders. Also in test dataset appears copies of train set images. So after training, I did not expect good results on images "from the wild".
7 |
8 | ### 1. Loss
9 | For this problem I used fairly common bce-dice loss. So the algorithm is simple: we take a logits output from model and put it inside binary cross-enthropy loss and the natural logarithm of dice loss (after passing sigmoid function). After that we only need to combine these losses with weights:
10 | ```
11 | dice_loss = (2. * intersection + eps) / (union + eps)
12 | loss = w * BCELoss + (1 - w) * log(dice_loss) * (-1)
13 | ```
14 | Also, in this case, we don't need to tune tresholds of final pseudo-probabilities (after sigmoid).
15 | Finally we can adjust weights to the mask (I did it inside BCELoss), to penalize model for mistakes around the mask borders. For this purpose we can use opencv erosion kernel-operation:
16 | ```
17 | def get_mask_weight(mask):
18 | mask_ = cv2.erode(mask, kernel=np.ones((8,8),np.uint8), iterations=1)
19 | mask_ = mask-mask_
20 | return mask_ + 1
21 | ```
22 | On the picture below we can see how the input data looks like:
23 |
24 | ### 2. Training
25 | I used modification of **unet** (which is well recommended for solving binary semantic segmentation problems) with two encoders pretrained on Imagenet: resnet101 and [mobilenetV2](https://github.com/tonylins/pytorch-mobilenet-v2). One of the goals was to compare the performance of "heavy" and "light" encoders.
26 | You can check all training params inside `train.py`.
27 |
28 | ```
29 | python3 train.py --train_path ./data/train_data --workdir ./data/ --model_type mobilenetV2
30 | ```
31 |
32 | Data augmentation was provided via brilliant [albumentaions](https://github.com/albu/albumentations) library.
33 | Inside the `utils.py` code you can find learning rate scheduling, encoder weights freezeing and some other useful hacks which can help to train networks in more efficient way. Also passing the parameter `model_type` you are able to choose one of the predefined models based on: resnet18, resnet34, resnet50, resnet101, mobilenetV2.
34 |
35 | So, in the end I've got two trained models with close Dice values on a validation set. Here is a few numbers:
36 |
37 | Encoder: | ResNet101 | MobileNetV2
38 | :-------------------------:|:-------------------------:|:-------------------------:
39 | epochs (best of 200) | 177 | 173
40 | Dice | 0.987 (0.988) | 0.986 (0.988)
41 | loss | 0.029 (0.022) | 0.030 (0.024)
42 | No. of parameters | 120 131 745 | 4 682 912
43 |
44 | ResNet101 evaluation process:
45 |
46 | MobileNetV2 evaluation process:
47 |
48 |
49 | Despite the fact that mobilenetV2 has ~x26 less weights and at the same time we are able to get models with pretty similar quality, we did it **with this particullar problem using this particullar dataset**. So I don't think it extends on any other classification problem.
50 |
51 | ### 3. Tests
52 | Inference time comparison on my work-station with input images 320x256 from the test-set:
53 |
54 | Device | ResNet101 | MobileNetV2
55 | :-------------------------:|:-------------------------:|:-------------------------:
56 | AMD Threadripper 1900X CPU (4 threads) | 2.08 s ± 7.58 ms | 345 ms ± 3.21 ms
57 | GTX 1080Ti GPU | 31.6 ms ± 897 µs | 22 ms ± 622 µs
58 |
59 | Often, output masks contain some noise on the borders (which is become more annoying on large images), so we can try to fix it applying morhological transform:
60 | ```
61 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
62 | y_pred[:, :, -1] = cv2.morphologyEx(y_pred[:, :, -1], cv2.MORPH_OPEN, kernel)
63 | ```
64 | Original | Transformed
65 | :-------------------------:|:-------------------------:
66 |
|
67 |
68 | Additionaly we can transform segmented images. For instance let's make a gaussian blur of a background:
69 | ```
70 | blurred = cv2.GaussianBlur(test_dataset[n],(21,21),0)
71 | dst = cv2.bitwise_and(blurred, blurred, mask=~out[0][:, :, -1])
72 | dst = cv2.add(cv2.bitwise_and(test_dataset[n], test_dataset[n], mask=out[0][:, :, -1]), dst)
73 | ```
74 |
75 |
76 | And actually we can process videos too (see `predict.py`). Example below is a video made by me with a cellphone (original image size: 800x450):
77 |
78 |
79 |
80 | These results has been obtained with mobilenetV2 model. You can play with it too, here is it's [weights and CoreML models](https://drive.google.com/file/d/1XSRaOaoWKKSllIuUgkW0BVsMKieQ8mbG/view?usp=sharing).
81 |
82 | ```
83 | python3 predict.py -p ./test --model_path ./models/mobilenetV2_model --gpu -1 --frame_rate 12 --denoise_borders --biggest_side 320
84 | ```
85 | This script reads all the data inside `-p` folder: both pictures and videos.
86 |
87 | ### 4. Porting model to IOS device
88 | Finally, we can convert trained mobilenetV2 model with CoreML to make inference on the IOS devices. The pipeline is simple: torch --> ONNX --> CoreML. To make this happen, don't keep encoder layers separatly inside the model class - use them in forward pass. Also, with the certain versions of torch and onnx (see `requirements.txt`), you can't convert upsampling / interpolation layers (so place them outside the model, as a post-processing step). Hope it will be fixed in the future releases.
89 |
90 | ```
91 | python3 CoreML_convert.py --tmp_onnx ./models/tmp.onnx --weights_path ./models/mobilenetV2_model/mobilenetV2_model_checkpoint_metric.pth
92 | ```
93 |
94 | ### 5. Environment
95 | For your own experiments I highly recommend to use [Deepo](https://github.com/ufoym/deepo) as a fast way to deploy universal deep-learning environment inside a Docker container. Other dependencies can be found in `requirements.txt`.
96 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn, cat
5 | import torchvision
6 |
7 | class ConvRelu(nn.Module):
8 | def __init__(self, in_: int, out: int, activate=True):
9 | super(ConvRelu, self).__init__()
10 | self.activate = activate
11 | self.conv = nn.Conv2d(in_, out, 3, padding=1)
12 | self.activation = nn.ReLU(inplace=True)
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.activate:
17 | x = self.activation(x)
18 | return x
19 |
20 | class ResidualBlock(nn.Module):
21 |
22 | def __init__(self, in_channels: int, num_filters: int, batch_activate=False):
23 | super(ResidualBlock, self).__init__()
24 | self.batch_activate = batch_activate
25 | self.activation = nn.ReLU(inplace=True)
26 | self.conv_block = ConvRelu(in_channels, num_filters, activate=True)
27 | self.conv_block_na = ConvRelu(in_channels, num_filters, activate=False)
28 | self.activation = nn.ReLU(inplace=True)
29 |
30 | def forward(self, inp):
31 | x = self.conv_block(inp)
32 | x = self.conv_block_na(x)
33 | x = x.add(inp)
34 | if self.batch_activate:
35 | x = self.activation(x)
36 | return x
37 |
38 | class DecoderBlockResnet(nn.Module):
39 | """
40 | Paramaters for Deconvolution were chosen to avoid artifacts, following
41 | link https://distill.pub/2016/deconv-checkerboard/
42 | """
43 |
44 | def __init__(self, in_channels, middle_channels, out_channels):
45 | super(DecoderBlockResnet, self).__init__()
46 | self.in_channels = in_channels
47 |
48 | self.block = nn.Sequential(
49 | ConvRelu(in_channels, middle_channels, activate=True),
50 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1),
51 | nn.ReLU(inplace=True)
52 | )
53 |
54 | def forward(self, x):
55 | return self.block(x)
56 |
57 | class UnetResNet(nn.Module):
58 |
59 | def __init__(self, num_classes=1, num_filters=32, pretrained=True, Dropout=.2, model="resnet50"):
60 |
61 | super().__init__()
62 | if model == "resnet18":
63 | self.encoder = torchvision.models.resnet18(pretrained=pretrained)
64 | elif model == "resnet34":
65 | self.encoder = torchvision.models.resnet34(pretrained=pretrained)
66 | elif model == "resnet50":
67 | self.encoder = torchvision.models.resnet50(pretrained=pretrained)
68 | elif model == "resnet101":
69 | self.encoder = torchvision.models.resnet101(pretrained=pretrained)
70 |
71 | if model in ["resnet18", "resnet34"]: model = "resnet18-34"
72 | else: model = "resnet50-101"
73 |
74 | self.filters_dict = {
75 | "resnet18-34": [512, 512, 256, 128, 64],
76 | "resnet50-101": [2048, 2048, 1024, 512, 256]
77 | }
78 |
79 | self.num_classes = num_classes
80 | self.Dropout = Dropout
81 | self.pool = nn.MaxPool2d(2, 2)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.conv1 = nn.Sequential(self.encoder.conv1,
84 | self.encoder.bn1,
85 | self.encoder.relu,
86 | self.pool)
87 | self.conv2 = self.encoder.layer1
88 | self.conv3 = self.encoder.layer2
89 | self.conv4 = self.encoder.layer3
90 | self.conv5 = self.encoder.layer4
91 |
92 | self.center = DecoderBlockResnet(self.filters_dict[model][0], num_filters * 8 * 2,
93 | num_filters * 8)
94 | self.dec5 = DecoderBlockResnet(self.filters_dict[model][1] + num_filters * 8,
95 | num_filters * 8 * 2, num_filters * 8)
96 | self.dec4 = DecoderBlockResnet(self.filters_dict[model][2] + num_filters * 8,
97 | num_filters * 8 * 2, num_filters * 8)
98 | self.dec3 = DecoderBlockResnet(self.filters_dict[model][3] + num_filters * 8,
99 | num_filters * 4 * 2, num_filters * 2)
100 | self.dec2 = DecoderBlockResnet(self.filters_dict[model][4] + num_filters * 2,
101 | num_filters * 2 * 2, num_filters * 2 * 2)
102 |
103 | self.dec1 = DecoderBlockResnet(num_filters * 2 * 2, num_filters * 2 * 2, num_filters)
104 | self.dec0 = ConvRelu(num_filters, num_filters)
105 |
106 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
107 | self.dropout_2d = nn.Dropout2d(p=self.Dropout)
108 |
109 |
110 | def forward(self, x, z=None):
111 | conv1 = self.conv1(x)
112 | conv2 = self.dropout_2d(self.conv2(conv1))
113 | conv3 = self.dropout_2d(self.conv3(conv2))
114 | conv4 = self.dropout_2d(self.conv4(conv3))
115 | conv5 = self.dropout_2d(self.conv5(conv4))
116 |
117 | center = self.center(self.pool(conv5))
118 | dec5 = self.dec5(torch.cat([center, conv5], 1))
119 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
120 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
121 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
122 | dec2 = self.dropout_2d(dec2)
123 |
124 | dec1 = self.dec1(dec2)
125 | dec0 = self.dec0(dec1)
126 |
127 | return self.final(dec0)
128 |
129 | ###########################################################################
130 | # Mobile Net
131 | ###########################################################################
132 |
133 | def conv_bn(inp, oup, stride):
134 | return nn.Sequential(
135 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
136 | nn.BatchNorm2d(oup),
137 | nn.ReLU6(inplace=True)
138 | )
139 |
140 | def conv_1x1_bn(inp, oup):
141 | return nn.Sequential(
142 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
143 | nn.BatchNorm2d(oup),
144 | nn.ReLU6(inplace=True)
145 | )
146 |
147 | class InvertedResidual(nn.Module):
148 | def __init__(self, inp, oup, stride, expand_ratio):
149 | super(InvertedResidual, self).__init__()
150 | self.stride = stride
151 | assert stride in [1, 2]
152 |
153 | hidden_dim = round(inp * expand_ratio)
154 | self.use_res_connect = self.stride == 1 and inp == oup
155 |
156 | if expand_ratio == 1:
157 | self.conv = nn.Sequential(
158 | # dw
159 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
160 | nn.BatchNorm2d(hidden_dim),
161 | nn.ReLU6(inplace=True),
162 | # pw-linear
163 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
164 | nn.BatchNorm2d(oup),
165 | )
166 | else:
167 | self.conv = nn.Sequential(
168 | # pw
169 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
170 | nn.BatchNorm2d(hidden_dim),
171 | nn.ReLU6(inplace=True),
172 | # dw
173 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
174 | nn.BatchNorm2d(hidden_dim),
175 | nn.ReLU6(inplace=True),
176 | # pw-linear
177 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
178 | nn.BatchNorm2d(oup),
179 | )
180 |
181 | def forward(self, x):
182 | if self.use_res_connect:
183 | return x + self.conv(x)
184 | else:
185 | return self.conv(x)
186 |
187 | class MobileNetV2(nn.Module):
188 |
189 | """
190 | from MobileNetV2 import MobileNetV2
191 |
192 | net = MobileNetV2(n_class=1000)
193 | state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
194 | net.load_state_dict(state_dict)
195 | """
196 |
197 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
198 | super(MobileNetV2, self).__init__()
199 | block = InvertedResidual
200 | input_channel = 32
201 | last_channel = 1280
202 | interverted_residual_setting = [
203 | # t, c, n, s
204 | [1, 16, 1, 1],
205 | [6, 24, 2, 2],
206 | [6, 32, 3, 2],
207 | [6, 64, 4, 2],
208 | [6, 96, 3, 1],
209 | [6, 160, 3, 2],
210 | [6, 320, 1, 1],
211 | ]
212 |
213 | # building first layer
214 | assert input_size % 32 == 0
215 | input_channel = int(input_channel * width_mult)
216 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
217 | self.features = [conv_bn(3, input_channel, 2)]
218 | # building inverted residual blocks
219 | for t, c, n, s in interverted_residual_setting:
220 | output_channel = int(c * width_mult)
221 | for i in range(n):
222 | if i == 0:
223 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
224 | else:
225 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
226 | input_channel = output_channel
227 | # building last several layers
228 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
229 | # make it nn.Sequential
230 | self.features = nn.Sequential(*self.features)
231 |
232 | # building classifier
233 | self.classifier = nn.Sequential(
234 | nn.Dropout(0.2),
235 | nn.Linear(self.last_channel, n_class),
236 | )
237 |
238 | self._initialize_weights()
239 |
240 | def forward(self, x):
241 | x = self.features(x)
242 | x = x.mean(3).mean(2)
243 | x = self.classifier(x)
244 | return x
245 |
246 | def _initialize_weights(self):
247 | for m in self.modules():
248 | if isinstance(m, nn.Conv2d):
249 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
250 | m.weight.data.normal_(0, math.sqrt(2. / n))
251 | if m.bias is not None:
252 | m.bias.data.zero_()
253 | elif isinstance(m, nn.BatchNorm2d):
254 | m.weight.data.fill_(1)
255 | m.bias.data.zero_()
256 | elif isinstance(m, nn.Linear):
257 | n = m.weight.size(1)
258 | m.weight.data.normal_(0, 0.01)
259 | m.bias.data.zero_()
260 |
261 | class UnetMobilenetV2(nn.Module):
262 | def __init__(self, num_classes=1, num_filters=32, pretrained=True,
263 | Dropout=.2, path='./data/mobilenet_v2.pth.tar'):
264 | super(UnetMobilenetV2, self).__init__()
265 |
266 | self.encoder = MobileNetV2(n_class=1000)
267 |
268 | self.num_classes = num_classes
269 |
270 | self.dconv1 = nn.ConvTranspose2d(1280, 96, 4, padding=1, stride=2)
271 | self.invres1 = InvertedResidual(192, 96, 1, 6)
272 |
273 | self.dconv2 = nn.ConvTranspose2d(96, 32, 4, padding=1, stride=2)
274 | self.invres2 = InvertedResidual(64, 32, 1, 6)
275 |
276 | self.dconv3 = nn.ConvTranspose2d(32, 24, 4, padding=1, stride=2)
277 | self.invres3 = InvertedResidual(48, 24, 1, 6)
278 |
279 | self.dconv4 = nn.ConvTranspose2d(24, 16, 4, padding=1, stride=2)
280 | self.invres4 = InvertedResidual(32, 16, 1, 6)
281 |
282 | self.conv_last = nn.Conv2d(16, 3, 1)
283 |
284 | self.conv_score = nn.Conv2d(3, 1, 1)
285 |
286 | #doesn't needed; obly for compatibility
287 | self.dconv_final = nn.ConvTranspose2d(1, 1, 4, padding=1, stride=2)
288 |
289 | if pretrained:
290 | state_dict = torch.load(path)
291 | self.encoder.load_state_dict(state_dict)
292 | else: self._init_weights()
293 |
294 | def forward(self, x):
295 | for n in range(0, 2):
296 | x = self.encoder.features[n](x)
297 | x1 = x
298 |
299 | for n in range(2, 4):
300 | x = self.encoder.features[n](x)
301 | x2 = x
302 |
303 | for n in range(4, 7):
304 | x = self.encoder.features[n](x)
305 | x3 = x
306 |
307 | for n in range(7, 14):
308 | x = self.encoder.features[n](x)
309 | x4 = x
310 |
311 | for n in range(14, 19):
312 | x = self.encoder.features[n](x)
313 | x5 = x
314 |
315 | up1 = torch.cat([
316 | x4,
317 | self.dconv1(x)
318 | ], dim=1)
319 | up1 = self.invres1(up1)
320 |
321 | up2 = torch.cat([
322 | x3,
323 | self.dconv2(up1)
324 | ], dim=1)
325 | up2 = self.invres2(up2)
326 |
327 | up3 = torch.cat([
328 | x2,
329 | self.dconv3(up2)
330 | ], dim=1)
331 | up3 = self.invres3(up3)
332 |
333 | up4 = torch.cat([
334 | x1,
335 | self.dconv4(up3)
336 | ], dim=1)
337 | up4 = self.invres4(up4)
338 | x = self.conv_last(up4)
339 | x = self.conv_score(x)
340 |
341 | return x
342 |
343 | def _init_weights(self):
344 | for m in self.modules():
345 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
346 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
347 | m.weight.data.normal_(0, math.sqrt(2. / n))
348 | if m.bias is not None:
349 | m.bias.data.zero_()
350 | elif isinstance(m, nn.BatchNorm2d):
351 | m.weight.data.fill_(1)
352 | m.bias.data.zero_()
353 | elif isinstance(m, nn.Linear):
354 | m.weight.data.normal_(0, 0.01)
355 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/pics/VID_edited.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/VID_edited.gif
--------------------------------------------------------------------------------
/pics/VID_orig.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/VID_orig.gif
--------------------------------------------------------------------------------
/pics/ex_2_orig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_2_orig.png
--------------------------------------------------------------------------------
/pics/ex_2_transformed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_2_transformed.png
--------------------------------------------------------------------------------
/pics/ex_3_edited_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_3_edited_mask.png
--------------------------------------------------------------------------------
/pics/ex_3_orig_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_3_orig_mask.png
--------------------------------------------------------------------------------
/pics/example_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/example_1.png
--------------------------------------------------------------------------------
/pics/girl_ex_blured.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/girl_ex_blured.png
--------------------------------------------------------------------------------
/pics/girl_ex_orig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/girl_ex_orig.png
--------------------------------------------------------------------------------
/pics/mobilenetV2_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/mobilenetV2_loss.png
--------------------------------------------------------------------------------
/pics/mobilenetV2_metric.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/mobilenetV2_metric.png
--------------------------------------------------------------------------------
/pics/resnet101_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/resnet101_loss.png
--------------------------------------------------------------------------------
/pics/resnet101_metric.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/resnet101_metric.png
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import sys
4 | import time
5 | import datetime
6 | import subprocess
7 | import argparse
8 |
9 | import numpy as np
10 | import cv2
11 |
12 | from utils import *
13 |
14 | # python3 predict.py -p ./test --model_path ./models/mobilenetV2_model --gpu -1 --frame_rate 12 --denoise_borders --biggest_side 320
15 |
16 | start = time.time()
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('-p', '--data_path', type=str, required=True)
20 | parser.add_argument('--model_path', type=str, required=True)
21 | parser.add_argument('--gpu', type=int, default=-1, required=False)
22 | parser.add_argument('--biggest_side', type=int, default=0, required=False)
23 | parser.add_argument('--delay', type=int, default=7, required=False)
24 | parser.add_argument('--frame_rate', type=int, default=12, required=False)
25 | parser.add_argument('--denoise_borders', action='store_true')
26 | args = parser.parse_args()
27 | globals().update(vars(args))
28 |
29 | biggest_side = None if not biggest_side else biggest_side
30 | delay = round(100/frame_rate + .5)
31 |
32 | trainer = Trainer(path=model_path, gpu=gpu)
33 | if gpu < 0:
34 | torch.set_num_threads(2)
35 | trainer.load_state(mode="metric")
36 | trainer.model.eval()
37 |
38 | files_list = os.listdir(data_path)
39 |
40 | images, vids = [], []
41 | if files_list:
42 | for fname in files_list:
43 | if fname.split(".")[-1] != "mp4": images.append(fname)
44 | elif fname.split(".")[-1] == "mp4": vids.append(fname)
45 |
46 | if images:
47 | for fname in images:
48 | img = cv2.imread(data_path+"/"+fname)
49 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
50 | img = np.array(img, dtype=np.uint8)
51 | out = trainer.predict_mask(img, biggest_side=biggest_side, denoise_borders=denoise_borders)
52 | cv2.imwrite('%s/%s_seg.png' % (data_path, fname.split(".")[0]), out[0])
53 | print(" [INFO] Images processed! ")
54 |
55 | if vids:
56 | for fname in vids:
57 | imgs = split_video(data_path+"/"+fname, frame_rate=frame_rate)
58 | out = trainer.predict_mask(imgs, biggest_side=biggest_side, denoise_borders=denoise_borders)
59 | vpath = data_path+"/%s" % fname.split(".")[0]
60 | os.mkdir(vpath)
61 | save_images(out, path=vpath)
62 | os.system(f"convert -delay {delay} -loop 0 -dispose Background {vpath}/*.png {vpath}/{fname.split('.')[0]}.gif")
63 | print(" [INFO] Videos processed! ")
64 |
65 | print(" [INFO] %s ms. " % round((time.time()-start)*1000, 0))
--------------------------------------------------------------------------------
/requirements.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | apt update && apt install ffmpeg imagemagick
3 | pip3 --no-cache-dir install --upgrade tqdm==4.28.1 \
4 | numpy==1.14.3 \
5 | scikit-image==0.13.1 \
6 | albumentations==0.1.7 \
7 | opencv-python==3.4.3.18 \
8 | torch==0.4.1 \
9 | torchvision==0.2.1 \
10 | onnx==1.3.0 \
11 | six==1.10.0 \
12 | onnx-coreml
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | from numpy.random import RandomState
5 |
6 | from model import *
7 | from utils import *
8 |
9 | # python3 train.py --train_path ./data/train_data --workdir ./data/ --model_type mobilenetV2
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--train_path', type=str, required=True)
13 | parser.add_argument('--workdir', type=str, required=True)
14 | parser.add_argument('--model_type', default="mobilenetV2", type=str)
15 | parser.add_argument('--batch_size', default=32, type=int)
16 | parser.add_argument('--max_lr', default=.5, type=float)
17 | parser.add_argument('--loss_window', default=10, type=int)
18 | parser.add_argument('--loss_growth_trsh', default=.5, type=float)
19 | parser.add_argument('--alpha', default=.1, type=float)
20 | parser.add_argument('--wd', default=0., type=float)
21 | parser.add_argument('--freeze_encoder', default=False, type=bool)
22 | parser.add_argument('--max_lr_decay', default=.8, type=float)
23 | parser.add_argument('--epoch', default=200, type=int)
24 | parser.add_argument('--learning_rate', default=1e-4, type=float)
25 | parser.add_argument('--bce_loss_weight', default=.5, type=float)
26 | parser.add_argument('--reduce_lr_patience', default=0, type=int)
27 | parser.add_argument('--reduce_lr_factor', default=0, type=int)
28 | parser.add_argument('--CLR', default=0, type=int)
29 | args = parser.parse_args()
30 |
31 | path_images = list(map(
32 | lambda x: x.split('.')[0],
33 | filter(lambda x: x.endswith('.jpg'), os.listdir(args["train_path"]))))
34 | prng = RandomState(42)
35 |
36 | path_images *= 3
37 | prng.shuffle(path_images)
38 | train_split = int(len(path_images)*.8)
39 | train_images, val_images = path_images[:train_split], path_images[train_split:]
40 |
41 | dataset = DatasetProcessor(
42 | args["train_path"], train_images, as_torch_tensor=True, augmentations=True, mask_weight=True)
43 | dataset_val = DatasetProcessor(
44 | args["train_path"], val_images, as_torch_tensor=True, augmentations=True, mask_weight=True)
45 |
46 | model_params = {
47 | "directory":args["workdir"],
48 | "model":args["model_type"],
49 | "model_name":"%s_model" % (args["model_type"]),
50 | "Dropout":.4,
51 | "device_idx":0,
52 | "pretrained":True,
53 | "num_classes":1,
54 | "num_filters":32,
55 | "reset":True,
56 | "ADAM":True
57 | }
58 |
59 | trainer = Trainer(**model_params)
60 | if args["CLR"] != 0:
61 | trainer.LR_finder(dataset, **args)
62 | trainer.show_lr_finder_out(save_only=True)
63 |
64 | trainer.fit(dataset, dataset_val, **args)
65 | trainer.plot_trainer_history(mode="loss", save_only=True)
66 | trainer.plot_trainer_history(mode="metric", save_only=True)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import string
4 | import itertools
5 | import pickle
6 |
7 | from skimage.morphology import remove_small_objects, remove_small_holes
8 | import cv2
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from albumentations import (
13 | PadIfNeeded,
14 | HorizontalFlip,
15 | VerticalFlip,
16 | CenterCrop,
17 | Crop,
18 | Compose,
19 | Transpose,
20 | RandomRotate90,
21 | ElasticTransform,
22 | GridDistortion,
23 | OpticalDistortion,
24 | RandomSizedCrop,
25 | OneOf,
26 | CLAHE,
27 | RandomContrast,
28 | RandomGamma,
29 | ShiftScaleRotate,
30 | RandomBrightness
31 | )
32 |
33 |
34 | import torch
35 | from torchvision import transforms
36 | from torch.utils import data
37 | from torch.autograd import Variable
38 |
39 | from model import *
40 |
41 | class DatasetProcessor(data.Dataset):
42 |
43 | def __init__(self, root_path, file_list, is_test=False, as_torch_tensor=True, augmentations=False, mask_weight=True):
44 | self.is_test = is_test
45 | self.mask_weight = mask_weight
46 | self.root_path = root_path
47 | self.file_list = file_list
48 | self.as_torch_tensor = as_torch_tensor
49 | self.augmentations = augmentations
50 | self.norm = transforms.Compose([
51 | transforms.ToTensor(),
52 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
53 | std=[0.229, 0.224, 0.225])
54 | ])
55 | self.been = []
56 |
57 | def clear_buff(self):
58 | self.been = []
59 |
60 | def __len__(self):
61 | return len(self.file_list)
62 |
63 | def transform(self, image, mask):
64 | aug = Compose([
65 | HorizontalFlip(p=0.9),
66 | RandomBrightness(p=.5,limit=0.3),
67 | RandomContrast(p=.5,limit=0.3),
68 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20,
69 | p=0.7, border_mode=0, interpolation=4)
70 | ])
71 |
72 | augmented = aug(image=image, mask=mask)
73 | return augmented['image'], augmented['mask']
74 |
75 | def get_mask_weight(self, mask):
76 | mask_ = cv2.erode(mask, kernel=np.ones((8,8),np.uint8), iterations=1)
77 | mask_ = mask-mask_
78 | return mask_ + 1
79 |
80 | def __getitem__(self, index):
81 |
82 | file_id = index
83 | if type(index) != str:
84 | file_id = self.file_list[index]
85 |
86 | image_folder = self.root_path
87 | image_path = os.path.join(image_folder, file_id + ".jpg")
88 |
89 | mask_folder = self.root_path[:-1] + "_mask/"
90 | mask_path = os.path.join(mask_folder, file_id + ".png")
91 |
92 | if self.as_torch_tensor:
93 |
94 | if not self.is_test:
95 | image = cv2.imread(str(image_path))
96 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
97 | mask = cv2.imread(str(mask_path))
98 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
99 |
100 | #resize to 320x256
101 | image = cv2.resize(image, (256, 320), interpolation=cv2.INTER_LANCZOS4)
102 | mask = cv2.resize(mask, (256, 320), interpolation=cv2.INTER_LANCZOS4)
103 |
104 | if self.augmentations:
105 | if file_id not in self.been:
106 | self.been.append(file_id)
107 | else:
108 | image, mask = self.transform(image, mask)
109 |
110 | mask = mask // 255
111 | mask = mask[:, :, np.newaxis]
112 | if self.mask_weight:
113 | mask_w = self.get_mask_weight(np.squeeze(mask))
114 | else:
115 | mask_w = np.ones((mask.shape[:-1]))
116 | mask_w = mask_w[:, :, np.newaxis]
117 |
118 | mask = torch.from_numpy(np.transpose(mask, (2, 0, 1)).astype('float32'))
119 | mask_w = torch.from_numpy(np.transpose(mask_w, (2, 0, 1)).astype('float32'))
120 | image = self.norm(image)
121 | return image, mask, mask_w
122 |
123 | else:
124 | image = cv2.imread(str(image_path))
125 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
126 | image = cv2.resize(image, (256, 320), interpolation=cv2.INTER_LANCZOS4)
127 | image = self.norm(image)
128 | return image
129 |
130 | else:
131 | image = cv2.imread(str(image_path))
132 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133 | image = np.array(image, dtype=np.uint8)
134 | if not self.is_test:
135 | mask = cv2.imread(str(mask_path))
136 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
137 | if self.augmentations:
138 | if file_id not in self.been:
139 | self.been.append(file_id)
140 | else:
141 | image, mask = self.transform(image, mask)
142 | return image, mask
143 |
144 | else:
145 | if self.augmentations:
146 | if file_id not in self.been:
147 | self.been.append(file_id)
148 | else:
149 | image = self.transform(image)
150 | return image
151 |
152 | def save_checkpoint(checkpoint_path, model, optimizer):
153 | state = {'state_dict': model.state_dict(),
154 | 'optimizer' : optimizer.state_dict()}
155 | torch.save(state, checkpoint_path)
156 | print('model saved to %s' % checkpoint_path)
157 |
158 | def load_checkpoint(checkpoint_path, model, optimizer, cpu):
159 | if cpu:
160 | state = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
161 | else:
162 | state = torch.load(checkpoint_path)
163 | model.load_state_dict(state['state_dict'])
164 | if optimizer:
165 | optimizer.load_state_dict(state['optimizer'])
166 | print('model loaded from %s' % checkpoint_path)
167 |
168 | def jaccard(intersection, union, eps=1e-15):
169 | return (intersection) / (union - intersection + eps)
170 |
171 | def dice(intersection, union, eps=1e-15, smooth=1.):
172 | return (2. * intersection + smooth) / (union + smooth + eps)
173 |
174 | class BCESoftJaccardDice:
175 |
176 | def __init__(self, bce_weight=0.5, mode="dice", eps=1e-7, weight=None, smooth=1.):
177 | self.nll_loss = torch.nn.BCEWithLogitsLoss(weight=weight)
178 | self.bce_weight = bce_weight
179 | self.eps = eps
180 | self.mode = mode
181 | self.smooth = smooth
182 |
183 | def __call__(self, outputs, targets):
184 | loss = self.bce_weight * self.nll_loss(outputs, targets)
185 |
186 | if self.bce_weight < 1.:
187 | targets = (targets == 1).float()
188 | outputs = torch.sigmoid(outputs)
189 | intersection = (outputs * targets).sum()
190 | union = outputs.sum() + targets.sum()
191 | if self.mode == "dice":
192 | score = dice(intersection, union, self.eps, self.smooth)
193 | elif self.mode == "jaccard":
194 | score = jaccard(intersection, union, self.eps)
195 | loss -= (1 - self.bce_weight) * torch.log(score)
196 | return loss
197 |
198 | def get_metric(pred, targets):
199 | batch_size = targets.shape[0]
200 | metric = []
201 | for batch in range(batch_size):
202 | t, p = targets[batch].squeeze(1), pred[batch].squeeze(1)
203 | if np.count_nonzero(t) == 0 and np.count_nonzero(p) > 0:
204 | metric.append(0)
205 | continue
206 | if np.count_nonzero(t) == 0 and np.count_nonzero(p) == 0:
207 | metric.append(1)
208 | continue
209 |
210 | t = (t == 1).float()
211 | intersection = (p * t).sum()
212 | union = p.sum() + t.sum()
213 | m = dice(intersection, union, eps=1e-15)
214 | metric.append(m)
215 | return np.mean(metric)
216 |
217 | class Trainer:
218 |
219 | def __init__(self, path=None, gpu=-1, **kwargs):
220 |
221 | if path is not None:
222 | kwargs = pickle.load(open(path+"/model_params.pickle.dat", "rb"))
223 | kwargs["device_idx"] = gpu
224 | kwargs["pretrained"], kwargs["reset"] = False, False
225 | self.path = path
226 | else:
227 | self.directory = kwargs["directory"]
228 | self.path = os.path.join(self.directory, self.model_name)
229 |
230 | self.model_name = kwargs["model_name"]
231 | self.model_type = kwargs["model"].lower()
232 | self.device_idx = kwargs["device_idx"]
233 | self.cpu = True if self.device_idx < 0 else False
234 | self.ADAM = kwargs["ADAM"]
235 | self.pretrained = kwargs["pretrained"]
236 | self.norm = transforms.Compose([
237 | transforms.ToTensor(),
238 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
239 | std=[0.229, 0.224, 0.225])
240 | ])
241 |
242 | self.cp_counter_loss, self.cp_counter_metric = 0, 0
243 | self.max_lr = .5
244 |
245 | net_init_params = {k:v for k, v in kwargs.items()
246 | if k in ["Dropout", "pretrained", "num_classes", "num_filters"]
247 | }
248 |
249 | if self.model_type == "mobilenetv2":
250 | self.initial_model = UnetMobilenetV2(**net_init_params)
251 | else:
252 | net_init_params["model"] = self.model_type
253 | self.initial_model = UnetResNet(**net_init_params)
254 |
255 | if kwargs["reset"]:
256 | try:
257 | shutil.rmtree(self.path)
258 | except:
259 | pass
260 | os.mkdir(self.path)
261 | kwargs["reset"] = False
262 | pickle.dump(kwargs, open(self.path+"/model_params.pickle.dat", "wb"))
263 | else:
264 | self.model = self.get_model(self.initial_model)
265 | if self.ADAM:
266 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
267 | else:
268 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-4, momentum=0.9, nesterov=True)
269 |
270 | def dfs_freeze(self, model):
271 | for name, child in model.named_children():
272 | for param in child.parameters():
273 | param.requires_grad = False if self.freeze_encoder else True
274 | self.dfs_freeze(child)
275 |
276 | def get_model(self, model):
277 | model = model.train()
278 | if self.cpu:
279 | return model.cpu()
280 | return model.cuda(self.device_idx)
281 |
282 | def LR_finder(self, dataset, **kwargs):
283 |
284 | max_lr = kwargs["max_lr"]
285 | batch_size = kwargs["batch_size"]
286 | learning_rate = kwargs["learning_rate"]
287 | bce_loss_weight = kwargs["bce_loss_weight"]
288 | loss_growth_trsh = kwargs["loss_growth_trsh"]
289 | loss_window = kwargs["loss_window"]
290 | wd = kwargs["wd"]
291 | alpha = kwargs["alpha"]
292 |
293 | torch.cuda.empty_cache()
294 | dataset.clear_buff()
295 | self.model = self.get_model(self.initial_model)
296 |
297 | iterations = len(dataset) // batch_size
298 | it = 0
299 | lr_mult = (max_lr/learning_rate)**(1/iterations)
300 |
301 | if self.ADAM:
302 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
303 | else:
304 | optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate,
305 | momentum=0.9, nesterov=True)
306 |
307 | #max LR search
308 | print(" [INFO] Start max. learning rate search... ")
309 | min_loss, self.lr_finder_losses = (np.inf, learning_rate), [[], []]
310 | for image, mask, mask_w in tqdm(data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers=0)):
311 | image = image.type(torch.FloatTensor).cuda(self.device_idx)
312 |
313 | it += 1
314 | current_lr = learning_rate * (lr_mult**it)
315 |
316 | y_pred = self.model(Variable(image))
317 | if self.model_type == "mobilenetv2":
318 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True)
319 |
320 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight,
321 | weight=mask_w.cuda(self.device_idx), mode="dice", eps=1.)
322 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx)))
323 |
324 | optimizer.zero_grad()
325 | loss.backward()
326 |
327 | #adjust learning rate and weights decay
328 | for param_group in optimizer.param_groups:
329 | param_group['lr'] = current_lr
330 | if wd:
331 | for param in param_group['params']:
332 | param.data = param.data.add(-wd * param_group['lr'], param.data)
333 |
334 | optimizer.step()
335 |
336 | if it > 1:
337 | current_loss = alpha * loss.item() + (1 - alpha) * current_loss
338 | else:
339 | current_loss = loss.item()
340 |
341 | self.lr_finder_losses[0].append(current_loss)
342 | self.lr_finder_losses[1].append(current_lr)
343 |
344 | if current_loss < min_loss[0]:
345 | min_loss = (current_loss, current_lr)
346 |
347 | if it >= loss_window:
348 | if (current_loss - min_loss[0]) / min_loss[0] >= loss_growth_trsh:
349 | break
350 |
351 | self.max_lr = round(min_loss[1], 5)
352 | print(" [INFO] max. lr = %.5f " % self.max_lr)
353 |
354 | def show_lr_finder_out(self, save_only=False):
355 | if not save_only:
356 | plt.show(block=False)
357 | plt.semilogx(self.lr_finder_losses[1], self.lr_finder_losses[0])
358 | plt.axvline(self.max_lr, c="gray")
359 | plt.savefig(self.path + '/lr_finder_out.png')
360 |
361 | def fit(self, dataset, dataset_val, **kwargs):
362 |
363 | epoch = kwargs["epoch"]
364 | learning_rate = kwargs["learning_rate"]
365 | batch_size = kwargs["batch_size"]
366 | bce_loss_weight = kwargs["bce_loss_weight"]
367 | CLR = kwargs["CLR"]
368 | wd = kwargs["wd"]
369 | reduce_lr_patience = kwargs["reduce_lr_patience"]
370 | reduce_lr_factor = kwargs["reduce_lr_factor"]
371 | max_lr_decay = kwargs["max_lr_decay"]
372 | self.freeze_encoder = kwargs["freeze_encoder"]
373 |
374 | torch.cuda.empty_cache()
375 | self.model = self.get_model(self.initial_model)
376 |
377 | if self.pretrained and self.freeze_encoder and self.model_type != "mobilenetv2":
378 | self.dfs_freeze(self.model.conv1)
379 | self.dfs_freeze(self.model.conv2)
380 | self.dfs_freeze(self.model.conv3)
381 | self.dfs_freeze(self.model.conv4)
382 | self.dfs_freeze(self.model.conv5)
383 |
384 | if self.ADAM:
385 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
386 | lr=learning_rate)
387 | else:
388 | self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()),
389 | lr=learning_rate, momentum=0.9, nesterov=True)
390 |
391 | max_lr = self.max_lr
392 | iterations = len(dataset) // batch_size
393 | if abs(CLR) == 1:
394 | iterations *= epoch
395 | lr_mult = (max_lr/learning_rate)**(1/iterations)
396 | current_rate = learning_rate
397 |
398 | checkpoint_metric, checkpoint_loss, it, k, cooldown = -np.inf, np.inf, 0, 1, 0
399 | self.history = {"loss":{"train":[], "test":[]}, "metric":{"train":[], "test":[]}}
400 |
401 | for e in range(epoch):
402 | torch.cuda.empty_cache()
403 | self.model.train()
404 |
405 | if e >= 2 and self.freeze_encoder and self.model_type != "mobilenetv2":
406 | self.freeze_encoder = False
407 | self.dfs_freeze(self.model.conv1)
408 | self.dfs_freeze(self.model.conv2)
409 | self.dfs_freeze(self.model.conv3)
410 | self.dfs_freeze(self.model.conv4)
411 | self.dfs_freeze(self.model.conv5)
412 |
413 | if self.ADAM:
414 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
415 | lr=current_rate)
416 | else:
417 | self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()),
418 | lr=current_rate, momentum=0.9, nesterov=True)
419 |
420 | if reduce_lr_patience and reduce_lr_factor:
421 | if not np.isinf(checkpoint_loss):
422 | if self.history["loss"]["test"][-1] >= checkpoint_loss:
423 | cooldown += 1
424 |
425 | if cooldown == reduce_lr_patience:
426 | learning_rate *= reduce_lr_factor; max_lr *= reduce_lr_factor
427 | lr_mult = (max_lr/learning_rate)**(1/iterations)
428 | cooldown = 0
429 | print(" [INFO] Learning rate has been reduced to: %.7f " % learning_rate)
430 |
431 | dataset.clear_buff()
432 | min_train_loss, train_loss, train_metric = np.inf, [], []
433 | for image, mask, mask_w in tqdm(data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers=0)):
434 | image = image.type(torch.FloatTensor).cuda(self.device_idx)
435 |
436 | if abs(CLR):
437 | it += 1; exp = it
438 | if CLR > 0:
439 | exp = iterations*k - it + 1
440 | current_rate = learning_rate * (lr_mult**exp)
441 |
442 | if abs(CLR) > 1:
443 | if iterations*k / it == 1:
444 | it = 0; k *= abs(CLR)
445 | if max_lr_decay < 1.:
446 | max_lr *= max_lr_decay
447 | lr_mult = (max_lr/learning_rate)**(1/(iterations*k))
448 |
449 | #re-init. optimzer to reset internal state
450 | if self.ADAM:
451 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=current_rate)
452 | else:
453 | self.optimizer = torch.optim.SGD(self.model.parameters(),
454 | lr=current_rate, momentum=0.9, nesterov=True)
455 |
456 | y_pred = self.model(Variable(image))
457 | if self.model_type == "mobilenetv2":
458 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True)
459 |
460 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight,
461 | weight=mask_w.cuda(self.device_idx), mode="dice")
462 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx)))
463 |
464 | self.optimizer.zero_grad()
465 | loss.backward()
466 |
467 | #adjust learning rate and weights decay
468 | for param_group in self.optimizer.param_groups:
469 | try: param_group['lr'] = current_lr
470 | except: pass
471 | if wd:
472 | for param in param_group['params']:
473 | param.data = param.data.add(-wd * param_group['lr'], param.data)
474 |
475 | self.optimizer.step()
476 | if loss.item() < min_train_loss:
477 | min_train_loss = loss.item()
478 | train_loss.append(loss.item())
479 | train_metric.append(get_metric((y_pred.cpu() > 0.).float(), mask))
480 |
481 | del y_pred; del image; del mask_w; del mask; del loss
482 |
483 | dataset_val.clear_buff()
484 | torch.cuda.empty_cache()
485 | self.model.eval()
486 | val_loss, val_metric = [], []
487 | for image, mask, mask_w in data.DataLoader(dataset_val, batch_size = batch_size // 2, shuffle = False, num_workers=0):
488 | image = image.cuda(self.device_idx)
489 |
490 | y_pred = self.model(Variable(image))
491 | if self.model_type == "mobilenetv2":
492 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True)
493 |
494 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight,
495 | weight=mask_w.cuda(self.device_idx), mode="dice", eps=1.)
496 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx)))
497 |
498 | val_loss.append(loss.item())
499 | val_metric.append(get_metric((y_pred.cpu() > 0.).float(), mask))
500 |
501 | del y_pred; del image; del mask_w; del mask; del loss
502 |
503 | train_loss, train_metric, val_loss, val_metric = \
504 | np.mean(train_loss), np.mean(train_metric), np.mean(val_loss), np.mean(val_metric)
505 |
506 | if val_loss < checkpoint_loss:
507 | save_checkpoint(self.path+'/%s_checkpoint_loss.pth' % (self.model_name), self.model, self.optimizer)
508 | checkpoint_loss = val_loss
509 |
510 | if val_metric > checkpoint_metric:
511 | save_checkpoint(self.path+'/%s_checkpoint_metric.pth' % (self.model_name), self.model, self.optimizer)
512 | checkpoint_metric = val_metric
513 |
514 | self.history["loss"]["train"].append(train_loss)
515 | self.history["loss"]["test"].append(val_loss)
516 | self.history["metric"]["train"].append(train_metric)
517 | self.history["metric"]["test"].append(val_metric)
518 |
519 | message = "Epoch: %d, Train loss: %.3f, Train metric: %.3f, Val loss: %.3f, Val metric: %.3f" % (
520 | e, train_loss, train_metric, val_loss, val_metric)
521 | print(message); os.system("echo " + message)
522 |
523 | self.current_epoch = e
524 | save_checkpoint(self.path+'/last_checkpoint.pth', self.model, self.optimizer)
525 |
526 | pickle.dump(self.history, open(self.path+'/history.pickle.dat', 'wb'))
527 |
528 | def plot_trainer_history(self, mode="metric", save_only=False):
529 | if not save_only:
530 | plt.show(block=False)
531 | plt.plot(self.history[mode]["train"], label="train")
532 | plt.plot(self.history[mode]["test"], label="val")
533 | plt.xlabel("epoch")
534 | plt.ylabel(mode)
535 | plt.grid(True)
536 | plt.legend(loc="best")
537 | plt.savefig(self.path + '/%s_history.png' % mode)
538 |
539 | def load_state(self, path=None, mode="metric", load_optimizer=True):
540 | if load_optimizer: load_optimizer = self.optimizer
541 | if path is None:
542 | path = self.path+'/%s_checkpoint_%s.pth' % (self.model_name, mode)
543 | load_checkpoint(path, self.model, load_optimizer, self.cpu)
544 |
545 | def predict_mask(self, imgs, biggest_side=None, denoise_borders=False):
546 | if not self.cpu:
547 | torch.cuda.empty_cache()
548 | if imgs.ndim < 4:
549 | imgs = np.expand_dims(imgs, axis=0)
550 | l, h, w, c = imgs.shape
551 | w_n, h_n = w, h
552 | if biggest_side is not None:
553 | w_n = int(w/h * min(biggest_side, h))
554 | h_n = min(biggest_side, h)
555 |
556 | wd, hd = w_n % 32, h_n % 32
557 | if wd != 0: w_n += 32 - wd
558 | if hd != 0: h_n += 32 - hd
559 |
560 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
561 | all_predictions = []
562 | for i in range(imgs.shape[0]):
563 | img = self.norm(cv2.resize(imgs[i], (w_n, h_n), interpolation=cv2.INTER_LANCZOS4))
564 | img = img.unsqueeze_(0)
565 | if not self.cpu:
566 | img = img.type(torch.FloatTensor).cuda(self.device_idx)
567 | else:
568 | img = img.type(torch.FloatTensor)
569 | output = self.model(Variable(img))
570 | if self.model_type == "mobilenetv2":
571 | output = nn.functional.interpolate(output, scale_factor=2, mode='bilinear', align_corners=True)
572 | output = torch.sigmoid(output)
573 | output = output.cpu().data.numpy()
574 | y_pred = np.squeeze(output[0])
575 | y_pred = remove_small_holes(remove_small_objects(y_pred > .3))
576 | y_pred = (y_pred * 255).astype(np.uint8)
577 | y_pred = cv2.resize(y_pred, (w, h), interpolation=cv2.INTER_LANCZOS4)
578 |
579 | _,alpha = cv2.threshold(y_pred.astype(np.uint8),0,255,cv2.THRESH_BINARY)
580 | b, g, r = cv2.split(imgs[i])
581 | bgra = [r,g,b, alpha]
582 | y_pred = cv2.merge(bgra,4)
583 | if denoise_borders:
584 | #denoise mask borders
585 | y_pred[:, :, -1] = cv2.morphologyEx(y_pred[:, :, -1], cv2.MORPH_OPEN, kernel)
586 | all_predictions.append(y_pred)
587 | return all_predictions
588 |
589 | def split_video(filename, frame_rate=12):
590 | vidcap = cv2.VideoCapture(filename)
591 | frames = []
592 | succ, frame = vidcap.read()
593 | h, w = frame.shape[:2]
594 | center = (w / 2, h / 2)
595 | while succ:
596 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
597 | frame = np.transpose(frame[:, ::-1, :], axes=[1,0,2])
598 | frames.append(frame)
599 | succ, frame = vidcap.read()
600 | return np.array(frames).astype(np.uint8)[::24 // frame_rate]
601 |
602 | def factorial(n):
603 | if n == 0:
604 | return 1
605 | else:
606 | return n * factorial(n-1)
607 |
608 | def n_unique_permuts(n, r):
609 | return factorial(n) / (factorial(r)*factorial(n-r))
610 |
611 | def save_images(out, path="./data/gif_test"):
612 | letters = string.ascii_lowercase
613 | r = 0; n_uniques = 0
614 | while n_uniques < len(out):
615 | r += 1
616 | n_uniques = n_unique_permuts(len(letters), r)
617 | names = list(itertools.combinations(letters, r))
618 | for im, fname in zip(out, names[:len(out)]):
619 | cv2.imwrite(path+"/%s.png" % ("".join(fname)), im)
--------------------------------------------------------------------------------