├── LICENSE ├── README.md ├── arch_resnet38.py ├── base_class.py ├── data ├── trainaug_id.txt └── trainaug_labels.txt ├── figure └── ssdd_module.png ├── imutils.py ├── main_ssdd.py ├── network.py ├── precompute_sssdd.py ├── prepare_labels ├── README.md ├── make_aff_labels.py ├── make_cam_labels.py ├── network │ ├── __pycache__ │ │ ├── resnet38_aff.cpython-36.pyc │ │ ├── resnet38_cls.cpython-36.pyc │ │ ├── resnet38d.cpython-36.pyc │ │ ├── vgg16_aff.cpython-36.pyc │ │ ├── vgg16_cls.cpython-36.pyc │ │ └── vgg16d.cpython-36.pyc │ ├── resnet38_aff.py │ ├── resnet38_cls.py │ ├── resnet38d.py │ ├── vgg16_20M.prototxt │ ├── vgg16_aff.py │ ├── vgg16_cls.py │ └── vgg16d.py ├── tool │ ├── imutils.py │ ├── pyutils.py │ └── torchutils.py └── voc12 │ ├── __pycache__ │ └── data.cpython-36.pyc │ ├── cls_labels.npy │ ├── data.py │ ├── make_cls_labels.py │ ├── test.txt │ ├── train.txt │ ├── train_aug.txt │ └── val.txt ├── pretrained_models └── tmp.txt ├── script ├── dssdd.html ├── gen_html.py ├── gen_html_dssdd.py ├── gen_html_val.py └── val.html ├── ssdd_function.py ├── ssdd_val.py ├── train_dssdd.py ├── train_sssdd.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 shimoda-uec 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSDD:Self-Supervised Difference Detection 2 | By Watal Shimoda and Keiji Yanai. 3 | [paper](https://arxiv.org/abs/1911.01370) 4 | 5 | ## Description 6 | This repository contains the codes for "Self-Supervised Difference Detection for Weakly Supervised Segmentation". 7 | It has been published at ICCV2019. 8 | 9 | We define an inputs of the pair of segmentation labels as Knowledge and Advice. 10 | The proposed method integrates the pair of segmentation labels by Self-Supervised Difference Detection(SSDD) module. 11 | In the paper, first, we integrate the segmentation labels of Pixel-level Semantic Affinity(PSA) and CRF applied segmentation masks by 12 | considering the labels as Knowledge and its CRF results as Advice. 13 | We denote this approach as static SSDD module. 14 | Furthermore, we develop the labels obtained in the previous step using two SSDD modules, 15 | and we train a segmentation model with the modules in an end-to-end manner. 16 | We denote this approach as dynamic SSDD module. 17 | In this dynamic module, we intend to adapt the proposed method to an iterative training approach proposed by Wei et al. : [arxiv](https://arxiv.org/abs/1509.03150). 18 | 19 | 20 | 21 | ## Visualization 22 | We provide the progress of training and inference on the validation set re-produced with this repository. 23 | The progress of training for dynamic ssdd module: [html](http://mm.cs.uec.ac.jp/shimoda-k/space0/wseg/ssdd/git/ssdd/script/dssdd.html). 24 | The inference on the validation: [html](http://mm.cs.uec.ac.jp/shimoda-k/space0/wseg/ssdd/git/ssdd/script/val.html). 25 | (65.4 in validation. There is only difference in the option of interpolation between this repository and conference version.) 26 | We also provide trained models, pre-computed results and evaluation results: [run.zip](http://mm.cs.uec.ac.jp/shimoda-k/space0/wseg/ssdd/git/ssdd/run.zip). 27 | 28 | ## Requirements 29 | Python 3.5 , Pytorch >= 0.4.1, [Pydensecrf](https://github.com/lucasb-eyer/pydensecrf) 30 | 31 | ## Advance preparation 32 | We assume that the root directory of Pascal VOC is located in the same directory with this name: "voc_root". 33 | So set a symblic link for the root directory. 34 | ``` 35 | ln -s "your_voc_root" voc_root 36 | ``` 37 | 38 | To train both of static ssdd module and dynamic ssdd module, seed labels are required. 39 | Please check this directory: [preparing_labels](https://github.com/shimoda-uec/ssdd/tree/master/prepare_labels). 40 | 41 | ## Usage 42 | First, train static SSDD module by following codes. (around half day) 43 | ``` 44 | python main_ssdd.py --mode=0 45 | ``` 46 | 47 | Second, compute the probability maps of the difference detection in advance. (around 1 hour) 48 | ``` 49 | python main_ssdd.py --mode=1 50 | ``` 51 | 52 | Third, train dynamic SSDD module. (around one day) 53 | ``` 54 | python main_ssdd.py --mode=2 55 | ``` 56 | 57 | After the training of the dynamic SSDD module, you can test your trained model. 58 | ``` 59 | python main_ssdd.py --mode=3 60 | ``` 61 | 62 | ## License and Citation 63 | Please cite our paper if it helps your research: 64 | ``` 65 | @inproceedings{shimodaICCV19, 66 | Author = {Wataru Shimoda and Keiji Yanai}, 67 | Title = {Self-Supervised Difference Detection for Weakly-Supervised Segmentation}, 68 | Booktitle = {Proceedings of the IEEE International Conference on Computer Vision}, 69 | Year = {2019} 70 | } 71 | ``` 72 | 73 | ## Acknowledgment 74 | Many codes of this repository have been derived from [PSA](https://github.com/jiwoon-ahn/psa). 75 | -------------------------------------------------------------------------------- /arch_resnet38.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | import torch.nn.functional as F 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1): 9 | super(ResBlock, self).__init__() 10 | 11 | self.same_shape = (in_channels == out_channels and stride == 1) 12 | 13 | if first_dilation == None: first_dilation = dilation 14 | 15 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 16 | 17 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride, 18 | padding=first_dilation, dilation=first_dilation, bias=False) 19 | 20 | self.bn_branch2b1 = nn.BatchNorm2d(mid_channels) 21 | 22 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False) 23 | 24 | if not self.same_shape: 25 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 26 | 27 | def forward(self, x, get_x_bn_relu=False): 28 | 29 | branch2 = self.bn_branch2a(x) 30 | branch2 = F.relu(branch2) 31 | 32 | x_bn_relu = branch2 33 | 34 | if not self.same_shape: 35 | branch1 = self.conv_branch1(branch2) 36 | else: 37 | branch1 = x 38 | 39 | branch2 = self.conv_branch2a(branch2) 40 | branch2 = self.bn_branch2b1(branch2) 41 | branch2 = F.relu(branch2) 42 | branch2 = self.conv_branch2b1(branch2) 43 | 44 | x = branch1 + branch2 45 | 46 | if get_x_bn_relu: 47 | return x, x_bn_relu 48 | 49 | return x 50 | 51 | def __call__(self, x, get_x_bn_relu=False): 52 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 53 | 54 | class ResBlock_bot(nn.Module): 55 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0.): 56 | super(ResBlock_bot, self).__init__() 57 | 58 | self.same_shape = (in_channels == out_channels and stride == 1) 59 | 60 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 61 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False) 62 | 63 | self.bn_branch2b1 = nn.BatchNorm2d(out_channels//4) 64 | self.dropout_2b1 = torch.nn.Dropout2d(dropout) 65 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False) 66 | 67 | self.bn_branch2b2 = nn.BatchNorm2d(out_channels//2) 68 | self.dropout_2b2 = torch.nn.Dropout2d(dropout) 69 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False) 70 | 71 | if not self.same_shape: 72 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 73 | 74 | def forward(self, x, get_x_bn_relu=False): 75 | 76 | branch2 = self.bn_branch2a(x) 77 | branch2 = F.relu(branch2) 78 | x_bn_relu = branch2 79 | 80 | branch1 = self.conv_branch1(branch2) 81 | 82 | branch2 = self.conv_branch2a(branch2) 83 | 84 | branch2 = self.bn_branch2b1(branch2) 85 | branch2 = F.relu(branch2) 86 | branch2 = self.dropout_2b1(branch2) 87 | branch2 = self.conv_branch2b1(branch2) 88 | 89 | branch2 = self.bn_branch2b2(branch2) 90 | branch2 = F.relu(branch2) 91 | branch2 = self.dropout_2b2(branch2) 92 | branch2 = self.conv_branch2b2(branch2) 93 | 94 | x = branch1 + branch2 95 | 96 | if get_x_bn_relu: 97 | return x, x_bn_relu 98 | 99 | return x 100 | 101 | def __call__(self, x, get_x_bn_relu=False): 102 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 103 | 104 | class Normalize(): 105 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 106 | 107 | self.mean = mean 108 | self.std = std 109 | 110 | def __call__(self, img): 111 | imgarr = np.asarray(img) 112 | proc_img = np.empty_like(imgarr, np.float32) 113 | 114 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 115 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 116 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 117 | 118 | return proc_img 119 | 120 | class Resnet38(nn.Module): 121 | def __init__(self, first_pool=False): 122 | super(Resnet38, self).__init__() 123 | 124 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False) 125 | self.firtst_pool=first_pool 126 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 127 | 128 | self.b2 = ResBlock(64, 128, 128, stride=2) 129 | self.b2_1 = ResBlock(128, 128, 128) 130 | self.b2_2 = ResBlock(128, 128, 128) 131 | 132 | self.b3 = ResBlock(128, 256, 256, stride=2) 133 | self.b3_1 = ResBlock(256, 256, 256) 134 | self.b3_2 = ResBlock(256, 256, 256) 135 | 136 | self.b4 = ResBlock(256, 512, 512, stride=2) 137 | self.b4_1 = ResBlock(512, 512, 512) 138 | self.b4_2 = ResBlock(512, 512, 512) 139 | self.b4_3 = ResBlock(512, 512, 512) 140 | self.b4_4 = ResBlock(512, 512, 512) 141 | self.b4_5 = ResBlock(512, 512, 512) 142 | 143 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2) 144 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2) 145 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2) 146 | 147 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3) 148 | 149 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5) 150 | 151 | self.bn7 = nn.BatchNorm2d(4096) 152 | 153 | self.not_training = [self.conv1a] 154 | 155 | self.normalize = Normalize() 156 | 157 | def forward(self, x): 158 | x = self.conv1a(x) 159 | if self.firtst_pool==True: 160 | x = self.pool(x) 161 | x = self.b2(x) 162 | x = self.b2_1(x) 163 | x = self.b2_2(x) 164 | 165 | x, conv2 = self.b3(x, get_x_bn_relu=True) 166 | x = self.b3_1(x) 167 | x = self.b3_2(x) 168 | 169 | x, conv3 = self.b4(x, get_x_bn_relu=True) 170 | x = self.b4_1(x) 171 | x = self.b4_2(x) 172 | x = self.b4_3(x) 173 | x = self.b4_4(x) 174 | conv4 = self.b4_5(x) 175 | 176 | x, _ = self.b5(conv4, get_x_bn_relu=True) 177 | x = self.b5_1(x) 178 | conv5 = self.b5_2(x) 179 | 180 | x, _ = self.b6(conv5, get_x_bn_relu=True) 181 | 182 | x = self.b7(x) 183 | conv6 = F.relu(self.bn7(x)) 184 | 185 | return (conv2, conv3, conv4, conv5, conv6) 186 | 187 | def conv5(self, x): 188 | x, _ = self.b5(x, get_x_bn_relu=True) 189 | x = self.b5_1(x) 190 | conv5 = self.b5_2(x) 191 | 192 | x, _ = self.b6(conv5, get_x_bn_relu=True) 193 | 194 | x = self.b7(x) 195 | conv6 = F.relu(self.bn7(x)) 196 | return conv6 197 | 198 | def train(self, mode=True): 199 | super().train(mode) 200 | for layer in self.not_training: 201 | if isinstance(layer, torch.nn.Conv2d): 202 | layer.weight.requires_grad = False 203 | elif isinstance(layer, torch.nn.Module): 204 | for c in layer.children(): 205 | c.weight.requires_grad = False 206 | if c.bias is not None: 207 | c.bias.requires_grad = False 208 | return 209 | 210 | def convert_mxnet_to_torch(filename): 211 | import mxnet 212 | 213 | save_dict = mxnet.nd.load(filename) 214 | 215 | renamed_dict = dict() 216 | 217 | bn_param_mx_pt = {'beta': 'bias', 'gamma': 'weight', 'mean': 'running_mean', 'var': 'running_var'} 218 | 219 | for k, v in save_dict.items(): 220 | 221 | v = torch.from_numpy(v.asnumpy()) 222 | toks = k.split('_') 223 | 224 | if 'conv1a' in toks[0]: 225 | renamed_dict['conv1a.weight'] = v 226 | 227 | elif 'linear1000' in toks[0]: 228 | pass 229 | 230 | elif 'branch' in toks[1]: 231 | 232 | pt_name = [] 233 | 234 | if toks[0][-1] != 'a': 235 | pt_name.append('b' + toks[0][-3] + '_' + toks[0][-1]) 236 | else: 237 | pt_name.append('b' + toks[0][-2]) 238 | 239 | if 'res' in toks[0]: 240 | layer_type = 'conv' 241 | last_name = 'weight' 242 | 243 | else: # 'bn' in toks[0]: 244 | layer_type = 'bn' 245 | last_name = bn_param_mx_pt[toks[-1]] 246 | 247 | pt_name.append(layer_type + '_' + toks[1]) 248 | 249 | pt_name.append(last_name) 250 | 251 | torch_name = '.'.join(pt_name) 252 | renamed_dict[torch_name] = v 253 | 254 | else: 255 | last_name = bn_param_mx_pt[toks[-1]] 256 | renamed_dict['bn7.' + last_name] = v 257 | 258 | return renamed_dict 259 | 260 | -------------------------------------------------------------------------------- /base_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from arch_resnet38 import Resnet38 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | import numpy as np 7 | import imutils 8 | import os 9 | import re 10 | 11 | class BaseModel(nn.Module): 12 | def initialize_weights(self): 13 | for m in self.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | nn.init.kaiming_normal(m.weight) 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.BatchNorm2d): 19 | m.weight.data.fill_(1) 20 | m.bias.data.zero_() 21 | def load_resnet38_weights(self, filepath): 22 | print(filepath, os.path.exists(filepath)) 23 | if os.path.exists(filepath): 24 | state_dict = torch.load(filepath) 25 | new_params = self.state_dict().copy() 26 | for i in new_params: 27 | i_parts = i.split('.') 28 | for i in state_dict: 29 | i_parts = i.split('.') 30 | if re.fullmatch('(fc8)', i_parts[0]): 31 | pass 32 | else: 33 | tmp=i_parts.copy() 34 | tmp.insert(0,'encoder') 35 | tmp='.'.join(tmp) 36 | new_params[tmp] = state_dict[i] 37 | self.load_state_dict(new_params) 38 | def train(self, mode=True): 39 | super().train(mode) 40 | for layer in self.modules(): 41 | if isinstance(layer, torch.nn.BatchNorm2d): 42 | layer.eval() 43 | layer.bias.requires_grad = False 44 | layer.weight.requires_grad = False 45 | 46 | class SegBaseModel(BaseModel): 47 | def __init__(self, config): 48 | super(SegBaseModel, self).__init__() 49 | self.config = config 50 | self.encoder=Resnet38() 51 | def get_crf(self, img_org, seg, gt_class_mlabel): 52 | img_org=img_org.data.cpu().numpy().astype(np.uint8) 53 | seg_crf=np.zeros((seg.shape[0],seg.shape[1],self.config.OUT_SHAPE[0],self.config.OUT_SHAPE[1])) 54 | for i in range(len(seg)): 55 | prob=[] 56 | for j in range(gt_class_mlabel.shape[1]): 57 | if gt_class_mlabel[i,j].item()==1: 58 | prob.append(seg[i,j:j+1]) 59 | prob=F.softmax(torch.cat(prob),dim=0).data.cpu().numpy() 60 | crf_map = imutils.crf_inference(img_org[i].copy(order='C'),prob,labels=prob.shape[0]) 61 | cnt=0 62 | for j in range(gt_class_mlabel.shape[1]): 63 | if gt_class_mlabel[i,j].item()==1: 64 | seg_crf[i][j]=crf_map[cnt] 65 | cnt += 1 66 | seg_crf=torch.from_numpy(seg_crf).cuda().float() 67 | _, seg_crf_mask=torch.max(seg_crf,1) 68 | return seg_crf, seg_crf_mask 69 | 70 | def get_seg(self, segment_module, x5, gt_class_mlabel): 71 | seg, seg_head = segment_module(x5) 72 | seg_prob=F.softmax(seg,dim=1) 73 | gt_class_mlabel_maps = gt_class_mlabel.view(gt_class_mlabel.shape[0],gt_class_mlabel.shape[1],1,1).repeat(1,1,seg.shape[2],seg.shape[3]) 74 | seg_prob=seg_prob*gt_class_mlabel_maps+gt_class_mlabel_maps*1e-4 75 | _,seg_mask = torch.max(seg_prob,1) 76 | return (seg, seg_prob, seg_mask, seg_head) 77 | 78 | class SSDDBaseModel(BaseModel): 79 | def __init__(self, config): 80 | super(SSDDBaseModel, self).__init__() 81 | self.config = config 82 | 83 | class PascalDataset(torch.utils.data.Dataset): 84 | def __init__(self, dataset, config): 85 | self.image_ids = np.copy(dataset.image_ids) 86 | self.dataset = dataset 87 | self.config = config 88 | self.mean=(0.485, 0.456, 0.406) 89 | self.std=(0.229, 0.224, 0.225) 90 | self.joint_transform_list=[ 91 | None, 92 | imutils.RandomHorizontalFlip(), 93 | imutils.RandomResizeLong(512, 832), 94 | imutils.RandomCrop(448), 95 | None, 96 | ] 97 | self.img_transform_list=[ 98 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 99 | np.asarray, 100 | None, 101 | imutils.Normalize(mean = self.mean, std = self.std), 102 | imutils.HWC_to_CHW 103 | ] 104 | def img_label_resize(self, inputs): 105 | for joint_transform, img_transform in zip(self.joint_transform_list, self.img_transform_list): 106 | img_norm = inputs[0] 107 | if img_transform: 108 | img_norm = img_transform(img_norm) 109 | inputs[0]=img_norm 110 | if joint_transform: 111 | outputs = joint_transform(inputs) 112 | inputs=outputs 113 | return inputs 114 | def get_prob_label(self, prob, mlabel): 115 | # prob shape [HxWxC] 116 | # mlabel shape [C] 117 | prob_label=np.zeros((prob.shape[0],prob.shape[1],mlabel.shape[0])) 118 | cnt=0 119 | for i in range(0,mlabel.shape[0]): 120 | if mlabel[i]==1: 121 | prob_label[:,:,i]=prob[:,:,cnt] 122 | cnt+=1 123 | return prob_label 124 | def __len__(self): 125 | return self.image_ids.shape[0] 126 | 127 | -------------------------------------------------------------------------------- /figure/ssdd_module.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/figure/ssdd_module.png -------------------------------------------------------------------------------- /imutils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import random 3 | import numpy as np 4 | import cv2 5 | 6 | class RandomHorizontalFlip(): 7 | def __init__(self): 8 | return 9 | 10 | def __call__(self, inputs): 11 | if bool(random.getrandbits(1)): 12 | outputs=[] 13 | for inp in inputs: 14 | out = np.fliplr(inp).copy() 15 | outputs.append(out) 16 | return outputs 17 | else: 18 | return inputs 19 | 20 | 21 | class RandomResizeLong(): 22 | def __init__(self, min_long, max_long): 23 | self.min_long = min_long 24 | self.max_long = max_long 25 | def __call__(self, inputs): 26 | img=inputs[0] 27 | target_long = random.randint(self.min_long, self.max_long) 28 | #w, h = img.size 29 | h, w, c = img.shape 30 | target_shape = (target_long, target_long) 31 | """ 32 | if w > h: 33 | target_shape = (int(round(w * target_long / h)), target_long) 34 | else: 35 | target_shape = (target_long, int(round(h * target_long / w))) 36 | """ 37 | outputs=[] 38 | for inp in inputs: 39 | out = cv2.resize(inp, target_shape) 40 | if len(out.shape)==2: 41 | out=np.expand_dims(out,2) 42 | outputs.append(out) 43 | return outputs 44 | 45 | class RandomCrop(): 46 | def __init__(self, cropsize): 47 | self.cropsize = cropsize 48 | def __call__(self, inputs): 49 | imgarr = np.concatenate(inputs, axis=-1) 50 | h, w, c = imgarr.shape 51 | ch = min(self.cropsize, h) 52 | cw = min(self.cropsize, w) 53 | w_space = w - self.cropsize 54 | h_space = h - self.cropsize 55 | if w_space > 0: 56 | cont_left = 0 57 | img_left = random.randrange(w_space+1) 58 | else: 59 | cont_left = random.randrange(-w_space+1) 60 | img_left = 0 61 | if h_space > 0: 62 | cont_top = 0 63 | img_top = random.randrange(h_space+1) 64 | else: 65 | cont_top = random.randrange(-h_space+1) 66 | img_top = 0 67 | 68 | outputs=[] 69 | for inp in inputs: 70 | container = np.zeros((self.cropsize, self.cropsize, inp.shape[-1]), np.float32) 71 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 72 | inp[img_top:img_top+ch, img_left:img_left+cw] 73 | outputs.append(container) 74 | return outputs 75 | 76 | class Normalize(): 77 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 78 | self.mean = mean 79 | self.std = std 80 | def __call__(self, img): 81 | imgarr = np.asarray(img) 82 | proc_img = np.empty_like(imgarr, np.float32) 83 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 84 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 85 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 86 | return proc_img 87 | 88 | 89 | def HWC_to_CHW(img): 90 | return np.transpose(img, (2, 0, 1)) 91 | 92 | class Rescale(): 93 | def __init__(self, scale): 94 | self.scale=scale 95 | def __call__(self, inputs): 96 | outputs=[] 97 | for inp in inputs: 98 | out = cv2.resize(inp, self.scale) 99 | if len(out.shape)==2: 100 | out=np.expand_dims(out,2) 101 | outputs.append(out) 102 | return outputs 103 | 104 | 105 | def crf_inference(img, probs, t=3, scale_factor=1, labels=21): 106 | import pydensecrf.densecrf as dcrf 107 | from pydensecrf.utils import unary_from_softmax 108 | h, w = img.shape[:2] 109 | n_labels = labels 110 | d = dcrf.DenseCRF2D(w, h, n_labels) 111 | unary = unary_from_softmax(probs) 112 | unary = np.ascontiguousarray(unary) 113 | d.setUnaryEnergy(unary) 114 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 115 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 116 | Q = d.inference(t) 117 | return np.array(Q).reshape((n_labels, h, w)) -------------------------------------------------------------------------------- /main_ssdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch 6 | import ssdd_val as val 7 | #import ssdd_test as test 8 | import train_dssdd 9 | import train_sssdd 10 | import precompute_sssdd 11 | ROOT_DIR = os.getcwd() 12 | #VOC_ROOT = os.environ['voc_root'] 13 | VOC_ROOT = 'voc_root' 14 | DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs") 15 | 16 | class Config(): 17 | OUT_SHAPE = (112,112) 18 | INP_SHAPE = (448,448) 19 | LEARNING_MOMENTUM = 0.9 20 | WEIGHT_DECAY = 2e-4 21 | NUM_CLASSES = 21 22 | LEARNING_RATE=1e-3 23 | 24 | ############################################################ 25 | # Dataset 26 | ############################################################ 27 | 28 | class PascalDataset(): 29 | def load(self): 30 | image_dir = VOC_ROOT +'/JPEGImages' 31 | fn='data/trainaug_id.txt' 32 | f = open(fn,'r') 33 | image_ids = f.read().splitlines() 34 | f.close() 35 | self.image_ids=image_ids 36 | label_listn='data/trainaug_labels.txt' 37 | label_list=np.loadtxt(label_listn) 38 | label_dic={} 39 | for i in range(len(image_ids)): 40 | label=label_list[i] 41 | label_dic[image_ids[i]]=label_list[i] 42 | self.label_dic=label_dic 43 | def load_val(self): 44 | image_dir = VOC_ROOT +'/JPEGImages' 45 | fn= VOC_ROOT +'/ImageSets/Segmentation/val.txt' 46 | f = open(fn,'r'); image_ids = f.read().splitlines(); f.close() 47 | self.image_ids=image_ids 48 | def load_test(self): 49 | image_dir = VOC_ROOT +'/JPEGImages' 50 | fn=VOC_ROOT +'/ImageSets/Segmentation/test.txt' 51 | f = open(fn,'r');image_ids = f.read().splitlines(); f.close() 52 | self.image_ids=image_ids 53 | 54 | 55 | ############################################################ 56 | # main 57 | ############################################################ 58 | 59 | 60 | if __name__ == '__main__': 61 | import argparse 62 | # Parse command line arguments 63 | parser = argparse.ArgumentParser(description='Train') 64 | parser.add_argument('--mode', required=True, 65 | default=0, 66 | metavar="<0-3>", 67 | help='mode', 68 | type=int) 69 | parser.add_argument('--bn', required=False, 70 | default=2, 71 | metavar="", 72 | type=int) 73 | parser.add_argument('--modelid', required=False, 74 | default='default', 75 | metavar="", 76 | help='An id for saving and loading ', 77 | type=str) 78 | args = parser.parse_args() 79 | 80 | def create_model(config, modellib, modeln, weight_file=None): 81 | model_factory = modellib.__dict__[modeln] 82 | model_params = dict(config=config, weight_file=weight_file) 83 | model = model_factory(**model_params) 84 | return model 85 | 86 | config = Config() 87 | config.VOC_ROOT=VOC_ROOT 88 | runner_name = os.path.basename(__file__).split(".")[0] 89 | if args.mode==0: 90 | print("Train the ssdd module for the difference between PSA and PSA with CRF") 91 | dataset_train=PascalDataset() 92 | dataset_train.load() 93 | weight_file='pretrained_models/res38_cls.pth' 94 | models=create_model(config, train_sssdd, 'models', weight_file) 95 | model_trainer=train_sssdd.Trainer(config=config, model_dir=DEFAULT_LOGS_DIR, model=models) 96 | model_trainer.config.BATCH=torch.cuda.device_count()*args.bn 97 | model_trainer.config.EPOCHS=16 98 | model_trainer.config.modelid=args.modelid 99 | model_trainer.set_log_dir('sssdd', args.modelid) 100 | model_trainer.train_model( 101 | dataset_train, 102 | ) 103 | elif args.mode==1: 104 | print("Precompute the prediction of the difference between PSA and PSA with CRF") 105 | dataset_train=PascalDataset() 106 | dataset_train.load() 107 | weight_file_seg='./logs/sssdd_default/models/seg_0010.pth' 108 | weight_file_ssdd='./logs/sssdd_default/models/ssdd_0010.pth' 109 | #weight_file_seg='sssdd_seg.pth' 110 | #weight_file_ssdd='sssdd_ssdd.pth' 111 | models=create_model(config, precompute_sssdd, 'models') 112 | model_precompute=precompute_sssdd.Precompute(config=config, model_dir=DEFAULT_LOGS_DIR, model=models, weight_files=(weight_file_seg, weight_file_ssdd)) 113 | model_precompute.config.BATCH=torch.cuda.device_count()*args.bn 114 | model_precompute.config.modelid=args.modelid 115 | model_precompute.set_log_dir('precompute', args.modelid) 116 | model_precompute.precompute_model( 117 | dataset_train, 118 | ) 119 | elif args.mode==2: 120 | print("Train the two ssdd modules and the segmentation model") 121 | dataset_train=PascalDataset() 122 | dataset_train.load() 123 | weight_file='pretrained_models/res38_cls.pth' 124 | models=create_model(config, train_dssdd, 'models', weight_file) 125 | config.BATCH=torch.cuda.device_count()*args.bn 126 | config.EPOCHS=41 127 | config.modelid=args.modelid 128 | model_trainer=train_dssdd.Trainer(config=config, model_dir=DEFAULT_LOGS_DIR, model=models) 129 | model_trainer.set_log_dir('dssdd', args.modelid) 130 | model_trainer.train_model( 131 | dataset_train, 132 | ) 133 | elif args.mode==3: 134 | print("Validation") 135 | dataset_val=PascalDataset() 136 | dataset_val.load_val() 137 | #weight_file='./segmodel_64pt9_val.pth' 138 | #weight_file='./logs/dssdd_default/models/seg_0030.pth' 139 | weight_file='dssdd_seg.pth' 140 | model=create_model(config, val, 'val') 141 | model=nn.DataParallel(model).cuda() 142 | state_dict = torch.load(weight_file) 143 | model.load_state_dict(state_dict,strict=False) 144 | model_evaluator=val.Evaluator(config=config, model=model) 145 | model_evaluator.config.BATCH=torch.cuda.device_count()*args.bn 146 | model_evaluator.config.modelid=args.modelid 147 | model_evaluator.set_log_dir('val', args.modelid) 148 | model_evaluator.eval_model( 149 | dataset_val, 150 | ) 151 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import init 7 | import cv2 8 | import numpy as np 9 | import time 10 | from torch.autograd import Variable 11 | 12 | class Conv2dbnPR(nn.Module): 13 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 14 | super(Conv2dbnPR,self).__init__() 15 | 16 | self.rpad = nn.ReflectionPad2d(padding) 17 | self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size,stride,0,dilation, bias=bias) 18 | self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01) 19 | def forward(self,x): 20 | x = self.rpad(x) 21 | x = self.conv1(x) 22 | x = self.bn1(x) 23 | x = F.relu(x) 24 | return x 25 | 26 | class Conv2dbn(nn.Module): 27 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 28 | super(Conv2dbn,self).__init__() 29 | 30 | self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation, bias=bias) 31 | self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01) 32 | def forward(self,x): 33 | x = self.conv1(x) 34 | x = self.bn1(x) 35 | x = F.relu(x) 36 | return x 37 | 38 | class Bottleneck(nn.Module): 39 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=False): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 42 | self.bn1 = nn.BatchNorm2d(out_planes, eps=0.001, momentum=0.01) 43 | self.rpad = nn.ReflectionPad2d(padding) 44 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=bias) 45 | self.bn2 = nn.BatchNorm2d(out_planes, eps=0.001, momentum=0.01) 46 | self.conv3 = nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 47 | self.bn3 = nn.BatchNorm2d(out_planes, eps=0.001, momentum=0.01) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.downsample = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 50 | self.bnd = nn.BatchNorm2d(out_planes, eps=0.001, momentum=0.01) 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | identity = x 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | out = self.rpad(out) 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | out = self.conv3(out) 63 | out = self.bn3(out) 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | identity = self.bnd(identity) 67 | out += identity 68 | out = self.relu(out) 69 | return out 70 | 71 | class PredictDiffHead(nn.Module): 72 | def __init__(self, config,cln=21, in_channel=256, dr_rate_a=0.5, in_channel2=128): 73 | super(PredictDiffHead, self).__init__() 74 | self.config=config 75 | chn=256 76 | self.conv1ab = Conv2dbnPR(in_channel2 + in_channel, chn, kernel_size=3, stride=1, padding=1) 77 | 78 | def forward(self, inputs): 79 | xa_in, xb_in = inputs 80 | xab=torch.cat((xa_in,xb_in),dim=1) 81 | xab=self.conv1ab(xab) 82 | return xab 83 | 84 | 85 | class PredictDiff(nn.Module): 86 | def __init__(self, config, cln=21, in_channel=256, in_channel2=128, dr_rate_d=0.5): 87 | super(PredictDiff, self).__init__() 88 | self.config=config 89 | chn=256 90 | self.conv1c = Conv2dbnPR(cln, chn, kernel_size=1, stride=1, padding=0) 91 | self.conv1abc = Bottleneck(chn*2,chn,kernel_size=3,padding=1) 92 | self.pred_abc = nn.Conv2d(chn,1,kernel_size=1,stride=1,padding=0,bias=False) 93 | 94 | def forward(self, inputs): 95 | xab, xc_in =inputs 96 | xc=self.conv1c(xc_in) 97 | xabc=torch.cat((xab,xc),dim=1) 98 | xabc=self.conv1abc(xabc) 99 | xabc=self.pred_abc(xabc) 100 | return xabc 101 | 102 | 103 | class SegmentationPsa(nn.Module): 104 | def __init__(self, config, num_classes, in_channel=4096, middle_channel=512, scale=8): 105 | super(SegmentationPsa, self).__init__() 106 | self.config=config 107 | self.seg1 = Conv2dbnPR(in_channel,middle_channel,3,1,padding=12, dilation=12, bias=True) 108 | self.rpad = nn.ReflectionPad2d(12) 109 | self.seg2 = nn.Conv2d(middle_channel,21,kernel_size=3,stride=1,padding=0,dilation=12, bias=True) 110 | self.upsample = nn.Upsample(scale_factor=scale, mode='bilinear') 111 | def forward(self, inputs): 112 | x=inputs 113 | seg_head=self.seg1(x) 114 | x=self.rpad(seg_head) 115 | x=self.seg2(x) 116 | seg_head=self.upsample(seg_head) 117 | x=self.upsample(x) 118 | return x, seg_head 119 | -------------------------------------------------------------------------------- /precompute_sssdd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import re 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from torch.autograd import Variable 13 | from torchvision import transforms 14 | import imutils 15 | import utils 16 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset 17 | import ssdd_function as ssddF 18 | import time 19 | from PIL import Image 20 | from network import SegmentationPsa, PredictDiff, PredictDiffHead 21 | import math 22 | import cv2 23 | cv2.setNumThreads(0) 24 | 25 | ############################################################ 26 | # dataset 27 | ############################################################ 28 | 29 | class SssddData(PascalDataset): 30 | def __init__(self, dataset, config): 31 | super().__init__(dataset, config) 32 | self.label_dic = dataset.label_dic 33 | self.joint_transform_list=[ 34 | None, 35 | imutils.RandomHorizontalFlip(), 36 | imutils.RandomResizeLong(448, 448), 37 | imutils.RandomCrop(448), 38 | None, 39 | ] 40 | def __getitem__(self, image_index): 41 | image_id = self.image_ids[image_index] 42 | impath = self.config.VOC_ROOT+'/JPEGImages/' 43 | imn = impath+image_id+'.jpg' 44 | img = Image.open(imn).convert("RGB") 45 | gt_class_mlabel = torch.from_numpy(self.label_dic[image_id]) 46 | gt_class_mlabel_bg = torch.from_numpy(np.concatenate(([1],self.label_dic[image_id]))) 47 | psan = 'prepare_labels/results/out_aff/'+image_id+'.npy' 48 | psa=np.array(list(np.load(psan).item().values())).transpose(1,2,0) 49 | psan = 'prepare_labels/results/out_aff_crf/'+image_id+'.npy' 50 | psa_crf=np.load(psan).transpose(1,2,0) 51 | 52 | h=psa.shape[0] 53 | w=psa.shape[1] 54 | img_norm, img_org, psa, psa_crf = self.img_label_resize([img, np.array(img), psa, psa_crf]) 55 | img_org = cv2.resize(img_org,self.config.OUT_SHAPE) 56 | psa = cv2.resize(psa,self.config.OUT_SHAPE) 57 | psa_crf = cv2.resize(psa_crf,self.config.OUT_SHAPE) 58 | psa=self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2,0,1) 59 | psa_crf=self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2,0,1) 60 | psa_mask = np.argmax(psa,0) 61 | psa_crf_mask = np.argmax(psa_crf,0) 62 | return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask 63 | def __len__(self): 64 | return self.image_ids.shape[0] 65 | 66 | ############################################################ 67 | # Models 68 | ############################################################ 69 | class SegModel(SegBaseModel): 70 | def __init__(self, config): 71 | super(SegModel, self).__init__(config) 72 | self.config = config 73 | in_channel=4096 74 | self.seg_main = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2) 75 | def set_bn_fix(m): 76 | classname = m.__class__.__name__ 77 | if classname.find('BatchNorm') != -1: 78 | for p in m.parameters(): p.requires_grad = False 79 | self.apply(set_bn_fix) 80 | def forward(self, inputs): 81 | x, img_org, gt_class_mlabel = inputs 82 | feats = self.encoder(x) 83 | [x1,x2,x3,x4,x5] = feats 84 | seg_outs_main = self.get_seg(self.seg_main, x5, gt_class_mlabel) 85 | return seg_outs_main, feats 86 | 87 | class SSDDModel(SSDDBaseModel): 88 | def __init__(self, config): 89 | super(SSDDModel, self).__init__(config) 90 | self.dd_head0 = PredictDiffHead(config, in_channel=512, in_channel2=128) 91 | self.dd0 = PredictDiff(config, in_channel=256, in_channel2=128) 92 | def forward(self, inputs): 93 | (seg_outs_main, feats), psa_mask, psa_crf_mask, gt_class_mlabel = inputs 94 | [x1,x2,x3,x4,x5] = feats 95 | x1=F.avg_pool2d(x1, 2, 2) 96 | # first step 97 | seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main 98 | ignore_flags0=torch.from_numpy(ssddF.get_ignore_flags(psa_mask, psa_crf_mask, gt_class_mlabel)).cuda().float() 99 | dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach())) 100 | dd00 = ssddF.get_dd(self.dd0, dd_head0, psa_mask) 101 | dd01 = ssddF.get_dd(self.dd0, dd_head0, psa_crf_mask) 102 | dd_outs0 = ssddF.get_dd_mask(dd00, dd01, psa_mask, psa_crf_mask, ignore_flags0, dd_bias=0.1, bg_bias=0.1) 103 | return dd_outs0 104 | 105 | ############################################################ 106 | # Precompute 107 | ############################################################ 108 | class Precompute(): 109 | def __init__(self, config, model_dir, model, weight_files): 110 | super(Precompute, self).__init__() 111 | self.config = config 112 | self.model_dir = model_dir 113 | self.epoch = 0 114 | self.layer_regex = { 115 | "lr1": r"(encoder.*)", 116 | "lr10": r"(seg_main.*)", 117 | "dd": r"(dd0.*)|(dd_head0.*)", 118 | } 119 | lr_1x = self.layer_regex["lr1"] 120 | lr_10x = self.layer_regex["lr10"] 121 | dd = self.layer_regex['dd'] 122 | seg_model=model[0].cuda() 123 | ssdd_model=model[1].cuda() 124 | self.param_lr_1x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_1x, name)) and not 'bn' in name] 125 | self.param_lr_10x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_10x, name)) and not 'bn' in name] 126 | self.param_dd = [param for name, param in ssdd_model.named_parameters() if bool(re.fullmatch(dd, name)) and not 'bn' in name] 127 | lr=1e-3 128 | self.seg_model=nn.DataParallel(seg_model) 129 | self.ssdd_model=nn.DataParallel(ssdd_model) 130 | self.seg_model.load_state_dict(torch.load(weight_files[0])) 131 | self.ssdd_model.load_state_dict(torch.load(weight_files[1])) 132 | def precompute_model(self, train_dataset): 133 | # Data generators 134 | self.train_set = SssddData(train_dataset, self.config) 135 | train_generator = torch.utils.data.DataLoader(self.train_set, batch_size=self.config.BATCH, shuffle=False, num_workers=8, pin_memory=True) 136 | self.seg_model.eval() 137 | self.ssdd_model.eval() 138 | self.cnt=0 139 | for inputs in train_generator: 140 | self.precompute_step(inputs) 141 | def precompute_step(self, inputs): 142 | img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask = inputs 143 | img_norm = Variable(img_norm).cuda().float() 144 | img_org = Variable(img_org).cuda().float() 145 | gt_class_mlabel = Variable(gt_class_mlabel).cuda().float() 146 | gt_class_mlabel_bg = Variable(gt_class_mlabel_bg).cuda().float() 147 | seg_outs = self.seg_model((img_norm, img_org, gt_class_mlabel_bg)) 148 | dd_outs = self.ssdd_model((seg_outs, psa_mask, psa_crf_mask, gt_class_mlabel)) 149 | seg_outs_main, feats = seg_outs 150 | seg_main, seg_prob_main, seg_mask_main, _ = seg_outs_main 151 | dd_outs0 = dd_outs 152 | (dd00, dd01, ignore_flags0, refine_mask0) = dd_outs0 153 | psa_mask = Variable(psa_mask).cuda().long() 154 | psa_crf_mask = Variable(psa_crf_mask).cuda().long() 155 | img_org=img_org.data.cpu().numpy()[...,::-1] 156 | for i in range(len(img_norm)): 157 | sid='_'+self.phase+'_'+self.saveid+'_'+str(self.cnt) 158 | saven = self.savedir + '/D'+sid+'.png' 159 | mask_png = utils.mask2png(saven, refine_mask0[i].squeeze().data.cpu().numpy()) 160 | saven = self.savedir + '/dk'+sid+'.png' 161 | tmp=F.sigmoid(dd00)[i].squeeze().data.cpu().numpy() 162 | cv2.imwrite(saven,tmp*255) 163 | saven = self.savedir + '/da'+sid+'.png' 164 | tmp=F.sigmoid(dd01)[i].squeeze().data.cpu().numpy() 165 | cv2.imwrite(saven,tmp*255) 166 | saven = self.savedir +'/dk'+sid 167 | np.save(saven,dd00[i].data.cpu().numpy()) 168 | saven = self.savedir +'/da'+sid 169 | np.save(saven,dd01[i].data.cpu().numpy()) 170 | print(self.cnt) 171 | self.cnt += 1 172 | def set_log_dir(self, phase, saveid, model_path=None): 173 | self.phase = phase 174 | self.saveid = saveid 175 | self.savedir = 'precompute/'+self.saveid 176 | print("save the results to "+self.savedir) 177 | if not os.path.exists(self.savedir): 178 | os.makedirs(self.savedir) 179 | 180 | def models(config, weight_file=None): 181 | seg_model = SegModel(config=config) 182 | seg_model.initialize_weights() 183 | ssdd_model = SSDDModel(config=config) 184 | ssdd_model.initialize_weights() 185 | return (seg_model, ssdd_model) 186 | -------------------------------------------------------------------------------- /prepare_labels/README.md: -------------------------------------------------------------------------------- 1 | # Preparing labels 2 | ## Description 3 | To reproduce the proposed method, preparing seed labels are required. 4 | For the seed labels, the proposed method uses an existing weakly-supervised segmentation method: Pixel-level Semantic Affinity(PSA): [arxiv](https://arxiv.org/abs/1803.10464). 5 | The codes in this directory are for generation of the seed labels. 6 | That are based on the [original implementation](https://github.com/jiwoon-ahn/psa). 7 | 8 | ## Usage 9 | The prepararion of the seed labels are consist of tow steps. 10 | First, obtain semantic probality maps using a classification model with CAM. 11 | Second, generate the seed labels by propagating semantic information using an affinity model. 12 | Each step requires a trained model, then download the trained models from following links. 13 | Classification model: [link](https://drive.google.com/file/d/1xESB7017zlZHqxEWuh1Rb89UhjTGIKOA/view) 14 | Affinity model: [link](https://drive.google.com/file/d/1xESB7017zlZHqxEWuh1Rb89UhjTGIKOA/view) 15 | Note that these links are derived from the author's [repository](https://github.com/jiwoon-ahn/psa) 16 | We recomend to save the models in this direcory: [pretrained_models](https://github.com/shimoda-uec/ssdd/tree/master/pretrained_models). 17 | 18 | After downloding the models, we can prepare the seed labels by following commands. 19 | ``` 20 | python make_cam_labels.py 21 | ``` 22 | ``` 23 | python make_aff_labels.py 24 | ``` 25 | -------------------------------------------------------------------------------- /prepare_labels/make_aff_labels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from tool import imutils 4 | 5 | import argparse 6 | import importlib 7 | import numpy as np 8 | import cv2 9 | import voc12.data 10 | from torch.utils.data import DataLoader 11 | import scipy.misc 12 | import torch.nn.functional as F 13 | import os.path 14 | from tool import imutils, pyutils 15 | import time 16 | from PIL import Image 17 | import os 18 | #voc12_root=os.environ['voc_root'] 19 | voc12_root="../voc_root" 20 | 21 | 22 | def get_palette(num_cls): 23 | n = num_cls 24 | palette = [0] * (n * 3) 25 | for j in range(0, n): 26 | lab = j 27 | palette[j * 3 + 0] = 0 28 | palette[j * 3 + 1] = 0 29 | palette[j * 3 + 2] = 0 30 | i = 0 31 | while lab: 32 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 33 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 34 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 35 | i += 1 36 | lab >>= 3 37 | return palette 38 | 39 | def mask2png(mask,saven): 40 | palette = get_palette(256) 41 | mask=Image.fromarray(mask.astype(np.uint8)) 42 | mask.putpalette(palette) 43 | mask.save(saven) 44 | 45 | if __name__ == '__main__': 46 | 47 | parser = argparse.ArgumentParser() 48 | #parser.add_argument("--weights", required=True, type=str) 49 | parser.add_argument("--infer_list", default="voc12/train_aug.txt", type=str) 50 | parser.add_argument("--num_workers", default=8, type=int) 51 | parser.add_argument("--alpha", default=16, type=int) 52 | parser.add_argument("--beta", default=8, type=int) 53 | parser.add_argument("--logt", default=8, type=int) 54 | 55 | args = parser.parse_args() 56 | 57 | model = getattr(importlib.import_module("network.resnet38_aff"), 'Net')() 58 | 59 | model.load_state_dict(torch.load("../pretrained_models/res38_aff.pth")) 60 | 61 | model.eval() 62 | model.cuda() 63 | 64 | infer_dataset = voc12.data.VOC12ClsDataset(args.infer_list, voc12_root=voc12_root, 65 | transform=torchvision.transforms.Compose( 66 | [np.asarray, 67 | model.normalize, 68 | imutils.HWC_to_CHW])) 69 | infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True) 70 | #save_dir=str(args.compatg)+"_"+str(args.compatb)+"_"+str(args.gxy)+"_"+str(args.bxy)+"_"+str(args.brgb) 71 | save_dir_cam = "results/out_cam" 72 | save_dir_aff = "results/out_aff" 73 | os.makedirs(save_dir_aff, exist_ok=True) 74 | save_dir_aff_crf = "results/out_aff_crf" 75 | os.makedirs(save_dir_aff_crf, exist_ok=True) 76 | 77 | for iter, (name, img, label) in enumerate(infer_data_loader): 78 | name = name[0]; label = label[0] 79 | img_path=voc12_root+'/JPEGImages/'+name+'.jpg' 80 | orig_img = np.asarray(Image.open(img_path)) 81 | orig_img_size = orig_img.shape[:2] 82 | print(iter) 83 | orig_shape = img.shape 84 | 85 | padded_size = (int(np.ceil(img.shape[2]/8)*8), int(np.ceil(img.shape[3]/8)*8)) 86 | 87 | p2d = (0, padded_size[1] - img.shape[3], 0, padded_size[0] - img.shape[2]) 88 | img = F.pad(img, p2d) 89 | 90 | dheight = int(np.ceil(img.shape[2]/8)) 91 | dwidth = int(np.ceil(img.shape[3]/8)) 92 | cam = np.load(os.path.join(save_dir_cam, name + '.npy')).item() 93 | cam_full_arr = np.zeros((21, orig_shape[2], orig_shape[3]), np.float32) 94 | for k, v in cam.items(): 95 | cam_full_arr[k+1] = v 96 | cam_full_arr[0] = (1 - np.max(cam_full_arr[1:], (0), keepdims=False))**args.alpha 97 | cam_full_arr = np.pad(cam_full_arr, ((0, 0), (0, p2d[3]), (0, p2d[1])), mode='constant') 98 | with torch.no_grad(): 99 | aff_mat = torch.pow(model.forward(img.cuda(), True), args.beta) 100 | trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)#D sum(W) 101 | for _ in range(args.logt): 102 | trans_mat = torch.matmul(trans_mat, trans_mat) 103 | cam_full_arr = torch.from_numpy(cam_full_arr) 104 | cam_full_arr = F.avg_pool2d(cam_full_arr, 8, 8) 105 | cam_vec = cam_full_arr.view(21, -1) 106 | cam_rw = torch.matmul(cam_vec.cuda(), trans_mat) 107 | cam_rw = cam_rw.view(1, 21, dheight, dwidth) 108 | cam_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(cam_rw) 109 | cam_rw = cam_rw[:,:,:orig_shape[2], :orig_shape[3]] 110 | cam_rw = cam_rw.squeeze().data.cpu().numpy() 111 | 112 | aff_dict = {} 113 | aff_dict[0] = cam_rw[0] 114 | for i in range(20): 115 | if label[i] > 1e-5: 116 | aff_dict[i+1] = cam_rw[i+1] 117 | np.save(os.path.join(save_dir_aff, name + '.npy'), aff_dict) 118 | mask=np.argmax(cam_rw,axis=0) 119 | mask2png(mask, os.path.join(save_dir_aff, name + '.png')) 120 | 121 | v = np.array(list(aff_dict.values())) 122 | aff_crf = imutils.crf_inference(orig_img, v, labels=v.shape[0]) 123 | aff_crf_full_arr = np.zeros((21, orig_shape[2], orig_shape[3]), np.float32) 124 | cnt=0 125 | for k, v in aff_dict.items(): 126 | aff_crf_full_arr[k] = aff_crf[cnt] 127 | cnt+=1 128 | np.save(os.path.join(save_dir_aff_crf, name + '.npy'), aff_crf) 129 | mask=np.argmax(aff_crf_full_arr,axis=0) 130 | mask2png(mask, os.path.join(save_dir_aff_crf, name + '.png')) 131 | -------------------------------------------------------------------------------- /prepare_labels/make_cam_labels.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.backends import cudnn 5 | cudnn.enabled = True 6 | import voc12.data 7 | import scipy.misc 8 | import importlib 9 | from torch.utils.data import DataLoader 10 | import torchvision 11 | from tool import imutils, pyutils 12 | import argparse 13 | from PIL import Image 14 | import torch.nn.functional as F 15 | import os.path 16 | import cv2 17 | #voc12_root=os.environ['voc_root'] 18 | voc12_root="../voc_root" 19 | if __name__ == '__main__': 20 | 21 | parser = argparse.ArgumentParser() 22 | #parser.add_argument("--weights", required=True, type=str) 23 | parser.add_argument("--network", default="network.resnet38_cls", type=str) 24 | parser.add_argument("--infer_list", default="voc12/train_aug.txt", type=str) 25 | parser.add_argument("--num_workers", default=8, type=int) 26 | parser.add_argument("--out_cam", default="out_cam", type=str) 27 | 28 | args = parser.parse_args() 29 | 30 | model = getattr(importlib.import_module(args.network), 'Net')() 31 | model.load_state_dict(torch.load("../pretrained_models/res38_cls.pth")) 32 | 33 | model.eval() 34 | model.cuda() 35 | 36 | infer_dataset = voc12.data.VOC12ClsDatasetMSF(args.infer_list, voc12_root=voc12_root, 37 | scales=(1, 0.5, 1.5, 2.0), 38 | inter_transform=torchvision.transforms.Compose( 39 | [np.asarray, 40 | model.normalize, 41 | imutils.HWC_to_CHW])) 42 | 43 | infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True) 44 | 45 | n_gpus = torch.cuda.device_count() 46 | model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus))) 47 | save_dir = "results/out_cam" 48 | os.makedirs(save_dir, exist_ok=True) 49 | for iter, (img_name, img_list, label) in enumerate(infer_data_loader): 50 | img_name = img_name[0]; label = label[0] 51 | 52 | img_path = voc12.data.get_img_path(img_name, voc12_root) 53 | orig_img = np.asarray(Image.open(img_path)) 54 | orig_img_size = orig_img.shape[:2] 55 | 56 | def _work(i, img): 57 | with torch.no_grad(): 58 | with torch.cuda.device(i%n_gpus): 59 | cam = model_replicas[i%n_gpus].forward_cam(img.cuda()) 60 | cam = F.upsample(cam, orig_img_size, mode='bilinear', align_corners=False)[0] 61 | cam = cam.cpu().numpy() * label.clone().view(20, 1, 1).numpy() 62 | if i % 2 == 1: 63 | cam = np.flip(cam, axis=-1) 64 | return cam 65 | 66 | thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)), 67 | batch_size=12, prefetch_size=0, processes=args.num_workers) 68 | 69 | cam_list = thread_pool.pop_results() 70 | 71 | sum_cam = np.sum(cam_list, axis=0) 72 | norm_cam = sum_cam / (np.max(sum_cam, (1, 2), keepdims=True) + 1e-5) 73 | 74 | cam_dict = {} 75 | for i in range(20): 76 | if label[i] > 1e-5: 77 | cam_dict[i] = norm_cam[i] 78 | 79 | if args.out_cam is not None: 80 | np.save(os.path.join(save_dir, img_name + '.npy'), cam_dict) 81 | 82 | print(iter) 83 | -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/resnet38_aff.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/resnet38_aff.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/resnet38_cls.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/resnet38_cls.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/resnet38d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/resnet38d.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/vgg16_aff.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/vgg16_aff.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/vgg16_cls.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/vgg16_cls.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/__pycache__/vgg16d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/network/__pycache__/vgg16d.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/network/resnet38_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | 6 | import network.resnet38d 7 | from tool import pyutils 8 | 9 | class Net(network.resnet38d.Net): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | 13 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 14 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 15 | self.f8_5 = torch.nn.Conv2d(4096, 256, 1, bias=False) 16 | 17 | self.f9 = torch.nn.Conv2d(448, 448, 1, bias=False) 18 | 19 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 20 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 21 | torch.nn.init.kaiming_normal_(self.f8_5.weight) 22 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 23 | 24 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 25 | 26 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f8_5, self.f9] 27 | 28 | self.predefined_featuresize = int(448//8) 29 | self.ind_from, self.ind_to = pyutils.get_indices_of_pairs(radius=5, size=(self.predefined_featuresize, self.predefined_featuresize)) 30 | self.ind_from = torch.from_numpy(self.ind_from); self.ind_to = torch.from_numpy(self.ind_to) 31 | 32 | return 33 | 34 | def forward(self, x, to_dense=False): 35 | 36 | d = super().forward_as_dict(x) 37 | 38 | f8_3 = F.elu(self.f8_3(d['conv4'])) 39 | f8_4 = F.elu(self.f8_4(d['conv5'])) 40 | f8_5 = F.elu(self.f8_5(d['conv6'])) 41 | x = F.elu(self.f9(torch.cat([f8_3, f8_4, f8_5], dim=1))) 42 | #print(x.shape) 43 | 44 | if x.size(2) == self.predefined_featuresize and x.size(3) == self.predefined_featuresize: 45 | ind_from = self.ind_from 46 | ind_to = self.ind_to 47 | else: 48 | ind_from, ind_to = pyutils.get_indices_of_pairs(5, (x.size(2), x.size(3))) 49 | ind_from = torch.from_numpy(ind_from); ind_to = torch.from_numpy(ind_to) 50 | 51 | x = x.view(x.size(0), x.size(1), -1) 52 | #print(x.shape) 53 | 54 | #print('ind',ind_from.shape) 55 | #print(ind_to.shape) 56 | ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True)) 57 | ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True)) 58 | #print('ff',ff.shape) 59 | #print(ft.shape) 60 | 61 | ff = torch.unsqueeze(ff, dim=2) 62 | ft = ft.view(ft.size(0), ft.size(1), -1, ff.size(3)) 63 | 64 | aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1)) 65 | #print('aff',aff.shape) 66 | if to_dense: 67 | aff = aff.view(-1).cpu() 68 | 69 | ind_from_exp = torch.unsqueeze(ind_from, dim=0).expand(ft.size(2), -1).contiguous().view(-1) 70 | indices = torch.stack([ind_from_exp, ind_to]) 71 | indices_tp = torch.stack([ind_to, ind_from_exp]) 72 | 73 | area = x.size(2) 74 | indices_id = torch.stack([torch.arange(0, area).long(), torch.arange(0, area).long()]) 75 | 76 | aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 77 | torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda() 78 | 79 | return aff_mat 80 | 81 | else: 82 | return aff 83 | 84 | 85 | def get_parameter_groups(self): 86 | groups = ([], [], [], []) 87 | 88 | for m in self.modules(): 89 | 90 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 91 | 92 | if m.weight.requires_grad: 93 | if m in self.from_scratch_layers: 94 | groups[2].append(m.weight) 95 | else: 96 | groups[0].append(m.weight) 97 | 98 | if m.bias is not None and m.bias.requires_grad: 99 | 100 | if m in self.from_scratch_layers: 101 | groups[3].append(m.bias) 102 | else: 103 | groups[1].append(m.bias) 104 | 105 | return groups 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /prepare_labels/network/resnet38_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import network.resnet38d 6 | 7 | 8 | class Net(network.resnet38d.Net): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.dropout7 = torch.nn.Dropout2d(0.5) 13 | 14 | self.fc8 = nn.Conv2d(4096, 20, 1, bias=False) 15 | torch.nn.init.xavier_uniform_(self.fc8.weight) 16 | 17 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 18 | self.from_scratch_layers = [self.fc8] 19 | 20 | 21 | def forward(self, x): 22 | x = super().forward(x) 23 | x = self.dropout7(x) 24 | 25 | x = F.avg_pool2d( 26 | x, kernel_size=(x.size(2), x.size(3)), padding=0) 27 | 28 | x = self.fc8(x) 29 | x = x.view(x.size(0), -1) 30 | 31 | return x 32 | 33 | def forward_cam(self, x): 34 | x = super().forward(x) 35 | 36 | x = F.conv2d(x, self.fc8.weight) 37 | x = F.relu(x) 38 | 39 | return x 40 | 41 | def get_parameter_groups(self): 42 | groups = ([], [], [], []) 43 | 44 | for m in self.modules(): 45 | 46 | if isinstance(m, nn.Conv2d): 47 | 48 | if m.weight.requires_grad: 49 | if m in self.from_scratch_layers: 50 | groups[2].append(m.weight) 51 | else: 52 | groups[0].append(m.weight) 53 | 54 | if m.bias is not None and m.bias.requires_grad: 55 | 56 | if m in self.from_scratch_layers: 57 | groups[3].append(m.bias) 58 | else: 59 | groups[1].append(m.bias) 60 | 61 | return groups -------------------------------------------------------------------------------- /prepare_labels/network/resnet38d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | import torch.nn.functional as F 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1): 9 | super(ResBlock, self).__init__() 10 | 11 | self.same_shape = (in_channels == out_channels and stride == 1) 12 | 13 | if first_dilation == None: first_dilation = dilation 14 | 15 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 16 | 17 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride, 18 | padding=first_dilation, dilation=first_dilation, bias=False) 19 | 20 | self.bn_branch2b1 = nn.BatchNorm2d(mid_channels) 21 | 22 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False) 23 | 24 | if not self.same_shape: 25 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 26 | 27 | def forward(self, x, get_x_bn_relu=False): 28 | 29 | branch2 = self.bn_branch2a(x) 30 | branch2 = F.relu(branch2) 31 | 32 | x_bn_relu = branch2 33 | 34 | if not self.same_shape: 35 | branch1 = self.conv_branch1(branch2) 36 | else: 37 | branch1 = x 38 | 39 | branch2 = self.conv_branch2a(branch2) 40 | branch2 = self.bn_branch2b1(branch2) 41 | branch2 = F.relu(branch2) 42 | branch2 = self.conv_branch2b1(branch2) 43 | 44 | x = branch1 + branch2 45 | 46 | if get_x_bn_relu: 47 | return x, x_bn_relu 48 | 49 | return x 50 | 51 | def __call__(self, x, get_x_bn_relu=False): 52 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 53 | 54 | class ResBlock_bot(nn.Module): 55 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0.): 56 | super(ResBlock_bot, self).__init__() 57 | 58 | self.same_shape = (in_channels == out_channels and stride == 1) 59 | 60 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 61 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False) 62 | 63 | self.bn_branch2b1 = nn.BatchNorm2d(out_channels//4) 64 | self.dropout_2b1 = torch.nn.Dropout2d(dropout) 65 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False) 66 | 67 | self.bn_branch2b2 = nn.BatchNorm2d(out_channels//2) 68 | self.dropout_2b2 = torch.nn.Dropout2d(dropout) 69 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False) 70 | 71 | if not self.same_shape: 72 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 73 | 74 | def forward(self, x, get_x_bn_relu=False): 75 | 76 | branch2 = self.bn_branch2a(x) 77 | branch2 = F.relu(branch2) 78 | x_bn_relu = branch2 79 | 80 | branch1 = self.conv_branch1(branch2) 81 | 82 | branch2 = self.conv_branch2a(branch2) 83 | 84 | branch2 = self.bn_branch2b1(branch2) 85 | branch2 = F.relu(branch2) 86 | branch2 = self.dropout_2b1(branch2) 87 | branch2 = self.conv_branch2b1(branch2) 88 | 89 | branch2 = self.bn_branch2b2(branch2) 90 | branch2 = F.relu(branch2) 91 | branch2 = self.dropout_2b2(branch2) 92 | branch2 = self.conv_branch2b2(branch2) 93 | 94 | x = branch1 + branch2 95 | 96 | if get_x_bn_relu: 97 | return x, x_bn_relu 98 | 99 | return x 100 | 101 | def __call__(self, x, get_x_bn_relu=False): 102 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 103 | 104 | class Normalize(): 105 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 106 | 107 | self.mean = mean 108 | self.std = std 109 | 110 | def __call__(self, img): 111 | imgarr = np.asarray(img) 112 | proc_img = np.empty_like(imgarr, np.float32) 113 | 114 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 115 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 116 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 117 | 118 | return proc_img 119 | 120 | class Net(nn.Module): 121 | def __init__(self): 122 | super(Net, self).__init__() 123 | 124 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False) 125 | 126 | self.b2 = ResBlock(64, 128, 128, stride=2) 127 | self.b2_1 = ResBlock(128, 128, 128) 128 | self.b2_2 = ResBlock(128, 128, 128) 129 | 130 | self.b3 = ResBlock(128, 256, 256, stride=2) 131 | self.b3_1 = ResBlock(256, 256, 256) 132 | self.b3_2 = ResBlock(256, 256, 256) 133 | 134 | self.b4 = ResBlock(256, 512, 512, stride=2) 135 | self.b4_1 = ResBlock(512, 512, 512) 136 | self.b4_2 = ResBlock(512, 512, 512) 137 | self.b4_3 = ResBlock(512, 512, 512) 138 | self.b4_4 = ResBlock(512, 512, 512) 139 | self.b4_5 = ResBlock(512, 512, 512) 140 | 141 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2) 142 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2) 143 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2) 144 | 145 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3) 146 | 147 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5) 148 | 149 | self.bn7 = nn.BatchNorm2d(4096) 150 | 151 | self.not_training = [self.conv1a] 152 | 153 | self.normalize = Normalize() 154 | 155 | return 156 | 157 | def forward(self, x): 158 | return self.forward_as_dict(x)['conv6'] 159 | 160 | def forward_as_dict(self, x): 161 | 162 | x = self.conv1a(x) 163 | 164 | x = self.b2(x) 165 | x = self.b2_1(x) 166 | x = self.b2_2(x) 167 | 168 | x = self.b3(x) 169 | x = self.b3_1(x) 170 | x = self.b3_2(x) 171 | 172 | x = self.b4(x) 173 | x = self.b4_1(x) 174 | x = self.b4_2(x) 175 | x = self.b4_3(x) 176 | x = self.b4_4(x) 177 | x = self.b4_5(x) 178 | 179 | x, conv4 = self.b5(x, get_x_bn_relu=True) 180 | x = self.b5_1(x) 181 | x = self.b5_2(x) 182 | 183 | x, conv5 = self.b6(x, get_x_bn_relu=True) 184 | 185 | x = self.b7(x) 186 | conv6 = F.relu(self.bn7(x)) 187 | 188 | return dict({'conv4': conv4, 'conv5': conv5, 'conv6': conv6}) 189 | 190 | 191 | def train(self, mode=True): 192 | 193 | super().train(mode) 194 | 195 | for layer in self.not_training: 196 | 197 | if isinstance(layer, torch.nn.Conv2d): 198 | layer.weight.requires_grad = False 199 | 200 | elif isinstance(layer, torch.nn.Module): 201 | for c in layer.children(): 202 | c.weight.requires_grad = False 203 | if c.bias is not None: 204 | c.bias.requires_grad = False 205 | 206 | for layer in self.modules(): 207 | 208 | if isinstance(layer, torch.nn.BatchNorm2d): 209 | layer.eval() 210 | layer.bias.requires_grad = False 211 | layer.weight.requires_grad = False 212 | 213 | return 214 | 215 | def convert_mxnet_to_torch(filename): 216 | import mxnet 217 | 218 | save_dict = mxnet.nd.load(filename) 219 | 220 | renamed_dict = dict() 221 | 222 | bn_param_mx_pt = {'beta': 'bias', 'gamma': 'weight', 'mean': 'running_mean', 'var': 'running_var'} 223 | 224 | for k, v in save_dict.items(): 225 | 226 | v = torch.from_numpy(v.asnumpy()) 227 | toks = k.split('_') 228 | 229 | if 'conv1a' in toks[0]: 230 | renamed_dict['conv1a.weight'] = v 231 | 232 | elif 'linear1000' in toks[0]: 233 | pass 234 | 235 | elif 'branch' in toks[1]: 236 | 237 | pt_name = [] 238 | 239 | if toks[0][-1] != 'a': 240 | pt_name.append('b' + toks[0][-3] + '_' + toks[0][-1]) 241 | else: 242 | pt_name.append('b' + toks[0][-2]) 243 | 244 | if 'res' in toks[0]: 245 | layer_type = 'conv' 246 | last_name = 'weight' 247 | 248 | else: # 'bn' in toks[0]: 249 | layer_type = 'bn' 250 | last_name = bn_param_mx_pt[toks[-1]] 251 | 252 | pt_name.append(layer_type + '_' + toks[1]) 253 | 254 | pt_name.append(last_name) 255 | 256 | torch_name = '.'.join(pt_name) 257 | renamed_dict[torch_name] = v 258 | 259 | else: 260 | last_name = bn_param_mx_pt[toks[-1]] 261 | renamed_dict['bn7.' + last_name] = v 262 | 263 | return renamed_dict 264 | 265 | -------------------------------------------------------------------------------- /prepare_labels/network/vgg16_20M.prototxt: -------------------------------------------------------------------------------- 1 | name: "VGG_ILSVRC_16_layers" 2 | input: "data" 3 | input_dim: 10 4 | input_dim: 3 5 | input_dim: 224 6 | input_dim: 224 7 | layer { 8 | name: "conv1_1" 9 | type: "Convolution" 10 | bottom: "data" 11 | top: "conv1_1a" 12 | convolution_param { 13 | num_output: 64 14 | pad: 1 15 | kernel_size: 3 16 | } 17 | } 18 | layer { 19 | name: "relu1_1" 20 | type: "ReLU" 21 | bottom: "conv1_1a" 22 | top: "conv1_1" 23 | } 24 | layer { 25 | name: "conv1_2" 26 | type: "Convolution" 27 | bottom: "conv1_1" 28 | top: "conv1_2" 29 | convolution_param { 30 | num_output: 64 31 | pad: 1 32 | kernel_size: 3 33 | } 34 | } 35 | layer { 36 | name: "relu1_2" 37 | type: "ReLU" 38 | bottom: "conv1_2" 39 | top: "conv1_2" 40 | } 41 | layer { 42 | name: "pool1" 43 | type: "Pooling" 44 | bottom: "conv1_2" 45 | top: "pool1" 46 | pooling_param { 47 | pool: MAX 48 | kernel_size: 2 49 | stride: 2 50 | } 51 | } 52 | layer { 53 | name: "conv2_1" 54 | type: "Convolution" 55 | bottom: "pool1" 56 | top: "conv2_1" 57 | convolution_param { 58 | num_output: 128 59 | pad: 1 60 | kernel_size: 3 61 | } 62 | } 63 | layer { 64 | name: "relu2_1" 65 | type: "ReLU" 66 | bottom: "conv2_1" 67 | top: "conv2_1" 68 | } 69 | layer { 70 | name: "conv2_2" 71 | type: "Convolution" 72 | bottom: "conv2_1" 73 | top: "conv2_2" 74 | convolution_param { 75 | num_output: 128 76 | pad: 1 77 | kernel_size: 3 78 | } 79 | } 80 | layer { 81 | name: "relu2_2" 82 | type: "ReLU" 83 | bottom: "conv2_2" 84 | top: "conv2_2" 85 | } 86 | layer { 87 | name: "pool2" 88 | type: "Pooling" 89 | bottom: "conv2_2" 90 | top: "pool2" 91 | pooling_param { 92 | pool: MAX 93 | kernel_size: 2 94 | stride: 2 95 | } 96 | } 97 | layer { 98 | name: "conv3_1" 99 | type: "Convolution" 100 | bottom: "pool2" 101 | top: "conv3_1" 102 | convolution_param { 103 | num_output: 256 104 | pad: 1 105 | kernel_size: 3 106 | } 107 | } 108 | layer { 109 | name: "relu3_1" 110 | type: "ReLU" 111 | bottom: "conv3_1" 112 | top: "conv3_1" 113 | } 114 | layer { 115 | name: "conv3_2" 116 | type: "Convolution" 117 | bottom: "conv3_1" 118 | top: "conv3_2" 119 | convolution_param { 120 | num_output: 256 121 | pad: 1 122 | kernel_size: 3 123 | } 124 | } 125 | layer { 126 | name: "relu3_2" 127 | type: "ReLU" 128 | bottom: "conv3_2" 129 | top: "conv3_2" 130 | } 131 | layer { 132 | name: "conv3_3" 133 | type: "Convolution" 134 | bottom: "conv3_2" 135 | top: "conv3_3" 136 | convolution_param { 137 | num_output: 256 138 | pad: 1 139 | kernel_size: 3 140 | } 141 | } 142 | layer { 143 | name: "relu3_3" 144 | type: "ReLU" 145 | bottom: "conv3_3" 146 | top: "conv3_3" 147 | } 148 | layer { 149 | name: "pool3" 150 | type: "Pooling" 151 | bottom: "conv3_3" 152 | top: "pool3" 153 | pooling_param { 154 | pool: MAX 155 | kernel_size: 2 156 | stride: 2 157 | } 158 | } 159 | layer { 160 | name: "conv4_1" 161 | type: "Convolution" 162 | bottom: "pool3" 163 | top: "conv4_1" 164 | convolution_param { 165 | num_output: 512 166 | pad: 1 167 | kernel_size: 3 168 | } 169 | } 170 | layer { 171 | name: "relu4_1" 172 | type: "ReLU" 173 | bottom: "conv4_1" 174 | top: "conv4_1" 175 | } 176 | layer { 177 | name: "conv4_2" 178 | type: "Convolution" 179 | bottom: "conv4_1" 180 | top: "conv4_2" 181 | convolution_param { 182 | num_output: 512 183 | pad: 1 184 | kernel_size: 3 185 | } 186 | } 187 | layer { 188 | name: "relu4_2" 189 | type: "ReLU" 190 | bottom: "conv4_2" 191 | top: "conv4_2" 192 | } 193 | layer { 194 | name: "conv4_3" 195 | type: "Convolution" 196 | bottom: "conv4_2" 197 | top: "conv4_3" 198 | convolution_param { 199 | num_output: 512 200 | pad: 1 201 | kernel_size: 3 202 | } 203 | } 204 | layer { 205 | name: "relu4_3" 206 | type: "ReLU" 207 | bottom: "conv4_3" 208 | top: "conv4_3" 209 | } 210 | layer { 211 | name: "pool4" 212 | type: "Pooling" 213 | bottom: "conv4_3" 214 | top: "pool4" 215 | pooling_param { 216 | pool: MAX 217 | kernel_size: 2 218 | stride: 2 219 | } 220 | } 221 | layer { 222 | name: "conv5_1" 223 | type: "Convolution" 224 | bottom: "pool4" 225 | top: "conv5_1" 226 | convolution_param { 227 | num_output: 512 228 | pad: 1 229 | kernel_size: 3 230 | } 231 | } 232 | layer { 233 | name: "relu5_1" 234 | type: "ReLU" 235 | bottom: "conv5_1" 236 | top: "conv5_1" 237 | } 238 | layer { 239 | name: "conv5_2" 240 | type: "Convolution" 241 | bottom: "conv5_1" 242 | top: "conv5_2" 243 | convolution_param { 244 | num_output: 512 245 | pad: 1 246 | kernel_size: 3 247 | } 248 | } 249 | layer { 250 | name: "relu5_2" 251 | type: "ReLU" 252 | bottom: "conv5_2" 253 | top: "conv5_2" 254 | } 255 | layer { 256 | name: "conv5_3" 257 | type: "Convolution" 258 | bottom: "conv5_2" 259 | top: "conv5_3" 260 | convolution_param { 261 | num_output: 512 262 | pad: 1 263 | kernel_size: 3 264 | } 265 | } 266 | layer { 267 | name: "relu5_3" 268 | type: "ReLU" 269 | bottom: "conv5_3" 270 | top: "conv5_3" 271 | } 272 | 273 | layer { 274 | bottom: "conv5_3" 275 | top: "pool5" 276 | name: "pool5" 277 | type: "Pooling" 278 | pooling_param { 279 | pool: MAX 280 | kernel_size: 3 281 | stride: 1 282 | pad: 1 283 | } 284 | } 285 | layer { 286 | bottom: "pool5" 287 | top: "pool5a" 288 | name: "pool5a" 289 | type: "Pooling" 290 | pooling_param { 291 | pool: AVE 292 | kernel_size: 3 293 | stride: 1 294 | pad: 1 295 | } 296 | } 297 | layer { 298 | name: "fc6" 299 | type: "Convolution" 300 | bottom: "pool5a" 301 | top: "fc6" 302 | convolution_param { 303 | num_output: 1024 304 | pad: 1 305 | kernel_size: 3 306 | } 307 | } 308 | layer { 309 | name: "relu6" 310 | type: "ReLU" 311 | bottom: "fc6" 312 | top: "fc6" 313 | } 314 | layer { 315 | name: "fc7" 316 | type: "Convolution" 317 | bottom: "fc6" 318 | top: "fc7" 319 | convolution_param { 320 | num_output: 1024 321 | kernel_size: 1 322 | } 323 | } 324 | layer { 325 | name: "relu7" 326 | type: "ReLU" 327 | bottom: "fc7" 328 | top: "fc7" 329 | } -------------------------------------------------------------------------------- /prepare_labels/network/vgg16_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | from tool import pyutils 6 | 7 | import network.vgg16d 8 | 9 | class Net(network.vgg16d.Net): 10 | def __init__(self): 11 | super(Net, self).__init__(fc6_dilation=4) 12 | 13 | self.f8_3 = nn.Conv2d(512, 64, 1, bias=False) 14 | self.f8_4 = nn.Conv2d(512, 128, 1, bias=False) 15 | self.f8_5 = nn.Conv2d(1024, 256, 1, bias=False) 16 | self.gn8_3 = nn.modules.normalization.GroupNorm(8, 64) 17 | self.gn8_4 = nn.modules.normalization.GroupNorm(16, 128) 18 | self.gn8_5 = nn.modules.normalization.GroupNorm(32, 256) 19 | 20 | self.f9 = torch.nn.Conv2d(448, 448, 1, bias=False) 21 | 22 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 23 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 24 | torch.nn.init.kaiming_normal_(self.f8_5.weight) 25 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 26 | 27 | self.not_training = [self.conv1_1, self.conv1_2, self.conv2_1, self.conv2_2] 28 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f8_5, self.f9] 29 | 30 | self.predefined_featuresize = int(448//8) 31 | self.ind_from, self.ind_to = pyutils.get_indices_of_pairs(5, (self.predefined_featuresize, self.predefined_featuresize)) 32 | self.ind_from = torch.from_numpy(self.ind_from); self.ind_to = torch.from_numpy(self.ind_to) 33 | 34 | return 35 | 36 | 37 | def forward(self, x, to_dense=False): 38 | 39 | d = super().forward_as_dict(x) 40 | 41 | f8_3 = F.elu(self.gn8_3(self.f8_3(d['conv4']))) 42 | f8_4 = F.elu(self.gn8_4(self.f8_4(d['conv5']))) 43 | f8_5 = F.elu(self.gn8_5(self.f8_5(d['conv5fc']))) 44 | 45 | x = torch.cat([f8_3, f8_4, f8_5], dim=1) 46 | x = F.elu(self.f9(x)) 47 | 48 | if x.size(2) == self.predefined_featuresize and x.size(3) == self.predefined_featuresize: 49 | ind_from = self.ind_from 50 | ind_to = self.ind_to 51 | else: 52 | ind_from, ind_to = pyutils.get_indices_of_pairs(5, (x.size(2), x.size(3))) 53 | ind_from = torch.from_numpy(ind_from); ind_to = torch.from_numpy(ind_to) 54 | 55 | x = x.view(x.size(0), x.size(1), -1) 56 | 57 | ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True)) 58 | ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True)) 59 | 60 | ff = torch.unsqueeze(ff, dim=2) 61 | ft = ft.view(ft.size(0), ft.size(1), -1, ff.size(3)) 62 | 63 | aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1)) 64 | 65 | if to_dense: 66 | aff = aff.view(-1).cpu() 67 | 68 | ind_from_exp = torch.unsqueeze(ind_from, dim=0).expand(ft.size(2), -1).contiguous().view(-1) 69 | indices = torch.stack([ind_from_exp, ind_to]) 70 | indices_tp = torch.stack([ind_to, ind_from_exp]) 71 | 72 | area = x.size(2) 73 | indices_id = torch.stack([torch.arange(0, area).long(), torch.arange(0, area).long()]) 74 | 75 | aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 76 | torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda() 77 | return aff_mat 78 | 79 | else: 80 | return aff 81 | 82 | def get_parameter_groups(self): 83 | groups = ([], [], [], []) 84 | 85 | for m in self.modules(): 86 | 87 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 88 | 89 | if m.weight.requires_grad: 90 | if m in self.from_scratch_layers: 91 | groups[2].append(m.weight) 92 | else: 93 | groups[0].append(m.weight) 94 | 95 | if m.bias is not None and m.bias.requires_grad: 96 | 97 | if m in self.from_scratch_layers: 98 | groups[3].append(m.bias) 99 | else: 100 | groups[1].append(m.bias) 101 | 102 | return groups 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /prepare_labels/network/vgg16_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import network.vgg16d 6 | 7 | class Net(network.vgg16d.Net): 8 | 9 | def __init__(self): 10 | super(Net, self).__init__() 11 | 12 | self.drop7 = nn.Dropout2d(p=0.5) 13 | self.fc8 = nn.Conv2d(1024, 20, 1, bias=False) 14 | torch.nn.init.xavier_uniform_(self.fc8.weight) 15 | 16 | self.not_training = [self.conv1_1, self.conv1_2, 17 | self.conv2_1, self.conv2_2] 18 | self.from_scratch_layers = [self.fc8] 19 | 20 | def forward(self, x): 21 | x = super().forward(x) 22 | x = self.drop7(x) 23 | 24 | x = self.fc8(x) 25 | 26 | x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0) 27 | 28 | x = x.view(-1, 20) 29 | 30 | return x 31 | 32 | def forward_cam(self, x): 33 | x = super().forward(x) 34 | x = self.fc8(x) 35 | x = F.relu(x) 36 | x = torch.sqrt(x) 37 | return x 38 | 39 | def fix_bn(self): 40 | self.bn8.eval() 41 | self.bn8.weight.requires_grad = False 42 | self.bn8.bias.requires_grad = False 43 | 44 | def get_parameter_groups(self): 45 | groups = ([], [], [], []) 46 | 47 | for m in self.modules(): 48 | 49 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)): 50 | 51 | if m.weight is not None and m.weight.requires_grad: 52 | if m in self.from_scratch_layers: 53 | groups[2].append(m.weight) 54 | else: 55 | groups[0].append(m.weight) 56 | 57 | if m.bias is not None and m.bias.requires_grad: 58 | 59 | if m in self.from_scratch_layers: 60 | groups[3].append(m.bias) 61 | else: 62 | groups[1].append(m.bias) 63 | 64 | return groups -------------------------------------------------------------------------------- /prepare_labels/network/vgg16d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class Normalize(): 7 | def __init__(self, mean = (122.675, 116.669, 104.008)): 8 | 9 | self.mean = mean 10 | 11 | def __call__(self, img): 12 | imgarr = np.asarray(img) 13 | proc_img = np.empty_like(imgarr, np.float32) 14 | 15 | proc_img[..., 0] = (imgarr[..., 2] - self.mean[2]) 16 | proc_img[..., 1] = (imgarr[..., 1] - self.mean[1]) 17 | proc_img[..., 2] = (imgarr[..., 0] - self.mean[0]) 18 | 19 | return proc_img 20 | 21 | class Net(nn.Module): 22 | def __init__(self, fc6_dilation = 1): 23 | super(Net, self).__init__() 24 | 25 | self.conv1_1 = nn.Conv2d(3,64,3,padding = 1) 26 | self.conv1_2 = nn.Conv2d(64,64,3,padding = 1) 27 | self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 28 | self.conv2_1 = nn.Conv2d(64,128,3,padding = 1) 29 | self.conv2_2 = nn.Conv2d(128,128,3,padding = 1) 30 | self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 31 | self.conv3_1 = nn.Conv2d(128,256,3,padding = 1) 32 | self.conv3_2 = nn.Conv2d(256,256,3,padding = 1) 33 | self.conv3_3 = nn.Conv2d(256,256,3,padding = 1) 34 | self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 35 | self.conv4_1 = nn.Conv2d(256,512,3,padding = 1) 36 | self.conv4_2 = nn.Conv2d(512,512,3,padding = 1) 37 | self.conv4_3 = nn.Conv2d(512,512,3,padding = 1) 38 | self.pool4 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding=1) 39 | self.conv5_1 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 40 | self.conv5_2 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 41 | self.conv5_3 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 42 | self.pool5 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding=1) 43 | self.pool5a = nn.AvgPool2d(kernel_size = 3, stride = 1, padding=1) 44 | 45 | self.fc6 = nn.Conv2d(512,1024, 3, padding = fc6_dilation, dilation = fc6_dilation) 46 | 47 | self.drop6 = nn.Dropout2d(p=0.5) 48 | self.fc7 = nn.Conv2d(1024,1024,1) 49 | 50 | self.normalize = Normalize() 51 | 52 | return 53 | 54 | def forward(self, x): 55 | return self.forward_as_dict(x)['conv5fc'] 56 | 57 | def forward_as_dict(self, x): 58 | 59 | x = F.relu(self.conv1_1(x)) 60 | x = F.relu(self.conv1_2(x)) 61 | x = self.pool1(x) 62 | 63 | x = F.relu(self.conv2_1(x)) 64 | x = F.relu(self.conv2_2(x)) 65 | x = self.pool2(x) 66 | 67 | x = F.relu(self.conv3_1(x)) 68 | x = F.relu(self.conv3_2(x)) 69 | x = F.relu(self.conv3_3(x)) 70 | x = self.pool3(x) 71 | 72 | x = F.relu(self.conv4_1(x)) 73 | x = F.relu(self.conv4_2(x)) 74 | x = F.relu(self.conv4_3(x)) 75 | conv4 = x 76 | 77 | x = self.pool4(x) 78 | 79 | x = F.relu(self.conv5_1(x)) 80 | x = F.relu(self.conv5_2(x)) 81 | x = F.relu(self.conv5_3(x)) 82 | conv5 = x 83 | 84 | x = F.relu(self.fc6(x)) 85 | x = self.drop6(x) 86 | x = F.relu(self.fc7(x)) 87 | 88 | conv5fc = x 89 | 90 | return dict({'conv4': conv4, 'conv5': conv5, 'conv5fc': conv5fc}) 91 | 92 | def train(self, mode=True): 93 | 94 | super().train(mode) 95 | 96 | for layer in self.not_training: 97 | 98 | if isinstance(layer, torch.nn.Conv2d): 99 | 100 | layer.weight.requires_grad = False 101 | layer.bias.requires_grad = False 102 | 103 | def convert_caffe_to_torch(caffemodel_path, prototxt_path='network/vgg16_20M.prototxt'): 104 | import caffe 105 | 106 | caffe_model = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST) 107 | 108 | dict = {} 109 | for caffe_name in list(caffe_model.params.keys()): 110 | dict[caffe_name + '.weight'] = torch.from_numpy(caffe_model.params[caffe_name][0].data) 111 | dict[caffe_name + '.bias'] = torch.from_numpy(caffe_model.params[caffe_name][1].data) 112 | 113 | return dict 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /prepare_labels/tool/imutils.py: -------------------------------------------------------------------------------- 1 | 2 | import PIL.Image 3 | import random 4 | import numpy as np 5 | 6 | class RandomResizeLong(): 7 | 8 | def __init__(self, min_long, max_long): 9 | self.min_long = min_long 10 | self.max_long = max_long 11 | 12 | def __call__(self, img): 13 | 14 | target_long = random.randint(self.min_long, self.max_long) 15 | w, h = img.size 16 | 17 | if w < h: 18 | target_shape = (int(round(w * target_long / h)), target_long) 19 | else: 20 | target_shape = (target_long, int(round(h * target_long / w))) 21 | 22 | img = img.resize(target_shape, resample=PIL.Image.CUBIC) 23 | 24 | return img 25 | 26 | 27 | class RandomCrop(): 28 | 29 | def __init__(self, cropsize): 30 | self.cropsize = cropsize 31 | 32 | def __call__(self, imgarr): 33 | 34 | h, w, c = imgarr.shape 35 | 36 | ch = min(self.cropsize, h) 37 | cw = min(self.cropsize, w) 38 | 39 | w_space = w - self.cropsize 40 | h_space = h - self.cropsize 41 | 42 | if w_space > 0: 43 | cont_left = 0 44 | img_left = random.randrange(w_space+1) 45 | else: 46 | cont_left = random.randrange(-w_space+1) 47 | img_left = 0 48 | 49 | if h_space > 0: 50 | cont_top = 0 51 | img_top = random.randrange(h_space+1) 52 | else: 53 | cont_top = random.randrange(-h_space+1) 54 | img_top = 0 55 | 56 | container = np.zeros((self.cropsize, self.cropsize, imgarr.shape[-1]), np.float32) 57 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 58 | imgarr[img_top:img_top+ch, img_left:img_left+cw] 59 | 60 | return container 61 | 62 | def get_random_crop_box(imgsize, cropsize): 63 | h, w = imgsize 64 | 65 | ch = min(cropsize, h) 66 | cw = min(cropsize, w) 67 | 68 | w_space = w - cropsize 69 | h_space = h - cropsize 70 | 71 | if w_space > 0: 72 | cont_left = 0 73 | img_left = random.randrange(w_space + 1) 74 | else: 75 | cont_left = random.randrange(-w_space + 1) 76 | img_left = 0 77 | 78 | if h_space > 0: 79 | cont_top = 0 80 | img_top = random.randrange(h_space + 1) 81 | else: 82 | cont_top = random.randrange(-h_space + 1) 83 | img_top = 0 84 | 85 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 86 | 87 | def crop_with_box(img, box): 88 | if len(img.shape) == 3: 89 | img_cont = np.zeros((max(box[1]-box[0], box[4]-box[5]), max(box[3]-box[2], box[7]-box[6]), img.shape[-1]), dtype=img.dtype) 90 | else: 91 | img_cont = np.zeros((max(box[1] - box[0], box[4] - box[5]), max(box[3] - box[2], box[7] - box[6])), dtype=img.dtype) 92 | img_cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 93 | return img_cont 94 | 95 | 96 | def random_crop(images, cropsize, fills): 97 | if isinstance(images[0], PIL.Image.Image): 98 | imgsize = images[0].size[::-1] 99 | else: 100 | imgsize = images[0].shape[:2] 101 | box = get_random_crop_box(imgsize, cropsize) 102 | 103 | new_images = [] 104 | for img, f in zip(images, fills): 105 | 106 | if isinstance(img, PIL.Image.Image): 107 | img = img.crop((box[6], box[4], box[7], box[5])) 108 | cont = PIL.Image.new(img.mode, (cropsize, cropsize)) 109 | cont.paste(img, (box[2], box[0])) 110 | new_images.append(cont) 111 | 112 | else: 113 | if len(img.shape) == 3: 114 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 115 | else: 116 | cont = np.ones((cropsize, cropsize), img.dtype)*f 117 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 118 | new_images.append(cont) 119 | 120 | return new_images 121 | 122 | 123 | class AvgPool2d(): 124 | 125 | def __init__(self, ksize): 126 | self.ksize = ksize 127 | 128 | def __call__(self, img): 129 | import skimage.measure 130 | 131 | return skimage.measure.block_reduce(img, (self.ksize, self.ksize, 1), np.mean) 132 | 133 | 134 | class RandomHorizontalFlip(): 135 | def __init__(self): 136 | return 137 | 138 | def __call__(self, img): 139 | if bool(random.getrandbits(1)): 140 | img = np.fliplr(img).copy() 141 | return img 142 | 143 | 144 | class CenterCrop(): 145 | 146 | def __init__(self, cropsize, default_value=0): 147 | self.cropsize = cropsize 148 | self.default_value = default_value 149 | 150 | def __call__(self, npimg): 151 | 152 | h, w = npimg.shape[:2] 153 | 154 | ch = min(self.cropsize, h) 155 | cw = min(self.cropsize, w) 156 | 157 | sh = h - self.cropsize 158 | sw = w - self.cropsize 159 | 160 | if sw > 0: 161 | cont_left = 0 162 | img_left = int(round(sw / 2)) 163 | else: 164 | cont_left = int(round(-sw / 2)) 165 | img_left = 0 166 | 167 | if sh > 0: 168 | cont_top = 0 169 | img_top = int(round(sh / 2)) 170 | else: 171 | cont_top = int(round(-sh / 2)) 172 | img_top = 0 173 | 174 | if len(npimg.shape) == 2: 175 | container = np.ones((self.cropsize, self.cropsize), npimg.dtype)*self.default_value 176 | else: 177 | container = np.ones((self.cropsize, self.cropsize, npimg.shape[2]), npimg.dtype)*self.default_value 178 | 179 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 180 | npimg[img_top:img_top+ch, img_left:img_left+cw] 181 | 182 | return container 183 | 184 | 185 | def HWC_to_CHW(img): 186 | return np.transpose(img, (2, 0, 1)) 187 | 188 | 189 | class RescaleNearest(): 190 | def __init__(self, scale): 191 | self.scale = scale 192 | 193 | def __call__(self, npimg): 194 | import cv2 195 | return cv2.resize(npimg, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_NEAREST) 196 | 197 | 198 | 199 | 200 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21): 201 | import pydensecrf.densecrf as dcrf 202 | from pydensecrf.utils import unary_from_softmax 203 | 204 | h, w = img.shape[:2] 205 | n_labels = labels 206 | 207 | d = dcrf.DenseCRF2D(w, h, n_labels) 208 | 209 | unary = unary_from_softmax(probs) 210 | unary = np.ascontiguousarray(unary) 211 | 212 | d.setUnaryEnergy(unary) 213 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 214 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 215 | Q = d.inference(t) 216 | 217 | return np.array(Q).reshape((n_labels, h, w)) -------------------------------------------------------------------------------- /prepare_labels/tool/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | -------------------------------------------------------------------------------- /prepare_labels/tool/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import os.path 6 | import random 7 | import numpy as np 8 | from tool import imutils 9 | 10 | class PolyOptimizer(torch.optim.SGD): 11 | 12 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 13 | super().__init__(params, lr, weight_decay) 14 | 15 | self.global_step = 0 16 | self.max_step = max_step 17 | self.momentum = momentum 18 | 19 | self.__initial_lr = [group['lr'] for group in self.param_groups] 20 | 21 | 22 | def step(self, closure=None): 23 | 24 | if self.global_step < self.max_step: 25 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 26 | 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 29 | 30 | super().step(closure) 31 | 32 | self.global_step += 1 33 | 34 | 35 | class BatchNorm2dFixed(torch.nn.Module): 36 | 37 | def __init__(self, num_features, eps=1e-5): 38 | super(BatchNorm2dFixed, self).__init__() 39 | self.num_features = num_features 40 | self.eps = eps 41 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 42 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 43 | self.register_buffer('running_mean', torch.zeros(num_features)) 44 | self.register_buffer('running_var', torch.ones(num_features)) 45 | 46 | 47 | def forward(self, input): 48 | 49 | return F.batch_norm( 50 | input, self.running_mean, self.running_var, self.weight, self.bias, 51 | False, eps=self.eps) 52 | 53 | def __call__(self, x): 54 | return self.forward(x) 55 | 56 | 57 | class SegmentationDataset(Dataset): 58 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 59 | img_transform=None, mask_transform=None): 60 | self.img_name_list_path = img_name_list_path 61 | self.img_dir = img_dir 62 | self.label_dir = label_dir 63 | 64 | self.img_transform = img_transform 65 | self.mask_transform = mask_transform 66 | 67 | self.img_name_list = open(self.img_name_list_path).read().splitlines() 68 | 69 | self.rescale = rescale 70 | self.flip = flip 71 | self.cropsize = cropsize 72 | 73 | def __len__(self): 74 | return len(self.img_name_list) 75 | 76 | def __getitem__(self, idx): 77 | 78 | name = self.img_name_list[idx] 79 | 80 | img = Image.open(os.path.join(self.img_dir, name + '.jpg')).convert("RGB") 81 | mask = Image.open(os.path.join(self.label_dir, name + '.png')) 82 | 83 | if self.rescale is not None: 84 | s = self.rescale[0] + random.random() * (self.rescale[1] - self.rescale[0]) 85 | adj_size = (round(img.size[0]*s/8)*8, round(img.size[1]*s/8)*8) 86 | img = img.resize(adj_size, resample=Image.CUBIC) 87 | mask = img.resize(adj_size, resample=Image.NEAREST) 88 | 89 | if self.img_transform is not None: 90 | img = self.img_transform(img) 91 | if self.mask_transform is not None: 92 | mask = self.mask_transform(mask) 93 | 94 | if self.cropsize is not None: 95 | img, mask = imutils.random_crop([img, mask], self.cropsize, (0, 255)) 96 | 97 | mask = imutils.RescaleNearest(0.125)(mask) 98 | 99 | if self.flip is True and bool(random.getrandbits(1)): 100 | img = np.flip(img, 1).copy() 101 | mask = np.flip(mask, 1).copy() 102 | 103 | img = np.transpose(img, (2, 0, 1)) 104 | 105 | return name, img, mask 106 | 107 | 108 | class ExtractAffinityLabelInRadius(): 109 | 110 | def __init__(self, cropsize, radius=5): 111 | self.radius = radius 112 | 113 | self.search_dist = [] 114 | 115 | for x in range(1, radius): 116 | self.search_dist.append((0, x)) 117 | 118 | for y in range(1, radius): 119 | for x in range(-radius+1, radius): 120 | if x*x + y*y < radius*radius: 121 | self.search_dist.append((y, x)) 122 | 123 | self.radius_floor = radius-1 124 | 125 | self.crop_height = cropsize - self.radius_floor 126 | self.crop_width = cropsize - 2 * self.radius_floor 127 | return 128 | 129 | def __call__(self, label): 130 | 131 | labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] 132 | labels_from = np.reshape(labels_from, [-1]) 133 | 134 | labels_to_list = [] 135 | valid_pair_list = [] 136 | 137 | for dy, dx in self.search_dist: 138 | labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] 139 | labels_to = np.reshape(labels_to, [-1]) 140 | 141 | valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) 142 | 143 | labels_to_list.append(labels_to) 144 | valid_pair_list.append(valid_pair) 145 | 146 | bc_labels_from = np.expand_dims(labels_from, 0) 147 | concat_labels_to = np.stack(labels_to_list) 148 | concat_valid_pair = np.stack(valid_pair_list) 149 | 150 | pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) 151 | 152 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) 153 | 154 | fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) 155 | 156 | neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) 157 | 158 | return bg_pos_affinity_label, fg_pos_affinity_label, neg_affinity_label 159 | 160 | class AffinityFromMaskDataset(SegmentationDataset): 161 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 162 | img_transform=None, mask_transform=None, radius=5): 163 | super().__init__(img_name_list_path, img_dir, label_dir, rescale, flip, cropsize, img_transform, mask_transform) 164 | 165 | self.radius = radius 166 | 167 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 168 | 169 | def __getitem__(self, idx): 170 | name, img, mask = super().__getitem__(idx) 171 | 172 | aff_label = self.extract_aff_lab_func(mask) 173 | 174 | return name, img, aff_label 175 | -------------------------------------------------------------------------------- /prepare_labels/voc12/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/voc12/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /prepare_labels/voc12/cls_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shimoda-uec/ssdd/564c3e08fae7a158516cdbd9f3599a74dc748aff/prepare_labels/voc12/cls_labels.npy -------------------------------------------------------------------------------- /prepare_labels/voc12/data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import PIL.Image 6 | import os.path 7 | import scipy.misc 8 | 9 | IMG_FOLDER_NAME = "JPEGImages" 10 | ANNOT_FOLDER_NAME = "Annotations" 11 | 12 | CAT_LIST = ['aeroplane', 'bicycle', 'bird', 'boat', 13 | 'bottle', 'bus', 'car', 'cat', 'chair', 14 | 'cow', 'diningtable', 'dog', 'horse', 15 | 'motorbike', 'person', 'pottedplant', 16 | 'sheep', 'sofa', 'train', 17 | 'tvmonitor'] 18 | 19 | CAT_NAME_TO_NUM = dict(zip(CAT_LIST,range(len(CAT_LIST)))) 20 | 21 | def load_image_label_from_xml(img_name, voc12_root): 22 | from xml.dom import minidom 23 | 24 | el_list = minidom.parse(os.path.join(voc12_root, ANNOT_FOLDER_NAME,img_name + '.xml')).getElementsByTagName('name') 25 | 26 | multi_cls_lab = np.zeros((20), np.float32) 27 | 28 | for el in el_list: 29 | cat_name = el.firstChild.data 30 | if cat_name in CAT_LIST: 31 | cat_num = CAT_NAME_TO_NUM[cat_name] 32 | multi_cls_lab[cat_num] = 1.0 33 | 34 | return multi_cls_lab 35 | 36 | def load_image_label_list_from_xml(img_name_list, voc12_root): 37 | 38 | return [load_image_label_from_xml(img_name, voc12_root) for img_name in img_name_list] 39 | 40 | def load_image_label_list_from_npy(img_name_list): 41 | 42 | cls_labels_dict = np.load('voc12/cls_labels.npy').item() 43 | 44 | return [cls_labels_dict[img_name] for img_name in img_name_list] 45 | 46 | def get_img_path(img_name, voc12_root): 47 | return os.path.join(voc12_root, IMG_FOLDER_NAME, img_name + '.jpg') 48 | 49 | def load_img_name_list(dataset_path): 50 | 51 | img_gt_name_list = open(dataset_path).read().splitlines() 52 | img_name_list = [img_gt_name.split(' ')[0][-15:-4] for img_gt_name in img_gt_name_list] 53 | 54 | return img_name_list 55 | 56 | class VOC12ImageDataset(Dataset): 57 | 58 | def __init__(self, img_name_list_path, voc12_root, transform=None): 59 | self.img_name_list = load_img_name_list(img_name_list_path) 60 | self.voc12_root = voc12_root 61 | self.transform = transform 62 | 63 | def __len__(self): 64 | return len(self.img_name_list) 65 | 66 | def __getitem__(self, idx): 67 | name = self.img_name_list[idx] 68 | 69 | img = PIL.Image.open(get_img_path(name, self.voc12_root)).convert("RGB") 70 | 71 | if self.transform: 72 | img = self.transform(img) 73 | 74 | return name, img 75 | 76 | 77 | class VOC12ClsDataset(VOC12ImageDataset): 78 | 79 | def __init__(self, img_name_list_path, voc12_root, transform=None): 80 | super().__init__(img_name_list_path, voc12_root, transform) 81 | self.label_list = load_image_label_list_from_npy(self.img_name_list) 82 | 83 | def __getitem__(self, idx): 84 | name, img = super().__getitem__(idx) 85 | 86 | label = torch.from_numpy(self.label_list[idx]) 87 | 88 | return name, img, label 89 | 90 | 91 | class VOC12ClsDatasetMSF(VOC12ClsDataset): 92 | 93 | def __init__(self, img_name_list_path, voc12_root, scales, inter_transform=None, unit=1): 94 | super().__init__(img_name_list_path, voc12_root, transform=None) 95 | self.scales = scales 96 | self.unit = unit 97 | self.inter_transform = inter_transform 98 | 99 | def __getitem__(self, idx): 100 | name, img, label = super().__getitem__(idx) 101 | 102 | rounded_size = (int(round(img.size[0]/self.unit)*self.unit), int(round(img.size[1]/self.unit)*self.unit)) 103 | 104 | ms_img_list = [] 105 | for s in self.scales: 106 | target_size = (round(rounded_size[0]*s), 107 | round(rounded_size[1]*s)) 108 | s_img = img.resize(target_size, resample=PIL.Image.CUBIC) 109 | ms_img_list.append(s_img) 110 | 111 | if self.inter_transform: 112 | for i in range(len(ms_img_list)): 113 | ms_img_list[i] = self.inter_transform(ms_img_list[i]) 114 | 115 | msf_img_list = [] 116 | for i in range(len(ms_img_list)): 117 | msf_img_list.append(ms_img_list[i]) 118 | msf_img_list.append(np.flip(ms_img_list[i], -1).copy()) 119 | 120 | return name, msf_img_list, label 121 | 122 | 123 | class ExtractAffinityLabelInRadius(): 124 | 125 | def __init__(self, cropsize, radius=5): 126 | self.radius = radius 127 | 128 | self.search_dist = [] 129 | 130 | for x in range(1, radius): 131 | self.search_dist.append((0, x)) 132 | 133 | for y in range(1, radius): 134 | for x in range(-radius+1, radius): 135 | if x*x + y*y < radius*radius: 136 | self.search_dist.append((y, x)) 137 | 138 | self.radius_floor = radius-1 139 | 140 | self.crop_height = cropsize - self.radius_floor 141 | self.crop_width = cropsize - 2 * self.radius_floor 142 | return 143 | 144 | def __call__(self, label): 145 | 146 | labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] 147 | labels_from = np.reshape(labels_from, [-1]) 148 | 149 | labels_to_list = [] 150 | valid_pair_list = [] 151 | 152 | for dy, dx in self.search_dist: 153 | labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] 154 | labels_to = np.reshape(labels_to, [-1]) 155 | 156 | valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) 157 | 158 | labels_to_list.append(labels_to) 159 | valid_pair_list.append(valid_pair) 160 | 161 | bc_labels_from = np.expand_dims(labels_from, 0) 162 | concat_labels_to = np.stack(labels_to_list) 163 | concat_valid_pair = np.stack(valid_pair_list) 164 | 165 | pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) 166 | 167 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) 168 | 169 | fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) 170 | 171 | neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) 172 | 173 | return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), torch.from_numpy(neg_affinity_label) 174 | 175 | 176 | class VOC12AffDataset(VOC12ImageDataset): 177 | 178 | def __init__(self, img_name_list_path, label_la_dir, label_ha_dir, cropsize, voc12_root, radius=5, 179 | joint_transform_list=None, img_transform_list=None, label_transform_list=None): 180 | super().__init__(img_name_list_path, voc12_root, transform=None) 181 | 182 | self.label_la_dir = label_la_dir 183 | self.label_ha_dir = label_ha_dir 184 | self.voc12_root = voc12_root 185 | 186 | self.joint_transform_list = joint_transform_list 187 | self.img_transform_list = img_transform_list 188 | self.label_transform_list = label_transform_list 189 | 190 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 191 | 192 | def __len__(self): 193 | return len(self.img_name_list) 194 | 195 | def __getitem__(self, idx): 196 | name, img = super().__getitem__(idx) 197 | 198 | label_la_path = os.path.join(self.label_la_dir, name + '.npy') 199 | 200 | label_ha_path = os.path.join(self.label_ha_dir, name + '.npy') 201 | 202 | label_la = np.load(label_la_path).item() 203 | label_ha = np.load(label_ha_path).item() 204 | 205 | label = np.array(list(label_la.values()) + list(label_ha.values())) 206 | label = np.transpose(label, (1, 2, 0)) 207 | 208 | for joint_transform, img_transform, label_transform \ 209 | in zip(self.joint_transform_list, self.img_transform_list, self.label_transform_list): 210 | 211 | if joint_transform: 212 | img_label = np.concatenate((img, label), axis=-1) 213 | img_label = joint_transform(img_label) 214 | img = img_label[..., :3] 215 | label = img_label[..., 3:] 216 | 217 | if img_transform: 218 | img = img_transform(img) 219 | if label_transform: 220 | label = label_transform(label) 221 | 222 | no_score_region = np.max(label, -1) < 1e-5 223 | label_la, label_ha = np.array_split(label, 2, axis=-1) 224 | label_la = np.argmax(label_la, axis=-1).astype(np.uint8) 225 | label_ha = np.argmax(label_ha, axis=-1).astype(np.uint8) 226 | label = label_la.copy() 227 | label[label_la == 0] = 255 228 | label[label_ha == 0] = 0 229 | label[no_score_region] = 255 # mostly outer of cropped region 230 | label = self.extract_aff_lab_func(label) 231 | 232 | return img, label 233 | 234 | 235 | class VOC12AffGtDataset(VOC12ImageDataset): 236 | 237 | def __init__(self, img_name_list_path, label_dir, cropsize, voc12_root, radius=5, 238 | joint_transform_list=None, img_transform_list=None, label_transform_list=None): 239 | super().__init__(img_name_list_path, voc12_root, transform=None) 240 | 241 | self.label_dir = label_dir 242 | self.voc12_root = voc12_root 243 | 244 | self.joint_transform_list = joint_transform_list 245 | self.img_transform_list = img_transform_list 246 | self.label_transform_list = label_transform_list 247 | 248 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 249 | 250 | def __len__(self): 251 | return len(self.img_name_list) 252 | 253 | def __getitem__(self, idx): 254 | name, img = super().__getitem__(idx) 255 | 256 | label_path = os.path.join(self.label_dir, name + '.png') 257 | 258 | label = scipy.misc.imread(label_path) 259 | 260 | for joint_transform, img_transform, label_transform \ 261 | in zip(self.joint_transform_list, self.img_transform_list, self.label_transform_list): 262 | 263 | if joint_transform: 264 | img_label = np.concatenate((img, label), axis=-1) 265 | img_label = joint_transform(img_label) 266 | img = img_label[..., :3] 267 | label = img_label[..., 3:] 268 | 269 | if img_transform: 270 | img = img_transform(img) 271 | if label_transform: 272 | label = label_transform(label) 273 | 274 | label = self.extract_aff_lab_func(label) 275 | 276 | return img, label -------------------------------------------------------------------------------- /prepare_labels/voc12/make_cls_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import voc12.data 3 | import numpy as np 4 | 5 | if __name__ == '__main__': 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--train_list", default='train_aug.txt', type=str) 9 | parser.add_argument("--val_list", default='val.txt', type=str) 10 | parser.add_argument("--out", default="cls_labels.npy", type=str) 11 | parser.add_argument("--voc12_root", required=True, type=str) 12 | args = parser.parse_args() 13 | 14 | img_name_list = voc12.data.load_img_name_list(args.train_list) 15 | img_name_list.extend(voc12.data.load_img_name_list(args.val_list)) 16 | label_list = voc12.data.load_image_label_list_from_xml(img_name_list, args.voc12_root) 17 | 18 | d = dict() 19 | for img_name, label in zip(img_name_list, label_list): 20 | d[img_name] = label 21 | 22 | np.save(args.out, d) -------------------------------------------------------------------------------- /pretrained_models/tmp.txt: -------------------------------------------------------------------------------- 1 | please download the models of PSA to here. 2 | -------------------------------------------------------------------------------- /script/gen_html.py: -------------------------------------------------------------------------------- 1 | import os 2 | print('') 3 | print('') 4 | print('') 5 | print('') 6 | print('title') 7 | print('') 8 | print('') 9 | print('') 10 | print('') 36 | btn_txt=['<<<','<<','<','>','>>','>>>'] 37 | for i in range(6): 38 | print('
'.format(i)) 39 | print('') 40 | print('
') 41 | print('') 42 | w=200 43 | h=200 44 | print('') 45 | for i in range(shown): 46 | print('') 47 | print('') 52 | print('') 57 | print('') 62 | print('') 63 | print('') 64 | -------------------------------------------------------------------------------- /script/gen_html_dssdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | modelid='default' 4 | 5 | print('') 6 | print('') 7 | print('') 8 | print('') 9 | print('title') 10 | print('') 11 | print('') 12 | print('') 13 | print('') 75 | print('epoch') 76 | btn_txt=['<','>'] 77 | for i in range(2): 78 | print('
'.format(i)) 79 | print('') 80 | print('
') 81 | print('

') 82 | print('count') 83 | btn_txt=['<','>'] 84 | for i in range(2): 85 | print('
'.format(i)) 86 | print('') 87 | print('
') 88 | print('

') 89 | print('

epoch 0

') 90 | print('

count 0

') 91 | print('
imagePSAPSA with CRF
') 48 | print('
'.format(i)) 49 | print(''.format(image_ids[i], w,h)) 50 | print('
') 51 | print('
') 53 | print('
'.format(i)) 54 | print(''.format(image_ids[i], w,h)) 55 | print('
') 56 | print('
') 58 | print('
'.format(i)) 59 | print(''.format(image_ids[i], w,h)) 60 | print('
') 61 | print('
') 92 | w=200 93 | h=200 94 | e=0 95 | print('') 96 | for i in range(shown): 97 | print('') 98 | print('') 103 | print('') 108 | print('') 113 | print('') 114 | print('') 115 | print('') 116 | print('') 117 | print('') 118 | print('') 119 | print('') 124 | print('') 129 | print('') 134 | print('') 139 | print('') 141 | print('') 146 | print('') 147 | print('') 148 | print('') 149 | print('') 150 | print('') 155 | print('') 160 | print('') 165 | print('') 170 | print('') 172 | print('') 177 | print('') 178 | 179 | print('') 180 | -------------------------------------------------------------------------------- /script/gen_html_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | print('') 3 | print('') 4 | print('') 5 | print('') 6 | print('title') 7 | print('') 8 | print('') 9 | print('') 10 | print('') 39 | btn_txt=['<<<','<<','<','>','>>','>>>'] 40 | for i in range(6): 41 | print('
'.format(i)) 42 | print('') 43 | print('
') 44 | print('
imageseed mask(sssdd)integrated mask(second step)
') 99 | print('
'.format(i)) 100 | print(''.format(modelid, modelid, e,i, w,h)) 101 | print('
') 102 | print('
') 104 | print('
'.format(i)) 105 | print(''.format(modelid, modelid, e,i, w,h)) 106 | print('
') 107 | print('
') 109 | print('
'.format(i)) 110 | print(''.format(modelid, modelid, e,i, w,h)) 111 | print('
') 112 | print('
first step
K1d_kd_aA1integrated_mask1
') 120 | print('
'.format(i)) 121 | print(''.format(modelid, modelid, e,i, w,h)) 122 | print('
') 123 | print('
') 125 | print('
'.format(i)) 126 | print(''.format(modelid, modelid, e,i, w,h)) 127 | print('
') 128 | print('
') 130 | print('
'.format(i)) 131 | print(''.format(modelid, modelid, e,i, w,h)) 132 | print('
') 133 | print('
') 135 | print('
'.format(i)) 136 | print(''.format(modelid, modelid, e,i, w,h)) 137 | print('
') 138 | print('
  ') 140 | print('') 142 | print('
'.format(i)) 143 | print(''.format(modelid, modelid, e,i, w,h)) 144 | print('
') 145 | print('
second step
K2d_kd_aA2integrated_mask2
') 151 | print('
'.format(i)) 152 | print(''.format(modelid, modelid, e,i, w,h)) 153 | print('
') 154 | print('
') 156 | print('
'.format(i)) 157 | print(''.format(modelid, modelid, e,i, w,h)) 158 | print('
') 159 | print('
') 161 | print('
'.format(i)) 162 | print(''.format(modelid, modelid, e,i, w,h)) 163 | print('
') 164 | print('
') 166 | print('
'.format(i)) 167 | print(''.format(modelid, modelid, e,i, w,h)) 168 | print('
') 169 | print('
') 171 | print('') 173 | print('
'.format(i)) 174 | print(''.format(modelid, modelid, e,i, w,h)) 175 | print('
') 176 | print('
') 45 | w=200 46 | h=200 47 | print('') 48 | for i in range(shown): 49 | print('') 50 | print('') 55 | print('') 60 | print('') 65 | print('') 66 | print('') 67 | -------------------------------------------------------------------------------- /script/val.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | title 6 | 7 | 8 | 9 | 1645 |
1646 | 1647 |
1648 |
1649 | 1650 |
1651 |
1652 | 1653 |
1654 |
1655 | 1656 |
1657 |
1658 | 1659 |
1660 |
1661 | 1662 |
1663 |
ImageInferenceGround truth
') 51 | print('
'.format(i)) 52 | print(''.format(image_ids[i], w,h)) 53 | print('
') 54 | print('
') 56 | print('
'.format(i)) 57 | print('') 59 | print('
') 61 | print('
'.format(i)) 62 | print(''.format(image_ids[i], w,h)) 63 | print('
') 64 | print('
1664 | 1665 | 1666 | 1671 | 1676 | 1681 | 1682 | 1683 | 1688 | 1693 | 1698 | 1699 | 1700 | 1705 | 1710 | 1715 | 1716 | 1717 | 1722 | 1727 | 1732 | 1733 | 1734 | 1739 | 1744 | 1749 | 1750 | 1751 | -------------------------------------------------------------------------------- /ssdd_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | sigmoid = torch.nn.Sigmoid() 6 | def compute_sig_mask_loss(logits, bin_mask): 7 | bin_mask=bin_mask.float() 8 | logits=sigmoid(logits).squeeze(1) 9 | loc0=bin_mask==0 10 | loc1=bin_mask==1 11 | logits0=logits[loc0] 12 | logits1=logits[loc1] 13 | bin_mask0=bin_mask[loc0] 14 | bin_mask1=bin_mask[loc1] 15 | loss0=F.binary_cross_entropy(logits0, bin_mask0) 16 | loss1=F.binary_cross_entropy(logits1, bin_mask1) 17 | return (loss0 + loss1)/2 18 | 19 | def add_class_weights(pixel_weights, mask0, mask1, ignore_flags, bg_bias=0.00): 20 | for i in range(len(mask0)): 21 | pixel_weight = pixel_weights[i] 22 | pixel_weight -= (mask0[i]==(0)).float()*(bg_bias) 23 | pixel_weight += (mask1[i]==(0)).float()*(bg_bias) 24 | for j in range(1,ignore_flags.shape[1]): 25 | pixel_weight -= (mask0[i]==(j)).float()*(ignore_flags[i,j]*1.0) 26 | pixel_weight += (mask1[i]==(j)).float()*(ignore_flags[i,j]*1.0) 27 | return pixel_weights 28 | def get_dd_mask(dd0, dd1, mask0, mask1, ignore_flags, dd_bias=0.15, bg_bias=0.05): 29 | dd0_prob = sigmoid(dd0) 30 | dd1_prob = sigmoid(dd1) 31 | w = dd0_prob-dd1_prob+dd_bias 32 | w = add_class_weights(w, mask0, mask1, ignore_flags, bg_bias=bg_bias) 33 | refine_mask=Variable(torch.zeros_like(mask0))+255 34 | bsc=((w.squeeze(1)>=0)) 35 | bcs=bsc==0 36 | refine_mask[bsc]=mask1[bsc] 37 | refine_mask[bcs]=mask0[bcs] 38 | return (dd0, dd1, ignore_flags, refine_mask) 39 | def get_dd(dd, dd_head, mask): 40 | binmask = get_binarymask(mask) 41 | dd_pred = dd((dd_head, binmask.detach())) 42 | return dd_pred 43 | 44 | def get_ignore_flags(mask0, mask1, mlabel, th=0.5): 45 | ignore_flags=np.zeros((len(mask0),21,)) 46 | for i in range(len(mlabel)): 47 | for j in range(len(mlabel[0])): 48 | if mlabel[i][j]==1: 49 | loc0=torch.sum(mask0[i]==(j+1)).item() 50 | loc1=torch.sum(mask1[i]==(j+1)).item() 51 | rate=loc1/max(loc0,1) 52 | if rate < th: 53 | ignore_flags[i,j+1]=1 54 | return ignore_flags 55 | 56 | def get_binarymask(masks, chn=21): 57 | # input [NxHxW] 58 | N,H,W=masks.shape 59 | bin_masks=torch.zeros(N,chn,H,W).cuda() 60 | for n in range(N): 61 | mask = masks[n] 62 | for c in range(chn): 63 | bin_mask = bin_masks[n,c] 64 | loc = mask==c 65 | locn=torch.sum(loc) 66 | if locn.sum()>0: 67 | bin_mask[loc]=1 68 | return bin_masks 69 | 70 | def get_ddloss(dd, diff_mask, ignore_flags): 71 | loss_dd = Variable(torch.FloatTensor([0]),requires_grad=True).cuda() 72 | cnt=0 73 | for k in range(len(dd)): 74 | if torch.sum(ignore_flags[k,1:]).item()>0: 75 | continue 76 | cnt+=1 77 | loss_dd += compute_sig_mask_loss(dd[k:k+1], diff_mask[k:k+1]) 78 | if cnt >0: 79 | loss_dd /= cnt 80 | return loss_dd 81 | -------------------------------------------------------------------------------- /ssdd_val.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import random 5 | import re 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torch.utils.data 13 | from torch.autograd import Variable 14 | import imutils 15 | import utils 16 | from torchvision import transforms 17 | from torch.utils.data.dataloader import default_collate 18 | import time 19 | from PIL import Image 20 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset 21 | from network import SegmentationPsa, PredictDiff, PredictDiffHead 22 | import math 23 | 24 | ############################################################ 25 | # dataset 26 | ############################################################ 27 | 28 | class SSDDValData(PascalDataset): 29 | def __init__(self, dataset, config): 30 | super().__init__(dataset, config) 31 | self.joint_transform_list=[ 32 | imutils.Rescale(self.config.INP_SHAPE), 33 | None, 34 | None, 35 | ] 36 | self.img_transform_list=[ 37 | np.asarray, 38 | imutils.Normalize(mean = self.mean, std = self.std), 39 | imutils.HWC_to_CHW 40 | ] 41 | def __getitem__(self, image_index): 42 | image_id = self.image_ids[image_index] 43 | # Load image and mask 44 | impath= self.config.VOC_ROOT+'/JPEGImages/' 45 | imn=impath+image_id+'.jpg' 46 | img = Image.open(imn).convert("RGB") 47 | img = self.img_label_resize([img])[0] 48 | images = torch.from_numpy(img) 49 | return images, image_index 50 | 51 | def __len__(self): 52 | return self.image_ids.shape[0] 53 | 54 | 55 | ############################################################ 56 | # Model Class 57 | ############################################################ 58 | 59 | class SegModel(SegBaseModel): 60 | def __init__(self, config): 61 | super(SegModel, self).__init__(config) 62 | in_channel=4096 63 | self.seg_main = SegmentationPsa(config, in_channel=in_channel, middle_channel=512, num_classes=21) 64 | 65 | def forward(self, inputs): 66 | x = inputs 67 | [x1, x2, x3, x4, x5] = self.encoder(x) 68 | seg, seg_head = self.seg_main(x5) 69 | return seg 70 | 71 | class Evaluator(): 72 | def __init__(self, config, model): 73 | super(Evaluator, self).__init__() 74 | self.config = config 75 | self.model=model 76 | 77 | def eval_model(self, val_dataset): 78 | self.val_set = SSDDValData(val_dataset, self.config) 79 | val_generator = torch.utils.data.DataLoader(self.val_set, batch_size=self.config.BATCH, shuffle=False, num_workers=torch.cuda.device_count()*2, pin_memory=True) 80 | self.model.eval() 81 | self.eval(val_generator) 82 | 83 | def get_segmentation(self, img): 84 | segs = self.get_ms_segout(img) 85 | fimg = img[:,:,:,torch.arange(img.shape[3]-1,-1,-1)] 86 | fsegs = self.get_ms_segout(fimg) 87 | seg_all = torch.zeros(1,segs[0].shape[1],segs[0].shape[2],segs[0].shape[3]) 88 | for i in range(len(segs)): 89 | seg_all += segs[i] 90 | for i in range(len(segs)): 91 | seg_all += fsegs[i][:,:,:,torch.arange(fsegs[i].shape[3]-1,-1,-1)] 92 | return seg_all 93 | 94 | def get_ms_segout(self, img): 95 | scales = [1/2, 3/4, 1, 5/4, 3/2] 96 | segs = [] 97 | for i in range(len(scales)): 98 | scale=scales[i] 99 | simg = F.interpolate(img, (int(img.shape[2]*scale),int(img.shape[3]*scale)), mode='bilinear') 100 | seg = self.model(simg) 101 | seg = F.softmax(seg,dim=1) 102 | seg = F.interpolate(seg, (int(img.shape[2]),int(img.shape[3])), mode='bilinear') 103 | seg = seg.data.cpu() 104 | segs.append(seg) 105 | torch.cuda.empty_cache() 106 | return segs 107 | 108 | def eval(self, datagenerator): 109 | end = time.time() 110 | cnt=0 111 | for inputs in datagenerator: 112 | print(cnt) 113 | data_time = time.time() 114 | start=time.time() 115 | images, imgindex = inputs 116 | images = Variable(images).cuda() 117 | segs=[] 118 | with torch.no_grad(): 119 | for i in range(len(images)): 120 | # segmentation 121 | seg=self.get_segmentation(images[i:i+1]) 122 | # crf 123 | image_id = self.val_set.image_ids[imgindex[i]] 124 | impath=self.config.VOC_ROOT+'/JPEGImages/' 125 | imn=impath+image_id+'.jpg' 126 | img_org = np.asarray(Image.open(imn)) 127 | seg=F.interpolate(seg,(img_org.shape[0],img_org.shape[1]),mode='bilinear') 128 | prob=F.softmax(seg,dim=1)[0].data.cpu().numpy() 129 | seg_mask = np.argmax(prob,0) 130 | seg_crf_map = imutils.crf_inference(img_org, prob, labels=prob.shape[0], t=10) 131 | seg_crf_mask = np.argmax(seg_crf_map,axis=0) 132 | # save results 133 | cnt+=1 134 | saven = os.path.join(self.savedir, 'seg_val_'+self.saveid+'_'+str(cnt)+'.png') 135 | utils.mask2png(saven, seg_mask) 136 | saven = os.path.join(self.savedir, 'seg_val_'+self.saveid+'_'+str(cnt)+'.txt') 137 | np.savetxt(saven, seg_mask) 138 | saven = os.path.join(self.savedir, 'seg_val_crf_'+self.saveid+'_'+str(cnt)+'.png') 139 | utils.mask2png(saven, seg_crf_mask) 140 | saven = os.path.join(self.savedir, 'seg_val_crf_'+self.saveid+'_'+str(cnt)+'.txt') 141 | np.savetxt(saven, seg_crf_mask) 142 | 143 | 144 | 145 | def set_log_dir(self, phase, saveid, model_path=None): 146 | self.phase = phase 147 | self.saveid = saveid 148 | self.savedir = 'validation/'+self.saveid 149 | print("save the results to "+self.savedir) 150 | if not os.path.exists(self.savedir): 151 | os.makedirs(self.savedir) 152 | 153 | def val(config, weight_file=None): 154 | model = SegModel(config=config) 155 | return model 156 | -------------------------------------------------------------------------------- /train_dssdd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import re 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from torch.autograd import Variable 13 | from torchvision import transforms 14 | import imutils 15 | import utils 16 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset 17 | import ssdd_function as ssddF 18 | import time 19 | from PIL import Image 20 | from network import SegmentationPsa, PredictDiff, PredictDiffHead 21 | import math 22 | import cv2 23 | cv2.setNumThreads(0) 24 | 25 | ############################################################ 26 | # dataset 27 | ############################################################ 28 | 29 | class DssddData(PascalDataset): 30 | def __init__(self, dataset, config): 31 | super().__init__(dataset, config) 32 | self.label_dic = dataset.label_dic 33 | self.joint_transform_list=[ 34 | None, 35 | imutils.RandomHorizontalFlip(), 36 | imutils.RandomResizeLong(512, 768), 37 | imutils.RandomCrop(448), 38 | None, 39 | ] 40 | def __getitem__(self, image_index): 41 | image_id = self.image_ids[image_index] 42 | impath = self.config.VOC_ROOT+'/JPEGImages/' 43 | imn = impath+image_id+'.jpg' 44 | img = Image.open(imn).convert("RGB") 45 | gt_class_mlabel = torch.from_numpy(self.label_dic[image_id]) 46 | gt_class_mlabel_bg = torch.from_numpy(np.concatenate(([1],self.label_dic[image_id]))) 47 | psan = 'prepare_labels/results/out_aff/'+image_id+'.npy' 48 | psa=np.array(list(np.load(psan).item().values())).transpose(1,2,0) 49 | psan = 'prepare_labels/results/out_aff_crf/'+image_id+'.npy' 50 | psa_crf=np.load(psan).transpose(1,2,0) 51 | h=psa.shape[0] 52 | w=psa.shape[1] 53 | saven = 'precompute/'+self.config.modelid+'/da_precompute_'+self.config.modelid+'_'+str(image_index)+'.npy' 54 | dd0=np.load(saven).transpose(1,2,0) 55 | dd0=np.reshape(cv2.resize(dd0,(w,h)),(h,w,1)) 56 | saven = 'precompute/'+self.config.modelid+'/dk_precompute_'+self.config.modelid+'_'+str(image_index)+'.npy' 57 | dd1=np.load(saven).transpose(1,2,0) 58 | dd1=np.reshape(cv2.resize(dd1,(w,h)),(h,w,1)) 59 | # resize inputs 60 | img_norm, img_org, psa, psa_crf, dp0, dp1 = self.img_label_resize([img, np.array(img), psa, psa_crf, dd0, dd1]) 61 | img_org = cv2.resize(img_org,self.config.OUT_SHAPE) 62 | dd0 = cv2.resize(dd0,self.config.OUT_SHAPE) 63 | dd1 = cv2.resize(dd1,self.config.OUT_SHAPE) 64 | psa = cv2.resize(psa,self.config.OUT_SHAPE) 65 | psa_crf = cv2.resize(psa_crf,self.config.OUT_SHAPE) 66 | psa=self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2,0,1) 67 | psa_crf=self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2,0,1) 68 | psa_mask = np.argmax(psa,0) 69 | psa_crf_mask = np.argmax(psa_crf,0) 70 | dd0 = torch.from_numpy(dd0).unsqueeze(0) 71 | dd1 = torch.from_numpy(dd1).unsqueeze(0) 72 | psa_mask = torch.from_numpy(psa_mask).unsqueeze(0) 73 | psa_crf_mask = torch.from_numpy(psa_crf_mask).unsqueeze(0) 74 | ignore_flags=torch.from_numpy(ssddF.get_ignore_flags(psa_mask, psa_crf_mask, [gt_class_mlabel])).float() 75 | # integration using sssdd module 76 | # the parameters are different from dssdd module 77 | (_, _, _, seed_mask) = ssddF.get_dd_mask(dd0, dd1, psa_mask, psa_crf_mask, ignore_flags, dd_bias=0.1, bg_bias=0.1) 78 | return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, seed_mask[0] 79 | def __len__(self): 80 | return self.image_ids.shape[0] 81 | 82 | ############################################################ 83 | # Models 84 | ############################################################ 85 | class SegModel(SegBaseModel): 86 | def __init__(self, config): 87 | super(SegModel, self).__init__(config) 88 | self.config = config 89 | in_channel=4096 90 | self.seg_main = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2) 91 | self.seg_sub = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2) 92 | def set_bn_fix(m): 93 | classname = m.__class__.__name__ 94 | if classname.find('BatchNorm') != -1: 95 | for p in m.parameters(): p.requires_grad = False 96 | self.apply(set_bn_fix) 97 | def forward(self, inputs): 98 | x, img_org, gt_class_mlabel = inputs 99 | feats = self.encoder(x) 100 | [x1,x2,x3,x4,x5] = feats 101 | seg_outs_main = self.get_seg(self.seg_main, x5, gt_class_mlabel) 102 | seg_outs_sub = self.get_seg(self.seg_sub, x5, gt_class_mlabel) 103 | seg_crf, seg_crf_mask = self.get_crf(img_org, seg_outs_main[0], gt_class_mlabel) 104 | return seg_outs_main, seg_outs_sub, seg_crf_mask, feats 105 | 106 | class SSDDModel(SSDDBaseModel): 107 | def __init__(self, config): 108 | super(SSDDModel, self).__init__(config) 109 | self.dd_head0 = PredictDiffHead(config, in_channel=512, in_channel2=128) 110 | self.dd_head1 = PredictDiffHead(config, in_channel=512, in_channel2=128) 111 | self.dd0 = PredictDiff(config, in_channel=256, in_channel2=128) 112 | self.dd1 = PredictDiff(config, in_channel=256, in_channel2=128) 113 | def forward(self, inputs): 114 | (seg_outs_main, seg_outs_sub, seg_crf_mask, feats), seed_mask, gt_class_mlabel = inputs 115 | [x1,x2,x3,x4,x5] = feats 116 | x1=F.avg_pool2d(x1, 2, 2) 117 | # first step 118 | seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main 119 | ignore_flags0=torch.from_numpy(ssddF.get_ignore_flags(seg_mask_main, seg_crf_mask, gt_class_mlabel)).cuda().float() 120 | dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach())) 121 | dd00 = ssddF.get_dd(self.dd0, dd_head0, seg_mask_main) 122 | dd01 = ssddF.get_dd(self.dd0, dd_head0, seg_crf_mask) 123 | dd_outs0 = ssddF.get_dd_mask(dd00, dd01, seg_mask_main, seg_crf_mask, ignore_flags0, dd_bias=0.4, bg_bias=0) 124 | (dd01, dd10, ignore_flags0, refine_mask0)=dd_outs0 125 | # second step 126 | seg_sub, seg_prob_sub, seg_mask_sub, seg_head_sub = seg_outs_sub 127 | dd_head1 = self.dd_head1((seg_head_sub.detach(), x1.detach())) 128 | dd10 = ssddF.get_dd(self.dd1, dd_head1, seed_mask) 129 | dd11 = ssddF.get_dd(self.dd1, dd_head1, refine_mask0) 130 | ignore_flags1 = torch.from_numpy(ssddF.get_ignore_flags(seed_mask, refine_mask0, gt_class_mlabel)).cuda().float() 131 | dd_outs1 = ssddF.get_dd_mask(dd10, dd11, seed_mask, refine_mask0, ignore_flags1, dd_bias=0.4, bg_bias=0) 132 | return dd_outs0, dd_outs1 133 | 134 | ############################################################ 135 | # Trainer 136 | ############################################################ 137 | class Trainer(): 138 | def __init__(self, config, model_dir, model): 139 | super(Trainer, self).__init__() 140 | self.config = config 141 | self.model_dir = model_dir 142 | self.epoch = 0 143 | self.layer_regex = { 144 | "lr1": r"(encoder.*)", 145 | "lr10": r"(seg_main.*)|(seg_sub.*)", 146 | "dd": r"(dd0.*)|(dd1.*)|(dd_head0.*)|(dd_head1.*)", 147 | } 148 | lr_1x = self.layer_regex["lr1"] 149 | lr_10x = self.layer_regex["lr10"] 150 | dd = self.layer_regex['dd'] 151 | seg_model=model[0].cuda() 152 | ssdd_model=model[1].cuda() 153 | self.param_lr_1x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_1x, name)) and not 'bn' in name] 154 | self.param_lr_10x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_10x, name)) and not 'bn' in name] 155 | self.param_dd = [param for name, param in ssdd_model.named_parameters() if bool(re.fullmatch(dd, name)) and not 'bn' in name] 156 | lr=1e-3 157 | self.seg_model=nn.DataParallel(seg_model) 158 | self.ssdd_model=nn.DataParallel(ssdd_model) 159 | def train_model(self, train_dataset): 160 | epochs=self.config.EPOCHS 161 | # Data generators 162 | self.train_set = DssddData(train_dataset, self.config) 163 | train_generator = torch.utils.data.DataLoader(self.train_set, batch_size=self.config.BATCH, shuffle=True, num_workers=8, pin_memory=True) 164 | self.config.LR_RAMPDOWN_EPOCHS=int(epochs*1.2) 165 | self.seg_model.train() 166 | self.ssdd_model.train() 167 | for epoch in range(0, epochs): 168 | print("Epoch {}/{}.".format(epoch,epochs)) 169 | # Training 170 | self.train_epoch(train_generator, epoch) 171 | # Save model 172 | if (epoch % 2 ==0) & (epoch>0): 173 | torch.save(self.seg_model.state_dict(), self.checkpoint_path_seg.format(epoch)) 174 | torch.save(self.ssdd_model.state_dict(), self.checkpoint_path_ssdd.format(epoch)) 175 | torch.cuda.empty_cache() 176 | def train_epoch(self, datagenerator, epoch): 177 | learning_rate=self.config.LEARNING_RATE 178 | self.cnt=0 179 | self.steps = len(datagenerator) 180 | self.step=0 181 | self.epoch=epoch 182 | end=time.time() 183 | for inputs in datagenerator: 184 | self.train_step(inputs, end) 185 | end=time.time() 186 | self.step += 1 187 | def train_step(self, inputs, end): 188 | start = time.time() 189 | # adjust learning rate 190 | lr=utils.adjust_learning_rate(self.config.LEARNING_RATE, self.epoch, self.config.LR_RAMPDOWN_EPOCHS, self.step, self.steps) 191 | self.optimizer = torch.optim.SGD([ 192 | {'params': self.param_lr_1x,'lr': lr*1, 'weight_decay': self.config.WEIGHT_DECAY}, 193 | {'params': self.param_lr_10x,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY}, 194 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY) 195 | self.optimizer_dd = torch.optim.SGD([ 196 | {'params': self.param_dd,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY}, 197 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY) 198 | # input items 199 | img_norm, img_org, gt_class_mlabels, gt_class_mlabels_bg, seed_mask = inputs 200 | img_norm = Variable(img_norm).cuda().float() 201 | img_org = Variable(img_org).cuda().float() 202 | seed_mask = Variable(seed_mask).cuda().long() 203 | gt_class_mlabels = Variable(gt_class_mlabels).cuda().float() 204 | gt_class_mlabels_bg = Variable(gt_class_mlabels_bg).cuda().float() 205 | # forward 206 | seg_outs = self.seg_model((img_norm, img_org, gt_class_mlabels_bg)) 207 | dd_outs = self.ssdd_model((seg_outs, seed_mask, gt_class_mlabels)) 208 | # get loss 209 | loss_seg, loss_dd = self.compute_loss(seg_outs, dd_outs, inputs) 210 | forward_time=time.time() 211 | # backward 212 | self.optimizer.zero_grad() 213 | loss_seg.backward() 214 | self.optimizer.step() 215 | forward_time=time.time() 216 | self.optimizer_dd.zero_grad() 217 | loss_dd.backward() 218 | self.optimizer_dd.step() 219 | forward_time=time.time() 220 | if (self.step%10==0): 221 | prefix="{}/{}/{}/{}".format(self.epoch, self.cnt, self.step + 1, self.steps) 222 | suffix="forward_time: {:.3f} data {:.3f} loss: {:.3f}".format( 223 | forward_time-start, (start-end),loss_seg.item()) 224 | print('%s %s' % (prefix, suffix), end = '\n') 225 | 226 | def compute_loss(self, seg_outs, dd_outs, inputs): 227 | seg_outs_main, seg_outs_sub, seg_crf_mask, feats = seg_outs 228 | seg_main, seg_prob_main, seg_mask_main, _ = seg_outs_main 229 | seg_sub, seg_prob_sub, seg_mask_sub, _ = seg_outs_sub 230 | dd_outs0, dd_outs1 = dd_outs 231 | images, img_org, gt_class_mlabels, gt_class_mlabels_bg, seed_mask = inputs 232 | seed_mask = Variable(seed_mask).cuda().long() 233 | (dd00, dd01, ignore_flags0, refine_mask0) = dd_outs0 234 | (dd10, dd11, ignore_flags1, refine_mask1) = dd_outs1 235 | # compute losses 236 | # segmentation loss 237 | loss_seg_main = F.cross_entropy(seg_main, refine_mask1, ignore_index=255) 238 | loss_seg_sub = 0.5*F.cross_entropy(seg_sub, seed_mask, ignore_index=255) + 0.5*F.cross_entropy(seg_sub, refine_mask1, ignore_index=255) 239 | loss_seg = loss_seg_main + loss_seg_sub 240 | # difference detection loss 241 | seg_crf_diff = seg_mask_main != seg_crf_mask 242 | loss_dd00 = ssddF.get_ddloss(dd00, seg_crf_diff, ignore_flags0) 243 | loss_dd01 = ssddF.get_ddloss(dd01, seg_crf_diff, ignore_flags0) 244 | loss_dd10 = ssddF.compute_sig_mask_loss(dd10, seed_mask != seg_mask_sub) 245 | loss_dd11 = ssddF.compute_sig_mask_loss(dd11, refine_mask1 != seg_mask_sub) 246 | loss_dd = (loss_dd00 + loss_dd01 + loss_dd10 + loss_dd11)/4 247 | # save temporary outputs 248 | if (self.step%30==0): 249 | sid='_'+self.phase+'_'+self.saveid+'_'+str(self.epoch)+'_'+str(self.cnt) 250 | img_org=img_org.data.cpu().numpy()[...,::-1] 251 | saven = self.log_dir_img + '/i'+sid+'.jpg' 252 | cv2.imwrite(saven,img_org[0]) 253 | saven = self.log_dir_img + '/D1'+sid+'.png' 254 | mask_png = utils.mask2png(saven, refine_mask0[0].squeeze().data.cpu().numpy()) 255 | saven = self.log_dir_img + '/K1'+sid+'.png' 256 | mask_png = utils.mask2png(saven, seg_mask_main[0].data.cpu().numpy().astype(np.float32)) 257 | saven = self.log_dir_img + '/A1'+sid+'.png' 258 | mask_png = utils.mask2png(saven, seg_crf_mask[0].squeeze().data.cpu().numpy()) 259 | 260 | saven = self.log_dir_img + '/dk1'+sid+'.png' 261 | tmp=F.sigmoid(dd00)[0].squeeze().data.cpu().numpy() 262 | cv2.imwrite(saven,tmp*255) 263 | saven = self.log_dir_img + '/da1'+sid+'.png' 264 | tmp=F.sigmoid(dd01)[0].squeeze().data.cpu().numpy() 265 | cv2.imwrite(saven,tmp*255) 266 | 267 | saven = self.log_dir_img + '/D2'+sid+'.png' 268 | mask_png = utils.mask2png(saven, refine_mask1[0].squeeze().data.cpu().numpy()) 269 | saven = self.log_dir_img + '/K2'+sid+'.png' 270 | mask_png = utils.mask2png(saven, seed_mask[0].data.cpu().numpy().astype(np.float32)) 271 | #saven = self.log_dir_img + '/A2'+sid+'.png' 272 | #mask_png = utils.mask2png(saven, refine_mask[0].squeeze().data.cpu().numpy()) 273 | 274 | saven = self.log_dir_img + '/dk2'+sid+'.png' 275 | tmp=F.sigmoid(dd10)[0].squeeze().data.cpu().numpy() 276 | cv2.imwrite(saven,tmp*255) 277 | saven = self.log_dir_img + '/da2'+sid+'.png' 278 | tmp=F.sigmoid(dd11)[0].squeeze().data.cpu().numpy() 279 | cv2.imwrite(saven,tmp*255) 280 | self.cnt += 1 281 | return loss_seg, loss_dd 282 | 283 | def set_log_dir(self, phase, saveid, model_path=None): 284 | self.epoch = 0 285 | self.phase = phase 286 | self.saveid = saveid 287 | self.log_dir = os.path.join(self.model_dir, "{}_{}".format(phase, saveid)) 288 | self.log_dir_model = self.log_dir +'/'+ 'models' 289 | if not os.path.exists(self.log_dir_model): 290 | os.makedirs(self.log_dir_model) 291 | self.log_dir_img = self.log_dir +'/'+ 'imgs' 292 | if not os.path.exists(self.log_dir_img): 293 | os.makedirs(self.log_dir_img) 294 | self.checkpoint_path_seg = os.path.join(self.log_dir_model, "seg_*epoch*.pth".format()) 295 | self.checkpoint_path_seg = self.checkpoint_path_seg.replace("*epoch*", "{:04d}") 296 | self.checkpoint_path_ssdd = os.path.join(self.log_dir_model, "ssdd_*epoch*.pth".format()) 297 | self.checkpoint_path_ssdd = self.checkpoint_path_ssdd.replace("*epoch*", "{:04d}") 298 | 299 | def models(config, weight_file=None): 300 | seg_model = SegModel(config=config) 301 | seg_model.initialize_weights() 302 | seg_model.load_resnet38_weights(weight_file) 303 | ssdd_model = SSDDModel(config=config) 304 | ssdd_model.initialize_weights() 305 | return (seg_model, ssdd_model) 306 | -------------------------------------------------------------------------------- /train_sssdd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import re 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from torch.autograd import Variable 13 | from torchvision import transforms 14 | import imutils 15 | import utils 16 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset 17 | import ssdd_function as ssddF 18 | import time 19 | from PIL import Image 20 | from network import SegmentationPsa, PredictDiff, PredictDiffHead 21 | import math 22 | import cv2 23 | cv2.setNumThreads(0) 24 | 25 | ############################################################ 26 | # dataset 27 | ############################################################ 28 | 29 | class SssddData(PascalDataset): 30 | def __init__(self, dataset, config): 31 | super().__init__(dataset, config) 32 | self.label_dic = dataset.label_dic 33 | self.joint_transform_list=[ 34 | None, 35 | imutils.RandomHorizontalFlip(), 36 | imutils.RandomResizeLong(448, 512), 37 | imutils.RandomCrop(448), 38 | None, 39 | ] 40 | def __getitem__(self, image_index): 41 | image_id = self.image_ids[image_index] 42 | impath = self.config.VOC_ROOT+'/JPEGImages/' 43 | imn = impath+image_id+'.jpg' 44 | img = Image.open(imn).convert("RGB") 45 | gt_class_mlabel = torch.from_numpy(self.label_dic[image_id]) 46 | gt_class_mlabel_bg = torch.from_numpy(np.concatenate(([1],self.label_dic[image_id]))) 47 | psan = 'prepare_labels/results/out_aff/'+image_id+'.npy' 48 | psa=np.array(list(np.load(psan).item().values())).transpose(1,2,0) 49 | psan = 'prepare_labels/results/out_aff_crf/'+image_id+'.npy' 50 | psa_crf=np.load(psan).transpose(1,2,0) 51 | h=psa.shape[0] 52 | w=psa.shape[1] 53 | # resize inputs 54 | img_norm, img_org, psa, psa_crf = self.img_label_resize([img, np.array(img), psa, psa_crf]) 55 | img_org = cv2.resize(img_org,self.config.OUT_SHAPE) 56 | psa = cv2.resize(psa,self.config.OUT_SHAPE) 57 | psa_crf = cv2.resize(psa_crf,self.config.OUT_SHAPE) 58 | psa=self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2,0,1) 59 | psa_crf=self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2,0,1) 60 | psa_mask = np.argmax(psa,0) 61 | psa_crf_mask = np.argmax(psa_crf,0) 62 | return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask 63 | def __len__(self): 64 | return self.image_ids.shape[0] 65 | 66 | ############################################################ 67 | # Models 68 | ############################################################ 69 | class SegModel(SegBaseModel): 70 | def __init__(self, config): 71 | super(SegModel, self).__init__(config) 72 | self.config = config 73 | in_channel=4096 74 | self.seg_main = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2) 75 | def set_bn_fix(m): 76 | classname = m.__class__.__name__ 77 | if classname.find('BatchNorm') != -1: 78 | for p in m.parameters(): p.requires_grad = False 79 | self.apply(set_bn_fix) 80 | def forward(self, inputs): 81 | x, img_org, gt_class_mlabel = inputs 82 | feats = self.encoder(x) 83 | [x1,x2,x3,x4,x5] = feats 84 | seg_outs_main = self.get_seg(self.seg_main, x5, gt_class_mlabel) 85 | return seg_outs_main, feats 86 | 87 | class SSDDModel(SSDDBaseModel): 88 | def __init__(self, config): 89 | super(SSDDModel, self).__init__(config) 90 | self.dd_head0 = PredictDiffHead(config, in_channel=512, in_channel2=128) 91 | self.dd0 = PredictDiff(config, in_channel=256, in_channel2=128) 92 | def forward(self, inputs): 93 | (seg_outs_main, feats), psa_mask, psa_crf_mask, gt_class_mlabel = inputs 94 | [x1,x2,x3,x4,x5] = feats 95 | x1=F.avg_pool2d(x1, 2, 2) 96 | # first step 97 | seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main 98 | ignore_flags0=torch.from_numpy(ssddF.get_ignore_flags(psa_mask, psa_crf_mask, gt_class_mlabel)).cuda().float() 99 | dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach())) 100 | dd00 = ssddF.get_dd(self.dd0, dd_head0, psa_mask) 101 | dd01 = ssddF.get_dd(self.dd0, dd_head0, psa_crf_mask) 102 | dd_outs0 = ssddF.get_dd_mask(dd00, dd01, psa_mask, psa_crf_mask, ignore_flags0, dd_bias=0.4, bg_bias=0) 103 | return dd_outs0 104 | 105 | ############################################################ 106 | # Trainer 107 | ############################################################ 108 | class Trainer(): 109 | def __init__(self, config, model_dir, model): 110 | super(Trainer, self).__init__() 111 | self.config = config 112 | self.model_dir = model_dir 113 | self.epoch = 0 114 | self.layer_regex = { 115 | "lr1": r"(encoder.*)", 116 | "lr10": r"(seg_main.*)", 117 | "dd": r"(dd0.*)|(dd_head0.*)", 118 | } 119 | lr_1x = self.layer_regex["lr1"] 120 | lr_10x = self.layer_regex["lr10"] 121 | dd = self.layer_regex['dd'] 122 | seg_model=model[0].cuda() 123 | ssdd_model=model[1].cuda() 124 | self.param_lr_1x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_1x, name)) and not 'bn' in name] 125 | self.param_lr_10x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_10x, name)) and not 'bn' in name] 126 | self.param_dd = [param for name, param in ssdd_model.named_parameters() if bool(re.fullmatch(dd, name)) and not 'bn' in name] 127 | lr=1e-3 128 | self.seg_model=nn.DataParallel(seg_model) 129 | self.ssdd_model=nn.DataParallel(ssdd_model) 130 | def train_model(self, train_dataset): 131 | epochs=self.config.EPOCHS 132 | # Data generators 133 | self.train_set = SssddData(train_dataset, self.config) 134 | train_generator = torch.utils.data.DataLoader(self.train_set, batch_size=self.config.BATCH, shuffle=True, num_workers=8, pin_memory=True) 135 | self.config.LR_RAMPDOWN_EPOCHS=int(epochs*1.2) 136 | self.seg_model.train() 137 | self.ssdd_model.train() 138 | for epoch in range(0, epochs): 139 | print("Epoch {}/{}.".format(epoch,epochs)) 140 | # Training 141 | self.train_epoch(train_generator, epoch) 142 | # Save model 143 | if (epoch % 2 ==0) & (epoch>0): 144 | torch.save(self.seg_model.state_dict(), self.checkpoint_path_seg.format(epoch)) 145 | torch.save(self.ssdd_model.state_dict(), self.checkpoint_path_ssdd.format(epoch)) 146 | torch.cuda.empty_cache() 147 | def train_epoch(self, datagenerator, epoch): 148 | learning_rate=self.config.LEARNING_RATE 149 | self.cnt=0 150 | self.steps = len(datagenerator) 151 | self.step=0 152 | self.epoch=epoch 153 | end=time.time() 154 | for inputs in datagenerator: 155 | self.train_step(inputs, end) 156 | end=time.time() 157 | self.step += 1 158 | def train_step(self, inputs, end): 159 | start = time.time() 160 | # adjust learning rate 161 | lr=utils.adjust_learning_rate(self.config.LEARNING_RATE, self.epoch, self.config.LR_RAMPDOWN_EPOCHS, self.step, self.steps) 162 | self.optimizer = torch.optim.SGD([ 163 | {'params': self.param_lr_1x,'lr': lr*1, 'weight_decay': self.config.WEIGHT_DECAY}, 164 | {'params': self.param_lr_10x,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY}, 165 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY) 166 | self.optimizer_dd = torch.optim.SGD([ 167 | {'params': self.param_dd,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY}, 168 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY) 169 | # input items 170 | img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask = inputs 171 | img_norm = Variable(img_norm).cuda().float() 172 | img_org = Variable(img_org).cuda().float() 173 | gt_class_mlabel = Variable(gt_class_mlabel).cuda().float() 174 | gt_class_mlabel_bg = Variable(gt_class_mlabel_bg).cuda().float() 175 | # forward 176 | seg_outs = self.seg_model((img_norm, img_org, gt_class_mlabel_bg)) 177 | dd_outs = self.ssdd_model((seg_outs, psa_mask, psa_crf_mask, gt_class_mlabel)) 178 | # get loss 179 | loss_seg, loss_dd = self.compute_loss(seg_outs, dd_outs, inputs) 180 | forward_time=time.time() 181 | # backward 182 | self.optimizer.zero_grad() 183 | loss_seg.backward() 184 | self.optimizer.step() 185 | forward_time=time.time() 186 | self.optimizer_dd.zero_grad() 187 | loss_dd.backward() 188 | self.optimizer_dd.step() 189 | forward_time=time.time() 190 | if (self.step%10==0): 191 | prefix="{}/{}/{}/{}".format(self.epoch, self.cnt, self.step + 1, self.steps) 192 | suffix="forward_time: {:.3f} time: {:.3f} data {:.3f} seg: {:.3f}".format( 193 | forward_time-start, (time.time()-start),(start-end),loss_seg.item()) 194 | print('\r%s %s' % (prefix, suffix), end = '\n') 195 | 196 | def compute_loss(self, seg_outs, dd_outs, inputs): 197 | seg_outs_main, feats = seg_outs 198 | seg_main, seg_prob_main, seg_mask_main, _ = seg_outs_main 199 | dd_outs0 = dd_outs 200 | img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask = inputs 201 | (dd00, dd01, ignore_flags0, refine_mask0) = dd_outs0 202 | psa_mask = Variable(psa_mask).cuda().long() 203 | psa_crf_mask = Variable(psa_crf_mask).cuda().long() 204 | # compute losses 205 | # segmentation loss 206 | loss_seg_main = F.cross_entropy(seg_main, psa_mask, ignore_index=255) 207 | loss_seg = loss_seg_main 208 | # difference detection loss 209 | psa_diff = psa_mask != psa_crf_mask 210 | loss_dd00 = ssddF.get_ddloss(dd00, psa_diff, ignore_flags0) 211 | loss_dd01 = ssddF.get_ddloss(dd01, psa_diff, ignore_flags0) 212 | loss_dd = (loss_dd00 + loss_dd01)/2 213 | # save temporary outputs 214 | if (self.step%30==0): 215 | sid='_'+self.phase+'_'+self.saveid+'_'+str(self.epoch)+'_'+str(self.cnt) 216 | img_org=img_org.data.cpu().numpy()[...,::-1] 217 | saven = self.log_dir_img + '/i'+sid+'.jpg' 218 | cv2.imwrite(saven,img_org[0]) 219 | saven = self.log_dir_img + '/D'+sid+'.png' 220 | mask_png = utils.mask2png(saven, refine_mask0[0].squeeze().data.cpu().numpy()) 221 | saven = self.log_dir_img + '/K'+sid+'.png' 222 | mask_png = utils.mask2png(saven, psa_mask[0].data.cpu().numpy().astype(np.float32)) 223 | saven = self.log_dir_img + '/A'+sid+'.png' 224 | mask_png = utils.mask2png(saven, psa_crf_mask[0].squeeze().data.cpu().numpy()) 225 | 226 | saven = self.log_dir_img + 'da'+sid+'.png' 227 | tmp=F.sigmoid(dd00)[0].squeeze().data.cpu().numpy() 228 | cv2.imwrite(saven,tmp*255) 229 | saven = self.log_dir_img + 'dk'+sid+'.png' 230 | tmp=F.sigmoid(dd01)[0].squeeze().data.cpu().numpy() 231 | cv2.imwrite(saven,tmp*255) 232 | self.cnt += 1 233 | return loss_seg, loss_dd 234 | 235 | def set_log_dir(self, phase, saveid): 236 | self.epoch = 0 237 | self.phase = phase 238 | self.saveid = saveid 239 | self.log_dir = os.path.join(self.model_dir, "{}_{}".format(phase, saveid)) 240 | self.log_dir_model = self.log_dir +'/'+ 'models' 241 | if not os.path.exists(self.log_dir_model): 242 | os.makedirs(self.log_dir_model) 243 | self.log_dir_img = self.log_dir +'/'+ 'imgs' 244 | if not os.path.exists(self.log_dir_img): 245 | os.makedirs(self.log_dir_img) 246 | self.checkpoint_path_seg = os.path.join(self.log_dir_model, "seg_*epoch*.pth".format()) 247 | self.checkpoint_path_seg = self.checkpoint_path_seg.replace("*epoch*", "{:04d}") 248 | self.checkpoint_path_ssdd = os.path.join(self.log_dir_model, "ssdd_*epoch*.pth".format()) 249 | self.checkpoint_path_ssdd = self.checkpoint_path_ssdd.replace("*epoch*", "{:04d}") 250 | 251 | def models(config, weight_file=None): 252 | seg_model = SegModel(config=config) 253 | seg_model.initialize_weights() 254 | seg_model.load_resnet38_weights(weight_file) 255 | ssdd_model = SSDDModel(config=config) 256 | ssdd_model.initialize_weights() 257 | return (seg_model, ssdd_model) 258 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from PIL import Image 6 | def adjust_learning_rate(lr, epoch, lr_rampdown_epochs, step_in_epoch, total_steps_in_epoch): 7 | epoch = epoch + step_in_epoch / total_steps_in_epoch 8 | def cosine_rampdown(current, rampdown_length): 9 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 10 | assert 0 <= current <= rampdown_length 11 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 12 | lr *= cosine_rampdown(epoch, lr_rampdown_epochs) 13 | return lr 14 | 15 | def get_labeled_tensor(tensor, class_label): 16 | labeled_tensor=[] 17 | for i in range(len(tensor)): 18 | for i in range(class_mlabel.shape[1]): 19 | if gt_class_mlabel[i,j].item()==1: 20 | tmp_prob.append(tensor[i:i+1,j:j+1]) 21 | tmp_prob=torch.cat(tmp_prob) 22 | 23 | def mask2png(saven, mask): 24 | palette = get_palette(256) 25 | mask=Image.fromarray(mask.astype(np.uint8)) 26 | mask.putpalette(palette) 27 | mask.save(saven) 28 | 29 | def get_palette(num_cls): 30 | n = num_cls 31 | palette = [0] * (n * 3) 32 | for j in range(0, n): 33 | lab = j 34 | palette[j * 3 + 0] = 0 35 | palette[j * 3 + 1] = 0 36 | palette[j * 3 + 2] = 0 37 | i = 0 38 | while lab: 39 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 40 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 41 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 42 | i += 1 43 | lab >>= 3 44 | return palette 45 | --------------------------------------------------------------------------------
ImageInferenceGround truth
1667 |
1668 | 1669 |
1670 |
1672 |
1673 | 1675 |
1677 |
1678 | 1679 |
1680 |
1684 |
1685 | 1686 |
1687 |
1689 |
1690 | 1692 |
1694 |
1695 | 1696 |
1697 |
1701 |
1702 | 1703 |
1704 |
1706 |
1707 | 1709 |
1711 |
1712 | 1713 |
1714 |
1718 |
1719 | 1720 |
1721 |
1723 |
1724 | 1726 |
1728 |
1729 | 1730 |
1731 |
1735 |
1736 | 1737 |
1738 |
1740 |
1741 | 1743 |
1745 |
1746 | 1747 |
1748 |