├── .DS_Store ├── README.md ├── __pycache__ └── .DS_Store ├── architect.py ├── auto_deeplab.py ├── dataloaders ├── .DS_Store ├── __init__.py ├── custom_transforms.py ├── datasets │ ├── __init__.py │ ├── cityscapes.py │ ├── coco.py │ ├── combine_dbs.py │ ├── pascal.py │ └── sbd.py └── utils.py ├── genotypes.py ├── model_search.py ├── mypath.py ├── operations.py ├── train_autodeeplab.py └── train_voc.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/AutoDeeplab/94659b26f708f2e694367785b9d2e75f175ab639/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoDeeplab 2 | 3 | This is an implementation of [Auto-DeepLab](https://arxiv.org/abs/1901.02985) using Pytorch. 4 | 5 | ## Environment 6 | 7 | The implementation needs the following dependencies: 8 | 9 | - Python = 3.7 10 | 11 | - Pytorch = 0.4 12 | 13 | - TensorboardX 14 | 15 | Other basic dependencies like matplotlib, tqdm ... are also needed. 16 | 17 | ## Installation 18 | 19 | First, clone the repository 20 | 21 | git clone https://github.com/MenghaoGuo/AutoDeeplab.git 22 | 23 | Then 24 | 25 | cd AutoDeeplab 26 | 27 | ## Train 28 | 29 | The dataloader module is built on this [repo](https://github.com/jfzhang95/pytorch-deeplab-xception) 30 | 31 | If you want to train this model on different datasets, you need to edit --dataset parameter and then: 32 | 33 | bash train_voc.sh 34 | 35 | 36 | ## Reference 37 | [1] : [Auto-DeepLab: Hierarchical Neural Architecture Search for Semantic Image Segmentation](https://arxiv.org/abs/1901.02985) 38 | 39 | 40 | [2] : [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception) 41 | -------------------------------------------------------------------------------- /__pycache__/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/AutoDeeplab/94659b26f708f2e694367785b9d2e75f175ab639/__pycache__/.DS_Store -------------------------------------------------------------------------------- /architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | class Architect () : 7 | def __init__(self, model, args): 8 | self.model = model 9 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), 10 | lr=args.arch_lr, betas=(0.9, 0.999), weight_decay=args.arch_weight_decay) 11 | 12 | def step (self, input_valid, target_valid) : 13 | self.optimizer.zero_grad () 14 | self._backward_step(input_valid, target_valid) 15 | self.optimizer.step() 16 | 17 | def _backward_step (self, input_valid, target_valid) : 18 | loss = self.model._loss (input_valid, target_valid) 19 | loss.backward () 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /auto_deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import model_search 5 | from genotypes import PRIMITIVES 6 | from genotypes import Genotype 7 | import torch.nn.functional as F 8 | from operations import * 9 | 10 | class AutoDeeplab (nn.Module) : 11 | 12 | def __init__(self, num_classes, num_layers, criterion, num_channel = 40, multiplier = 5, step = 5, crop_size=None, cell=model_search.Cell): 13 | super(AutoDeeplab, self).__init__() 14 | self.cells = nn.ModuleList() 15 | self._num_layers = num_layers 16 | self._num_classes = num_classes 17 | self._step = step 18 | self._multiplier = multiplier 19 | self._num_channel = num_channel 20 | self._criterion = criterion 21 | self._crop_size = crop_size 22 | self._arch_param_names = ["alphas_cell", "alphas_network"] 23 | self._initialize_alphas () 24 | self.stem0 = nn.Sequential( 25 | nn.Conv2d(3, 64, 3, stride=2, padding=1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU () 28 | ) 29 | self.stem1 = nn.Sequential( 30 | nn.Conv2d(64, 64, 3, padding=1), 31 | nn.BatchNorm2d(64), 32 | nn.ReLU () 33 | ) 34 | self.stem2 = nn.Sequential( 35 | nn.Conv2d(64, 128, 3, stride=2, padding=1), 36 | nn.BatchNorm2d(128), 37 | nn.ReLU () 38 | ) 39 | 40 | C_prev_prev = 64 41 | C_prev = 128 42 | for i in range (self._num_layers) : 43 | # def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, rate) : rate = 0 , 1, 2 reduce rate 44 | 45 | if i == 0 : 46 | cell1 = cell (self._step, self._multiplier, -1, C_prev, self._num_channel, 1) 47 | cell2 = cell (self._step, self._multiplier, -1, C_prev, self._num_channel * 2, 2) 48 | self.cells += [cell1] 49 | self.cells += [cell2] 50 | elif i == 1 : 51 | cell1_1 = cell (self._step, self._multiplier, C_prev, self._num_channel, self._num_channel, 1) 52 | cell1_2 = cell (self._step, self._multiplier, C_prev, self._num_channel * 2, self._num_channel, 0) 53 | 54 | cell2_1 = cell (self._step, self._multiplier, -1, self._num_channel, self._num_channel * 2, 2) 55 | cell2_2 = cell (self._step, self._multiplier, -1, self._num_channel * 2, self._num_channel * 2, 1) 56 | 57 | cell3 = cell (self._step, self._multiplier, -1, self._num_channel * 2, self._num_channel * 4, 2) 58 | 59 | self.cells += [cell1_1] 60 | self.cells += [cell1_2] 61 | self.cells += [cell2_1] 62 | self.cells += [cell2_2] 63 | self.cells += [cell3] 64 | 65 | elif i == 2 : 66 | cell1_1 = cell (self._step, self._multiplier, self._num_channel, self._num_channel, self._num_channel, 1) 67 | cell1_2 = cell (self._step, self._multiplier, self._num_channel, self._num_channel * 2, self._num_channel, 0) 68 | 69 | cell2_1 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel, self._num_channel * 2, 2) 70 | cell2_2 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 2, self._num_channel * 2, 1) 71 | cell2_3 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 4, self._num_channel * 2, 0) 72 | 73 | 74 | cell3_1 = cell (self._step, self._multiplier, -1, self._num_channel * 2, self._num_channel * 4, 2) 75 | cell3_2 = cell (self._step, self._multiplier, -1, self._num_channel * 4, self._num_channel * 4, 1) 76 | 77 | cell4 = cell (self._step, self._multiplier, -1, self._num_channel * 4, self._num_channel * 8, 2) 78 | 79 | self.cells += [cell1_1] 80 | self.cells += [cell1_2] 81 | self.cells += [cell2_1] 82 | self.cells += [cell2_2] 83 | self.cells += [cell2_3] 84 | self.cells += [cell3_1] 85 | self.cells += [cell3_2] 86 | self.cells += [cell4] 87 | 88 | 89 | 90 | elif i == 3 : 91 | cell1_1 = cell (self._step, self._multiplier, self._num_channel, self._num_channel, self._num_channel, 1) 92 | cell1_2 = cell (self._step, self._multiplier, self._num_channel, self._num_channel * 2, self._num_channel, 0) 93 | 94 | cell2_1 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel, self._num_channel * 2, 2) 95 | cell2_2 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 2, self._num_channel * 2, 1) 96 | cell2_3 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 4, self._num_channel * 2, 0) 97 | 98 | 99 | cell3_1 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 2, self._num_channel * 4, 2) 100 | cell3_2 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 4, self._num_channel * 4, 1) 101 | cell3_3 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 8, self._num_channel * 4, 0) 102 | 103 | 104 | cell4_1 = cell (self._step, self._multiplier, -1, self._num_channel * 4, self._num_channel * 8, 2) 105 | cell4_2 = cell (self._step, self._multiplier, -1, self._num_channel * 8, self._num_channel * 8, 1) 106 | 107 | self.cells += [cell1_1] 108 | self.cells += [cell1_2] 109 | self.cells += [cell2_1] 110 | self.cells += [cell2_2] 111 | self.cells += [cell2_3] 112 | self.cells += [cell3_1] 113 | self.cells += [cell3_2] 114 | self.cells += [cell3_3] 115 | self.cells += [cell4_1] 116 | self.cells += [cell4_2] 117 | 118 | else : 119 | cell1_1 = cell (self._step, self._multiplier, self._num_channel, self._num_channel, self._num_channel, 1) 120 | cell1_2 = cell (self._step, self._multiplier, self._num_channel, self._num_channel * 2, self._num_channel, 0) 121 | 122 | cell2_1 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel, self._num_channel * 2, 2) 123 | cell2_2 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 2, self._num_channel * 2, 1) 124 | cell2_3 = cell (self._step, self._multiplier, self._num_channel * 2, self._num_channel * 4, self._num_channel * 2, 0) 125 | 126 | 127 | cell3_1 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 2, self._num_channel * 4, 2) 128 | cell3_2 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 4, self._num_channel * 4, 1) 129 | cell3_3 = cell (self._step, self._multiplier, self._num_channel * 4, self._num_channel * 8, self._num_channel * 4, 0) 130 | 131 | 132 | cell4_1 = cell (self._step, self._multiplier, self._num_channel * 8, self._num_channel * 4, self._num_channel * 8, 2) 133 | cell4_2 = cell (self._step, self._multiplier, self._num_channel * 8, self._num_channel * 8, self._num_channel * 8, 1) 134 | 135 | self.cells += [cell1_1] 136 | self.cells += [cell1_2] 137 | self.cells += [cell2_1] 138 | self.cells += [cell2_2] 139 | self.cells += [cell2_3] 140 | self.cells += [cell3_1] 141 | self.cells += [cell3_2] 142 | self.cells += [cell3_3] 143 | self.cells += [cell4_1] 144 | self.cells += [cell4_2] 145 | self.aspp_4 = nn.Sequential ( 146 | ASPP (self._num_channel, 24, 24, self._num_classes) 147 | ) 148 | 149 | self.aspp_8 = nn.Sequential ( 150 | ASPP (self._num_channel * 2, 12, 12, self._num_classes) 151 | ) 152 | self.aspp_16 = nn.Sequential ( 153 | ASPP (self._num_channel * 4, 6, 6, self._num_classes) 154 | ) 155 | self.aspp_32 = nn.Sequential ( 156 | ASPP (self._num_channel * 8, 3, 3, self._num_classes) 157 | ) 158 | 159 | 160 | 161 | 162 | 163 | def forward (self, x) : 164 | self.level_2 = [] 165 | self.level_4 = [] 166 | self.level_8 = [] 167 | self.level_16 = [] 168 | self.level_32 = [] 169 | 170 | # self._init_level_arr (x) 171 | temp = self.stem0 (x) 172 | self.level_2.append (self.stem1 (temp)) 173 | self.level_4.append (self.stem2 (self.level_2[-1])) 174 | weight_cells = F.softmax(self.alphas_cell, dim=-1) 175 | weight_network = F.softmax (self.alphas_network, dim = -1) 176 | count = 0 177 | weight_network = F.softmax (self.alphas_network, dim = -1) 178 | weight_cells = F.softmax(self.alphas_cell, dim=-1) 179 | for layer in range (self._num_layers) : 180 | 181 | if layer == 0 : 182 | level4_new = self.cells[count] (None, self.level_4[-1], weight_cells) 183 | count += 1 184 | level8_new = self.cells[count] (None, self.level_4[-1], weight_cells) 185 | count += 1 186 | self.level_4.append (level4_new * self.alphas_network[layer][0][0]) 187 | self.level_8.append (level8_new * self.alphas_network[layer][0][1]) 188 | # print ((self.level_4[-2]).size (), (self.level_4[-1]).size()) 189 | elif layer == 1 : 190 | level4_new_1 = self.cells[count] (self.level_4[-2], self.level_4[-1], weight_cells) 191 | count += 1 192 | level4_new_2 = self.cells[count] (self.level_4[-2], self.level_8[-1], weight_cells) 193 | count += 1 194 | level4_new = self.alphas_network[layer][0][0] * level4_new_1 + self.alphas_network[layer][0][1] * level4_new_2 195 | 196 | level8_new_1 = self.cells[count] (None, self.level_4[-1], weight_cells) 197 | count += 1 198 | level8_new_2 = self.cells[count] (None, self.level_8[-1], weight_cells) 199 | count += 1 200 | level8_new = self.alphas_network[layer][1][0] * level8_new_1 + self.alphas_network[layer][1][1] * level8_new_2 201 | 202 | level16_new = self.cells[count] (None, self.level_8[-1], weight_cells) 203 | level16_new = level16_new * self.alphas_network[layer][1][2] 204 | count += 1 205 | 206 | 207 | self.level_4.append (level4_new) 208 | self.level_8.append (level8_new) 209 | self.level_16.append (level16_new) 210 | 211 | elif layer == 2 : 212 | level4_new_1 = self.cells[count] (self.level_4[-2], self.level_4[-1], weight_cells) 213 | count += 1 214 | level4_new_2 = self.cells[count] (self.level_4[-2], self.level_8[-1], weight_cells) 215 | count += 1 216 | level4_new = self.alphas_network[layer][0][0] * level4_new_1 + self.alphas_network[layer][0][1] * level4_new_2 217 | 218 | level8_new_1 = self.cells[count] (self.level_8[-2], self.level_4[-1], weight_cells) 219 | count += 1 220 | level8_new_2 = self.cells[count] (self.level_8[-2], self.level_8[-1], weight_cells) 221 | count += 1 222 | # print (self.level_8[-1].size(),self.level_16[-1].size()) 223 | level8_new_3 = self.cells[count] (self.level_8[-2], self.level_16[-1], weight_cells) 224 | count += 1 225 | level8_new = self.alphas_network[layer][1][0] * level8_new_1 + self.alphas_network[layer][1][1] * level8_new_2 + self.alphas_network[layer][1][2] * level8_new_3 226 | 227 | level16_new_1 = self.cells[count] (None, self.level_8[-1], weight_cells) 228 | count += 1 229 | level16_new_2 = self.cells[count] (None, self.level_16[-1], weight_cells) 230 | count += 1 231 | level16_new = self.alphas_network[layer][2][0] * level16_new_1 + self.alphas_network[layer][2][1] * level16_new_2 232 | 233 | 234 | level32_new = self.cells[count] (None, self.level_16[-1], weight_cells) 235 | level32_new = level32_new * self.alphas_network[layer][2][2] 236 | count += 1 237 | 238 | self.level_4.append (level4_new) 239 | self.level_8.append (level8_new) 240 | self.level_16.append (level16_new) 241 | self.level_32.append (level32_new) 242 | 243 | elif layer == 3 : 244 | level4_new_1 = self.cells[count] (self.level_4[-2], self.level_4[-1], weight_cells) 245 | count += 1 246 | level4_new_2 = self.cells[count] (self.level_4[-2], self.level_8[-1], weight_cells) 247 | count += 1 248 | level4_new = self.alphas_network[layer][0][0] * level4_new_1 + self.alphas_network[layer][0][1] * level4_new_2 249 | 250 | level8_new_1 = self.cells[count] (self.level_8[-2], self.level_4[-1], weight_cells) 251 | count += 1 252 | level8_new_2 = self.cells[count] (self.level_8[-2], self.level_8[-1], weight_cells) 253 | count += 1 254 | level8_new_3 = self.cells[count] (self.level_8[-2], self.level_16[-1], weight_cells) 255 | count += 1 256 | level8_new = self.alphas_network[layer][1][0] * level8_new_1 + self.alphas_network[layer][1][1] * level8_new_2 + self.alphas_network[layer][1][2] * level8_new_3 257 | 258 | level16_new_1 = self.cells[count] (self.level_16[-2], self.level_8[-1], weight_cells) 259 | count += 1 260 | level16_new_2 = self.cells[count] (self.level_16[-2], self.level_16[-1], weight_cells) 261 | count += 1 262 | level16_new_3 = self.cells[count] (self.level_16[-2], self.level_32[-1], weight_cells) 263 | count += 1 264 | level16_new = self.alphas_network[layer][2][0] * level16_new_1 + self.alphas_network[layer][2][1] * level16_new_2 + self.alphas_network[layer][2][2] * level16_new_3 265 | 266 | 267 | level32_new_1 = self.cells[count] (None, self.level_16[-1], weight_cells) 268 | count += 1 269 | level32_new_2 = self.cells[count] (None, self.level_32[-1], weight_cells) 270 | count += 1 271 | level32_new = self.alphas_network[layer][3][0] * level32_new_1 + self.alphas_network[layer][3][1] * level32_new_2 272 | 273 | 274 | self.level_4.append (level4_new) 275 | self.level_8.append (level8_new) 276 | self.level_16.append (level16_new) 277 | self.level_32.append (level32_new) 278 | 279 | 280 | else : 281 | level4_new_1 = self.cells[count] (self.level_4[-2], self.level_4[-1], weight_cells) 282 | count += 1 283 | level4_new_2 = self.cells[count] (self.level_4[-2], self.level_8[-1], weight_cells) 284 | count += 1 285 | level4_new = self.alphas_network[layer][0][0] * level4_new_1 + self.alphas_network[layer][0][1] * level4_new_2 286 | 287 | level8_new_1 = self.cells[count] (self.level_8[-2], self.level_4[-1], weight_cells) 288 | count += 1 289 | level8_new_2 = self.cells[count] (self.level_8[-2], self.level_8[-1], weight_cells) 290 | count += 1 291 | level8_new_3 = self.cells[count] (self.level_8[-2], self.level_16[-1], weight_cells) 292 | count += 1 293 | level8_new = self.alphas_network[layer][1][0] * level8_new_1 + self.alphas_network[layer][1][1] * level8_new_2 + self.alphas_network[layer][1][2] * level8_new_3 294 | 295 | level16_new_1 = self.cells[count] (self.level_16[-2], self.level_8[-1], weight_cells) 296 | count += 1 297 | level16_new_2 = self.cells[count] (self.level_16[-2], self.level_16[-1], weight_cells) 298 | count += 1 299 | level16_new_3 = self.cells[count] (self.level_16[-2], self.level_32[-1], weight_cells) 300 | count += 1 301 | level16_new = self.alphas_network[layer][2][0] * level16_new_1 + self.alphas_network[layer][2][1] * level16_new_2 + self.alphas_network[layer][2][2] * level16_new_3 302 | 303 | 304 | level32_new_1 = self.cells[count] (self.level_32[-2], self.level_16[-1], weight_cells) 305 | count += 1 306 | level32_new_2 = self.cells[count] (self.level_32[-2], self.level_32[-1], weight_cells) 307 | count += 1 308 | level32_new = self.alphas_network[layer][3][0] * level32_new_1 + self.alphas_network[layer][3][1] * level32_new_2 309 | 310 | 311 | self.level_4.append (level4_new) 312 | self.level_8.append (level8_new) 313 | self.level_16.append (level16_new) 314 | self.level_32.append (level32_new) 315 | # print (self.level_4[-1].size(),self.level_8[-1].size(),self.level_16[-1].size(),self.level_32[-1].size()) 316 | # concate_feature_map = torch.cat ([self.level_4[-1], self.level_8[-1],self.level_16[-1], self.level_32[-1]], 1) 317 | aspp_result_4 = self.aspp_4 (self.level_4[-1]) 318 | 319 | aspp_result_8 = self.aspp_8 (self.level_8[-1]) 320 | aspp_result_16 = self.aspp_16 (self.level_16[-1]) 321 | aspp_result_32 = self.aspp_32 (self.level_32[-1]) 322 | upsample = nn.Upsample(size=(self._crop_size,self._crop_size), mode='bilinear', align_corners=True) 323 | aspp_result_4 = upsample (aspp_result_4) 324 | aspp_result_8 = upsample (aspp_result_8) 325 | aspp_result_16 = upsample (aspp_result_16) 326 | aspp_result_32 = upsample (aspp_result_32) 327 | 328 | sum_feature_map1 = torch.add (aspp_result_4, aspp_result_8) 329 | sum_feature_map2 = torch.add (aspp_result_16, aspp_result_32) 330 | sum_feature_map = torch.add (sum_feature_map1, sum_feature_map2) 331 | return sum_feature_map 332 | 333 | 334 | def _initialize_alphas(self): 335 | k = sum(1 for i in range(self._step) for n in range(2+i)) 336 | num_ops = len(PRIMITIVES) 337 | alphas_cell = torch.tensor (1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True) 338 | self.register_parameter(self._arch_param_names[0], nn.Parameter(alphas_cell)) 339 | 340 | # num_layer x num_spatial_levels x num_spatial_connections (down, level, up) 341 | alphas_network = torch.tensor (1e-3*torch.randn(self._num_layers, 4, 3).cuda(), requires_grad=True) 342 | self.register_parameter(self._arch_param_names[1], nn.Parameter(alphas_network)) 343 | self.alphas_network_mask = torch.ones(self._num_layers, 4, 3) 344 | 345 | 346 | def decode_network (self) : 347 | best_result = [] 348 | max_prop = 0 349 | def _parse (weight_network, layer, curr_value, curr_result, last) : 350 | nonlocal best_result 351 | nonlocal max_prop 352 | if layer == self._num_layers : 353 | if max_prop < curr_value : 354 | # print (curr_result) 355 | best_result = curr_result[:] 356 | max_prop = curr_value 357 | return 358 | 359 | if layer == 0 : 360 | print ('begin0') 361 | num = 0 362 | if last == num : 363 | curr_value = curr_value * weight_network[layer][num][0] 364 | curr_result.append ([num,0]) 365 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 366 | curr_value = curr_value / weight_network[layer][num][0] 367 | curr_result.pop () 368 | print ('end0-1') 369 | curr_value = curr_value * weight_network[layer][num][1] 370 | curr_result.append ([num,1]) 371 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 372 | curr_value = curr_value / weight_network[layer][num][1] 373 | curr_result.pop () 374 | 375 | elif layer == 1 : 376 | print ('begin1') 377 | 378 | num = 0 379 | if last == num : 380 | curr_value = curr_value * weight_network[layer][num][0] 381 | curr_result.append ([num,0]) 382 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 383 | curr_value = curr_value / weight_network[layer][num][0] 384 | curr_result.pop () 385 | print ('end1-1') 386 | 387 | curr_value = curr_value * weight_network[layer][num][1] 388 | curr_result.append ([num,1]) 389 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 390 | curr_value = curr_value / weight_network[layer][num][1] 391 | curr_result.pop () 392 | 393 | num = 1 394 | if last == num : 395 | curr_value = curr_value * weight_network[layer][num][0] 396 | curr_result.append ([num,0]) 397 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 398 | curr_value = curr_value / weight_network[layer][num][0] 399 | curr_result.pop () 400 | curr_value = curr_value * weight_network[layer][num][1] 401 | curr_result.append ([num,1]) 402 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 403 | curr_value = curr_value / weight_network[layer][num][1] 404 | curr_result.pop () 405 | curr_value = curr_value * weight_network[layer][num][2] 406 | curr_result.append ([num,2]) 407 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 408 | curr_value = curr_value / weight_network[layer][num][2] 409 | curr_result.pop () 410 | 411 | 412 | elif layer == 2 : 413 | print ('begin2') 414 | 415 | num = 0 416 | if last == num : 417 | curr_value = curr_value * weight_network[layer][num][0] 418 | curr_result.append ([num,0]) 419 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 420 | curr_value = curr_value / weight_network[layer][num][0] 421 | curr_result.pop () 422 | print ('end2-1') 423 | curr_value = curr_value * weight_network[layer][num][1] 424 | curr_result.append ([num,1]) 425 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 426 | curr_value = curr_value / weight_network[layer][num][1] 427 | curr_result.pop () 428 | 429 | num = 1 430 | if last == num : 431 | curr_value = curr_value * weight_network[layer][num][0] 432 | curr_result.append ([num,0]) 433 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 434 | curr_value = curr_value / weight_network[layer][num][0] 435 | curr_result.pop () 436 | curr_value = curr_value * weight_network[layer][num][1] 437 | curr_result.append ([num,1]) 438 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 439 | curr_value = curr_value / weight_network[layer][num][1] 440 | curr_result.pop () 441 | curr_value = curr_value * weight_network[layer][num][2] 442 | curr_result.append ([num,2]) 443 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 444 | curr_value = curr_value / weight_network[layer][num][2] 445 | curr_result.pop () 446 | 447 | num = 2 448 | if last == num : 449 | curr_value = curr_value * weight_network[layer][num][0] 450 | curr_result.append ([num,0]) 451 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 452 | curr_value = curr_value / weight_network[layer][num][0] 453 | curr_result.pop () 454 | curr_value = curr_value * weight_network[layer][num][1] 455 | curr_result.append ([num,1]) 456 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 457 | curr_value = curr_value / weight_network[layer][num][1] 458 | curr_result.pop () 459 | curr_value = curr_value * weight_network[layer][num][2] 460 | curr_result.append ([num,2]) 461 | _parse (weight_network, layer + 1, curr_value, curr_result, 3) 462 | curr_value = curr_value / weight_network[layer][num][2] 463 | curr_result.pop () 464 | else : 465 | 466 | num = 0 467 | if last == num : 468 | curr_value = curr_value * weight_network[layer][num][0] 469 | curr_result.append ([num,0]) 470 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 471 | curr_value = curr_value / weight_network[layer][num][0] 472 | curr_result.pop () 473 | 474 | curr_value = curr_value * weight_network[layer][num][1] 475 | curr_result.append ([num,1]) 476 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 477 | curr_value = curr_value / weight_network[layer][num][1] 478 | curr_result.pop () 479 | 480 | num = 1 481 | if last == num : 482 | curr_value = curr_value * weight_network[layer][num][0] 483 | curr_result.append ([num,0]) 484 | _parse (weight_network, layer + 1, curr_value, curr_result, 0) 485 | curr_value = curr_value / weight_network[layer][num][0] 486 | curr_result.pop () 487 | 488 | curr_value = curr_value * weight_network[layer][num][1] 489 | curr_result.append ([num,1]) 490 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 491 | curr_value = curr_value / weight_network[layer][num][1] 492 | curr_result.pop () 493 | 494 | curr_value = curr_value * weight_network[layer][num][2] 495 | curr_result.append ([num,2]) 496 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 497 | curr_value = curr_value / weight_network[layer][num][2] 498 | curr_result.pop () 499 | 500 | num = 2 501 | if last == num : 502 | curr_value = curr_value * weight_network[layer][num][0] 503 | curr_result.append ([num,0]) 504 | _parse (weight_network, layer + 1, curr_value, curr_result, 1) 505 | curr_value = curr_value / weight_network[layer][num][0] 506 | curr_result.pop () 507 | 508 | curr_value = curr_value * weight_network[layer][num][1] 509 | curr_result.append ([num,1]) 510 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 511 | curr_value = curr_value / weight_network[layer][num][1] 512 | curr_result.pop () 513 | 514 | curr_value = curr_value * weight_network[layer][num][2] 515 | curr_result.append ([num,2]) 516 | _parse (weight_network, layer + 1, curr_value, curr_result, 3) 517 | curr_value = curr_value / weight_network[layer][num][2] 518 | curr_result.pop () 519 | 520 | num = 3 521 | if last == num : 522 | curr_value = curr_value * weight_network[layer][num][0] 523 | curr_result.append ([num,0]) 524 | _parse (weight_network, layer + 1, curr_value, curr_result, 2) 525 | curr_value = curr_value / weight_network[layer][num][0] 526 | curr_result.pop () 527 | 528 | curr_value = curr_value * weight_network[layer][num][1] 529 | curr_result.append ([num,1]) 530 | _parse (weight_network, layer + 1, curr_value, curr_result, 3) 531 | curr_value = curr_value / weight_network[layer][num][1] 532 | curr_result.pop () 533 | network_weight = F.softmax(self.alphas_network, dim=-1) * 5 534 | network_weight = network_weight.data.cpu().numpy() 535 | _parse (network_weight, 0, 1, [],0) 536 | print (max_prop) 537 | return best_result 538 | 539 | def arch_parameters(self): 540 | return [param for name, param in self.named_parameters() if name in self._arch_param_names] 541 | 542 | def weight_parameters(self): 543 | return [param for name, param in self.named_parameters() if name not in self._arch_param_names] 544 | 545 | def genotype(self): 546 | def _parse(weights): 547 | gene = [] 548 | n = 2 549 | start = 0 550 | for i in range(self._step): 551 | end = start + n 552 | W = weights[start:end].copy() 553 | edges = sorted (range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] 554 | for j in edges: 555 | k_best = None 556 | for k in range(len(W[j])): 557 | if k != PRIMITIVES.index('none'): 558 | if k_best is None or W[j][k] > W[j][k_best]: 559 | k_best = k 560 | gene.append((PRIMITIVES[k_best], j)) 561 | start = end 562 | n += 1 563 | return gene 564 | 565 | gene_cell = _parse(F.softmax(self.alphas_cell, dim=-1).data.cpu().numpy()) 566 | concat = range(2+self._step-self._multiplier, self._step+2) 567 | genotype = Genotype( 568 | cell=gene_cell, cell_concat=concat 569 | ) 570 | 571 | return genotype 572 | 573 | def _loss (self, input, target) : 574 | logits = self (input) 575 | return self._criterion (logits, target) 576 | 577 | 578 | 579 | 580 | def main () : 581 | model = AutoDeeplab (5, 12, None) 582 | x = torch.tensor (torch.ones (4, 3, 224, 224)) 583 | result = model.decode_network () 584 | print (result) 585 | print (model.genotype()) 586 | # x = x.cuda() 587 | # y = model (x) 588 | # print (model.arch_parameters ()) 589 | # print (y.size()) 590 | 591 | if __name__ == '__main__' : 592 | main () 593 | -------------------------------------------------------------------------------- /dataloaders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/AutoDeeplab/94659b26f708f2e694367785b9d2e75f175ab639/dataloaders/.DS_Store -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd 2 | from torch.utils.data import DataLoader 3 | 4 | def make_data_loader(args, **kwargs): 5 | 6 | if args.dataset == 'pascal': 7 | train_set = pascal.VOCSegmentation(args, split='train') 8 | val_set = pascal.VOCSegmentation(args, split='val') 9 | if args.use_sbd: 10 | sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) 11 | train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 12 | 13 | num_class = train_set.NUM_CLASSES 14 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 15 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 16 | test_loader = None 17 | 18 | return train_loader, train_loader, val_loader, test_loader, num_class 19 | 20 | elif args.dataset == 'cityscapes': 21 | train_set1, train_set2 = cityscapes.sp(args, split='train') 22 | val_set = cityscapes.CityscapesSegmentation(args, split='val') 23 | test_set = cityscapes.CityscapesSegmentation(args, split='test') 24 | num_class = train_set1.NUM_CLASSES 25 | train_loader1 = DataLoader(train_set1, batch_size=args.batch_size, shuffle=True, **kwargs) 26 | train_loader2 = DataLoader(train_set2, batch_size=args.batch_size, shuffle=True, **kwargs) 27 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 28 | #test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 29 | 30 | #return train_loader1, train_loader2, val_loader, test_loader, num_class 31 | return train_loader1, train_loader2, val_loader, num_class 32 | 33 | elif args.dataset == 'coco': 34 | train_set = coco.COCOSegmentation(args, split='train') 35 | val_set = coco.COCOSegmentation(args, split='val') 36 | num_class = train_set.NUM_CLASSES 37 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 38 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 39 | test_loader = None 40 | return train_loader, train_loader, val_loader, test_loader, num_class 41 | 42 | else: 43 | raise NotImplementedError 44 | 45 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | mask = np.array(mask).astype(np.float32) 22 | img /= 255.0 23 | img -= self.mean 24 | img /= self.std 25 | 26 | return {'image': img, 27 | 'label': mask} 28 | 29 | 30 | class ToTensor(object): 31 | """Convert ndarrays in sample to Tensors.""" 32 | 33 | def __call__(self, sample): 34 | # swap color axis because 35 | # numpy image: H x W x C 36 | # torch image: C X H X W 37 | img = sample['image'] 38 | mask = sample['label'] 39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 40 | mask = np.array(mask).astype(np.float32) 41 | 42 | img = torch.from_numpy(img).float() 43 | mask = torch.from_numpy(mask).float() 44 | 45 | return {'image': img, 46 | 'label': mask} 47 | 48 | 49 | class RandomHorizontalFlip(object): 50 | def __call__(self, sample): 51 | img = sample['image'] 52 | mask = sample['label'] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {'image': img, 58 | 'label': mask} 59 | 60 | 61 | class RandomRotate(object): 62 | def __init__(self, degree): 63 | self.degree = degree 64 | 65 | def __call__(self, sample): 66 | img = sample['image'] 67 | mask = sample['label'] 68 | rotate_degree = random.uniform(-1*self.degree, self.degree) 69 | img = img.rotate(rotate_degree, Image.BILINEAR) 70 | mask = mask.rotate(rotate_degree, Image.NEAREST) 71 | 72 | return {'image': img, 73 | 'label': mask} 74 | 75 | 76 | class RandomGaussianBlur(object): 77 | def __call__(self, sample): 78 | img = sample['image'] 79 | mask = sample['label'] 80 | if random.random() < 0.5: 81 | img = img.filter(ImageFilter.GaussianBlur( 82 | radius=random.random())) 83 | 84 | return {'image': img, 85 | 'label': mask} 86 | 87 | 88 | class RandomScaleCrop(object): 89 | def __init__(self, base_size, crop_size, fill=0): 90 | self.base_size = base_size 91 | self.crop_size = crop_size 92 | self.fill = fill 93 | 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | mask = sample['label'] 97 | # random scale (short edge) 98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 99 | w, h = img.size 100 | if h > w: 101 | ow = short_size 102 | oh = int(1.0 * h * ow / w) 103 | else: 104 | oh = short_size 105 | ow = int(1.0 * w * oh / h) 106 | img = img.resize((ow, oh), Image.BILINEAR) 107 | mask = mask.resize((ow, oh), Image.NEAREST) 108 | # pad crop 109 | if short_size < self.crop_size: 110 | padh = self.crop_size - oh if oh < self.crop_size else 0 111 | padw = self.crop_size - ow if ow < self.crop_size else 0 112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 114 | # random crop crop_size 115 | w, h = img.size 116 | x1 = random.randint(0, w - self.crop_size) 117 | y1 = random.randint(0, h - self.crop_size) 118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | 121 | return {'image': img, 122 | 'label': mask} 123 | 124 | 125 | class FixScaleCrop(object): 126 | def __init__(self, crop_size): 127 | self.crop_size = crop_size 128 | 129 | def __call__(self, sample): 130 | img = sample['image'] 131 | mask = sample['label'] 132 | w, h = img.size 133 | if w > h: 134 | oh = self.crop_size 135 | ow = int(1.0 * w * oh / h) 136 | else: 137 | ow = self.crop_size 138 | oh = int(1.0 * h * ow / w) 139 | img = img.resize((ow, oh), Image.BILINEAR) 140 | mask = mask.resize((ow, oh), Image.NEAREST) 141 | # center crop 142 | w, h = img.size 143 | x1 = int(round((w - self.crop_size) / 2.)) 144 | y1 = int(round((h - self.crop_size) / 2.)) 145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | 148 | return {'image': img, 149 | 'label': mask} 150 | 151 | class FixedResize(object): 152 | def __init__(self, size): 153 | self.size = (size, size) # size: (h, w) 154 | 155 | def __call__(self, sample): 156 | img = sample['image'] 157 | mask = sample['label'] 158 | 159 | assert img.size == mask.size 160 | 161 | img = img.resize(self.size, Image.BILINEAR) 162 | mask = mask.resize(self.size, Image.NEAREST) 163 | 164 | return {'image': img, 165 | 'label': mask} -------------------------------------------------------------------------------- /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/AutoDeeplab/94659b26f708f2e694367785b9d2e75f175ab639/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as m 4 | from PIL import Image 5 | from torch.utils import data 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | import random 10 | 11 | def sp(args, split='train'): 12 | root=Path.db_root_dir('cityscapes') 13 | split="train" 14 | images_base = os.path.join(root, 'leftImg8bit', split) 15 | rootdir=images_base 16 | suffix='.png' 17 | 18 | ls = [os.path.join(looproot, filename) 19 | for looproot, _, filenames in os.walk(rootdir) 20 | for filename in filenames if filename.endswith(suffix)] 21 | random.shuffle(ls) 22 | split = 2975//2 23 | 24 | return CityscapesSegmentation(args, split='train', part=ls[split:]), CityscapesSegmentation(args, split='train', part=ls[:split]) 25 | 26 | class CityscapesSegmentation(data.Dataset): 27 | NUM_CLASSES = 19 28 | 29 | def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train", part=None): 30 | self.NUM_CLASSES = 19 31 | self.root = root 32 | self.split = split 33 | self.args = args 34 | self.files = {} 35 | self.part=part 36 | self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) 37 | self.annotations_base = os.path.join(self.root, 'gtFine', self.split) 38 | 39 | if self.split=="train": 40 | self.files[split] = part 41 | else: 42 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') 43 | 44 | 45 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 46 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 47 | self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 48 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 49 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 50 | 'motorcycle', 'bicycle'] 51 | 52 | self.ignore_index = 255 53 | self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES))) 54 | 55 | if not self.files[split]: 56 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 57 | 58 | print("Found %d %s images" % (len(self.files[split]), split)) 59 | 60 | def __len__(self): 61 | return len(self.files[self.split]) 62 | 63 | def __getitem__(self, index): 64 | 65 | img_path = self.files[self.split][index].rstrip() 66 | lbl_path = os.path.join(self.annotations_base, 67 | img_path.split(os.sep)[-2], 68 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 69 | 70 | _img = Image.open(img_path).convert('RGB') 71 | _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) 72 | _tmp = self.encode_segmap(_tmp) 73 | _target = Image.fromarray(_tmp) 74 | 75 | sample = {'image': _img, 'label': _target} 76 | 77 | if self.split == 'train': 78 | return self.transform_tr(sample) 79 | elif self.split == 'val': 80 | return self.transform_val(sample) 81 | elif self.split == 'test': 82 | return self.transform_ts(sample) 83 | 84 | def encode_segmap(self, mask): 85 | # Put all void classes to zero 86 | for _voidc in self.void_classes: 87 | mask[mask == _voidc] = self.ignore_index 88 | for _validc in self.valid_classes: 89 | mask[mask == _validc] = self.class_map[_validc] 90 | return mask 91 | 92 | def recursive_glob(self, rootdir='.', suffix=''): 93 | """Performs recursive glob with given suffix and rootdir 94 | :param rootdir is the root directory 95 | :param suffix is the suffix to be searched 96 | """ 97 | return [os.path.join(looproot, filename) 98 | for looproot, _, filenames in os.walk(rootdir) 99 | for filename in filenames if filename.endswith(suffix)] 100 | 101 | def transform_tr(self, sample): 102 | composed_transforms = transforms.Compose([ 103 | FixedResize(resize=self.args.resize), 104 | RandomCrop(crop_size=self.args.crop_size), 105 | #tr.RandomGaussianBlur(), 106 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 107 | tr.ToTensor()]) 108 | 109 | return composed_transforms(sample) 110 | 111 | def transform_val(self, sample): 112 | 113 | composed_transforms = transforms.Compose([ 114 | tr.FixScaleCrop(crop_size=self.args.crop_size), 115 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 116 | tr.ToTensor()]) 117 | 118 | return composed_transforms(sample) 119 | 120 | def transform_ts(self, sample): 121 | 122 | composed_transforms = transforms.Compose([ 123 | tr.FixedResize(size=self.args.crop_size), 124 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 125 | tr.ToTensor()]) 126 | 127 | return composed_transforms(sample) 128 | 129 | 130 | # resize to 512*1024 131 | class FixedResize(object): 132 | """change the short edge length to size""" 133 | def __init__(self, resize=512): 134 | self.size1 = resize # size= 512 135 | 136 | def __call__(self, sample): 137 | img = sample['image'] 138 | mask = sample['label'] 139 | assert img.size == mask.size 140 | 141 | w, h = img.size 142 | if w > h: 143 | oh = self.size1 144 | ow = int(1.0 * w * oh / h) 145 | else: 146 | ow = self.size1 147 | oh = int(1.0 * h * ow / w) 148 | img = img.resize((ow,oh), Image.BILINEAR) 149 | mask = mask.resize((ow,oh), Image.NEAREST) 150 | return {'image': img, 151 | 'label': mask} 152 | 153 | # random corp 321*321 154 | class RandomCrop(object): 155 | def __init__(self, crop_size=321): 156 | self.crop_size = crop_size 157 | 158 | def __call__(self, sample): 159 | img = sample['image'] 160 | mask = sample['label'] 161 | w, h = img.size 162 | x1 = random.randint(0, w - self.crop_size) 163 | y1 = random.randint(0, h - self.crop_size) 164 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 165 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 166 | return {'image': img, 167 | 'label': mask} 168 | 169 | 170 | if __name__ == '__main__': 171 | from dataloaders.utils import decode_segmap 172 | from torch.utils.data import DataLoader 173 | import matplotlib.pyplot as plt 174 | import argparse 175 | 176 | parser = argparse.ArgumentParser() 177 | args = parser.parse_args() 178 | args.base_size = 513 179 | args.crop_size = 513 180 | 181 | cityscapes_train = CityscapesSegmentation(args, split='train') 182 | 183 | dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) 184 | 185 | for ii, sample in enumerate(dataloader): 186 | for jj in range(sample["image"].size()[0]): 187 | img = sample['image'].numpy() 188 | gt = sample['label'].numpy() 189 | tmp = np.array(gt[jj]).astype(np.uint8) 190 | segmap = decode_segmap(tmp, dataset='cityscapes') 191 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 192 | img_tmp *= (0.229, 0.224, 0.225) 193 | img_tmp += (0.485, 0.456, 0.406) 194 | img_tmp *= 255.0 195 | img_tmp = img_tmp.astype(np.uint8) 196 | plt.figure() 197 | plt.title('display') 198 | plt.subplot(211) 199 | plt.imshow(img_tmp) 200 | plt.subplot(212) 201 | plt.imshow(segmap) 202 | 203 | if ii == 1: 204 | break 205 | 206 | plt.show(block=True) 207 | 208 | -------------------------------------------------------------------------------- /dataloaders/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from mypath import Path 5 | from tqdm import trange 6 | import os 7 | from pycocotools.coco import COCO 8 | from pycocotools import mask 9 | from torchvision import transforms 10 | from dataloaders import custom_transforms as tr 11 | from PIL import Image, ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | class COCOSegmentation(Dataset): 16 | NUM_CLASSES = 21 17 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 18 | 1, 64, 20, 63, 7, 72] 19 | 20 | def __init__(self, 21 | args, 22 | base_dir=Path.db_root_dir('coco'), 23 | split='train', 24 | year='2017'): 25 | super().__init__() 26 | ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year)) 27 | ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year)) 28 | self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year)) 29 | self.split = split 30 | self.coco = COCO(ann_file) 31 | self.coco_mask = mask 32 | if os.path.exists(ids_file): 33 | self.ids = torch.load(ids_file) 34 | else: 35 | ids = list(self.coco.imgs.keys()) 36 | self.ids = self._preprocess(ids, ids_file) 37 | self.args = args 38 | 39 | def __getitem__(self, index): 40 | _img, _target = self._make_img_gt_point_pair(index) 41 | sample = {'image': _img, 'label': _target} 42 | 43 | if self.split == "train": 44 | return self.transform_tr(sample) 45 | elif self.split == 'val': 46 | return self.transform_val(sample) 47 | 48 | def _make_img_gt_point_pair(self, index): 49 | coco = self.coco 50 | img_id = self.ids[index] 51 | img_metadata = coco.loadImgs(img_id)[0] 52 | path = img_metadata['file_name'] 53 | _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') 54 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 55 | _target = Image.fromarray(self._gen_seg_mask( 56 | cocotarget, img_metadata['height'], img_metadata['width'])) 57 | 58 | return _img, _target 59 | 60 | def _preprocess(self, ids, ids_file): 61 | print("Preprocessing mask, this will take a while. " + \ 62 | "But don't worry, it only run once for each split.") 63 | tbar = trange(len(ids)) 64 | new_ids = [] 65 | for i in tbar: 66 | img_id = ids[i] 67 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 68 | img_metadata = self.coco.loadImgs(img_id)[0] 69 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], 70 | img_metadata['width']) 71 | # more than 1k pixels 72 | if (mask > 0).sum() > 1000: 73 | new_ids.append(img_id) 74 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 75 | format(i, len(ids), len(new_ids))) 76 | print('Found number of qualified images: ', len(new_ids)) 77 | torch.save(new_ids, ids_file) 78 | return new_ids 79 | 80 | def _gen_seg_mask(self, target, h, w): 81 | mask = np.zeros((h, w), dtype=np.uint8) 82 | coco_mask = self.coco_mask 83 | for instance in target: 84 | rle = coco_mask.frPyObjects(instance['segmentation'], h, w) 85 | m = coco_mask.decode(rle) 86 | cat = instance['category_id'] 87 | if cat in self.CAT_LIST: 88 | c = self.CAT_LIST.index(cat) 89 | else: 90 | continue 91 | if len(m.shape) < 3: 92 | mask[:, :] += (mask == 0) * (m * c) 93 | else: 94 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 95 | return mask 96 | 97 | def transform_tr(self, sample): 98 | composed_transforms = transforms.Compose([ 99 | tr.RandomHorizontalFlip(), 100 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 101 | tr.RandomGaussianBlur(), 102 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 103 | tr.ToTensor()]) 104 | 105 | return composed_transforms(sample) 106 | 107 | def transform_val(self, sample): 108 | 109 | composed_transforms = transforms.Compose([ 110 | tr.FixScaleCrop(crop_size=self.args.crop_size), 111 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 112 | tr.ToTensor()]) 113 | 114 | return composed_transforms(sample) 115 | 116 | 117 | def __len__(self): 118 | return len(self.ids) 119 | 120 | 121 | 122 | if __name__ == "__main__": 123 | from dataloaders import custom_transforms as tr 124 | from dataloaders.utils import decode_segmap 125 | from torch.utils.data import DataLoader 126 | from torchvision import transforms 127 | import matplotlib.pyplot as plt 128 | import argparse 129 | 130 | parser = argparse.ArgumentParser() 131 | args = parser.parse_args() 132 | args.base_size = 513 133 | args.crop_size = 513 134 | 135 | coco_val = COCOSegmentation(args, split='val', year='2017') 136 | 137 | dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0) 138 | 139 | for ii, sample in enumerate(dataloader): 140 | for jj in range(sample["image"].size()[0]): 141 | img = sample['image'].numpy() 142 | gt = sample['label'].numpy() 143 | tmp = np.array(gt[jj]).astype(np.uint8) 144 | segmap = decode_segmap(tmp, dataset='coco') 145 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 146 | img_tmp *= (0.229, 0.224, 0.225) 147 | img_tmp += (0.485, 0.456, 0.406) 148 | img_tmp *= 255.0 149 | img_tmp = img_tmp.astype(np.uint8) 150 | plt.figure() 151 | plt.title('display') 152 | plt.subplot(211) 153 | plt.imshow(img_tmp) 154 | plt.subplot(212) 155 | plt.imshow(segmap) 156 | 157 | if ii == 1: 158 | break 159 | 160 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/datasets/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | NUM_CLASSES = 21 6 | def __init__(self, dataloaders, excluded=None): 7 | self.dataloaders = dataloaders 8 | self.excluded = excluded 9 | self.im_ids = [] 10 | 11 | # Combine object lists 12 | for dl in dataloaders: 13 | for elem in dl.im_ids: 14 | if elem not in self.im_ids: 15 | self.im_ids.append(elem) 16 | 17 | # Exclude 18 | if excluded: 19 | for dl in excluded: 20 | for elem in dl.im_ids: 21 | if elem in self.im_ids: 22 | self.im_ids.remove(elem) 23 | 24 | # Get object pointers 25 | self.cat_list = [] 26 | self.im_list = [] 27 | new_im_ids = [] 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | num_images += 1 33 | new_im_ids.append(curr_im_id) 34 | self.cat_list.append({'db_ii': ii, 'cat_ii': jj}) 35 | 36 | self.im_ids = new_im_ids 37 | print('Combined number of images: {:d}'.format(num_images)) 38 | 39 | def __getitem__(self, index): 40 | 41 | _db_ii = self.cat_list[index]["db_ii"] 42 | _cat_ii = self.cat_list[index]['cat_ii'] 43 | sample = self.dataloaders[_db_ii].__getitem__(_cat_ii) 44 | 45 | if 'meta' in sample.keys(): 46 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 47 | 48 | return sample 49 | 50 | def __len__(self): 51 | return len(self.cat_list) 52 | 53 | def __str__(self): 54 | include_db = [str(db) for db in self.dataloaders] 55 | exclude_db = [str(db) for db in self.excluded] 56 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 57 | 58 | 59 | if __name__ == "__main__": 60 | import matplotlib.pyplot as plt 61 | from dataloaders.datasets import pascal, sbd 62 | from dataloaders import sbd 63 | import torch 64 | import numpy as np 65 | from dataloaders.utils import decode_segmap 66 | import argparse 67 | 68 | parser = argparse.ArgumentParser() 69 | args = parser.parse_args() 70 | args.base_size = 513 71 | args.crop_size = 513 72 | 73 | pascal_voc_val = pascal.VOCSegmentation(args, split='val') 74 | sbd = sbd.SBDSegmentation(args, split=['train', 'val']) 75 | pascal_voc_train = pascal.VOCSegmentation(args, split='train') 76 | 77 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val]) 78 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0) 79 | 80 | for ii, sample in enumerate(dataloader): 81 | for jj in range(sample["image"].size()[0]): 82 | img = sample['image'].numpy() 83 | gt = sample['label'].numpy() 84 | tmp = np.array(gt[jj]).astype(np.uint8) 85 | segmap = decode_segmap(tmp, dataset='pascal') 86 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 87 | img_tmp *= (0.229, 0.224, 0.225) 88 | img_tmp += (0.485, 0.456, 0.406) 89 | img_tmp *= 255.0 90 | img_tmp = img_tmp.astype(np.uint8) 91 | plt.figure() 92 | plt.title('display') 93 | plt.subplot(211) 94 | plt.imshow(img_tmp) 95 | plt.subplot(212) 96 | plt.imshow(segmap) 97 | 98 | if ii == 1: 99 | break 100 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class VOCSegmentation(Dataset): 11 | """ 12 | PascalVoc dataset 13 | """ 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('pascal'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages') 29 | self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass') 30 | 31 | if isinstance(split, str): 32 | self.split = [split] 33 | else: 34 | split.sort() 35 | self.split = split 36 | 37 | self.args = args 38 | 39 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') 40 | 41 | self.im_ids = [] 42 | self.images = [] 43 | self.categories = [] 44 | 45 | for splt in self.split: 46 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for ii, line in enumerate(lines): 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _cat = os.path.join(self._cat_dir, line + ".png") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_cat) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_cat) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images in {}: {:d}'.format(split, len(self.images))) 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | 67 | def __getitem__(self, index): 68 | _img, _target = self._make_img_gt_point_pair(index) 69 | sample = {'image': _img, 'label': _target} 70 | 71 | for split in self.split: 72 | if split == "train": 73 | return self.transform_tr(sample) 74 | elif split == 'val': 75 | return self.transform_val(sample) 76 | 77 | 78 | def _make_img_gt_point_pair(self, index): 79 | _img = Image.open(self.images[index]).convert('RGB') 80 | _target = Image.open(self.categories[index]) 81 | 82 | return _img, _target 83 | 84 | def transform_tr(self, sample): 85 | composed_transforms = transforms.Compose([ 86 | tr.RandomHorizontalFlip(), 87 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 88 | tr.RandomGaussianBlur(), 89 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 90 | tr.ToTensor()]) 91 | 92 | return composed_transforms(sample) 93 | 94 | def transform_val(self, sample): 95 | 96 | composed_transforms = transforms.Compose([ 97 | tr.FixScaleCrop(crop_size=self.args.crop_size), 98 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 99 | tr.ToTensor()]) 100 | 101 | return composed_transforms(sample) 102 | 103 | def __str__(self): 104 | return 'VOC2012(split=' + str(self.split) + ')' 105 | 106 | 107 | if __name__ == '__main__': 108 | from dataloaders.utils import decode_segmap 109 | from torch.utils.data import DataLoader 110 | import matplotlib.pyplot as plt 111 | import argparse 112 | 113 | parser = argparse.ArgumentParser() 114 | args = parser.parse_args() 115 | args.base_size = 513 116 | args.crop_size = 513 117 | 118 | voc_train = VOCSegmentation(args, split='train') 119 | 120 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 121 | 122 | for ii, sample in enumerate(dataloader): 123 | for jj in range(sample["image"].size()[0]): 124 | img = sample['image'].numpy() 125 | gt = sample['label'].numpy() 126 | tmp = np.array(gt[jj]).astype(np.uint8) 127 | segmap = decode_segmap(tmp, dataset='pascal') 128 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 129 | img_tmp *= (0.229, 0.224, 0.225) 130 | img_tmp += (0.485, 0.456, 0.406) 131 | img_tmp *= 255.0 132 | img_tmp = img_tmp.astype(np.uint8) 133 | plt.figure() 134 | plt.title('display') 135 | plt.subplot(211) 136 | plt.imshow(img_tmp) 137 | plt.subplot(212) 138 | plt.imshow(segmap) 139 | 140 | if ii == 1: 141 | break 142 | 143 | plt.show(block=True) 144 | 145 | 146 | -------------------------------------------------------------------------------- /dataloaders/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import numpy as np 5 | import scipy.io 6 | import torch.utils.data as data 7 | from PIL import Image 8 | from mypath import Path 9 | 10 | from torchvision import transforms 11 | from dataloaders import custom_transforms as tr 12 | 13 | class SBDSegmentation(data.Dataset): 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('sbd'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._dataset_dir = os.path.join(self._base_dir, 'dataset') 29 | self._image_dir = os.path.join(self._dataset_dir, 'img') 30 | self._cat_dir = os.path.join(self._dataset_dir, 'cls') 31 | 32 | 33 | if isinstance(split, str): 34 | self.split = [split] 35 | else: 36 | split.sort() 37 | self.split = split 38 | 39 | self.args = args 40 | 41 | # Get list of all images from the split and check that the files exist 42 | self.im_ids = [] 43 | self.images = [] 44 | self.categories = [] 45 | for splt in self.split: 46 | with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for line in lines: 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _categ= os.path.join(self._cat_dir, line + ".mat") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_categ) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_categ) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images: {:d}'.format(len(self.images))) 62 | 63 | 64 | def __getitem__(self, index): 65 | _img, _target = self._make_img_gt_point_pair(index) 66 | sample = {'image': _img, 'label': _target} 67 | 68 | return self.transform(sample) 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def _make_img_gt_point_pair(self, index): 74 | _img = Image.open(self.images[index]).convert('RGB') 75 | _target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0]) 76 | 77 | return _img, _target 78 | 79 | def transform(self, sample): 80 | composed_transforms = transforms.Compose([ 81 | tr.RandomHorizontalFlip(), 82 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 83 | tr.RandomGaussianBlur(), 84 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 85 | tr.ToTensor()]) 86 | 87 | return composed_transforms(sample) 88 | 89 | 90 | def __str__(self): 91 | return 'SBDSegmentation(split=' + str(self.split) + ')' 92 | 93 | 94 | if __name__ == '__main__': 95 | from dataloaders.utils import decode_segmap 96 | from torch.utils.data import DataLoader 97 | import matplotlib.pyplot as plt 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | args = parser.parse_args() 102 | args.base_size = 513 103 | args.crop_size = 513 104 | 105 | sbd_train = SBDSegmentation(args, split='train') 106 | dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2) 107 | 108 | for ii, sample in enumerate(dataloader): 109 | for jj in range(sample["image"].size()[0]): 110 | img = sample['image'].numpy() 111 | gt = sample['label'].numpy() 112 | tmp = np.array(gt[jj]).astype(np.uint8) 113 | segmap = decode_segmap(tmp, dataset='pascal') 114 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 115 | img_tmp *= (0.229, 0.224, 0.225) 116 | img_tmp += (0.485, 0.456, 0.406) 117 | img_tmp *= 255.0 118 | img_tmp = img_tmp.astype(np.uint8) 119 | plt.figure() 120 | plt.title('display') 121 | plt.subplot(211) 122 | plt.imshow(img_tmp) 123 | plt.subplot(212) 124 | plt.imshow(segmap) 125 | 126 | if ii == 1: 127 | break 128 | 129 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | else: 31 | raise NotImplementedError 32 | 33 | r = label_mask.copy() 34 | g = label_mask.copy() 35 | b = label_mask.copy() 36 | for ll in range(0, n_classes): 37 | r[label_mask == ll] = label_colours[ll, 0] 38 | g[label_mask == ll] = label_colours[ll, 1] 39 | b[label_mask == ll] = label_colours[ll, 2] 40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 41 | rgb[:, :, 0] = r / 255.0 42 | rgb[:, :, 1] = g / 255.0 43 | rgb[:, :, 2] = b / 255.0 44 | if plot: 45 | plt.imshow(rgb) 46 | plt.show() 47 | else: 48 | return rgb 49 | 50 | 51 | def encode_segmap(mask): 52 | """Encode segmentation label images as pascal classes 53 | Args: 54 | mask (np.ndarray): raw segmentation label image of dimension 55 | (M, N, 3), in which the Pascal classes are encoded as colours. 56 | Returns: 57 | (np.ndarray): class map with dimensions (M,N), where the value at 58 | a given location is the integer denoting the class index. 59 | """ 60 | mask = mask.astype(int) 61 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 62 | for ii, label in enumerate(get_pascal_labels()): 63 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 64 | label_mask = label_mask.astype(int) 65 | return label_mask 66 | 67 | 68 | def get_cityscapes_labels(): 69 | return np.array([ 70 | [128, 64, 128], 71 | [244, 35, 232], 72 | [70, 70, 70], 73 | [102, 102, 156], 74 | [190, 153, 153], 75 | [153, 153, 153], 76 | [250, 170, 30], 77 | [220, 220, 0], 78 | [107, 142, 35], 79 | [152, 251, 152], 80 | [0, 130, 180], 81 | [220, 20, 60], 82 | [255, 0, 0], 83 | [0, 0, 142], 84 | [0, 0, 70], 85 | [0, 60, 100], 86 | [0, 80, 100], 87 | [0, 0, 230], 88 | [119, 11, 32]]) 89 | 90 | 91 | def get_pascal_labels(): 92 | """Load the mapping that associates pascal classes with label colors 93 | Returns: 94 | np.ndarray with dimensions (21, 3) 95 | """ 96 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 97 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 98 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 99 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 100 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 101 | [0, 64, 128]]) -------------------------------------------------------------------------------- /genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'cell cell_concat') 4 | 5 | PRIMITIVES = [ 6 | 'none', 7 | 'max_pool_3x3', 8 | 'avg_pool_3x3', 9 | 'skip_connect', 10 | 'sep_conv_3x3', 11 | 'sep_conv_5x5', 12 | 'dil_conv_3x3', 13 | 'dil_conv_5x5' 14 | ] 15 | 16 | 17 | -------------------------------------------------------------------------------- /model_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from operations import * 5 | from torch.autograd import Variable 6 | from genotypes import PRIMITIVES 7 | from genotypes import Genotype 8 | 9 | 10 | class MixedOp (nn.Module): 11 | 12 | def __init__(self, C, stride): 13 | super(MixedOp, self).__init__() 14 | self._ops = nn.ModuleList() 15 | for primitive in PRIMITIVES: 16 | op = OPS[primitive](C, stride, False) 17 | if 'pool' in primitive: 18 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False)) 19 | self._ops.append(op) 20 | 21 | def forward(self, x, weights): 22 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 23 | 24 | 25 | class Cell(nn.Module): 26 | 27 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, rate): 28 | 29 | super(Cell, self).__init__() 30 | self.C_out = C 31 | if C_prev_prev != -1 : 32 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) 33 | 34 | if rate == 2 : 35 | self.preprocess1 = FactorizedReduce (C_prev, C, affine= False) 36 | elif rate == 0 : 37 | self.preprocess1 = FactorizedIncrease (C_prev, C) 38 | else : 39 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) 40 | self._steps = steps 41 | self._multiplier = multiplier 42 | self._ops = nn.ModuleList() 43 | for i in range(self._steps): 44 | for j in range(2+i): 45 | if C_prev_prev != -1 and j != 0: 46 | op = MixedOp(C, stride) 47 | else: 48 | stride = 1 49 | op = None 50 | stride = 1 51 | self._ops.append(op) 52 | self.ReLUConvBN = ReLUConvBN (self._multiplier * self.C_out, self.C_out, 1, 1, 0) 53 | 54 | def forward(self, s0, s1, weights): 55 | if s0 is not None : 56 | s0 = self.preprocess0 (s0) 57 | s1 = self.preprocess1(s1) 58 | states = [s0, s1] 59 | 60 | offset = 0 61 | for i in range(self._steps): 62 | s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states) if h is not None) 63 | offset += len(states) 64 | states.append(s) 65 | 66 | concat_feature = torch.cat(states[-self._multiplier:], dim=1) 67 | return self.ReLUConvBN (concat_feature) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(dataset): 4 | if dataset == 'pascal': 5 | return '/path/to/datasets/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/. 6 | elif dataset == 'sbd': 7 | return '/path/to/datasets/benchmark_RELEASE/' # folder that contains dataset/. 8 | elif dataset == 'cityscapes': 9 | return '/path/to/datasets/cityscapes/' # foler that contains leftImg8bit/ 10 | elif dataset == 'coco': 11 | return '/path/to/datasets/coco/' 12 | else: 13 | print('Dataset {} not available.'.format(dataset)) 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'none' : lambda C, stride, affine: Zero(stride), 6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 9 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 10 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 11 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 12 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 13 | } 14 | 15 | class ReLUConvBN(nn.Module): 16 | 17 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 18 | super(ReLUConvBN, self).__init__() 19 | self.op = nn.Sequential( 20 | nn.ReLU(inplace=False), 21 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 22 | nn.BatchNorm2d(C_out, affine=affine) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.op(x) 27 | 28 | class DilConv(nn.Module): 29 | 30 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 31 | super(DilConv, self).__init__() 32 | self.op = nn.Sequential( 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), 35 | nn.BatchNorm2d(C_out, affine=affine), 36 | ) 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class SepConv(nn.Module): 43 | 44 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 45 | super(SepConv, self).__init__() 46 | self.op = nn.Sequential( 47 | nn.ReLU(inplace=False), 48 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 49 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 50 | nn.BatchNorm2d(C_out, affine=affine), 51 | ) 52 | 53 | def forward(self, x): 54 | return self.op(x) 55 | 56 | 57 | class Identity(nn.Module): 58 | 59 | def __init__(self): 60 | super(Identity, self).__init__() 61 | 62 | def forward(self, x): 63 | return x 64 | 65 | 66 | class Zero(nn.Module): 67 | 68 | def __init__(self, stride): 69 | super(Zero, self).__init__() 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | if self.stride == 1: 74 | return x.mul(0.) 75 | return x[:,:,::self.stride,::self.stride].mul(0.) 76 | 77 | 78 | class FactorizedReduce(nn.Module): 79 | 80 | def __init__(self, C_in, C_out, affine=True): 81 | super(FactorizedReduce, self).__init__() 82 | assert C_out % 2 == 0 83 | self.relu = nn.ReLU(inplace=False) 84 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 85 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 86 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 87 | 88 | def forward(self, x): 89 | x = self.relu(x) 90 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 91 | out = self.bn(out) 92 | return out 93 | 94 | class FactorizedIncrease (nn.Module) : 95 | def __init__ (self, in_channel, out_channel) : 96 | super(FactorizedIncrease, self).__init__() 97 | 98 | self._in_channel = in_channel 99 | self.op = nn.Sequential ( 100 | nn.Upsample(scale_factor=2, mode="bilinear"), 101 | nn.ReLU(inplace = False), 102 | nn.Conv2d(self._in_channel, out_channel, 1, stride=1, padding=0), 103 | nn.BatchNorm2d(out_channel) 104 | ) 105 | def forward (self, x) : 106 | return self.op (x) 107 | 108 | 109 | 110 | class ASPP(nn.Module): 111 | def __init__(self, in_channels, paddings, dilations, num_classes): 112 | # todo depthwise separable conv 113 | super(ASPP, self).__init__() 114 | self._num_classes =num_classes 115 | self.conv11 = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, bias=False), 116 | nn.BatchNorm2d(in_channels)) 117 | self.conv33 = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 118 | padding=paddings, dilation=dilations, bias=False), 119 | nn.BatchNorm2d(in_channels)) 120 | self.conv_p = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, bias=False), 121 | nn.BatchNorm2d(in_channels)) 122 | 123 | self.concate_conv = nn.Sequential(nn.Conv2d(in_channels * 3, self._num_classes, 1, stride=1, padding=0)) 124 | 125 | def forward(self, x): 126 | conv11 = self.conv11(x) 127 | conv33 = self.conv33(x) 128 | 129 | # image pool and upsample 130 | image_pool = nn.AvgPool2d(kernel_size=x.size()[2:]) 131 | image_pool = image_pool(x) 132 | upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True) 133 | upsample = upsample(image_pool) 134 | upsample = self.conv_p(upsample) 135 | 136 | 137 | # concate 138 | concate = torch.cat([conv11, conv33, upsample], dim=1) 139 | 140 | return self.concate_conv(concate) 141 | -------------------------------------------------------------------------------- /train_autodeeplab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch.nn as nn 6 | from mypath import Path 7 | from dataloaders import make_data_loader 8 | from modeling.sync_batchnorm.replicate import patch_replication_callback 9 | from modeling.deeplab import * 10 | from utils.loss import SegmentationLosses 11 | from utils.calculate_weights import calculate_weigths_labels 12 | from utils.lr_scheduler import LR_Scheduler 13 | from utils.saver import Saver 14 | from utils.summaries import TensorboardSummary 15 | from utils.metrics import Evaluator 16 | from auto_deeplab import AutoDeeplab 17 | from architect import Architect 18 | 19 | class Trainer(object): 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | # Define Saver 24 | self.saver = Saver(args) 25 | self.saver.save_experiment_config() 26 | # Define Tensorboard Summary 27 | self.summary = TensorboardSummary(self.saver.experiment_dir) 28 | self.writer = self.summary.create_summary() 29 | 30 | # Define Dataloader 31 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 32 | #self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 33 | self.train_loader1, self.train_loader2, self.val_loader, self.nclass = make_data_loader(args, **kwargs) 34 | 35 | # Define Criterion 36 | # whether to use class balanced weights 37 | if args.use_balanced_weights: 38 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') 39 | if os.path.isfile(classes_weights_path): 40 | weight = np.load(classes_weights_path) 41 | else: 42 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 43 | weight = torch.from_numpy(weight.astype(np.float32)) 44 | else: 45 | weight = None 46 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) 47 | 48 | # Define network 49 | model = AutoDeeplab (self.nclass, 12, self.criterion, crop_size=self.args.crop_size) 50 | optimizer = torch.optim.SGD( 51 | model.weight_parameters(), 52 | args.lr, 53 | momentum=args.momentum, 54 | weight_decay=args.weight_decay 55 | ) 56 | self.model, self.optimizer = model, optimizer 57 | 58 | # Using cuda 59 | if args.cuda: 60 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 61 | patch_replication_callback(self.model) 62 | self.model = self.model.cuda() 63 | print ('cuda finished') 64 | 65 | 66 | # Define Optimizer 67 | 68 | 69 | self.model, self.optimizer = model, optimizer 70 | 71 | # Define Evaluator 72 | self.evaluator = Evaluator(self.nclass) 73 | # Define lr scheduler 74 | self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, 75 | args.epochs, len(self.train_loader1)) 76 | 77 | self.architect = Architect (self.model, args) 78 | # Resuming checkpoint 79 | self.best_pred = 0.0 80 | if args.resume is not None: 81 | if not os.path.isfile(args.resume): 82 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 83 | checkpoint = torch.load(args.resume) 84 | args.start_epoch = checkpoint['epoch'] 85 | if args.cuda: 86 | self.model.load_state_dict(checkpoint['state_dict']) 87 | else: 88 | self.model.load_state_dict(checkpoint['state_dict']) 89 | if not args.ft: 90 | self.optimizer.load_state_dict(checkpoint['optimizer']) 91 | self.best_pred = checkpoint['best_pred'] 92 | print("=> loaded checkpoint '{}' (epoch {})" 93 | .format(args.resume, checkpoint['epoch'])) 94 | 95 | # Clear start epoch if fine-tuning 96 | if args.ft: 97 | args.start_epoch = 0 98 | 99 | def training(self, epoch): 100 | train_loss = 0.0 101 | self.model.train() 102 | tbar = tqdm(self.train_loader1) 103 | num_img_tr = len(self.train_loader1) 104 | for i, sample in enumerate(tbar): 105 | image, target = sample['image'], sample['label'] 106 | search = next (iter (self.train_loader2)) 107 | image_search, target_search = search['image'], search['label'] 108 | # print ('------------------------begin-----------------------') 109 | if self.args.cuda: 110 | image, target = image.cuda(), target.cuda() 111 | image_search, target_search = image_search.cuda (), target_search.cuda () 112 | # print ('cuda finish') 113 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 114 | self.optimizer.zero_grad() 115 | output = self.model(image) 116 | torch.cuda.empty_cache() 117 | loss = self.criterion(output, target) 118 | loss.backward() 119 | self.optimizer.step() 120 | if epoch>19: 121 | self.architect.step (image_search, target_search) 122 | train_loss += loss.item() 123 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 124 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 125 | 126 | # Show 10 * 3 inference results each epoch 127 | if i % (num_img_tr // 10) == 0: 128 | global_step = i + num_img_tr * epoch 129 | self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) 130 | 131 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 132 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 133 | print('Loss: %.3f' % train_loss) 134 | 135 | if self.args.no_val: 136 | # save checkpoint every epoch 137 | is_best = False 138 | self.saver.save_checkpoint({ 139 | 'epoch': epoch + 1, 140 | 'state_dict': self.model.state_dict(), 141 | 'optimizer': self.optimizer.state_dict(), 142 | 'best_pred': self.best_pred, 143 | }, is_best) 144 | 145 | 146 | def validation(self, epoch): 147 | self.model.eval() 148 | self.evaluator.reset() 149 | tbar = tqdm(self.val_loader, desc='\r') 150 | test_loss = 0.0 151 | for i, sample in enumerate(tbar): 152 | image, target = sample['image'], sample['label'] 153 | if self.args.cuda: 154 | image, target = image.cuda(), target.cuda() 155 | with torch.no_grad(): 156 | output = self.model(image) 157 | loss = self.criterion(output, target) 158 | test_loss += loss.item() 159 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 160 | pred = output.data.cpu().numpy() 161 | target = target.cpu().numpy() 162 | pred = np.argmax(pred, axis=1) 163 | # Add batch sample into evaluator 164 | self.evaluator.add_batch(target, pred) 165 | 166 | # Fast test during the training 167 | Acc = self.evaluator.Pixel_Accuracy() 168 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 169 | mIoU = self.evaluator.Mean_Intersection_over_Union() 170 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 171 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 172 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 173 | self.writer.add_scalar('val/Acc', Acc, epoch) 174 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 175 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 176 | print('Validation:') 177 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 178 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 179 | print('Loss: %.3f' % test_loss) 180 | 181 | new_pred = mIoU 182 | if new_pred > self.best_pred: 183 | is_best = True 184 | self.best_pred = new_pred 185 | self.saver.save_checkpoint({ 186 | 'epoch': epoch + 1, 187 | 'state_dict': self.model.state_dict(), 188 | 'optimizer': self.optimizer.state_dict(), 189 | 'best_pred': self.best_pred, 190 | }, is_best) 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 194 | parser.add_argument('--backbone', type=str, default='resnet', 195 | choices=['resnet', 'xception', 'drn', 'mobilenet'], 196 | help='backbone name (default: resnet)') 197 | parser.add_argument('--out_stride', type=int, default=16, 198 | help='network output stride (default: 8)') 199 | parser.add_argument('--dataset', type=str, default='pascal', 200 | choices=['pascal', 'coco', 'cityscapes'], 201 | help='dataset name (default: pascal)') 202 | parser.add_argument('--use_sbd', action='store_true', default=False, 203 | help='whether to use SBD dataset (default: True)') 204 | parser.add_argument('--workers', type=int, default=4, 205 | metavar='N', help='dataloader threads') 206 | parser.add_argument('--base_size', type=int, default=320, 207 | help='base image size') 208 | parser.add_argument('--crop_size', type=int, default=320, 209 | help='crop image size') 210 | parser.add_argument('--resize', type=int, default=512, 211 | help='resize image size') 212 | parser.add_argument('--sync_bn', type=bool, default=None, 213 | help='whether to use sync bn (default: auto)') 214 | parser.add_argument('--freeze_bn', type=bool, default=False, 215 | help='whether to freeze bn parameters (default: False)') 216 | parser.add_argument('--loss_type', type=str, default='ce', 217 | choices=['ce', 'focal'], 218 | help='loss func type (default: ce)') 219 | # training hyper params 220 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 221 | help='number of epochs to train (default: auto)') 222 | parser.add_argument('--start_epoch', type=int, default=0, 223 | metavar='N', help='start epochs (default:0)') 224 | parser.add_argument('--batch_size', type=int, default=None, 225 | metavar='N', help='input batch size for \ 226 | training (default: auto)') 227 | parser.add_argument('--test_batch_size', type=int, default=None, 228 | metavar='N', help='input batch size for \ 229 | testing (default: auto)') 230 | parser.add_argument('--use_balanced_weights', action='store_true', default=False, 231 | help='whether to use balanced weights (default: False)') 232 | # optimizer params 233 | parser.add_argument('--lr', type=float, default=0.025, metavar='LR', 234 | help='learning rate (default: auto)') 235 | parser.add_argument('--arch_lr', type=float, default=3e-3, 236 | help='learning rate for alpha and beta in architect searching process') 237 | 238 | parser.add_argument('--lr_scheduler', type=str, default='cos', 239 | choices=['poly', 'step', 'cos'], 240 | help='lr scheduler mode: (default: cos)') 241 | parser.add_argument('--momentum', type=float, default=0.9, 242 | metavar='M', help='momentum (default: 0.9)') 243 | parser.add_argument('--weight_decay', type=float, default=3e-4, 244 | metavar='M', help='w-decay (default: 5e-4)') 245 | parser.add_argument('--arch_weight_decay', type=float, default=1e-3, 246 | metavar='M', help='w-decay (default: 5e-4)') 247 | 248 | parser.add_argument('--nesterov', action='store_true', default=False, 249 | help='whether use nesterov (default: False)') 250 | # cuda, seed and logging 251 | parser.add_argument('--no_cuda', action='store_true', default= 252 | False, help='disables CUDA training') 253 | parser.add_argument('--gpu-ids', nargs='*', type=int, default=0, 254 | help='which GPU to train on (default: 0)') 255 | parser.add_argument('--seed', type=int, default=1, metavar='S', 256 | help='random seed (default: 1)') 257 | # checking point 258 | parser.add_argument('--resume', type=str, default=None, 259 | help='put the path to resuming file if needed') 260 | parser.add_argument('--checkname', type=str, default=None, 261 | help='set the checkpoint name') 262 | # finetuning pre-trained models 263 | parser.add_argument('--ft', action='store_true', default=False, 264 | help='finetuning on a different dataset') 265 | # evaluation option 266 | parser.add_argument('--eval_interval', type=int, default=1, 267 | help='evaluuation interval (default: 1)') 268 | parser.add_argument('--no_val', action='store_true', default=False, 269 | help='skip validation during training') 270 | 271 | args = parser.parse_args() 272 | args.cuda = not args.no_cuda and torch.cuda.is_available() 273 | 274 | if args.sync_bn is None: 275 | if args.cuda and len(args.gpu_ids) > 1: 276 | args.sync_bn = True 277 | else: 278 | args.sync_bn = False 279 | 280 | # default settings for epochs, batch_size and lr 281 | if args.epochs is None: 282 | epoches = { 283 | 'coco': 30, 284 | 'cityscapes': 200, 285 | 'pascal': 50, 286 | } 287 | args.epochs = epoches[args.dataset.lower()] 288 | 289 | if args.batch_size is None: 290 | args.batch_size = 4 * len(args.gpu_ids) 291 | 292 | if args.test_batch_size is None: 293 | args.test_batch_size = args.batch_size 294 | 295 | if args.lr is None: 296 | lrs = { 297 | 'coco': 0.1, 298 | 'cityscapes': 0.025, 299 | 'pascal': 0.007, 300 | } 301 | #args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 302 | 303 | 304 | if args.checkname is None: 305 | args.checkname = 'deeplab-'+str(args.backbone) 306 | print(args) 307 | torch.manual_seed(args.seed) 308 | trainer = Trainer(args) 309 | print('Starting Epoch:', trainer.args.start_epoch) 310 | print('Total Epoches:', trainer.args.epochs) 311 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 312 | trainer.training(epoch) 313 | if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1): 314 | trainer.validation(epoch) 315 | 316 | trainer.writer.close() 317 | 318 | if __name__ == "__main__": 319 | main() 320 | -------------------------------------------------------------------------------- /train_voc.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_autodeeplab.py --backbone resnet --lr 0.007 --workers 4 --epochs 40 --batch_size 2 --gpu_ids 0,1,2,3 --eval_interval 1 --dataset cityscapes 2 | --------------------------------------------------------------------------------