├── .gitignore ├── CoreML_convert.py ├── README.md ├── model.py ├── pics ├── VID_edited.gif ├── VID_orig.gif ├── ex_2_orig.png ├── ex_2_transformed.png ├── ex_3_edited_mask.png ├── ex_3_orig_mask.png ├── example_1.png ├── girl_ex_blured.png ├── girl_ex_orig.png ├── mobilenetV2_loss.png ├── mobilenetV2_metric.png ├── resnet101_loss.png └── resnet101_metric.png ├── predict.py ├── requirements.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /test/ 2 | /models/ 3 | .vscode 4 | 5 | -------------------------------------------------------------------------------- /CoreML_convert.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | 4 | import onnx 5 | import torch 6 | from onnx import onnx_pb 7 | from onnx_coreml import convert 8 | 9 | from model import * 10 | 11 | #https://github.com/akirasosa/mobile-semantic-segmentation/blob/master/coreml_converter.py 12 | 13 | # python3 CoreML_convert.py --tmp_onnx ./models/tmp.onnx --weights_path ./models/mobilenetV2_model/mobilenetV2_model_checkpoint_metric.pth 14 | 15 | def init_unet(state_dict): 16 | model = UnetMobilenetV2(pretrained=False, num_classes=1, num_filters=32, Dropout=.2) 17 | model.load_state_dict(state_dict["state_dict"]) 18 | return model 19 | 20 | parser = argparse.ArgumentParser(description='crnn_ctc_loss') 21 | parser.add_argument('--tmp_onnx', type=str, required=True) 22 | parser.add_argument('--weights_path', type=str, required=True) 23 | parser.add_argument('--img_H', type=int, default= 320) 24 | parser.add_argument('--img_W', type=int, default= 256) 25 | args = parser.parse_args() 26 | globals().update(vars(args)) 27 | 28 | coreml_path = re.sub('\.pth$', '.mlmodel', weights_path) 29 | 30 | #convert and save ONNX 31 | model = init_unet(torch.load(weights_path, map_location=lambda storage, loc: storage)) 32 | torch.onnx.export(model, 33 | torch.randn(1, 3, img_H, img_W), 34 | tmp_onnx) 35 | 36 | # Convert ONNX to CoreML model 37 | model_file = open(tmp_onnx, 'rb') 38 | model_proto = onnx_pb.ModelProto() 39 | model_proto.ParseFromString(model_file.read()) 40 | # 595 is the identifier of output. 41 | coreml_model = convert(model_proto, 42 | image_input_names=['0'], 43 | image_output_names=['595']) 44 | coreml_model.save(coreml_path) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PicsArtHack-binary-segmentation 2 | 3 | The goal of the hackathon was to build some image processing algorithm which can be helpful for [PicsArt](https://picsart.com/?hl=en) applications. 4 | Here I publish results of the first stage: segmenting people on selfies. 5 | PicsArt gave us labeled dataset. [Dice](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) coef. was used as evaluation metric. 6 | I noticed that a lot of images has been labeled by another segmentation model due to a lot of artifacts around the masks borders. Also in test dataset appears copies of train set images. So after training, I did not expect good results on images "from the wild". 7 | 8 | ### 1. Loss 9 | For this problem I used fairly common bce-dice loss. So the algorithm is simple: we take a logits output from model and put it inside binary cross-enthropy loss and the natural logarithm of dice loss (after passing sigmoid function). After that we only need to combine these losses with weights: 10 | ``` 11 | dice_loss = (2. * intersection + eps) / (union + eps) 12 | loss = w * BCELoss + (1 - w) * log(dice_loss) * (-1) 13 | ``` 14 | Also, in this case, we don't need to tune tresholds of final pseudo-probabilities (after sigmoid). 15 | Finally we can adjust weights to the mask (I did it inside BCELoss), to penalize model for mistakes around the mask borders. For this purpose we can use opencv erosion kernel-operation: 16 | ``` 17 | def get_mask_weight(mask): 18 | mask_ = cv2.erode(mask, kernel=np.ones((8,8),np.uint8), iterations=1) 19 | mask_ = mask-mask_ 20 | return mask_ + 1 21 | ``` 22 | On the picture below we can see how the input data looks like: 23 | 24 | ### 2. Training 25 | I used modification of **unet** (which is well recommended for solving binary semantic segmentation problems) with two encoders pretrained on Imagenet: resnet101 and [mobilenetV2](https://github.com/tonylins/pytorch-mobilenet-v2). One of the goals was to compare the performance of "heavy" and "light" encoders. 26 | You can check all training params inside `train.py`. 27 | 28 | ``` 29 | python3 train.py --train_path ./data/train_data --workdir ./data/ --model_type mobilenetV2 30 | ``` 31 | 32 | Data augmentation was provided via brilliant [albumentaions](https://github.com/albu/albumentations) library. 33 | Inside the `utils.py` code you can find learning rate scheduling, encoder weights freezeing and some other useful hacks which can help to train networks in more efficient way. Also passing the parameter `model_type` you are able to choose one of the predefined models based on: resnet18, resnet34, resnet50, resnet101, mobilenetV2. 34 | 35 | So, in the end I've got two trained models with close Dice values on a validation set. Here is a few numbers: 36 | 37 | Encoder: | ResNet101 | MobileNetV2 38 | :-------------------------:|:-------------------------:|:-------------------------: 39 | epochs (best of 200) | 177 | 173 40 | Dice | 0.987 (0.988) | 0.986 (0.988) 41 | loss | 0.029 (0.022) | 0.030 (0.024) 42 | No. of parameters | 120 131 745 | 4 682 912 43 | 44 | ResNet101 evaluation process: 45 | 46 | MobileNetV2 evaluation process: 47 | 48 | 49 | Despite the fact that mobilenetV2 has ~x26 less weights and at the same time we are able to get models with pretty similar quality, we did it **with this particullar problem using this particullar dataset**. So I don't think it extends on any other classification problem. 50 | 51 | ### 3. Tests 52 | Inference time comparison on my work-station with input images 320x256 from the test-set: 53 | 54 | Device | ResNet101 | MobileNetV2 55 | :-------------------------:|:-------------------------:|:-------------------------: 56 | AMD Threadripper 1900X CPU (4 threads) | 2.08 s ± 7.58 ms | 345 ms ± 3.21 ms 57 | GTX 1080Ti GPU | 31.6 ms ± 897 µs | 22 ms ± 622 µs 58 | 59 | Often, output masks contain some noise on the borders (which is become more annoying on large images), so we can try to fix it applying morhological transform: 60 | ``` 61 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5)) 62 | y_pred[:, :, -1] = cv2.morphologyEx(y_pred[:, :, -1], cv2.MORPH_OPEN, kernel) 63 | ``` 64 | Original | Transformed 65 | :-------------------------:|:-------------------------: 66 | | 67 | 68 | Additionaly we can transform segmented images. For instance let's make a gaussian blur of a background: 69 | ``` 70 | blurred = cv2.GaussianBlur(test_dataset[n],(21,21),0) 71 | dst = cv2.bitwise_and(blurred, blurred, mask=~out[0][:, :, -1]) 72 | dst = cv2.add(cv2.bitwise_and(test_dataset[n], test_dataset[n], mask=out[0][:, :, -1]), dst) 73 | ``` 74 | 75 | 76 | And actually we can process videos too (see `predict.py`). Example below is a video made by me with a cellphone (original image size: 800x450): 77 | 78 | 79 | 80 | These results has been obtained with mobilenetV2 model. You can play with it too, here is it's [weights and CoreML models](https://drive.google.com/file/d/1XSRaOaoWKKSllIuUgkW0BVsMKieQ8mbG/view?usp=sharing). 81 | 82 | ``` 83 | python3 predict.py -p ./test --model_path ./models/mobilenetV2_model --gpu -1 --frame_rate 12 --denoise_borders --biggest_side 320 84 | ``` 85 | This script reads all the data inside `-p` folder: both pictures and videos. 86 | 87 | ### 4. Porting model to IOS device 88 | Finally, we can convert trained mobilenetV2 model with CoreML to make inference on the IOS devices. The pipeline is simple: torch --> ONNX --> CoreML. To make this happen, don't keep encoder layers separatly inside the model class - use them in forward pass. Also, with the certain versions of torch and onnx (see `requirements.txt`), you can't convert upsampling / interpolation layers (so place them outside the model, as a post-processing step). Hope it will be fixed in the future releases. 89 | 90 | ``` 91 | python3 CoreML_convert.py --tmp_onnx ./models/tmp.onnx --weights_path ./models/mobilenetV2_model/mobilenetV2_model_checkpoint_metric.pth 92 | ``` 93 | 94 | ### 5. Environment 95 | For your own experiments I highly recommend to use [Deepo](https://github.com/ufoym/deepo) as a fast way to deploy universal deep-learning environment inside a Docker container. Other dependencies can be found in `requirements.txt`. 96 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn, cat 5 | import torchvision 6 | 7 | class ConvRelu(nn.Module): 8 | def __init__(self, in_: int, out: int, activate=True): 9 | super(ConvRelu, self).__init__() 10 | self.activate = activate 11 | self.conv = nn.Conv2d(in_, out, 3, padding=1) 12 | self.activation = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.activate: 17 | x = self.activation(x) 18 | return x 19 | 20 | class ResidualBlock(nn.Module): 21 | 22 | def __init__(self, in_channels: int, num_filters: int, batch_activate=False): 23 | super(ResidualBlock, self).__init__() 24 | self.batch_activate = batch_activate 25 | self.activation = nn.ReLU(inplace=True) 26 | self.conv_block = ConvRelu(in_channels, num_filters, activate=True) 27 | self.conv_block_na = ConvRelu(in_channels, num_filters, activate=False) 28 | self.activation = nn.ReLU(inplace=True) 29 | 30 | def forward(self, inp): 31 | x = self.conv_block(inp) 32 | x = self.conv_block_na(x) 33 | x = x.add(inp) 34 | if self.batch_activate: 35 | x = self.activation(x) 36 | return x 37 | 38 | class DecoderBlockResnet(nn.Module): 39 | """ 40 | Paramaters for Deconvolution were chosen to avoid artifacts, following 41 | link https://distill.pub/2016/deconv-checkerboard/ 42 | """ 43 | 44 | def __init__(self, in_channels, middle_channels, out_channels): 45 | super(DecoderBlockResnet, self).__init__() 46 | self.in_channels = in_channels 47 | 48 | self.block = nn.Sequential( 49 | ConvRelu(in_channels, middle_channels, activate=True), 50 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1), 51 | nn.ReLU(inplace=True) 52 | ) 53 | 54 | def forward(self, x): 55 | return self.block(x) 56 | 57 | class UnetResNet(nn.Module): 58 | 59 | def __init__(self, num_classes=1, num_filters=32, pretrained=True, Dropout=.2, model="resnet50"): 60 | 61 | super().__init__() 62 | if model == "resnet18": 63 | self.encoder = torchvision.models.resnet18(pretrained=pretrained) 64 | elif model == "resnet34": 65 | self.encoder = torchvision.models.resnet34(pretrained=pretrained) 66 | elif model == "resnet50": 67 | self.encoder = torchvision.models.resnet50(pretrained=pretrained) 68 | elif model == "resnet101": 69 | self.encoder = torchvision.models.resnet101(pretrained=pretrained) 70 | 71 | if model in ["resnet18", "resnet34"]: model = "resnet18-34" 72 | else: model = "resnet50-101" 73 | 74 | self.filters_dict = { 75 | "resnet18-34": [512, 512, 256, 128, 64], 76 | "resnet50-101": [2048, 2048, 1024, 512, 256] 77 | } 78 | 79 | self.num_classes = num_classes 80 | self.Dropout = Dropout 81 | self.pool = nn.MaxPool2d(2, 2) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv1 = nn.Sequential(self.encoder.conv1, 84 | self.encoder.bn1, 85 | self.encoder.relu, 86 | self.pool) 87 | self.conv2 = self.encoder.layer1 88 | self.conv3 = self.encoder.layer2 89 | self.conv4 = self.encoder.layer3 90 | self.conv5 = self.encoder.layer4 91 | 92 | self.center = DecoderBlockResnet(self.filters_dict[model][0], num_filters * 8 * 2, 93 | num_filters * 8) 94 | self.dec5 = DecoderBlockResnet(self.filters_dict[model][1] + num_filters * 8, 95 | num_filters * 8 * 2, num_filters * 8) 96 | self.dec4 = DecoderBlockResnet(self.filters_dict[model][2] + num_filters * 8, 97 | num_filters * 8 * 2, num_filters * 8) 98 | self.dec3 = DecoderBlockResnet(self.filters_dict[model][3] + num_filters * 8, 99 | num_filters * 4 * 2, num_filters * 2) 100 | self.dec2 = DecoderBlockResnet(self.filters_dict[model][4] + num_filters * 2, 101 | num_filters * 2 * 2, num_filters * 2 * 2) 102 | 103 | self.dec1 = DecoderBlockResnet(num_filters * 2 * 2, num_filters * 2 * 2, num_filters) 104 | self.dec0 = ConvRelu(num_filters, num_filters) 105 | 106 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 107 | self.dropout_2d = nn.Dropout2d(p=self.Dropout) 108 | 109 | 110 | def forward(self, x, z=None): 111 | conv1 = self.conv1(x) 112 | conv2 = self.dropout_2d(self.conv2(conv1)) 113 | conv3 = self.dropout_2d(self.conv3(conv2)) 114 | conv4 = self.dropout_2d(self.conv4(conv3)) 115 | conv5 = self.dropout_2d(self.conv5(conv4)) 116 | 117 | center = self.center(self.pool(conv5)) 118 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 119 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 120 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 121 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 122 | dec2 = self.dropout_2d(dec2) 123 | 124 | dec1 = self.dec1(dec2) 125 | dec0 = self.dec0(dec1) 126 | 127 | return self.final(dec0) 128 | 129 | ########################################################################### 130 | # Mobile Net 131 | ########################################################################### 132 | 133 | def conv_bn(inp, oup, stride): 134 | return nn.Sequential( 135 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 136 | nn.BatchNorm2d(oup), 137 | nn.ReLU6(inplace=True) 138 | ) 139 | 140 | def conv_1x1_bn(inp, oup): 141 | return nn.Sequential( 142 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 143 | nn.BatchNorm2d(oup), 144 | nn.ReLU6(inplace=True) 145 | ) 146 | 147 | class InvertedResidual(nn.Module): 148 | def __init__(self, inp, oup, stride, expand_ratio): 149 | super(InvertedResidual, self).__init__() 150 | self.stride = stride 151 | assert stride in [1, 2] 152 | 153 | hidden_dim = round(inp * expand_ratio) 154 | self.use_res_connect = self.stride == 1 and inp == oup 155 | 156 | if expand_ratio == 1: 157 | self.conv = nn.Sequential( 158 | # dw 159 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 160 | nn.BatchNorm2d(hidden_dim), 161 | nn.ReLU6(inplace=True), 162 | # pw-linear 163 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 164 | nn.BatchNorm2d(oup), 165 | ) 166 | else: 167 | self.conv = nn.Sequential( 168 | # pw 169 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 170 | nn.BatchNorm2d(hidden_dim), 171 | nn.ReLU6(inplace=True), 172 | # dw 173 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 174 | nn.BatchNorm2d(hidden_dim), 175 | nn.ReLU6(inplace=True), 176 | # pw-linear 177 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 178 | nn.BatchNorm2d(oup), 179 | ) 180 | 181 | def forward(self, x): 182 | if self.use_res_connect: 183 | return x + self.conv(x) 184 | else: 185 | return self.conv(x) 186 | 187 | class MobileNetV2(nn.Module): 188 | 189 | """ 190 | from MobileNetV2 import MobileNetV2 191 | 192 | net = MobileNetV2(n_class=1000) 193 | state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu 194 | net.load_state_dict(state_dict) 195 | """ 196 | 197 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 198 | super(MobileNetV2, self).__init__() 199 | block = InvertedResidual 200 | input_channel = 32 201 | last_channel = 1280 202 | interverted_residual_setting = [ 203 | # t, c, n, s 204 | [1, 16, 1, 1], 205 | [6, 24, 2, 2], 206 | [6, 32, 3, 2], 207 | [6, 64, 4, 2], 208 | [6, 96, 3, 1], 209 | [6, 160, 3, 2], 210 | [6, 320, 1, 1], 211 | ] 212 | 213 | # building first layer 214 | assert input_size % 32 == 0 215 | input_channel = int(input_channel * width_mult) 216 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 217 | self.features = [conv_bn(3, input_channel, 2)] 218 | # building inverted residual blocks 219 | for t, c, n, s in interverted_residual_setting: 220 | output_channel = int(c * width_mult) 221 | for i in range(n): 222 | if i == 0: 223 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 224 | else: 225 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 226 | input_channel = output_channel 227 | # building last several layers 228 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 229 | # make it nn.Sequential 230 | self.features = nn.Sequential(*self.features) 231 | 232 | # building classifier 233 | self.classifier = nn.Sequential( 234 | nn.Dropout(0.2), 235 | nn.Linear(self.last_channel, n_class), 236 | ) 237 | 238 | self._initialize_weights() 239 | 240 | def forward(self, x): 241 | x = self.features(x) 242 | x = x.mean(3).mean(2) 243 | x = self.classifier(x) 244 | return x 245 | 246 | def _initialize_weights(self): 247 | for m in self.modules(): 248 | if isinstance(m, nn.Conv2d): 249 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 250 | m.weight.data.normal_(0, math.sqrt(2. / n)) 251 | if m.bias is not None: 252 | m.bias.data.zero_() 253 | elif isinstance(m, nn.BatchNorm2d): 254 | m.weight.data.fill_(1) 255 | m.bias.data.zero_() 256 | elif isinstance(m, nn.Linear): 257 | n = m.weight.size(1) 258 | m.weight.data.normal_(0, 0.01) 259 | m.bias.data.zero_() 260 | 261 | class UnetMobilenetV2(nn.Module): 262 | def __init__(self, num_classes=1, num_filters=32, pretrained=True, 263 | Dropout=.2, path='./data/mobilenet_v2.pth.tar'): 264 | super(UnetMobilenetV2, self).__init__() 265 | 266 | self.encoder = MobileNetV2(n_class=1000) 267 | 268 | self.num_classes = num_classes 269 | 270 | self.dconv1 = nn.ConvTranspose2d(1280, 96, 4, padding=1, stride=2) 271 | self.invres1 = InvertedResidual(192, 96, 1, 6) 272 | 273 | self.dconv2 = nn.ConvTranspose2d(96, 32, 4, padding=1, stride=2) 274 | self.invres2 = InvertedResidual(64, 32, 1, 6) 275 | 276 | self.dconv3 = nn.ConvTranspose2d(32, 24, 4, padding=1, stride=2) 277 | self.invres3 = InvertedResidual(48, 24, 1, 6) 278 | 279 | self.dconv4 = nn.ConvTranspose2d(24, 16, 4, padding=1, stride=2) 280 | self.invres4 = InvertedResidual(32, 16, 1, 6) 281 | 282 | self.conv_last = nn.Conv2d(16, 3, 1) 283 | 284 | self.conv_score = nn.Conv2d(3, 1, 1) 285 | 286 | #doesn't needed; obly for compatibility 287 | self.dconv_final = nn.ConvTranspose2d(1, 1, 4, padding=1, stride=2) 288 | 289 | if pretrained: 290 | state_dict = torch.load(path) 291 | self.encoder.load_state_dict(state_dict) 292 | else: self._init_weights() 293 | 294 | def forward(self, x): 295 | for n in range(0, 2): 296 | x = self.encoder.features[n](x) 297 | x1 = x 298 | 299 | for n in range(2, 4): 300 | x = self.encoder.features[n](x) 301 | x2 = x 302 | 303 | for n in range(4, 7): 304 | x = self.encoder.features[n](x) 305 | x3 = x 306 | 307 | for n in range(7, 14): 308 | x = self.encoder.features[n](x) 309 | x4 = x 310 | 311 | for n in range(14, 19): 312 | x = self.encoder.features[n](x) 313 | x5 = x 314 | 315 | up1 = torch.cat([ 316 | x4, 317 | self.dconv1(x) 318 | ], dim=1) 319 | up1 = self.invres1(up1) 320 | 321 | up2 = torch.cat([ 322 | x3, 323 | self.dconv2(up1) 324 | ], dim=1) 325 | up2 = self.invres2(up2) 326 | 327 | up3 = torch.cat([ 328 | x2, 329 | self.dconv3(up2) 330 | ], dim=1) 331 | up3 = self.invres3(up3) 332 | 333 | up4 = torch.cat([ 334 | x1, 335 | self.dconv4(up3) 336 | ], dim=1) 337 | up4 = self.invres4(up4) 338 | x = self.conv_last(up4) 339 | x = self.conv_score(x) 340 | 341 | return x 342 | 343 | def _init_weights(self): 344 | for m in self.modules(): 345 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 346 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 347 | m.weight.data.normal_(0, math.sqrt(2. / n)) 348 | if m.bias is not None: 349 | m.bias.data.zero_() 350 | elif isinstance(m, nn.BatchNorm2d): 351 | m.weight.data.fill_(1) 352 | m.bias.data.zero_() 353 | elif isinstance(m, nn.Linear): 354 | m.weight.data.normal_(0, 0.01) 355 | m.bias.data.zero_() -------------------------------------------------------------------------------- /pics/VID_edited.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/VID_edited.gif -------------------------------------------------------------------------------- /pics/VID_orig.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/VID_orig.gif -------------------------------------------------------------------------------- /pics/ex_2_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_2_orig.png -------------------------------------------------------------------------------- /pics/ex_2_transformed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_2_transformed.png -------------------------------------------------------------------------------- /pics/ex_3_edited_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_3_edited_mask.png -------------------------------------------------------------------------------- /pics/ex_3_orig_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/ex_3_orig_mask.png -------------------------------------------------------------------------------- /pics/example_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/example_1.png -------------------------------------------------------------------------------- /pics/girl_ex_blured.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/girl_ex_blured.png -------------------------------------------------------------------------------- /pics/girl_ex_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/girl_ex_orig.png -------------------------------------------------------------------------------- /pics/mobilenetV2_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/mobilenetV2_loss.png -------------------------------------------------------------------------------- /pics/mobilenetV2_metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/mobilenetV2_metric.png -------------------------------------------------------------------------------- /pics/resnet101_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/resnet101_loss.png -------------------------------------------------------------------------------- /pics/resnet101_metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gasparian/PicsArtHack-binary-segmentation/ecab001f334949d5082a79b8fbd1dc2fdb8b093e/pics/resnet101_metric.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import sys 4 | import time 5 | import datetime 6 | import subprocess 7 | import argparse 8 | 9 | import numpy as np 10 | import cv2 11 | 12 | from utils import * 13 | 14 | # python3 predict.py -p ./test --model_path ./models/mobilenetV2_model --gpu -1 --frame_rate 12 --denoise_borders --biggest_side 320 15 | 16 | start = time.time() 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-p', '--data_path', type=str, required=True) 20 | parser.add_argument('--model_path', type=str, required=True) 21 | parser.add_argument('--gpu', type=int, default=-1, required=False) 22 | parser.add_argument('--biggest_side', type=int, default=0, required=False) 23 | parser.add_argument('--delay', type=int, default=7, required=False) 24 | parser.add_argument('--frame_rate', type=int, default=12, required=False) 25 | parser.add_argument('--denoise_borders', action='store_true') 26 | args = parser.parse_args() 27 | globals().update(vars(args)) 28 | 29 | biggest_side = None if not biggest_side else biggest_side 30 | delay = round(100/frame_rate + .5) 31 | 32 | trainer = Trainer(path=model_path, gpu=gpu) 33 | if gpu < 0: 34 | torch.set_num_threads(2) 35 | trainer.load_state(mode="metric") 36 | trainer.model.eval() 37 | 38 | files_list = os.listdir(data_path) 39 | 40 | images, vids = [], [] 41 | if files_list: 42 | for fname in files_list: 43 | if fname.split(".")[-1] != "mp4": images.append(fname) 44 | elif fname.split(".")[-1] == "mp4": vids.append(fname) 45 | 46 | if images: 47 | for fname in images: 48 | img = cv2.imread(data_path+"/"+fname) 49 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 50 | img = np.array(img, dtype=np.uint8) 51 | out = trainer.predict_mask(img, biggest_side=biggest_side, denoise_borders=denoise_borders) 52 | cv2.imwrite('%s/%s_seg.png' % (data_path, fname.split(".")[0]), out[0]) 53 | print(" [INFO] Images processed! ") 54 | 55 | if vids: 56 | for fname in vids: 57 | imgs = split_video(data_path+"/"+fname, frame_rate=frame_rate) 58 | out = trainer.predict_mask(imgs, biggest_side=biggest_side, denoise_borders=denoise_borders) 59 | vpath = data_path+"/%s" % fname.split(".")[0] 60 | os.mkdir(vpath) 61 | save_images(out, path=vpath) 62 | os.system(f"convert -delay {delay} -loop 0 -dispose Background {vpath}/*.png {vpath}/{fname.split('.')[0]}.gif") 63 | print(" [INFO] Videos processed! ") 64 | 65 | print(" [INFO] %s ms. " % round((time.time()-start)*1000, 0)) -------------------------------------------------------------------------------- /requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | apt update && apt install ffmpeg imagemagick 3 | pip3 --no-cache-dir install --upgrade tqdm==4.28.1 \ 4 | numpy==1.14.3 \ 5 | scikit-image==0.13.1 \ 6 | albumentations==0.1.7 \ 7 | opencv-python==3.4.3.18 \ 8 | torch==0.4.1 \ 9 | torchvision==0.2.1 \ 10 | onnx==1.3.0 \ 11 | six==1.10.0 \ 12 | onnx-coreml -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from numpy.random import RandomState 5 | 6 | from model import * 7 | from utils import * 8 | 9 | # python3 train.py --train_path ./data/train_data --workdir ./data/ --model_type mobilenetV2 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--train_path', type=str, required=True) 13 | parser.add_argument('--workdir', type=str, required=True) 14 | parser.add_argument('--model_type', default="mobilenetV2", type=str) 15 | parser.add_argument('--batch_size', default=32, type=int) 16 | parser.add_argument('--max_lr', default=.5, type=float) 17 | parser.add_argument('--loss_window', default=10, type=int) 18 | parser.add_argument('--loss_growth_trsh', default=.5, type=float) 19 | parser.add_argument('--alpha', default=.1, type=float) 20 | parser.add_argument('--wd', default=0., type=float) 21 | parser.add_argument('--freeze_encoder', default=False, type=bool) 22 | parser.add_argument('--max_lr_decay', default=.8, type=float) 23 | parser.add_argument('--epoch', default=200, type=int) 24 | parser.add_argument('--learning_rate', default=1e-4, type=float) 25 | parser.add_argument('--bce_loss_weight', default=.5, type=float) 26 | parser.add_argument('--reduce_lr_patience', default=0, type=int) 27 | parser.add_argument('--reduce_lr_factor', default=0, type=int) 28 | parser.add_argument('--CLR', default=0, type=int) 29 | args = parser.parse_args() 30 | 31 | path_images = list(map( 32 | lambda x: x.split('.')[0], 33 | filter(lambda x: x.endswith('.jpg'), os.listdir(args["train_path"])))) 34 | prng = RandomState(42) 35 | 36 | path_images *= 3 37 | prng.shuffle(path_images) 38 | train_split = int(len(path_images)*.8) 39 | train_images, val_images = path_images[:train_split], path_images[train_split:] 40 | 41 | dataset = DatasetProcessor( 42 | args["train_path"], train_images, as_torch_tensor=True, augmentations=True, mask_weight=True) 43 | dataset_val = DatasetProcessor( 44 | args["train_path"], val_images, as_torch_tensor=True, augmentations=True, mask_weight=True) 45 | 46 | model_params = { 47 | "directory":args["workdir"], 48 | "model":args["model_type"], 49 | "model_name":"%s_model" % (args["model_type"]), 50 | "Dropout":.4, 51 | "device_idx":0, 52 | "pretrained":True, 53 | "num_classes":1, 54 | "num_filters":32, 55 | "reset":True, 56 | "ADAM":True 57 | } 58 | 59 | trainer = Trainer(**model_params) 60 | if args["CLR"] != 0: 61 | trainer.LR_finder(dataset, **args) 62 | trainer.show_lr_finder_out(save_only=True) 63 | 64 | trainer.fit(dataset, dataset_val, **args) 65 | trainer.plot_trainer_history(mode="loss", save_only=True) 66 | trainer.plot_trainer_history(mode="metric", save_only=True) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import string 4 | import itertools 5 | import pickle 6 | 7 | from skimage.morphology import remove_small_objects, remove_small_holes 8 | import cv2 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from albumentations import ( 13 | PadIfNeeded, 14 | HorizontalFlip, 15 | VerticalFlip, 16 | CenterCrop, 17 | Crop, 18 | Compose, 19 | Transpose, 20 | RandomRotate90, 21 | ElasticTransform, 22 | GridDistortion, 23 | OpticalDistortion, 24 | RandomSizedCrop, 25 | OneOf, 26 | CLAHE, 27 | RandomContrast, 28 | RandomGamma, 29 | ShiftScaleRotate, 30 | RandomBrightness 31 | ) 32 | 33 | 34 | import torch 35 | from torchvision import transforms 36 | from torch.utils import data 37 | from torch.autograd import Variable 38 | 39 | from model import * 40 | 41 | class DatasetProcessor(data.Dataset): 42 | 43 | def __init__(self, root_path, file_list, is_test=False, as_torch_tensor=True, augmentations=False, mask_weight=True): 44 | self.is_test = is_test 45 | self.mask_weight = mask_weight 46 | self.root_path = root_path 47 | self.file_list = file_list 48 | self.as_torch_tensor = as_torch_tensor 49 | self.augmentations = augmentations 50 | self.norm = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | ]) 55 | self.been = [] 56 | 57 | def clear_buff(self): 58 | self.been = [] 59 | 60 | def __len__(self): 61 | return len(self.file_list) 62 | 63 | def transform(self, image, mask): 64 | aug = Compose([ 65 | HorizontalFlip(p=0.9), 66 | RandomBrightness(p=.5,limit=0.3), 67 | RandomContrast(p=.5,limit=0.3), 68 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20, 69 | p=0.7, border_mode=0, interpolation=4) 70 | ]) 71 | 72 | augmented = aug(image=image, mask=mask) 73 | return augmented['image'], augmented['mask'] 74 | 75 | def get_mask_weight(self, mask): 76 | mask_ = cv2.erode(mask, kernel=np.ones((8,8),np.uint8), iterations=1) 77 | mask_ = mask-mask_ 78 | return mask_ + 1 79 | 80 | def __getitem__(self, index): 81 | 82 | file_id = index 83 | if type(index) != str: 84 | file_id = self.file_list[index] 85 | 86 | image_folder = self.root_path 87 | image_path = os.path.join(image_folder, file_id + ".jpg") 88 | 89 | mask_folder = self.root_path[:-1] + "_mask/" 90 | mask_path = os.path.join(mask_folder, file_id + ".png") 91 | 92 | if self.as_torch_tensor: 93 | 94 | if not self.is_test: 95 | image = cv2.imread(str(image_path)) 96 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 97 | mask = cv2.imread(str(mask_path)) 98 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 99 | 100 | #resize to 320x256 101 | image = cv2.resize(image, (256, 320), interpolation=cv2.INTER_LANCZOS4) 102 | mask = cv2.resize(mask, (256, 320), interpolation=cv2.INTER_LANCZOS4) 103 | 104 | if self.augmentations: 105 | if file_id not in self.been: 106 | self.been.append(file_id) 107 | else: 108 | image, mask = self.transform(image, mask) 109 | 110 | mask = mask // 255 111 | mask = mask[:, :, np.newaxis] 112 | if self.mask_weight: 113 | mask_w = self.get_mask_weight(np.squeeze(mask)) 114 | else: 115 | mask_w = np.ones((mask.shape[:-1])) 116 | mask_w = mask_w[:, :, np.newaxis] 117 | 118 | mask = torch.from_numpy(np.transpose(mask, (2, 0, 1)).astype('float32')) 119 | mask_w = torch.from_numpy(np.transpose(mask_w, (2, 0, 1)).astype('float32')) 120 | image = self.norm(image) 121 | return image, mask, mask_w 122 | 123 | else: 124 | image = cv2.imread(str(image_path)) 125 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 126 | image = cv2.resize(image, (256, 320), interpolation=cv2.INTER_LANCZOS4) 127 | image = self.norm(image) 128 | return image 129 | 130 | else: 131 | image = cv2.imread(str(image_path)) 132 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 133 | image = np.array(image, dtype=np.uint8) 134 | if not self.is_test: 135 | mask = cv2.imread(str(mask_path)) 136 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 137 | if self.augmentations: 138 | if file_id not in self.been: 139 | self.been.append(file_id) 140 | else: 141 | image, mask = self.transform(image, mask) 142 | return image, mask 143 | 144 | else: 145 | if self.augmentations: 146 | if file_id not in self.been: 147 | self.been.append(file_id) 148 | else: 149 | image = self.transform(image) 150 | return image 151 | 152 | def save_checkpoint(checkpoint_path, model, optimizer): 153 | state = {'state_dict': model.state_dict(), 154 | 'optimizer' : optimizer.state_dict()} 155 | torch.save(state, checkpoint_path) 156 | print('model saved to %s' % checkpoint_path) 157 | 158 | def load_checkpoint(checkpoint_path, model, optimizer, cpu): 159 | if cpu: 160 | state = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 161 | else: 162 | state = torch.load(checkpoint_path) 163 | model.load_state_dict(state['state_dict']) 164 | if optimizer: 165 | optimizer.load_state_dict(state['optimizer']) 166 | print('model loaded from %s' % checkpoint_path) 167 | 168 | def jaccard(intersection, union, eps=1e-15): 169 | return (intersection) / (union - intersection + eps) 170 | 171 | def dice(intersection, union, eps=1e-15, smooth=1.): 172 | return (2. * intersection + smooth) / (union + smooth + eps) 173 | 174 | class BCESoftJaccardDice: 175 | 176 | def __init__(self, bce_weight=0.5, mode="dice", eps=1e-7, weight=None, smooth=1.): 177 | self.nll_loss = torch.nn.BCEWithLogitsLoss(weight=weight) 178 | self.bce_weight = bce_weight 179 | self.eps = eps 180 | self.mode = mode 181 | self.smooth = smooth 182 | 183 | def __call__(self, outputs, targets): 184 | loss = self.bce_weight * self.nll_loss(outputs, targets) 185 | 186 | if self.bce_weight < 1.: 187 | targets = (targets == 1).float() 188 | outputs = torch.sigmoid(outputs) 189 | intersection = (outputs * targets).sum() 190 | union = outputs.sum() + targets.sum() 191 | if self.mode == "dice": 192 | score = dice(intersection, union, self.eps, self.smooth) 193 | elif self.mode == "jaccard": 194 | score = jaccard(intersection, union, self.eps) 195 | loss -= (1 - self.bce_weight) * torch.log(score) 196 | return loss 197 | 198 | def get_metric(pred, targets): 199 | batch_size = targets.shape[0] 200 | metric = [] 201 | for batch in range(batch_size): 202 | t, p = targets[batch].squeeze(1), pred[batch].squeeze(1) 203 | if np.count_nonzero(t) == 0 and np.count_nonzero(p) > 0: 204 | metric.append(0) 205 | continue 206 | if np.count_nonzero(t) == 0 and np.count_nonzero(p) == 0: 207 | metric.append(1) 208 | continue 209 | 210 | t = (t == 1).float() 211 | intersection = (p * t).sum() 212 | union = p.sum() + t.sum() 213 | m = dice(intersection, union, eps=1e-15) 214 | metric.append(m) 215 | return np.mean(metric) 216 | 217 | class Trainer: 218 | 219 | def __init__(self, path=None, gpu=-1, **kwargs): 220 | 221 | if path is not None: 222 | kwargs = pickle.load(open(path+"/model_params.pickle.dat", "rb")) 223 | kwargs["device_idx"] = gpu 224 | kwargs["pretrained"], kwargs["reset"] = False, False 225 | self.path = path 226 | else: 227 | self.directory = kwargs["directory"] 228 | self.path = os.path.join(self.directory, self.model_name) 229 | 230 | self.model_name = kwargs["model_name"] 231 | self.model_type = kwargs["model"].lower() 232 | self.device_idx = kwargs["device_idx"] 233 | self.cpu = True if self.device_idx < 0 else False 234 | self.ADAM = kwargs["ADAM"] 235 | self.pretrained = kwargs["pretrained"] 236 | self.norm = transforms.Compose([ 237 | transforms.ToTensor(), 238 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 239 | std=[0.229, 0.224, 0.225]) 240 | ]) 241 | 242 | self.cp_counter_loss, self.cp_counter_metric = 0, 0 243 | self.max_lr = .5 244 | 245 | net_init_params = {k:v for k, v in kwargs.items() 246 | if k in ["Dropout", "pretrained", "num_classes", "num_filters"] 247 | } 248 | 249 | if self.model_type == "mobilenetv2": 250 | self.initial_model = UnetMobilenetV2(**net_init_params) 251 | else: 252 | net_init_params["model"] = self.model_type 253 | self.initial_model = UnetResNet(**net_init_params) 254 | 255 | if kwargs["reset"]: 256 | try: 257 | shutil.rmtree(self.path) 258 | except: 259 | pass 260 | os.mkdir(self.path) 261 | kwargs["reset"] = False 262 | pickle.dump(kwargs, open(self.path+"/model_params.pickle.dat", "wb")) 263 | else: 264 | self.model = self.get_model(self.initial_model) 265 | if self.ADAM: 266 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 267 | else: 268 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-4, momentum=0.9, nesterov=True) 269 | 270 | def dfs_freeze(self, model): 271 | for name, child in model.named_children(): 272 | for param in child.parameters(): 273 | param.requires_grad = False if self.freeze_encoder else True 274 | self.dfs_freeze(child) 275 | 276 | def get_model(self, model): 277 | model = model.train() 278 | if self.cpu: 279 | return model.cpu() 280 | return model.cuda(self.device_idx) 281 | 282 | def LR_finder(self, dataset, **kwargs): 283 | 284 | max_lr = kwargs["max_lr"] 285 | batch_size = kwargs["batch_size"] 286 | learning_rate = kwargs["learning_rate"] 287 | bce_loss_weight = kwargs["bce_loss_weight"] 288 | loss_growth_trsh = kwargs["loss_growth_trsh"] 289 | loss_window = kwargs["loss_window"] 290 | wd = kwargs["wd"] 291 | alpha = kwargs["alpha"] 292 | 293 | torch.cuda.empty_cache() 294 | dataset.clear_buff() 295 | self.model = self.get_model(self.initial_model) 296 | 297 | iterations = len(dataset) // batch_size 298 | it = 0 299 | lr_mult = (max_lr/learning_rate)**(1/iterations) 300 | 301 | if self.ADAM: 302 | optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 303 | else: 304 | optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate, 305 | momentum=0.9, nesterov=True) 306 | 307 | #max LR search 308 | print(" [INFO] Start max. learning rate search... ") 309 | min_loss, self.lr_finder_losses = (np.inf, learning_rate), [[], []] 310 | for image, mask, mask_w in tqdm(data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers=0)): 311 | image = image.type(torch.FloatTensor).cuda(self.device_idx) 312 | 313 | it += 1 314 | current_lr = learning_rate * (lr_mult**it) 315 | 316 | y_pred = self.model(Variable(image)) 317 | if self.model_type == "mobilenetv2": 318 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True) 319 | 320 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight, 321 | weight=mask_w.cuda(self.device_idx), mode="dice", eps=1.) 322 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx))) 323 | 324 | optimizer.zero_grad() 325 | loss.backward() 326 | 327 | #adjust learning rate and weights decay 328 | for param_group in optimizer.param_groups: 329 | param_group['lr'] = current_lr 330 | if wd: 331 | for param in param_group['params']: 332 | param.data = param.data.add(-wd * param_group['lr'], param.data) 333 | 334 | optimizer.step() 335 | 336 | if it > 1: 337 | current_loss = alpha * loss.item() + (1 - alpha) * current_loss 338 | else: 339 | current_loss = loss.item() 340 | 341 | self.lr_finder_losses[0].append(current_loss) 342 | self.lr_finder_losses[1].append(current_lr) 343 | 344 | if current_loss < min_loss[0]: 345 | min_loss = (current_loss, current_lr) 346 | 347 | if it >= loss_window: 348 | if (current_loss - min_loss[0]) / min_loss[0] >= loss_growth_trsh: 349 | break 350 | 351 | self.max_lr = round(min_loss[1], 5) 352 | print(" [INFO] max. lr = %.5f " % self.max_lr) 353 | 354 | def show_lr_finder_out(self, save_only=False): 355 | if not save_only: 356 | plt.show(block=False) 357 | plt.semilogx(self.lr_finder_losses[1], self.lr_finder_losses[0]) 358 | plt.axvline(self.max_lr, c="gray") 359 | plt.savefig(self.path + '/lr_finder_out.png') 360 | 361 | def fit(self, dataset, dataset_val, **kwargs): 362 | 363 | epoch = kwargs["epoch"] 364 | learning_rate = kwargs["learning_rate"] 365 | batch_size = kwargs["batch_size"] 366 | bce_loss_weight = kwargs["bce_loss_weight"] 367 | CLR = kwargs["CLR"] 368 | wd = kwargs["wd"] 369 | reduce_lr_patience = kwargs["reduce_lr_patience"] 370 | reduce_lr_factor = kwargs["reduce_lr_factor"] 371 | max_lr_decay = kwargs["max_lr_decay"] 372 | self.freeze_encoder = kwargs["freeze_encoder"] 373 | 374 | torch.cuda.empty_cache() 375 | self.model = self.get_model(self.initial_model) 376 | 377 | if self.pretrained and self.freeze_encoder and self.model_type != "mobilenetv2": 378 | self.dfs_freeze(self.model.conv1) 379 | self.dfs_freeze(self.model.conv2) 380 | self.dfs_freeze(self.model.conv3) 381 | self.dfs_freeze(self.model.conv4) 382 | self.dfs_freeze(self.model.conv5) 383 | 384 | if self.ADAM: 385 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 386 | lr=learning_rate) 387 | else: 388 | self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), 389 | lr=learning_rate, momentum=0.9, nesterov=True) 390 | 391 | max_lr = self.max_lr 392 | iterations = len(dataset) // batch_size 393 | if abs(CLR) == 1: 394 | iterations *= epoch 395 | lr_mult = (max_lr/learning_rate)**(1/iterations) 396 | current_rate = learning_rate 397 | 398 | checkpoint_metric, checkpoint_loss, it, k, cooldown = -np.inf, np.inf, 0, 1, 0 399 | self.history = {"loss":{"train":[], "test":[]}, "metric":{"train":[], "test":[]}} 400 | 401 | for e in range(epoch): 402 | torch.cuda.empty_cache() 403 | self.model.train() 404 | 405 | if e >= 2 and self.freeze_encoder and self.model_type != "mobilenetv2": 406 | self.freeze_encoder = False 407 | self.dfs_freeze(self.model.conv1) 408 | self.dfs_freeze(self.model.conv2) 409 | self.dfs_freeze(self.model.conv3) 410 | self.dfs_freeze(self.model.conv4) 411 | self.dfs_freeze(self.model.conv5) 412 | 413 | if self.ADAM: 414 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 415 | lr=current_rate) 416 | else: 417 | self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), 418 | lr=current_rate, momentum=0.9, nesterov=True) 419 | 420 | if reduce_lr_patience and reduce_lr_factor: 421 | if not np.isinf(checkpoint_loss): 422 | if self.history["loss"]["test"][-1] >= checkpoint_loss: 423 | cooldown += 1 424 | 425 | if cooldown == reduce_lr_patience: 426 | learning_rate *= reduce_lr_factor; max_lr *= reduce_lr_factor 427 | lr_mult = (max_lr/learning_rate)**(1/iterations) 428 | cooldown = 0 429 | print(" [INFO] Learning rate has been reduced to: %.7f " % learning_rate) 430 | 431 | dataset.clear_buff() 432 | min_train_loss, train_loss, train_metric = np.inf, [], [] 433 | for image, mask, mask_w in tqdm(data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers=0)): 434 | image = image.type(torch.FloatTensor).cuda(self.device_idx) 435 | 436 | if abs(CLR): 437 | it += 1; exp = it 438 | if CLR > 0: 439 | exp = iterations*k - it + 1 440 | current_rate = learning_rate * (lr_mult**exp) 441 | 442 | if abs(CLR) > 1: 443 | if iterations*k / it == 1: 444 | it = 0; k *= abs(CLR) 445 | if max_lr_decay < 1.: 446 | max_lr *= max_lr_decay 447 | lr_mult = (max_lr/learning_rate)**(1/(iterations*k)) 448 | 449 | #re-init. optimzer to reset internal state 450 | if self.ADAM: 451 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=current_rate) 452 | else: 453 | self.optimizer = torch.optim.SGD(self.model.parameters(), 454 | lr=current_rate, momentum=0.9, nesterov=True) 455 | 456 | y_pred = self.model(Variable(image)) 457 | if self.model_type == "mobilenetv2": 458 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True) 459 | 460 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight, 461 | weight=mask_w.cuda(self.device_idx), mode="dice") 462 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx))) 463 | 464 | self.optimizer.zero_grad() 465 | loss.backward() 466 | 467 | #adjust learning rate and weights decay 468 | for param_group in self.optimizer.param_groups: 469 | try: param_group['lr'] = current_lr 470 | except: pass 471 | if wd: 472 | for param in param_group['params']: 473 | param.data = param.data.add(-wd * param_group['lr'], param.data) 474 | 475 | self.optimizer.step() 476 | if loss.item() < min_train_loss: 477 | min_train_loss = loss.item() 478 | train_loss.append(loss.item()) 479 | train_metric.append(get_metric((y_pred.cpu() > 0.).float(), mask)) 480 | 481 | del y_pred; del image; del mask_w; del mask; del loss 482 | 483 | dataset_val.clear_buff() 484 | torch.cuda.empty_cache() 485 | self.model.eval() 486 | val_loss, val_metric = [], [] 487 | for image, mask, mask_w in data.DataLoader(dataset_val, batch_size = batch_size // 2, shuffle = False, num_workers=0): 488 | image = image.cuda(self.device_idx) 489 | 490 | y_pred = self.model(Variable(image)) 491 | if self.model_type == "mobilenetv2": 492 | y_pred = nn.functional.interpolate(y_pred, scale_factor=2, mode='bilinear', align_corners=True) 493 | 494 | loss_fn = BCESoftJaccardDice(bce_weight=bce_loss_weight, 495 | weight=mask_w.cuda(self.device_idx), mode="dice", eps=1.) 496 | loss = loss_fn(y_pred, Variable(mask.cuda(self.device_idx))) 497 | 498 | val_loss.append(loss.item()) 499 | val_metric.append(get_metric((y_pred.cpu() > 0.).float(), mask)) 500 | 501 | del y_pred; del image; del mask_w; del mask; del loss 502 | 503 | train_loss, train_metric, val_loss, val_metric = \ 504 | np.mean(train_loss), np.mean(train_metric), np.mean(val_loss), np.mean(val_metric) 505 | 506 | if val_loss < checkpoint_loss: 507 | save_checkpoint(self.path+'/%s_checkpoint_loss.pth' % (self.model_name), self.model, self.optimizer) 508 | checkpoint_loss = val_loss 509 | 510 | if val_metric > checkpoint_metric: 511 | save_checkpoint(self.path+'/%s_checkpoint_metric.pth' % (self.model_name), self.model, self.optimizer) 512 | checkpoint_metric = val_metric 513 | 514 | self.history["loss"]["train"].append(train_loss) 515 | self.history["loss"]["test"].append(val_loss) 516 | self.history["metric"]["train"].append(train_metric) 517 | self.history["metric"]["test"].append(val_metric) 518 | 519 | message = "Epoch: %d, Train loss: %.3f, Train metric: %.3f, Val loss: %.3f, Val metric: %.3f" % ( 520 | e, train_loss, train_metric, val_loss, val_metric) 521 | print(message); os.system("echo " + message) 522 | 523 | self.current_epoch = e 524 | save_checkpoint(self.path+'/last_checkpoint.pth', self.model, self.optimizer) 525 | 526 | pickle.dump(self.history, open(self.path+'/history.pickle.dat', 'wb')) 527 | 528 | def plot_trainer_history(self, mode="metric", save_only=False): 529 | if not save_only: 530 | plt.show(block=False) 531 | plt.plot(self.history[mode]["train"], label="train") 532 | plt.plot(self.history[mode]["test"], label="val") 533 | plt.xlabel("epoch") 534 | plt.ylabel(mode) 535 | plt.grid(True) 536 | plt.legend(loc="best") 537 | plt.savefig(self.path + '/%s_history.png' % mode) 538 | 539 | def load_state(self, path=None, mode="metric", load_optimizer=True): 540 | if load_optimizer: load_optimizer = self.optimizer 541 | if path is None: 542 | path = self.path+'/%s_checkpoint_%s.pth' % (self.model_name, mode) 543 | load_checkpoint(path, self.model, load_optimizer, self.cpu) 544 | 545 | def predict_mask(self, imgs, biggest_side=None, denoise_borders=False): 546 | if not self.cpu: 547 | torch.cuda.empty_cache() 548 | if imgs.ndim < 4: 549 | imgs = np.expand_dims(imgs, axis=0) 550 | l, h, w, c = imgs.shape 551 | w_n, h_n = w, h 552 | if biggest_side is not None: 553 | w_n = int(w/h * min(biggest_side, h)) 554 | h_n = min(biggest_side, h) 555 | 556 | wd, hd = w_n % 32, h_n % 32 557 | if wd != 0: w_n += 32 - wd 558 | if hd != 0: h_n += 32 - hd 559 | 560 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5)) 561 | all_predictions = [] 562 | for i in range(imgs.shape[0]): 563 | img = self.norm(cv2.resize(imgs[i], (w_n, h_n), interpolation=cv2.INTER_LANCZOS4)) 564 | img = img.unsqueeze_(0) 565 | if not self.cpu: 566 | img = img.type(torch.FloatTensor).cuda(self.device_idx) 567 | else: 568 | img = img.type(torch.FloatTensor) 569 | output = self.model(Variable(img)) 570 | if self.model_type == "mobilenetv2": 571 | output = nn.functional.interpolate(output, scale_factor=2, mode='bilinear', align_corners=True) 572 | output = torch.sigmoid(output) 573 | output = output.cpu().data.numpy() 574 | y_pred = np.squeeze(output[0]) 575 | y_pred = remove_small_holes(remove_small_objects(y_pred > .3)) 576 | y_pred = (y_pred * 255).astype(np.uint8) 577 | y_pred = cv2.resize(y_pred, (w, h), interpolation=cv2.INTER_LANCZOS4) 578 | 579 | _,alpha = cv2.threshold(y_pred.astype(np.uint8),0,255,cv2.THRESH_BINARY) 580 | b, g, r = cv2.split(imgs[i]) 581 | bgra = [r,g,b, alpha] 582 | y_pred = cv2.merge(bgra,4) 583 | if denoise_borders: 584 | #denoise mask borders 585 | y_pred[:, :, -1] = cv2.morphologyEx(y_pred[:, :, -1], cv2.MORPH_OPEN, kernel) 586 | all_predictions.append(y_pred) 587 | return all_predictions 588 | 589 | def split_video(filename, frame_rate=12): 590 | vidcap = cv2.VideoCapture(filename) 591 | frames = [] 592 | succ, frame = vidcap.read() 593 | h, w = frame.shape[:2] 594 | center = (w / 2, h / 2) 595 | while succ: 596 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 597 | frame = np.transpose(frame[:, ::-1, :], axes=[1,0,2]) 598 | frames.append(frame) 599 | succ, frame = vidcap.read() 600 | return np.array(frames).astype(np.uint8)[::24 // frame_rate] 601 | 602 | def factorial(n): 603 | if n == 0: 604 | return 1 605 | else: 606 | return n * factorial(n-1) 607 | 608 | def n_unique_permuts(n, r): 609 | return factorial(n) / (factorial(r)*factorial(n-r)) 610 | 611 | def save_images(out, path="./data/gif_test"): 612 | letters = string.ascii_lowercase 613 | r = 0; n_uniques = 0 614 | while n_uniques < len(out): 615 | r += 1 616 | n_uniques = n_unique_permuts(len(letters), r) 617 | names = list(itertools.combinations(letters, r)) 618 | for im, fname in zip(out, names[:len(out)]): 619 | cv2.imwrite(path+"/%s.png" % ("".join(fname)), im) --------------------------------------------------------------------------------