├── .gitignore ├── README.md ├── data ├── __init__.py ├── augmentations.py ├── base.py ├── city_utils.py ├── cityscapes_loader.py ├── dataset_processing.py ├── pcontext_loader.py ├── voc_dataset.py └── voc_list │ ├── train_aug.txt │ ├── train_img.txt │ ├── train_label.txt │ ├── val.txt │ ├── val_img.txt │ └── val_label.txt ├── evaluate.py ├── mlmt_output └── output_ema_p_1_0_voc_5.txt ├── model ├── deeplabv2.py ├── deeplabv3p.py ├── discriminator.py ├── utils.py └── wider_resnet.py ├── splits ├── city │ └── split_0.pkl ├── pc │ └── split_0.pkl └── voc │ └── split_0.pkl ├── train_full.py ├── train_mlmt.py ├── train_s4GAN.py ├── train_s4GAN_wrn38.py └── utils ├── loss.py ├── metric.py ├── misc.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/voc_dataset/* 2 | data/city_dataset/* 3 | data/pcontext_dataset/* 4 | checkpoints/* 5 | pretrained_models/* 6 | 7 | *__pycache__* 8 | */__pycache__ 9 | 10 | train_pascal* 11 | train_city* 12 | train_pc* 13 | train_voc* 14 | 15 | evaluate_* 16 | load_* 17 | results/* 18 | mlmt_output/ 19 | reports/ 20 | 21 | mlmt_data.py 22 | mlmt_v1.py 23 | mlmt_v2.py 24 | mlmt_feat.py 25 | *.sh 26 | myprog* 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-supevised Semantic Segmentation with High- and Low-level Consistency 2 | 3 | This Pytorch repository contains the code for our work [Semi-supervised Semantic Segmentation with High- and Low-level Consistency](https://arxiv.org/pdf/1908.05724.pdf). The approach uses two network branches that link semi-supervised classification with semi-supervised segmentation including self-training. The approach attains significant improvement over existing methods, especially when trained with very few labeled samples. On several standard benchmarks - PASCAL VOC 2012,PASCAL-Context, and Cityscapes - the approach achieves new state-of-the-art in semi-supervised learning. 4 | 5 | We propose a two-branch approach to the task of semi-supervised semantic segmentation. The lower branch predicts pixel-wise class labels and is referred to as the Semi-Supervised Semantic Segmentation GAN(s4GAN). The upper branch performs image-level classification and is denoted as the Multi-Label Mean Teacher(MLMT). 6 | 7 | Here, this repository contains the source code for the s4GAN branch. MLMT branch is adapted from Mean-Teacher work for semi-supervised classification. Instructions for setting up the MLMT branch are given below. 8 | 9 | 10 | ## Package pre-requisites 11 | The code runs on Python 3 and Pytorch 0.4 The following packages are required. 12 | 13 | ``` 14 | pip install scipy tqdm matplotlib numpy opencv-python 15 | ``` 16 | 17 | ## Dataset preparation 18 | 19 | Download ImageNet pretrained Resnet-101([Link](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth)) and place it ```./pretrained_models/``` 20 | 21 | ### PASCAL VOC 22 | Download the dataset([Link](https://lmb.informatik.uni-freiburg.de/resources/datasets/voc_dataset.tar.gz)) and extract in ```./data/voc_dataset/``` 23 | 24 | ### PASCAL Context 25 | Download the annotations([Link](https://lmb.informatik.uni-freiburg.de/resources/datasets/pascal_context_labels.tar.gz)) and extract in ```./data/pcontext_dataset/``` 26 | 27 | ### Cityscapes 28 | Download the dataset from the Cityscapes dataset server([Link](https://www.cityscapes-dataset.com/)). Download the files named 'gtFine_trainvaltest.zip', 'leftImg8bit_trainvaltest.zip' and extract in ```./data/city_dataset/``` 29 | 30 | ## Training and Validation on PASCAL-VOC Dataset 31 | 32 | Results in the paper are averaged over 3 random splits. Same splits are used for reporting baseline performance for fair comparison. 33 | 34 | ### Training fully-supervised Baseline (FSL) 35 | ``` 36 | python train_full.py --dataset pascal_voc \ 37 | --checkpoint-dir ./checkpoints/voc_full \ 38 | --ignore-label 255 \ 39 | --num-classes 21 40 | ``` 41 | ### Training semi-supervised s4GAN (SSL) 42 | ``` 43 | python train_s4GAN.py --dataset pascal_voc \ 44 | --checkpoint-dir ./checkpoints/voc_semi_0_125 \ 45 | --labeled-ratio 0.125 \ 46 | --ignore-label 255 \ 47 | --num-classes 21 48 | ``` 49 | ### Validation 50 | ``` 51 | python evaluate.py --dataset pascal_voc \ 52 | --num-classes 21 \ 53 | --restore-from ./checkpoints/voc_semi_0_125/VOC_30000.pth 54 | ``` 55 | 56 | ### Training MLMT Branch 57 | 58 | ``` 59 | python train_mlmt.py \ 60 | --batch-size-lab 16 \ 61 | --batch-size-unlab 80 \ 62 | --labeled-ratio 0.125 \ 63 | --exp-name voc_semi_0_125_MLMT \ 64 | --pkl-file ./checkpoints/voc_semi_0_125/train_voc_split.pkl 65 | ``` 66 | 67 | ### Final Evaluation S4GAN + MLMT 68 | ``` 69 | python evaluate.py --dataset pascal_voc \ 70 | --num-classes 21 \ 71 | --restore-from ./checkpoints/voc_semi_0_125/VOC_30000.pth \ 72 | --with-mlmt \ 73 | --mlmt-file ./mlmt_output/voc_semi_0_125_MLMT/output_ema_raw_100.txt 74 | 75 | ``` 76 | 77 | ## Training and Validation on PASCAL-Context Dataset 78 | ``` 79 | python train_full.py --dataset pascal_context \ 80 | --checkpoint-dir ./checkpoints/pc_full \ 81 | --ignore-label -1 \ 82 | --num-classes 60 83 | 84 | python train_s4GAN.py --dataset pascal_context \ 85 | --checkpoint-dir ./checkpoints/pc_semi_0_125 \ 86 | --labeled-ratio 0.125 \ 87 | --ignore-label -1 \ 88 | --num-classes 60 \ 89 | --split-id ./splits/pc/split_0.pkl 90 | --num-steps 60000 91 | 92 | python evaluate.py --dataset pascal_context \ 93 | --num-classes 60 \ 94 | --restore-from ./checkpoints/pc_semi_0_125/VOC_40000.pth 95 | ``` 96 | 97 | ## Training and Validation on Cityscapes Dataset 98 | ``` 99 | python train_full.py --dataset cityscapes \ 100 | --checkpoint-dir ./checkpoints/city_full_0_125 \ 101 | --ignore-label 250 \ 102 | --num-classes 19 \ 103 | --input-size '256,512' 104 | 105 | python train_s4GAN.py --dataset cityscapes \ 106 | --checkpoint-dir ./checkpoints/city_semi_0_125 \ 107 | --labeled-ratio 0.125 \ 108 | --ignore-label 250 \ 109 | --num-classes 19 \ 110 | --split-id ./splits/city/split_0.pkl \ 111 | --input-size '256,512' \ 112 | --threshold-st 0.7 \ 113 | --learning-rate-D 1e-5 114 | 115 | python evaluate.py --dataset cityscapes \ 116 | --num-classes 19 \ 117 | --restore-from ./checkpoints/city_semi_0_125/VOC_30000.pth 118 | ``` 119 | ##Training and validation on Cityscapes with DeepLabv3+ (WRN-38 backbone) 120 | ``` 121 | python train_s4GAN_wrn38.py --dataset cityscapes \ 122 | --checkpoint-dir ./checkpoints/city_semi_v3_wrn38_0_10 \ 123 | --labeled-ratio 0.10 \ 124 | --ignore-label 250 \ 125 | --num-classes 19 \ 126 | --split-id ./splits/city/split_0.pkl \ 127 | --input-size '256,512' \ 128 | --threshold-st 0.55 \ 129 | --learning-rate-D 1e-5 \ 130 | --learning-rate 1e-3 \ 131 | --out results/city_semi_v3_wrn38_0_10 132 | ``` 133 | 134 | ## Acknowledgement 135 | 136 | Parts of the code have been adapted from: 137 | [DeepLab-Resnet-Pytorch](https://github.com/speedinghzl/Pytorch-Deeplab), [AdvSemiSeg](https://github.com/hfslyc/AdvSemiSeg), [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 138 | 139 | 140 | ## Citation 141 | 142 | 143 | 144 | ``` 145 | @ARTICLE{8935407, 146 | author={S. {Mittal} and M. {Tatarchenko} and T. {Brox}}, 147 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 148 | title={Semi-Supervised Semantic Segmentation With High- and Low-Level Consistency}, 149 | year={2021}, 150 | volume={43}, 151 | number={4}, 152 | pages={1369-1379}, 153 | doi={10.1109/TPAMI.2019.2960224}} 154 | ``` 155 | 156 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from data.base import * 4 | from data.pcontext_loader import ContextSegmentation 5 | from data.cityscapes_loader import cityscapesLoader 6 | 7 | def get_loader(name): 8 | """get_loader 9 | :param name: 10 | """ 11 | return { 12 | "cityscapes": cityscapesLoader, 13 | "pascal_context": ContextSegmentation, 14 | }[name] 15 | 16 | def get_data_path(name): 17 | """get_data_path 18 | :param name: 19 | :param config_file: 20 | """ 21 | if name == 'cityscapes': 22 | return './data/city_dataset/' 23 | if name == 'pascal_context': 24 | return './data/' 25 | -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/ZijunDeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 2 | 3 | import math 4 | import numbers 5 | import random 6 | import numpy as np 7 | 8 | from PIL import Image, ImageOps 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, augmentations): 13 | self.augmentations = augmentations 14 | 15 | def __call__(self, img, mask): 16 | img, mask = Image.fromarray(img, mode="RGB"), Image.fromarray(mask, mode="L") 17 | assert img.size == mask.size 18 | for a in self.augmentations: 19 | img, mask = a(img, mask) 20 | return np.array(img), np.array(mask, dtype=np.uint8) 21 | 22 | 23 | class RandomCrop(object): 24 | def __init__(self, size, padding=0): 25 | #if isinstance(size, numbers.Number): 26 | #self.size = (int(size), int(size)) 27 | #else: 28 | #self.size = size 29 | self.size = tuple(size) 30 | self.padding = padding 31 | 32 | def __call__(self, img, mask): 33 | if self.padding > 0: 34 | img = ImageOps.expand(img, border=self.padding, fill=0) 35 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 36 | 37 | assert img.size == mask.size 38 | w, h = img.size 39 | th, tw = self.size 40 | if w == tw and h == th: 41 | return img, mask 42 | if w < tw or h < th: 43 | return ( 44 | img.resize((tw, th), Image.BILINEAR), 45 | mask.resize((tw, th), Image.NEAREST), 46 | ) 47 | 48 | x1 = random.randint(0, w - tw) 49 | y1 = random.randint(0, h - th) 50 | return ( 51 | img.crop((x1, y1, x1 + tw, y1 + th)), 52 | mask.crop((x1, y1, x1 + tw, y1 + th)), 53 | ) 54 | 55 | 56 | class RandomCrop_city(object): # used for results in the CVPR-19 submission 57 | def __init__(self, size, padding=0): 58 | #if isinstance(size, numbers.Number): 59 | #self.size = (int(size), int(size)) 60 | #else: 61 | #self.size = size 62 | self.size = tuple(size) 63 | self.padding = padding 64 | 65 | def __call__(self, img, mask): 66 | if self.padding > 0: 67 | img = ImageOps.expand(img, border=self.padding, fill=0) 68 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 69 | 70 | assert img.size == mask.size 71 | w, h = img.size 72 | th, tw = self.size 73 | ''' 74 | if w == tw and h == th: 75 | return img, mask 76 | if w < tw or h < th: 77 | return ( 78 | img.resize((tw, th), Image.BILINEAR), 79 | mask.resize((tw, th), Image.NEAREST), 80 | ) 81 | ''' 82 | img = img.resize((int(w/2), int(h/2)), Image.BILINEAR) 83 | mask = mask.resize((int(w/2), int(h/2)), Image.NEAREST) 84 | #img = img.resize((600, 300), Image.BILINEAR) 85 | #mask = mask.resize((600, 300), Image.NEAREST) 86 | #img = img.resize((512, 256), Image.BILINEAR) 87 | #mask = mask.resize((512, 256), Image.NEAREST) 88 | 89 | x1 = random.randint(0, int(w/2) - tw) 90 | y1 = random.randint(0, int(h/2) - th) 91 | 92 | return ( 93 | img.crop((x1, y1, x1 + tw, y1 + th)), 94 | mask.crop((x1, y1, x1 + tw, y1 + th)), 95 | ) 96 | 97 | 98 | class RandomCrop_city_gnet(object): # used for gnet training 99 | def __init__(self, size, padding=0): 100 | #if isinstance(size, numbers.Number): 101 | #self.size = (int(size), int(size)) 102 | #else: 103 | #self.size = size 104 | self.size = tuple(size) 105 | self.padding = padding 106 | 107 | def __call__(self, img, mask): 108 | if self.padding > 0: 109 | img = ImageOps.expand(img, border=self.padding, fill=0) 110 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 111 | 112 | assert img.size == mask.size 113 | w, h = img.size 114 | th, tw = self.size 115 | ''' 116 | if w == tw and h == th: 117 | return img, mask 118 | if w < tw or h < th: 119 | return ( 120 | img.resize((tw, th), Image.BILINEAR), 121 | mask.resize((tw, th), Image.NEAREST), 122 | ) 123 | ''' 124 | #img = img.resize((int(w/2), int(h/2)), Image.BILINEAR) 125 | #mask = mask.resize((int(w/2), int(h/2)), Image.NEAREST) 126 | img = img.resize((600, 300), Image.BILINEAR) 127 | mask = mask.resize((600, 300), Image.NEAREST) 128 | #img = img.resize((512, 256), Image.BILINEAR) 129 | #mask = mask.resize((512, 256), Image.NEAREST) 130 | 131 | x1 = random.randint(0, 600 - tw) 132 | y1 = random.randint(0, 300 - th) 133 | 134 | return ( 135 | img.crop((x1, y1, x1 + tw, y1 + th)), 136 | mask.crop((x1, y1, x1 + tw, y1 + th)), 137 | ) 138 | 139 | class CenterCrop(object): 140 | def __init__(self, size): 141 | ''' 142 | if isinstance(size, numbers.Number): 143 | self.size = (int(size), int(size)) 144 | else: 145 | self.size = size 146 | ''' 147 | self.size = tuple(size) 148 | 149 | def __call__(self, img, mask): 150 | assert img.size == mask.size 151 | w, h = img.size 152 | th, tw = self.size 153 | x1 = int(round((w - tw) / 2.)) 154 | y1 = int(round((h - th) / 2.)) 155 | return ( 156 | img.crop((x1, y1, x1 + tw, y1 + th)), 157 | mask.crop((x1, y1, x1 + tw, y1 + th)), 158 | ) 159 | 160 | 161 | class RandomHorizontallyFlip(object): 162 | def __call__(self, img, mask): 163 | if random.random() < 0.5: 164 | return ( 165 | img.transpose(Image.FLIP_LEFT_RIGHT), 166 | mask.transpose(Image.FLIP_LEFT_RIGHT), 167 | ) 168 | return img, mask 169 | 170 | 171 | class FreeScale(object): 172 | def __init__(self, size): 173 | self.size = tuple(reversed(size)) # size: (h, w) 174 | 175 | def __call__(self, img, mask): 176 | assert img.size == mask.size 177 | return ( 178 | img.resize(self.size, Image.BILINEAR), 179 | mask.resize(self.size, Image.NEAREST), 180 | ) 181 | 182 | 183 | class Scale(object): 184 | def __init__(self, size): 185 | self.size = tuple(size) 186 | 187 | def __call__(self, img, mask): 188 | assert img.size == mask.size 189 | f_w, f_h = self.size 190 | w, h = img.size 191 | if (w >= h and w == f_w) or (h >= w and h == f_h): 192 | return img, mask 193 | if w > h: 194 | ow = f_w 195 | oh = int(f_w * h / w) 196 | return ( 197 | img.resize((ow, oh), Image.BILINEAR), 198 | mask.resize((ow, oh), Image.NEAREST), 199 | ) 200 | else: 201 | oh = f_h 202 | ow = int(f_h * w / h) 203 | return ( 204 | img.resize((ow, oh), Image.BILINEAR), 205 | mask.resize((ow, oh), Image.NEAREST), 206 | ) 207 | 208 | 209 | class RSCrop(object): 210 | def __init__(self, size): 211 | self.size = size 212 | #self.size = tuple(size) 213 | 214 | def __call__(self, img, mask): 215 | assert img.size == mask.size 216 | #for attempt in range(10): 217 | #random scale (0.5 to 2.0) 218 | crop_size = self.size 219 | short_size = random.randint(int(self.size*0.5), int(self.size*2.0)) 220 | w, h = img.size 221 | if h > w: 222 | ow = short_size 223 | oh = int(1.0 * h * ow / w) 224 | else: 225 | oh = short_size 226 | ow = int(1.0 * w * oh / h) 227 | img = img.resize((ow, oh), Image.BILINEAR) 228 | mask = mask.resize((ow, oh), Image.NEAREST) 229 | 230 | #deg = random.uniform(-10, 10) 231 | #img = img.rotate(deg, resample=Image.BILINEAR) 232 | #mask = mask.rotate(deg, resample=Image.NEAREST) 233 | # pad crop 234 | if short_size < crop_size: 235 | padh = crop_size - oh if oh < crop_size else 0 236 | padw = crop_size - ow if ow < crop_size else 0 237 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 238 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 239 | 240 | # random crop crop_size 241 | w, h = img.size 242 | x1 = random.randint(0, w - crop_size) 243 | y1 = random.randint(0, h - crop_size) 244 | img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) 245 | mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) 246 | 247 | return img, mask 248 | 249 | 250 | class RSCrop_city(object): 251 | def __init__(self, size): 252 | #self.size = size 253 | self.size = tuple(size) 254 | self.base_size = 1024 255 | 256 | def __call__(self, img, mask): 257 | assert img.size == mask.size 258 | #for attempt in range(10): 259 | #random scale (0.5 to 2.0) 260 | #crop_size = self.size 261 | short_size = random.randint(int(self.base_size*0.25), int(self.base_size*1.0)) 262 | w, h = img.size 263 | if h > w: 264 | ow = short_size 265 | oh = int(1.0 * h * ow / w) 266 | else: 267 | oh = short_size 268 | ow = int(1.0 * w * oh / h) 269 | img = img.resize((ow, oh), Image.BILINEAR) 270 | mask = mask.resize((ow, oh), Image.NEAREST) 271 | 272 | deg = random.uniform(-10, 10) 273 | img = img.rotate(deg, resample=Image.BILINEAR) 274 | mask = mask.rotate(deg, resample=Image.NEAREST) 275 | 276 | ''' 277 | # pad crop 278 | #if short_size < crop_size: 279 | padh = crop_size - oh if oh < crop_size else 0 280 | padw = crop_size - ow if ow < crop_size else 0 281 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 282 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 283 | ''' 284 | # random crop crop_size 285 | #w, h = img.size 286 | x1 = random.randint(0, w - self.size[0]) 287 | y1 = random.randint(0, h - self.size[1]) 288 | img = img.crop((x1, y1, x1+self.size[0], y1+self.size[1])) 289 | mask = mask.crop((x1, y1, x1+self.size[0], y1+self.size[1])) 290 | 291 | return img, mask 292 | 293 | class RandomSizedCrop(object): 294 | def __init__(self, size): 295 | #self.size = size 296 | self.size = tuple(size) 297 | 298 | def __call__(self, img, mask): 299 | assert img.size == mask.size 300 | for attempt in range(10): 301 | area = img.size[0] * img.size[1] 302 | target_area = random.uniform(0.45, 1.0) * area 303 | aspect_ratio = random.uniform(0.5, 2) 304 | 305 | w = int(round(math.sqrt(target_area * aspect_ratio))) 306 | h = int(round(math.sqrt(target_area / aspect_ratio))) 307 | 308 | f_w, f_h = self.size 309 | 310 | if random.random() < 0.5: 311 | w, h = h, w 312 | 313 | if w <= img.size[0] and h <= img.size[1]: 314 | x1 = random.randint(0, img.size[0] - w) 315 | y1 = random.randint(0, img.size[1] - h) 316 | 317 | img = img.crop((x1, y1, x1 + w, y1 + h)) 318 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 319 | assert img.size == (w, h) 320 | 321 | return ( 322 | img.resize((f_w, f_h), Image.BILINEAR), 323 | mask.resize((f_w, f_h), Image.NEAREST), 324 | ) 325 | 326 | # Fallback 327 | scale = Scale(self.size) 328 | crop = CenterCrop(self.size) 329 | return crop(*scale(img, mask)) 330 | 331 | 332 | class RandomRotate(object): 333 | def __init__(self, degree): 334 | self.degree = degree 335 | 336 | def __call__(self, img, mask): 337 | rotate_degree = random.random() * 2 * self.degree - self.degree 338 | return ( 339 | img.rotate(rotate_degree, Image.BILINEAR), 340 | mask.rotate(rotate_degree, Image.NEAREST), 341 | ) 342 | 343 | 344 | class RandomSized(object): 345 | def __init__(self, size): 346 | self.size = size 347 | self.scale = Scale(self.size) 348 | self.crop = RandomCrop(self.size) 349 | 350 | def __call__(self, img, mask): 351 | assert img.size == mask.size 352 | 353 | w = int(random.uniform(0.5, 2) * img.size[0]) 354 | h = int(random.uniform(0.5, 2) * img.size[1]) 355 | 356 | img, mask = ( 357 | img.resize((w, h), Image.BILINEAR), 358 | mask.resize((w, h), Image.NEAREST), 359 | ) 360 | 361 | return self.crop(*self.scale(img, mask)) 362 | -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | 7 | import random 8 | import numpy as np 9 | from PIL import Image, ImageOps, ImageFilter 10 | import torch 11 | import torch.utils.data as data 12 | 13 | __all__ = ['BaseDataset', 'test_batchify_fn'] 14 | 15 | class BaseDataset(data.Dataset): 16 | def __init__(self, root, split, mode=None, transform=None, 17 | target_transform=None, base_size=520, crop_size=480): 18 | self.root = root 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | self.split = split 22 | self.mode = mode if mode is not None else split 23 | self.base_size = base_size 24 | self.crop_size = crop_size 25 | if self.mode == 'train': 26 | print('BaseDataset: base_size {}, crop_size {}'. \ 27 | format(base_size, crop_size)) 28 | 29 | def __getitem__(self, index): 30 | raise NotImplemented 31 | 32 | @property 33 | def num_class(self): 34 | return self.NUM_CLASS 35 | 36 | @property 37 | def pred_offset(self): 38 | raise NotImplemented 39 | 40 | def make_pred(self, x): 41 | return x + self.pred_offset 42 | 43 | def _val_sync_transform(self, img, mask): 44 | outsize = self.crop_size 45 | short_size = outsize 46 | w, h = img.size 47 | if w > h: 48 | oh = short_size 49 | ow = int(1.0 * w * oh / h) 50 | else: 51 | ow = short_size 52 | oh = int(1.0 * h * ow / w) 53 | img = img.resize((ow, oh), Image.BILINEAR) 54 | mask = mask.resize((ow, oh), Image.NEAREST) 55 | # center crop 56 | w, h = img.size 57 | x1 = int(round((w - outsize) / 2.)) 58 | y1 = int(round((h - outsize) / 2.)) 59 | img = img.crop((x1, y1, x1+outsize, y1+outsize)) 60 | mask = mask.crop((x1, y1, x1+outsize, y1+outsize)) 61 | # final transform 62 | return img, self._mask_transform(mask) 63 | 64 | def _sync_transform(self, img, mask): 65 | # random mirror 66 | if random.random() < 0.5: 67 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 68 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 69 | crop_size = self.crop_size 70 | # random scale (short edge from 480 to 720) 71 | short_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) 72 | w, h = img.size 73 | if h > w: 74 | ow = short_size 75 | oh = int(1.0 * h * ow / w) 76 | else: 77 | oh = short_size 78 | ow = int(1.0 * w * oh / h) 79 | img = img.resize((ow, oh), Image.BILINEAR) 80 | mask = mask.resize((ow, oh), Image.NEAREST) 81 | # pad crop 82 | if short_size < crop_size: 83 | padh = crop_size - oh if oh < crop_size else 0 84 | padw = crop_size - ow if ow < crop_size else 0 85 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 86 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 87 | # random crop crop_size 88 | w, h = img.size 89 | x1 = random.randint(0, w - crop_size) 90 | y1 = random.randint(0, h - crop_size) 91 | img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) 92 | mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) 93 | # gaussian blur as in PSP 94 | if random.random() < 0.5: 95 | img = img.filter(ImageFilter.GaussianBlur( 96 | radius=random.random())) 97 | # final transform 98 | return img, self._mask_transform(mask) 99 | 100 | def _mask_transform(self, mask): 101 | return torch.from_numpy(np.array(mask)).long() 102 | 103 | 104 | def test_batchify_fn(data): 105 | error_msg = "batch must contain tensors, tuples or lists; found {}" 106 | if isinstance(data[0], (str, torch.Tensor)): 107 | return list(data) 108 | elif isinstance(data[0], (tuple, list)): 109 | data = zip(*data) 110 | return [test_batchify_fn(i) for i in data] 111 | raise TypeError((error_msg.format(type(batch[0])))) 112 | -------------------------------------------------------------------------------- /data/city_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | from collections import OrderedDict 5 | import os 6 | import numpy as np 7 | 8 | 9 | def recursive_glob(rootdir=".", suffix=""): 10 | """Performs recursive glob with given suffix and rootdir 11 | :param rootdir is the root directory 12 | :param suffix is the suffix to be searched 13 | """ 14 | return [ 15 | os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames 18 | if filename.endswith(suffix) 19 | ] 20 | 21 | 22 | def poly_lr_scheduler( 23 | optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9 24 | ): 25 | """Polynomial decay of learning rate 26 | :param init_lr is base learning rate 27 | :param iter is a current iteration 28 | :param lr_decay_iter how frequently decay occurs, default is 1 29 | :param max_iter is number of maximum iterations 30 | :param power is a polymomial power 31 | 32 | """ 33 | if iter % lr_decay_iter or iter > max_iter: 34 | return optimizer 35 | 36 | for param_group in optimizer.param_groups: 37 | param_group["lr"] = init_lr * (1 - iter / max_iter) ** power 38 | 39 | 40 | def adjust_learning_rate(optimizer, init_lr, epoch): 41 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 42 | lr = init_lr * (0.1 ** (epoch // 30)) 43 | for param_group in optimizer.param_groups: 44 | param_group["lr"] = lr 45 | 46 | 47 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 48 | """Alpha Blending utility to overlay RGB masks on RBG images 49 | :param input_image is a np.ndarray with 3 channels 50 | :param segmentation_mask is a np.ndarray with 3 channels 51 | :param alpha is a float value 52 | 53 | """ 54 | blended = np.zeros(input_image.size, dtype=np.float32) 55 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 56 | return blended 57 | 58 | 59 | def convert_state_dict(state_dict): 60 | """Converts a state dict saved from a dataParallel module to normal 61 | module state_dict inplace 62 | :param state_dict is the loaded DataParallel model_state 63 | 64 | """ 65 | new_state_dict = OrderedDict() 66 | for k, v in state_dict.items(): 67 | name = k[7:] # remove `module.` 68 | new_state_dict[name] = v 69 | return new_state_dict 70 | -------------------------------------------------------------------------------- /data/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from data.city_utils import recursive_glob 9 | from data.augmentations import * 10 | 11 | class cityscapesLoader(data.Dataset): 12 | """cityscapesLoader 13 | 14 | https://www.cityscapes-dataset.com 15 | 16 | Data is derived from CityScapes, and can be downloaded from here: 17 | https://www.cityscapes-dataset.com/downloads/ 18 | 19 | Many Thanks to @fvisin for the loader repo: 20 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 21 | """ 22 | 23 | colors = [ # [ 0, 0, 0], 24 | [128, 64, 128], 25 | [244, 35, 232], 26 | [70, 70, 70], 27 | [102, 102, 156], 28 | [190, 153, 153], 29 | [153, 153, 153], 30 | [250, 170, 30], 31 | [220, 220, 0], 32 | [107, 142, 35], 33 | [152, 251, 152], 34 | [0, 130, 180], 35 | [220, 20, 60], 36 | [255, 0, 0], 37 | [0, 0, 142], 38 | [0, 0, 70], 39 | [0, 60, 100], 40 | [0, 80, 100], 41 | [0, 0, 230], 42 | [119, 11, 32], 43 | ] 44 | 45 | label_colours = dict(zip(range(19), colors)) 46 | 47 | mean_rgb = {"cityscapes": [73.15835921, 82.90891754, 72.39239876],} 48 | 49 | def __init__( 50 | self, 51 | root, 52 | split="train", 53 | is_transform=False, 54 | img_size=(256, 512), 55 | img_norm=True, 56 | augmentations=None, 57 | version="cityscapes", 58 | ): 59 | """__init__ 60 | 61 | :param root: 62 | :param split: 63 | :param is_transform: 64 | :param img_size: 65 | :param augmentations 66 | """ 67 | self.root = root 68 | self.split = split 69 | self.is_transform = is_transform 70 | self.augmentations = augmentations 71 | self.img_norm = img_norm 72 | self.n_classes = 19 73 | self.img_size = ( 74 | img_size if isinstance(img_size, tuple) else (img_size, img_size) 75 | ) 76 | self.mean = np.array(self.mean_rgb[version]) 77 | self.files = {} 78 | 79 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 80 | self.annotations_base = os.path.join( 81 | self.root, "gtFine_trainvaltest", "gtFine", self.split 82 | ) 83 | 84 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png") 85 | 86 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 87 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,] 88 | self.class_names = [ 89 | "unlabelled", 90 | "road", 91 | "sidewalk", 92 | "building", 93 | "wall", 94 | "fence", 95 | "pole", 96 | "traffic_light", 97 | "traffic_sign", 98 | "vegetation", 99 | "terrain", 100 | "sky", 101 | "person", 102 | "rider", 103 | "car", 104 | "truck", 105 | "bus", 106 | "train", 107 | "motorcycle", 108 | "bicycle", 109 | ] 110 | 111 | self.ignore_index = 250 112 | self.class_map = dict(zip(self.valid_classes, range(19))) 113 | 114 | if not self.files[split]: 115 | raise Exception( 116 | "No files for split=[%s] found in %s" % (split, self.images_base) 117 | ) 118 | 119 | print("Found %d %s images" % (len(self.files[split]), split)) 120 | 121 | def __len__(self): 122 | """__len__""" 123 | return len(self.files[self.split]) 124 | 125 | def __getitem__(self, index): 126 | """__getitem__ 127 | 128 | :param index: 129 | """ 130 | img_path = self.files[self.split][index].rstrip() 131 | lbl_path = os.path.join( 132 | self.annotations_base, 133 | img_path.split(os.sep)[-2], # temporary for cross validation 134 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 135 | ) 136 | 137 | img = m.imread(img_path) 138 | img = np.array(img, dtype=np.uint8) 139 | 140 | lbl = m.imread(lbl_path) 141 | lbl = np.array(lbl, dtype=np.uint8) 142 | lbl = self.encode_segmap(lbl) 143 | 144 | if self.augmentations is not None: 145 | img, lbl = self.augmentations(img, lbl) 146 | 147 | if self.is_transform: 148 | img, lbl = self.transform(img, lbl) 149 | 150 | img_name = img_path.split('/')[-1] 151 | return img, lbl, img_name, img_name, img_name 152 | 153 | def transform(self, img, lbl): 154 | """transform 155 | 156 | :param img: 157 | :param lbl: 158 | """ 159 | img = m.imresize( 160 | img, (self.img_size[0], self.img_size[1]) 161 | ) # uint8 with RGB mode 162 | img = img[:, :, ::-1] # RGB -> BGR 163 | img = img.astype(np.float64) 164 | img -= self.mean 165 | if self.img_norm: 166 | # Resize scales images from 0 to 255, thus we need 167 | # to divide by 255.0 168 | img = img.astype(float) / 255.0 169 | # NHWC -> NCHW 170 | img = img.transpose(2, 0, 1) 171 | 172 | classes = np.unique(lbl) 173 | lbl = lbl.astype(float) 174 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 175 | lbl = lbl.astype(int) 176 | 177 | if not np.all(classes == np.unique(lbl)): 178 | print("WARN: resizing labels yielded fewer classes") 179 | 180 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 181 | print("after det", classes, np.unique(lbl)) 182 | raise ValueError("Segmentation map contained invalid class values") 183 | 184 | img = torch.from_numpy(img).float() 185 | lbl = torch.from_numpy(lbl).long() 186 | 187 | return img, lbl 188 | 189 | def decode_segmap(self, temp): 190 | r = temp.copy() 191 | g = temp.copy() 192 | b = temp.copy() 193 | for l in range(0, self.n_classes): 194 | r[temp == l] = self.label_colours[l][0] 195 | g[temp == l] = self.label_colours[l][1] 196 | b[temp == l] = self.label_colours[l][2] 197 | 198 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 199 | rgb[:, :, 0] = r / 255.0 200 | rgb[:, :, 1] = g / 255.0 201 | rgb[:, :, 2] = b / 255.0 202 | return rgb 203 | 204 | def encode_segmap(self, mask): 205 | # Put all void classes to zero 206 | for _voidc in self.void_classes: 207 | mask[mask == _voidc] = self.ignore_index 208 | for _validc in self.valid_classes: 209 | mask[mask == _validc] = self.class_map[_validc] 210 | return mask 211 | 212 | ''' 213 | if __name__ == "__main__": 214 | import torchvision 215 | import matplotlib.pyplot as plt 216 | 217 | augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip()]) 218 | 219 | local_path = "./data/city_dataset/" 220 | dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) 221 | bs = 4 222 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 223 | for i, data in enumerate(trainloader): 224 | imgs, labels = data 225 | imgs = imgs.numpy()[:, ::-1, :, :] 226 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 227 | f, axarr = plt.subplots(bs, 2) 228 | for j in range(bs): 229 | axarr[j][0].imshow(imgs[j]) 230 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 231 | plt.show() 232 | a = raw_input() 233 | if a == "ex": 234 | break 235 | else: 236 | plt.close() 237 | ''' 238 | -------------------------------------------------------------------------------- /data/dataset_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | from torch.utils.data.dataset import Dataset 5 | import numpy as np 6 | import pickle 7 | import random 8 | from PIL import ImageOps, ImageFilter 9 | 10 | class DatasetProcessing(Dataset): 11 | def __init__(self, data_path, img_path, img_filename, label_filename, transform=None, train=False): 12 | self.img_path = os.path.join(data_path, img_path) 13 | self.transform = transform 14 | # reading img file from file 15 | img_filepath = os.path.join(data_path, img_filename) 16 | fp = open(img_filepath, 'r') 17 | self.img_filename = [x.strip() for x in fp] 18 | fp.close() 19 | # reading labels from file 20 | label_filepath = os.path.join(data_path, label_filename) 21 | labels = np.loadtxt(label_filepath, dtype=np.int64) 22 | self.label = labels 23 | self.train = train 24 | 25 | def __getitem__(self, index): 26 | img = Image.open(os.path.join(self.img_path, self.img_filename[index])) 27 | img = img.convert('RGB') 28 | #if self.transform is not None: 29 | if self.train: 30 | img1, img2 = self.transform(img) 31 | else: 32 | img = self.transform(img) 33 | 34 | label = torch.from_numpy(self.label[index]) 35 | label = label.type(torch.FloatTensor) 36 | if self.train: 37 | return (img1, img2), label 38 | else: 39 | return img, label 40 | 41 | def __len__(self): 42 | return len(self.img_filename) 43 | 44 | def split_idxs(pkl_file, percent): 45 | 46 | train_ids = pickle.load(open(pkl_file, 'rb')) 47 | partial_size = int(percent*len(train_ids)) 48 | 49 | labeled_idxs = train_ids[:partial_size] 50 | unlabeled_idxs = train_ids[partial_size:] 51 | 52 | return labeled_idxs, unlabeled_idxs 53 | 54 | class GaussianBlur(object): 55 | def __init__(self, sigma=[.1, 2.]): 56 | self.sigma = sigma 57 | 58 | def __call__(self, x): 59 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 60 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 61 | return x 62 | 63 | class TransformTwice: 64 | def __init__(self, transform, aug_transform): 65 | self.transform = transform 66 | self.aug_transform = aug_transform 67 | 68 | def __call__(self, inp): 69 | out1 = self.transform(inp) 70 | out2 = self.aug_transform(inp) 71 | return out1, out2 72 | 73 | def update_ema_variables(model, ema_model, alpha, global_step): 74 | alpha = min(1 - 1 / (global_step + 1), alpha) 75 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 76 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 77 | -------------------------------------------------------------------------------- /data/pcontext_loader.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | import os 9 | import math 10 | import random 11 | import numpy as np 12 | from tqdm import trange 13 | 14 | import torch 15 | from .base import BaseDataset 16 | 17 | import matplotlib 18 | matplotlib.use('agg') 19 | import matplotlib.pyplot as plt 20 | 21 | class ContextSegmentation(BaseDataset): 22 | BASE_DIR = 'pcontext_dataset' 23 | NUM_CLASS = 60 24 | def __init__(self, root, split='train', 25 | mode=None, transform=None, target_transform=None, **kwargs): 26 | super(ContextSegmentation, self).__init__( 27 | root, split, mode, transform, target_transform, **kwargs) 28 | from detail import Detail 29 | #from detail import mask 30 | root = os.path.join(root, self.BASE_DIR) 31 | annFile = os.path.join(root, 'trainval_merged.json') 32 | imgDir = os.path.join(root, 'JPEGImages') 33 | # training mode 34 | self.detail = Detail(annFile, imgDir, split) 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.ids = self.detail.getImgs() 38 | # generate masks 39 | self._mapping = np.sort(np.array([ 40 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 41 | 23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296, 42 | 427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424, 43 | 68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360, 44 | 98, 187, 104, 105, 366, 189, 368, 113, 115])) 45 | self._key = np.array(range(len(self._mapping))).astype('uint8') 46 | mask_file = os.path.join(root, self.split+'.pth') 47 | print('mask_file:', mask_file) 48 | if os.path.exists(mask_file): 49 | self.masks = torch.load(mask_file) 50 | else: 51 | self.masks = self._preprocess(mask_file) 52 | 53 | def _class_to_index(self, mask): 54 | # assert the values 55 | values= np.unique(mask) 56 | for i in range(len(values)): 57 | assert(values[i] in self._mapping) 58 | index = np.digitize(mask.ravel(), self._mapping, right=True) 59 | return self._key[index].reshape(mask.shape) 60 | 61 | def _preprocess(self, mask_file): 62 | masks = {} 63 | tbar = trange(len(self.ids)) 64 | print("Preprocessing mask, this will take a while." + \ 65 | "But don't worry, it only run once for each split.") 66 | for i in tbar: 67 | img_id = self.ids[i] 68 | mask = Image.fromarray(self._class_to_index( 69 | self.detail.getMask(img_id))) 70 | masks[img_id['image_id']] = mask 71 | tbar.set_description("Preprocessing masks {}".format(img_id['image_id'])) 72 | torch.save(masks, mask_file) 73 | return masks 74 | 75 | def __getitem__(self, index): 76 | img_id = self.ids[index] 77 | path = img_id['file_name'] 78 | iid = img_id['image_id'] 79 | img = Image.open(os.path.join(self.detail.img_folder, path)).convert('RGB') 80 | 81 | size = img.size 82 | if self.mode == 'test': 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | return img, os.path.basename(path) 86 | # convert mask to 60 categories 87 | mask = self.masks[iid] 88 | # synchrosized transform 89 | if self.mode == 'train': 90 | img, mask = self._sync_transform(img, mask) 91 | elif self.mode == 'val': 92 | img, mask = self._val_sync_transform(img, mask) 93 | else: 94 | assert self.mode == 'testval' 95 | mask = self._mask_transform(mask) 96 | # general resize, normalize and toTensori 97 | #mask.save('./pc_gt_train/'+img_id['file_name']) 98 | 99 | 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | if self.target_transform is not None: 103 | mask = self.target_transform(mask) 104 | return img, mask, size , size , size#, img_id['file_name'] 105 | 106 | def _mask_transform(self, mask): 107 | #target = np.array(mask).astype('int32') - 1 # for 59 classes 108 | target = np.array(mask).astype('int32') # for 60 classes 109 | 110 | return torch.from_numpy(target).long() 111 | 112 | def __len__(self): 113 | return len(self.ids) 114 | 115 | @property 116 | def pred_offset(self): 117 | return 1 118 | -------------------------------------------------------------------------------- /data/voc_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | #import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | import cv2 10 | from torch.utils import data 11 | from PIL import Image 12 | 13 | class VOCDataSet(data.Dataset): 14 | def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255): 15 | self.root = root 16 | self.list_path = list_path 17 | self.crop_h, self.crop_w = crop_size 18 | self.scale = scale 19 | self.ignore_label = ignore_label 20 | self.mean = mean 21 | self.is_mirror = mirror 22 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 23 | if not max_iters==None: 24 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 25 | self.files = [] 26 | # for split in ["train", "trainval", "val"]: 27 | for name in self.img_ids: 28 | img_file = osp.join(self.root, "JPEGImages/%s.jpg" % name) 29 | label_file = osp.join(self.root, "SegmentationClassAug/%s.png" % name) 30 | self.files.append({ 31 | "img": img_file, 32 | "label": label_file, 33 | "name": name 34 | }) 35 | 36 | def __len__(self): 37 | return len(self.files) 38 | 39 | def generate_scale_label(self, image, label): 40 | f_scale = 0.5 + random.randint(0, 11) / 10.0 41 | image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_LINEAR) 42 | label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) 43 | return image, label 44 | 45 | def __getitem__(self, index): 46 | datafiles = self.files[index] 47 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 48 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 49 | size = image.shape 50 | name = datafiles["name"] 51 | if self.scale: 52 | image, label = self.generate_scale_label(image, label) 53 | image = np.asarray(image, np.float32) 54 | image -= self.mean 55 | img_h, img_w = label.shape 56 | pad_h = max(self.crop_h - img_h, 0) 57 | pad_w = max(self.crop_w - img_w, 0) 58 | if pad_h > 0 or pad_w > 0: 59 | img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, 60 | pad_w, cv2.BORDER_CONSTANT, 61 | value=(0.0, 0.0, 0.0)) 62 | label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, 63 | pad_w, cv2.BORDER_CONSTANT, 64 | value=(self.ignore_label,)) 65 | else: 66 | img_pad, label_pad = image, label 67 | 68 | img_h, img_w = label_pad.shape 69 | h_off = random.randint(0, img_h - self.crop_h) 70 | w_off = random.randint(0, img_w - self.crop_w) 71 | image = np.asarray(img_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) 72 | label = np.asarray(label_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) 73 | image = image[:, :, ::-1] # change to BGR 74 | image = image.transpose((2, 0, 1)) 75 | if self.is_mirror: 76 | flip = np.random.choice(2) * 2 - 1 77 | image = image[:, :, ::flip] 78 | label = label[:, ::flip] 79 | 80 | return image.copy(), label.copy(), np.array(size), name, index 81 | 82 | 83 | class VOCGTDataSet(data.Dataset): 84 | def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255): 85 | self.root = root 86 | self.list_path = list_path 87 | self.crop_size = crop_size 88 | self.crop_h, self.crop_w = crop_size 89 | self.scale = scale 90 | self.ignore_label = ignore_label 91 | self.mean = mean 92 | self.is_mirror = mirror 93 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 94 | if not max_iters==None: 95 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 96 | self.files = [] 97 | for name in self.img_ids: 98 | img_file = osp.join(self.root, "JPEGImages/%s.jpg" % name) 99 | label_file = osp.join(self.root, "SegmentationClassAug/%s.png" % name) 100 | self.files.append({ 101 | "img": img_file, 102 | "label": label_file, 103 | "name": name 104 | }) 105 | 106 | def __len__(self): 107 | return len(self.files) 108 | 109 | def generate_scale_label(self, image, label): 110 | f_scale = 0.5 + random.randint(0, 11) / 10.0 111 | image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_LINEAR) 112 | label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) 113 | return image, label 114 | 115 | def __getitem__(self, index): 116 | datafiles = self.files[index] 117 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 118 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 119 | #print (label) 120 | size = image.shape 121 | name = datafiles["name"] 122 | 123 | attempt = 0 124 | while attempt < 10 : 125 | if self.scale: 126 | image, label = self.generate_scale_label(image, label) 127 | 128 | img_h, img_w = label.shape 129 | pad_h = max(self.crop_h - img_h, 0) 130 | pad_w = max(self.crop_w - img_w, 0) 131 | if pad_h > 0 or pad_w > 0: 132 | attempt += 1 133 | continue 134 | else: 135 | break 136 | 137 | if attempt == 10 : 138 | image = cv2.resize(image, self.crop_size, interpolation = cv2.INTER_LINEAR) 139 | label = cv2.resize(label, self.crop_size, interpolation = cv2.INTER_NEAREST) 140 | 141 | 142 | image = np.asarray(image, np.float32) 143 | image -= self.mean 144 | 145 | img_h, img_w = label.shape 146 | h_off = random.randint(0, img_h - self.crop_h) 147 | w_off = random.randint(0, img_w - self.crop_w) 148 | image = np.asarray(image[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) 149 | label = np.asarray(label[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) 150 | image = image[:, :, ::-1] # change to BGR 151 | image = image.transpose((2, 0, 1)) 152 | if self.is_mirror: 153 | flip = np.random.choice(2) * 2 - 1 154 | image = image[:, :, ::flip] 155 | label = label[:, ::flip] 156 | 157 | return image.copy(), label.copy(), np.array(size), name 158 | 159 | class VOCDataTestSet(data.Dataset): 160 | def __init__(self, root, list_path, crop_size=(505, 505), mean=(128, 128, 128)): 161 | self.root = root 162 | self.list_path = list_path 163 | self.crop_h, self.crop_w = crop_size 164 | self.mean = mean 165 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 166 | self.files = [] 167 | # for split in ["train", "trainval", "val"]: 168 | for name in self.img_ids: 169 | img_file = osp.join(self.root, "JPEGImages/%s.jpg" % name) 170 | self.files.append({ 171 | "img": img_file 172 | }) 173 | 174 | def __len__(self): 175 | return len(self.files) 176 | 177 | def __getitem__(self, index): 178 | datafiles = self.files[index] 179 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 180 | size = image.shape 181 | name = osp.splitext(osp.basename(datafiles["img"]))[0] 182 | image = np.asarray(image, np.float32) 183 | image -= self.mean 184 | 185 | img_h, img_w, _ = image.shape 186 | pad_h = max(self.crop_h - img_h, 0) 187 | pad_w = max(self.crop_w - img_w, 0) 188 | if pad_h > 0 or pad_w > 0: 189 | image = cv2.copyMakeBorder(image, 0, pad_h, 0, 190 | pad_w, cv2.BORDER_CONSTANT, 191 | value=(0.0, 0.0, 0.0)) 192 | image = image.transpose((2, 0, 1)) 193 | return image, name, size 194 | 195 | 196 | if __name__ == '__main__': 197 | dst = VOCDataSet("./data", is_transform=True) 198 | trainloader = data.DataLoader(dst, batch_size=4) 199 | for i, data in enumerate(trainloader): 200 | imgs, labels = data 201 | if i == 0: 202 | img = torchvision.utils.make_grid(imgs).numpy() 203 | img = np.transpose(img, (1, 2, 0)) 204 | img = img[:, :, ::-1] 205 | #plt.imshow(img) 206 | #plt.show() 207 | -------------------------------------------------------------------------------- /data/voc_list/val.txt: -------------------------------------------------------------------------------- 1 | 2007_000033 2 | 2007_000042 3 | 2007_000061 4 | 2007_000123 5 | 2007_000129 6 | 2007_000175 7 | 2007_000187 8 | 2007_000323 9 | 2007_000332 10 | 2007_000346 11 | 2007_000452 12 | 2007_000464 13 | 2007_000491 14 | 2007_000529 15 | 2007_000559 16 | 2007_000572 17 | 2007_000629 18 | 2007_000636 19 | 2007_000661 20 | 2007_000663 21 | 2007_000676 22 | 2007_000727 23 | 2007_000762 24 | 2007_000783 25 | 2007_000799 26 | 2007_000804 27 | 2007_000830 28 | 2007_000837 29 | 2007_000847 30 | 2007_000862 31 | 2007_000925 32 | 2007_000999 33 | 2007_001154 34 | 2007_001175 35 | 2007_001239 36 | 2007_001284 37 | 2007_001288 38 | 2007_001289 39 | 2007_001299 40 | 2007_001311 41 | 2007_001321 42 | 2007_001377 43 | 2007_001408 44 | 2007_001423 45 | 2007_001430 46 | 2007_001457 47 | 2007_001458 48 | 2007_001526 49 | 2007_001568 50 | 2007_001585 51 | 2007_001586 52 | 2007_001587 53 | 2007_001594 54 | 2007_001630 55 | 2007_001677 56 | 2007_001678 57 | 2007_001717 58 | 2007_001733 59 | 2007_001761 60 | 2007_001763 61 | 2007_001774 62 | 2007_001884 63 | 2007_001955 64 | 2007_002046 65 | 2007_002094 66 | 2007_002119 67 | 2007_002132 68 | 2007_002260 69 | 2007_002266 70 | 2007_002268 71 | 2007_002284 72 | 2007_002376 73 | 2007_002378 74 | 2007_002387 75 | 2007_002400 76 | 2007_002412 77 | 2007_002426 78 | 2007_002427 79 | 2007_002445 80 | 2007_002470 81 | 2007_002539 82 | 2007_002565 83 | 2007_002597 84 | 2007_002618 85 | 2007_002619 86 | 2007_002624 87 | 2007_002643 88 | 2007_002648 89 | 2007_002719 90 | 2007_002728 91 | 2007_002823 92 | 2007_002824 93 | 2007_002852 94 | 2007_002903 95 | 2007_003011 96 | 2007_003020 97 | 2007_003022 98 | 2007_003051 99 | 2007_003088 100 | 2007_003101 101 | 2007_003106 102 | 2007_003110 103 | 2007_003131 104 | 2007_003134 105 | 2007_003137 106 | 2007_003143 107 | 2007_003169 108 | 2007_003188 109 | 2007_003194 110 | 2007_003195 111 | 2007_003201 112 | 2007_003349 113 | 2007_003367 114 | 2007_003373 115 | 2007_003499 116 | 2007_003503 117 | 2007_003506 118 | 2007_003530 119 | 2007_003571 120 | 2007_003587 121 | 2007_003611 122 | 2007_003621 123 | 2007_003682 124 | 2007_003711 125 | 2007_003714 126 | 2007_003742 127 | 2007_003786 128 | 2007_003841 129 | 2007_003848 130 | 2007_003861 131 | 2007_003872 132 | 2007_003917 133 | 2007_003957 134 | 2007_003991 135 | 2007_004033 136 | 2007_004052 137 | 2007_004112 138 | 2007_004121 139 | 2007_004143 140 | 2007_004189 141 | 2007_004190 142 | 2007_004193 143 | 2007_004241 144 | 2007_004275 145 | 2007_004281 146 | 2007_004380 147 | 2007_004392 148 | 2007_004405 149 | 2007_004468 150 | 2007_004483 151 | 2007_004510 152 | 2007_004538 153 | 2007_004558 154 | 2007_004644 155 | 2007_004649 156 | 2007_004712 157 | 2007_004722 158 | 2007_004856 159 | 2007_004866 160 | 2007_004902 161 | 2007_004969 162 | 2007_005058 163 | 2007_005074 164 | 2007_005107 165 | 2007_005114 166 | 2007_005149 167 | 2007_005173 168 | 2007_005281 169 | 2007_005294 170 | 2007_005296 171 | 2007_005304 172 | 2007_005331 173 | 2007_005354 174 | 2007_005358 175 | 2007_005428 176 | 2007_005460 177 | 2007_005469 178 | 2007_005509 179 | 2007_005547 180 | 2007_005600 181 | 2007_005608 182 | 2007_005626 183 | 2007_005689 184 | 2007_005696 185 | 2007_005705 186 | 2007_005759 187 | 2007_005803 188 | 2007_005813 189 | 2007_005828 190 | 2007_005844 191 | 2007_005845 192 | 2007_005857 193 | 2007_005911 194 | 2007_005915 195 | 2007_005978 196 | 2007_006028 197 | 2007_006035 198 | 2007_006046 199 | 2007_006076 200 | 2007_006086 201 | 2007_006117 202 | 2007_006171 203 | 2007_006241 204 | 2007_006260 205 | 2007_006277 206 | 2007_006348 207 | 2007_006364 208 | 2007_006373 209 | 2007_006444 210 | 2007_006449 211 | 2007_006549 212 | 2007_006553 213 | 2007_006560 214 | 2007_006647 215 | 2007_006678 216 | 2007_006680 217 | 2007_006698 218 | 2007_006761 219 | 2007_006802 220 | 2007_006837 221 | 2007_006841 222 | 2007_006864 223 | 2007_006866 224 | 2007_006946 225 | 2007_007007 226 | 2007_007084 227 | 2007_007109 228 | 2007_007130 229 | 2007_007165 230 | 2007_007168 231 | 2007_007195 232 | 2007_007196 233 | 2007_007203 234 | 2007_007211 235 | 2007_007235 236 | 2007_007341 237 | 2007_007414 238 | 2007_007417 239 | 2007_007470 240 | 2007_007477 241 | 2007_007493 242 | 2007_007498 243 | 2007_007524 244 | 2007_007534 245 | 2007_007624 246 | 2007_007651 247 | 2007_007688 248 | 2007_007748 249 | 2007_007795 250 | 2007_007810 251 | 2007_007815 252 | 2007_007818 253 | 2007_007836 254 | 2007_007849 255 | 2007_007881 256 | 2007_007996 257 | 2007_008051 258 | 2007_008084 259 | 2007_008106 260 | 2007_008110 261 | 2007_008204 262 | 2007_008222 263 | 2007_008256 264 | 2007_008260 265 | 2007_008339 266 | 2007_008374 267 | 2007_008415 268 | 2007_008430 269 | 2007_008543 270 | 2007_008547 271 | 2007_008596 272 | 2007_008645 273 | 2007_008670 274 | 2007_008708 275 | 2007_008722 276 | 2007_008747 277 | 2007_008802 278 | 2007_008815 279 | 2007_008897 280 | 2007_008944 281 | 2007_008964 282 | 2007_008973 283 | 2007_008980 284 | 2007_009015 285 | 2007_009068 286 | 2007_009084 287 | 2007_009088 288 | 2007_009096 289 | 2007_009221 290 | 2007_009245 291 | 2007_009251 292 | 2007_009252 293 | 2007_009258 294 | 2007_009320 295 | 2007_009323 296 | 2007_009331 297 | 2007_009346 298 | 2007_009392 299 | 2007_009413 300 | 2007_009419 301 | 2007_009446 302 | 2007_009458 303 | 2007_009521 304 | 2007_009562 305 | 2007_009592 306 | 2007_009654 307 | 2007_009655 308 | 2007_009684 309 | 2007_009687 310 | 2007_009691 311 | 2007_009706 312 | 2007_009750 313 | 2007_009756 314 | 2007_009764 315 | 2007_009794 316 | 2007_009817 317 | 2007_009841 318 | 2007_009897 319 | 2007_009911 320 | 2007_009923 321 | 2007_009938 322 | 2008_000009 323 | 2008_000016 324 | 2008_000073 325 | 2008_000075 326 | 2008_000080 327 | 2008_000107 328 | 2008_000120 329 | 2008_000123 330 | 2008_000149 331 | 2008_000182 332 | 2008_000213 333 | 2008_000215 334 | 2008_000223 335 | 2008_000233 336 | 2008_000234 337 | 2008_000239 338 | 2008_000254 339 | 2008_000270 340 | 2008_000271 341 | 2008_000345 342 | 2008_000359 343 | 2008_000391 344 | 2008_000401 345 | 2008_000464 346 | 2008_000469 347 | 2008_000474 348 | 2008_000501 349 | 2008_000510 350 | 2008_000533 351 | 2008_000573 352 | 2008_000589 353 | 2008_000602 354 | 2008_000630 355 | 2008_000657 356 | 2008_000661 357 | 2008_000662 358 | 2008_000666 359 | 2008_000673 360 | 2008_000700 361 | 2008_000725 362 | 2008_000731 363 | 2008_000763 364 | 2008_000765 365 | 2008_000782 366 | 2008_000795 367 | 2008_000811 368 | 2008_000848 369 | 2008_000853 370 | 2008_000863 371 | 2008_000911 372 | 2008_000919 373 | 2008_000943 374 | 2008_000992 375 | 2008_001013 376 | 2008_001028 377 | 2008_001040 378 | 2008_001070 379 | 2008_001074 380 | 2008_001076 381 | 2008_001078 382 | 2008_001135 383 | 2008_001150 384 | 2008_001170 385 | 2008_001231 386 | 2008_001249 387 | 2008_001260 388 | 2008_001283 389 | 2008_001308 390 | 2008_001379 391 | 2008_001404 392 | 2008_001433 393 | 2008_001439 394 | 2008_001478 395 | 2008_001491 396 | 2008_001504 397 | 2008_001513 398 | 2008_001514 399 | 2008_001531 400 | 2008_001546 401 | 2008_001547 402 | 2008_001580 403 | 2008_001629 404 | 2008_001640 405 | 2008_001682 406 | 2008_001688 407 | 2008_001715 408 | 2008_001821 409 | 2008_001874 410 | 2008_001885 411 | 2008_001895 412 | 2008_001966 413 | 2008_001971 414 | 2008_001992 415 | 2008_002043 416 | 2008_002152 417 | 2008_002205 418 | 2008_002212 419 | 2008_002239 420 | 2008_002240 421 | 2008_002241 422 | 2008_002269 423 | 2008_002273 424 | 2008_002358 425 | 2008_002379 426 | 2008_002383 427 | 2008_002429 428 | 2008_002464 429 | 2008_002467 430 | 2008_002492 431 | 2008_002495 432 | 2008_002504 433 | 2008_002521 434 | 2008_002536 435 | 2008_002588 436 | 2008_002623 437 | 2008_002680 438 | 2008_002681 439 | 2008_002775 440 | 2008_002778 441 | 2008_002835 442 | 2008_002859 443 | 2008_002864 444 | 2008_002900 445 | 2008_002904 446 | 2008_002929 447 | 2008_002936 448 | 2008_002942 449 | 2008_002958 450 | 2008_003003 451 | 2008_003026 452 | 2008_003034 453 | 2008_003076 454 | 2008_003105 455 | 2008_003108 456 | 2008_003110 457 | 2008_003135 458 | 2008_003141 459 | 2008_003155 460 | 2008_003210 461 | 2008_003238 462 | 2008_003270 463 | 2008_003330 464 | 2008_003333 465 | 2008_003369 466 | 2008_003379 467 | 2008_003451 468 | 2008_003461 469 | 2008_003477 470 | 2008_003492 471 | 2008_003499 472 | 2008_003511 473 | 2008_003546 474 | 2008_003576 475 | 2008_003577 476 | 2008_003676 477 | 2008_003709 478 | 2008_003733 479 | 2008_003777 480 | 2008_003782 481 | 2008_003821 482 | 2008_003846 483 | 2008_003856 484 | 2008_003858 485 | 2008_003874 486 | 2008_003876 487 | 2008_003885 488 | 2008_003886 489 | 2008_003926 490 | 2008_003976 491 | 2008_004069 492 | 2008_004101 493 | 2008_004140 494 | 2008_004172 495 | 2008_004175 496 | 2008_004212 497 | 2008_004279 498 | 2008_004339 499 | 2008_004345 500 | 2008_004363 501 | 2008_004367 502 | 2008_004396 503 | 2008_004399 504 | 2008_004453 505 | 2008_004477 506 | 2008_004552 507 | 2008_004562 508 | 2008_004575 509 | 2008_004610 510 | 2008_004612 511 | 2008_004621 512 | 2008_004624 513 | 2008_004654 514 | 2008_004659 515 | 2008_004687 516 | 2008_004701 517 | 2008_004704 518 | 2008_004705 519 | 2008_004754 520 | 2008_004758 521 | 2008_004854 522 | 2008_004910 523 | 2008_004995 524 | 2008_005049 525 | 2008_005089 526 | 2008_005097 527 | 2008_005105 528 | 2008_005145 529 | 2008_005197 530 | 2008_005217 531 | 2008_005242 532 | 2008_005245 533 | 2008_005254 534 | 2008_005262 535 | 2008_005338 536 | 2008_005398 537 | 2008_005399 538 | 2008_005422 539 | 2008_005439 540 | 2008_005445 541 | 2008_005525 542 | 2008_005544 543 | 2008_005628 544 | 2008_005633 545 | 2008_005637 546 | 2008_005642 547 | 2008_005676 548 | 2008_005680 549 | 2008_005691 550 | 2008_005727 551 | 2008_005738 552 | 2008_005812 553 | 2008_005904 554 | 2008_005915 555 | 2008_006008 556 | 2008_006036 557 | 2008_006055 558 | 2008_006063 559 | 2008_006108 560 | 2008_006130 561 | 2008_006143 562 | 2008_006159 563 | 2008_006216 564 | 2008_006219 565 | 2008_006229 566 | 2008_006254 567 | 2008_006275 568 | 2008_006325 569 | 2008_006327 570 | 2008_006341 571 | 2008_006408 572 | 2008_006480 573 | 2008_006523 574 | 2008_006526 575 | 2008_006528 576 | 2008_006553 577 | 2008_006554 578 | 2008_006703 579 | 2008_006722 580 | 2008_006752 581 | 2008_006784 582 | 2008_006835 583 | 2008_006874 584 | 2008_006981 585 | 2008_006986 586 | 2008_007025 587 | 2008_007031 588 | 2008_007048 589 | 2008_007120 590 | 2008_007123 591 | 2008_007143 592 | 2008_007194 593 | 2008_007219 594 | 2008_007273 595 | 2008_007350 596 | 2008_007378 597 | 2008_007392 598 | 2008_007402 599 | 2008_007497 600 | 2008_007498 601 | 2008_007507 602 | 2008_007513 603 | 2008_007527 604 | 2008_007548 605 | 2008_007596 606 | 2008_007677 607 | 2008_007737 608 | 2008_007797 609 | 2008_007804 610 | 2008_007811 611 | 2008_007814 612 | 2008_007828 613 | 2008_007836 614 | 2008_007945 615 | 2008_007994 616 | 2008_008051 617 | 2008_008103 618 | 2008_008127 619 | 2008_008221 620 | 2008_008252 621 | 2008_008268 622 | 2008_008296 623 | 2008_008301 624 | 2008_008335 625 | 2008_008362 626 | 2008_008392 627 | 2008_008393 628 | 2008_008421 629 | 2008_008434 630 | 2008_008469 631 | 2008_008629 632 | 2008_008682 633 | 2008_008711 634 | 2008_008746 635 | 2009_000012 636 | 2009_000013 637 | 2009_000022 638 | 2009_000032 639 | 2009_000037 640 | 2009_000039 641 | 2009_000074 642 | 2009_000080 643 | 2009_000087 644 | 2009_000096 645 | 2009_000121 646 | 2009_000136 647 | 2009_000149 648 | 2009_000156 649 | 2009_000201 650 | 2009_000205 651 | 2009_000219 652 | 2009_000242 653 | 2009_000309 654 | 2009_000318 655 | 2009_000335 656 | 2009_000351 657 | 2009_000354 658 | 2009_000387 659 | 2009_000391 660 | 2009_000412 661 | 2009_000418 662 | 2009_000421 663 | 2009_000426 664 | 2009_000440 665 | 2009_000446 666 | 2009_000455 667 | 2009_000457 668 | 2009_000469 669 | 2009_000487 670 | 2009_000488 671 | 2009_000523 672 | 2009_000573 673 | 2009_000619 674 | 2009_000628 675 | 2009_000641 676 | 2009_000664 677 | 2009_000675 678 | 2009_000704 679 | 2009_000705 680 | 2009_000712 681 | 2009_000716 682 | 2009_000723 683 | 2009_000727 684 | 2009_000730 685 | 2009_000731 686 | 2009_000732 687 | 2009_000771 688 | 2009_000825 689 | 2009_000828 690 | 2009_000839 691 | 2009_000840 692 | 2009_000845 693 | 2009_000879 694 | 2009_000892 695 | 2009_000919 696 | 2009_000924 697 | 2009_000931 698 | 2009_000935 699 | 2009_000964 700 | 2009_000989 701 | 2009_000991 702 | 2009_000998 703 | 2009_001008 704 | 2009_001082 705 | 2009_001108 706 | 2009_001160 707 | 2009_001215 708 | 2009_001240 709 | 2009_001255 710 | 2009_001278 711 | 2009_001299 712 | 2009_001300 713 | 2009_001314 714 | 2009_001332 715 | 2009_001333 716 | 2009_001363 717 | 2009_001391 718 | 2009_001411 719 | 2009_001433 720 | 2009_001505 721 | 2009_001535 722 | 2009_001536 723 | 2009_001565 724 | 2009_001607 725 | 2009_001644 726 | 2009_001663 727 | 2009_001683 728 | 2009_001684 729 | 2009_001687 730 | 2009_001718 731 | 2009_001731 732 | 2009_001765 733 | 2009_001768 734 | 2009_001775 735 | 2009_001804 736 | 2009_001816 737 | 2009_001818 738 | 2009_001850 739 | 2009_001851 740 | 2009_001854 741 | 2009_001941 742 | 2009_001991 743 | 2009_002012 744 | 2009_002035 745 | 2009_002042 746 | 2009_002082 747 | 2009_002094 748 | 2009_002097 749 | 2009_002122 750 | 2009_002150 751 | 2009_002155 752 | 2009_002164 753 | 2009_002165 754 | 2009_002171 755 | 2009_002185 756 | 2009_002202 757 | 2009_002221 758 | 2009_002238 759 | 2009_002239 760 | 2009_002265 761 | 2009_002268 762 | 2009_002291 763 | 2009_002295 764 | 2009_002317 765 | 2009_002320 766 | 2009_002346 767 | 2009_002366 768 | 2009_002372 769 | 2009_002382 770 | 2009_002390 771 | 2009_002415 772 | 2009_002445 773 | 2009_002487 774 | 2009_002521 775 | 2009_002527 776 | 2009_002535 777 | 2009_002539 778 | 2009_002549 779 | 2009_002562 780 | 2009_002568 781 | 2009_002571 782 | 2009_002573 783 | 2009_002584 784 | 2009_002591 785 | 2009_002594 786 | 2009_002604 787 | 2009_002618 788 | 2009_002635 789 | 2009_002638 790 | 2009_002649 791 | 2009_002651 792 | 2009_002727 793 | 2009_002732 794 | 2009_002749 795 | 2009_002753 796 | 2009_002771 797 | 2009_002808 798 | 2009_002856 799 | 2009_002887 800 | 2009_002888 801 | 2009_002928 802 | 2009_002936 803 | 2009_002975 804 | 2009_002982 805 | 2009_002990 806 | 2009_003003 807 | 2009_003005 808 | 2009_003043 809 | 2009_003059 810 | 2009_003063 811 | 2009_003065 812 | 2009_003071 813 | 2009_003080 814 | 2009_003105 815 | 2009_003123 816 | 2009_003193 817 | 2009_003196 818 | 2009_003217 819 | 2009_003224 820 | 2009_003241 821 | 2009_003269 822 | 2009_003273 823 | 2009_003299 824 | 2009_003304 825 | 2009_003311 826 | 2009_003323 827 | 2009_003343 828 | 2009_003378 829 | 2009_003387 830 | 2009_003406 831 | 2009_003433 832 | 2009_003450 833 | 2009_003466 834 | 2009_003481 835 | 2009_003494 836 | 2009_003498 837 | 2009_003504 838 | 2009_003507 839 | 2009_003517 840 | 2009_003523 841 | 2009_003542 842 | 2009_003549 843 | 2009_003551 844 | 2009_003564 845 | 2009_003569 846 | 2009_003576 847 | 2009_003589 848 | 2009_003607 849 | 2009_003640 850 | 2009_003666 851 | 2009_003696 852 | 2009_003703 853 | 2009_003707 854 | 2009_003756 855 | 2009_003771 856 | 2009_003773 857 | 2009_003804 858 | 2009_003806 859 | 2009_003810 860 | 2009_003849 861 | 2009_003857 862 | 2009_003858 863 | 2009_003895 864 | 2009_003903 865 | 2009_003904 866 | 2009_003928 867 | 2009_003938 868 | 2009_003971 869 | 2009_003991 870 | 2009_004021 871 | 2009_004033 872 | 2009_004043 873 | 2009_004070 874 | 2009_004072 875 | 2009_004084 876 | 2009_004099 877 | 2009_004125 878 | 2009_004140 879 | 2009_004217 880 | 2009_004221 881 | 2009_004247 882 | 2009_004248 883 | 2009_004255 884 | 2009_004298 885 | 2009_004324 886 | 2009_004455 887 | 2009_004494 888 | 2009_004497 889 | 2009_004504 890 | 2009_004507 891 | 2009_004509 892 | 2009_004540 893 | 2009_004568 894 | 2009_004579 895 | 2009_004581 896 | 2009_004590 897 | 2009_004592 898 | 2009_004594 899 | 2009_004635 900 | 2009_004653 901 | 2009_004687 902 | 2009_004721 903 | 2009_004730 904 | 2009_004732 905 | 2009_004738 906 | 2009_004748 907 | 2009_004789 908 | 2009_004799 909 | 2009_004801 910 | 2009_004848 911 | 2009_004859 912 | 2009_004867 913 | 2009_004882 914 | 2009_004886 915 | 2009_004895 916 | 2009_004942 917 | 2009_004969 918 | 2009_004987 919 | 2009_004993 920 | 2009_004994 921 | 2009_005038 922 | 2009_005078 923 | 2009_005087 924 | 2009_005089 925 | 2009_005137 926 | 2009_005148 927 | 2009_005156 928 | 2009_005158 929 | 2009_005189 930 | 2009_005190 931 | 2009_005217 932 | 2009_005219 933 | 2009_005220 934 | 2009_005231 935 | 2009_005260 936 | 2009_005262 937 | 2009_005302 938 | 2010_000003 939 | 2010_000038 940 | 2010_000065 941 | 2010_000083 942 | 2010_000084 943 | 2010_000087 944 | 2010_000110 945 | 2010_000159 946 | 2010_000160 947 | 2010_000163 948 | 2010_000174 949 | 2010_000216 950 | 2010_000238 951 | 2010_000241 952 | 2010_000256 953 | 2010_000272 954 | 2010_000284 955 | 2010_000309 956 | 2010_000318 957 | 2010_000330 958 | 2010_000335 959 | 2010_000342 960 | 2010_000372 961 | 2010_000422 962 | 2010_000426 963 | 2010_000427 964 | 2010_000502 965 | 2010_000530 966 | 2010_000552 967 | 2010_000559 968 | 2010_000572 969 | 2010_000573 970 | 2010_000622 971 | 2010_000628 972 | 2010_000639 973 | 2010_000666 974 | 2010_000679 975 | 2010_000682 976 | 2010_000683 977 | 2010_000724 978 | 2010_000738 979 | 2010_000764 980 | 2010_000788 981 | 2010_000814 982 | 2010_000836 983 | 2010_000874 984 | 2010_000904 985 | 2010_000906 986 | 2010_000907 987 | 2010_000918 988 | 2010_000929 989 | 2010_000941 990 | 2010_000952 991 | 2010_000961 992 | 2010_001000 993 | 2010_001010 994 | 2010_001011 995 | 2010_001016 996 | 2010_001017 997 | 2010_001024 998 | 2010_001036 999 | 2010_001061 1000 | 2010_001069 1001 | 2010_001070 1002 | 2010_001079 1003 | 2010_001104 1004 | 2010_001124 1005 | 2010_001149 1006 | 2010_001151 1007 | 2010_001174 1008 | 2010_001206 1009 | 2010_001246 1010 | 2010_001251 1011 | 2010_001256 1012 | 2010_001264 1013 | 2010_001292 1014 | 2010_001313 1015 | 2010_001327 1016 | 2010_001331 1017 | 2010_001351 1018 | 2010_001367 1019 | 2010_001376 1020 | 2010_001403 1021 | 2010_001448 1022 | 2010_001451 1023 | 2010_001522 1024 | 2010_001534 1025 | 2010_001553 1026 | 2010_001557 1027 | 2010_001563 1028 | 2010_001577 1029 | 2010_001579 1030 | 2010_001646 1031 | 2010_001656 1032 | 2010_001692 1033 | 2010_001699 1034 | 2010_001734 1035 | 2010_001752 1036 | 2010_001767 1037 | 2010_001768 1038 | 2010_001773 1039 | 2010_001820 1040 | 2010_001830 1041 | 2010_001851 1042 | 2010_001908 1043 | 2010_001913 1044 | 2010_001951 1045 | 2010_001956 1046 | 2010_001962 1047 | 2010_001966 1048 | 2010_001995 1049 | 2010_002017 1050 | 2010_002025 1051 | 2010_002030 1052 | 2010_002106 1053 | 2010_002137 1054 | 2010_002142 1055 | 2010_002146 1056 | 2010_002147 1057 | 2010_002150 1058 | 2010_002161 1059 | 2010_002200 1060 | 2010_002228 1061 | 2010_002232 1062 | 2010_002251 1063 | 2010_002271 1064 | 2010_002305 1065 | 2010_002310 1066 | 2010_002336 1067 | 2010_002348 1068 | 2010_002361 1069 | 2010_002390 1070 | 2010_002396 1071 | 2010_002422 1072 | 2010_002450 1073 | 2010_002480 1074 | 2010_002512 1075 | 2010_002531 1076 | 2010_002536 1077 | 2010_002538 1078 | 2010_002546 1079 | 2010_002623 1080 | 2010_002682 1081 | 2010_002691 1082 | 2010_002693 1083 | 2010_002701 1084 | 2010_002763 1085 | 2010_002792 1086 | 2010_002868 1087 | 2010_002900 1088 | 2010_002902 1089 | 2010_002921 1090 | 2010_002929 1091 | 2010_002939 1092 | 2010_002988 1093 | 2010_003014 1094 | 2010_003060 1095 | 2010_003123 1096 | 2010_003127 1097 | 2010_003132 1098 | 2010_003168 1099 | 2010_003183 1100 | 2010_003187 1101 | 2010_003207 1102 | 2010_003231 1103 | 2010_003239 1104 | 2010_003275 1105 | 2010_003276 1106 | 2010_003293 1107 | 2010_003302 1108 | 2010_003325 1109 | 2010_003362 1110 | 2010_003365 1111 | 2010_003381 1112 | 2010_003402 1113 | 2010_003409 1114 | 2010_003418 1115 | 2010_003446 1116 | 2010_003453 1117 | 2010_003468 1118 | 2010_003473 1119 | 2010_003495 1120 | 2010_003506 1121 | 2010_003514 1122 | 2010_003531 1123 | 2010_003532 1124 | 2010_003541 1125 | 2010_003547 1126 | 2010_003597 1127 | 2010_003675 1128 | 2010_003708 1129 | 2010_003716 1130 | 2010_003746 1131 | 2010_003758 1132 | 2010_003764 1133 | 2010_003768 1134 | 2010_003771 1135 | 2010_003772 1136 | 2010_003781 1137 | 2010_003813 1138 | 2010_003820 1139 | 2010_003854 1140 | 2010_003912 1141 | 2010_003915 1142 | 2010_003947 1143 | 2010_003956 1144 | 2010_003971 1145 | 2010_004041 1146 | 2010_004042 1147 | 2010_004056 1148 | 2010_004063 1149 | 2010_004104 1150 | 2010_004120 1151 | 2010_004149 1152 | 2010_004165 1153 | 2010_004208 1154 | 2010_004219 1155 | 2010_004226 1156 | 2010_004314 1157 | 2010_004320 1158 | 2010_004322 1159 | 2010_004337 1160 | 2010_004348 1161 | 2010_004355 1162 | 2010_004369 1163 | 2010_004382 1164 | 2010_004419 1165 | 2010_004432 1166 | 2010_004472 1167 | 2010_004479 1168 | 2010_004519 1169 | 2010_004520 1170 | 2010_004529 1171 | 2010_004543 1172 | 2010_004550 1173 | 2010_004551 1174 | 2010_004556 1175 | 2010_004559 1176 | 2010_004628 1177 | 2010_004635 1178 | 2010_004662 1179 | 2010_004697 1180 | 2010_004757 1181 | 2010_004763 1182 | 2010_004772 1183 | 2010_004783 1184 | 2010_004789 1185 | 2010_004795 1186 | 2010_004815 1187 | 2010_004825 1188 | 2010_004828 1189 | 2010_004856 1190 | 2010_004857 1191 | 2010_004861 1192 | 2010_004941 1193 | 2010_004946 1194 | 2010_004951 1195 | 2010_004980 1196 | 2010_004994 1197 | 2010_005013 1198 | 2010_005021 1199 | 2010_005046 1200 | 2010_005063 1201 | 2010_005108 1202 | 2010_005118 1203 | 2010_005159 1204 | 2010_005160 1205 | 2010_005166 1206 | 2010_005174 1207 | 2010_005180 1208 | 2010_005187 1209 | 2010_005206 1210 | 2010_005245 1211 | 2010_005252 1212 | 2010_005284 1213 | 2010_005305 1214 | 2010_005344 1215 | 2010_005353 1216 | 2010_005366 1217 | 2010_005401 1218 | 2010_005421 1219 | 2010_005428 1220 | 2010_005432 1221 | 2010_005433 1222 | 2010_005496 1223 | 2010_005501 1224 | 2010_005508 1225 | 2010_005531 1226 | 2010_005534 1227 | 2010_005575 1228 | 2010_005582 1229 | 2010_005606 1230 | 2010_005626 1231 | 2010_005644 1232 | 2010_005664 1233 | 2010_005705 1234 | 2010_005706 1235 | 2010_005709 1236 | 2010_005718 1237 | 2010_005719 1238 | 2010_005727 1239 | 2010_005762 1240 | 2010_005788 1241 | 2010_005860 1242 | 2010_005871 1243 | 2010_005877 1244 | 2010_005888 1245 | 2010_005899 1246 | 2010_005922 1247 | 2010_005991 1248 | 2010_005992 1249 | 2010_006026 1250 | 2010_006034 1251 | 2010_006054 1252 | 2010_006070 1253 | 2011_000045 1254 | 2011_000051 1255 | 2011_000054 1256 | 2011_000066 1257 | 2011_000070 1258 | 2011_000112 1259 | 2011_000173 1260 | 2011_000178 1261 | 2011_000185 1262 | 2011_000226 1263 | 2011_000234 1264 | 2011_000238 1265 | 2011_000239 1266 | 2011_000248 1267 | 2011_000283 1268 | 2011_000291 1269 | 2011_000310 1270 | 2011_000312 1271 | 2011_000338 1272 | 2011_000396 1273 | 2011_000412 1274 | 2011_000419 1275 | 2011_000435 1276 | 2011_000436 1277 | 2011_000438 1278 | 2011_000455 1279 | 2011_000456 1280 | 2011_000479 1281 | 2011_000481 1282 | 2011_000482 1283 | 2011_000503 1284 | 2011_000512 1285 | 2011_000521 1286 | 2011_000526 1287 | 2011_000536 1288 | 2011_000548 1289 | 2011_000566 1290 | 2011_000585 1291 | 2011_000598 1292 | 2011_000607 1293 | 2011_000618 1294 | 2011_000638 1295 | 2011_000658 1296 | 2011_000661 1297 | 2011_000669 1298 | 2011_000747 1299 | 2011_000780 1300 | 2011_000789 1301 | 2011_000807 1302 | 2011_000809 1303 | 2011_000813 1304 | 2011_000830 1305 | 2011_000843 1306 | 2011_000874 1307 | 2011_000888 1308 | 2011_000900 1309 | 2011_000912 1310 | 2011_000953 1311 | 2011_000969 1312 | 2011_001005 1313 | 2011_001014 1314 | 2011_001020 1315 | 2011_001047 1316 | 2011_001060 1317 | 2011_001064 1318 | 2011_001069 1319 | 2011_001071 1320 | 2011_001082 1321 | 2011_001110 1322 | 2011_001114 1323 | 2011_001159 1324 | 2011_001161 1325 | 2011_001190 1326 | 2011_001232 1327 | 2011_001263 1328 | 2011_001276 1329 | 2011_001281 1330 | 2011_001287 1331 | 2011_001292 1332 | 2011_001313 1333 | 2011_001341 1334 | 2011_001346 1335 | 2011_001350 1336 | 2011_001407 1337 | 2011_001416 1338 | 2011_001421 1339 | 2011_001434 1340 | 2011_001447 1341 | 2011_001489 1342 | 2011_001529 1343 | 2011_001530 1344 | 2011_001534 1345 | 2011_001546 1346 | 2011_001567 1347 | 2011_001589 1348 | 2011_001597 1349 | 2011_001601 1350 | 2011_001607 1351 | 2011_001613 1352 | 2011_001614 1353 | 2011_001619 1354 | 2011_001624 1355 | 2011_001642 1356 | 2011_001665 1357 | 2011_001669 1358 | 2011_001674 1359 | 2011_001708 1360 | 2011_001713 1361 | 2011_001714 1362 | 2011_001722 1363 | 2011_001726 1364 | 2011_001745 1365 | 2011_001748 1366 | 2011_001775 1367 | 2011_001782 1368 | 2011_001793 1369 | 2011_001794 1370 | 2011_001812 1371 | 2011_001862 1372 | 2011_001863 1373 | 2011_001868 1374 | 2011_001880 1375 | 2011_001910 1376 | 2011_001984 1377 | 2011_001988 1378 | 2011_002002 1379 | 2011_002040 1380 | 2011_002041 1381 | 2011_002064 1382 | 2011_002075 1383 | 2011_002098 1384 | 2011_002110 1385 | 2011_002121 1386 | 2011_002124 1387 | 2011_002150 1388 | 2011_002156 1389 | 2011_002178 1390 | 2011_002200 1391 | 2011_002223 1392 | 2011_002244 1393 | 2011_002247 1394 | 2011_002279 1395 | 2011_002295 1396 | 2011_002298 1397 | 2011_002308 1398 | 2011_002317 1399 | 2011_002322 1400 | 2011_002327 1401 | 2011_002343 1402 | 2011_002358 1403 | 2011_002371 1404 | 2011_002379 1405 | 2011_002391 1406 | 2011_002498 1407 | 2011_002509 1408 | 2011_002515 1409 | 2011_002532 1410 | 2011_002535 1411 | 2011_002548 1412 | 2011_002575 1413 | 2011_002578 1414 | 2011_002589 1415 | 2011_002592 1416 | 2011_002623 1417 | 2011_002641 1418 | 2011_002644 1419 | 2011_002662 1420 | 2011_002675 1421 | 2011_002685 1422 | 2011_002713 1423 | 2011_002730 1424 | 2011_002754 1425 | 2011_002812 1426 | 2011_002863 1427 | 2011_002879 1428 | 2011_002885 1429 | 2011_002929 1430 | 2011_002951 1431 | 2011_002975 1432 | 2011_002993 1433 | 2011_002997 1434 | 2011_003003 1435 | 2011_003011 1436 | 2011_003019 1437 | 2011_003030 1438 | 2011_003055 1439 | 2011_003085 1440 | 2011_003103 1441 | 2011_003114 1442 | 2011_003145 1443 | 2011_003146 1444 | 2011_003182 1445 | 2011_003197 1446 | 2011_003205 1447 | 2011_003240 1448 | 2011_003256 1449 | 2011_003271 1450 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import cv2 5 | import numpy as np 6 | import sys 7 | from collections import OrderedDict 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import torchvision.models as models 14 | import torch.nn.functional as F 15 | from torch.utils import data, model_zoo 16 | import torch.backends.cudnn as cudnn 17 | 18 | from model.deeplabv2 import Res_Deeplab 19 | #from model.deeplabv3p import Res_Deeplab 20 | from data.voc_dataset import VOCDataSet 21 | from data import get_data_path, get_loader 22 | import torchvision.transforms as transform 23 | 24 | from PIL import Image 25 | import scipy.misc 26 | 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATASET = 'pascal_voc' # pascal_context 31 | 32 | MODEL = 'deeplabv2' # deeeplabv2, deeplabv3p 33 | DATA_DIRECTORY = './data/voc_dataset/' 34 | DATA_LIST_PATH = './data/voc_list/val.txt' 35 | IGNORE_LABEL = 255 36 | NUM_CLASSES = 21 # 60 for pascal context 37 | RESTORE_FROM = '' 38 | PRETRAINED_MODEL = None 39 | SAVE_DIRECTORY = 'results' 40 | MLMT_FILE = './mlmt_output/output_ema_p_1_0_voc_5.txt' 41 | 42 | def get_arguments(): 43 | """Parse all the arguments provided from the CLI. 44 | 45 | Returns: 46 | A list of parsed arguments. 47 | """ 48 | parser = argparse.ArgumentParser(description="VOC evaluation script") 49 | parser.add_argument("--model", type=str, default=MODEL, 50 | help="available options : DeepLab/DRN") 51 | parser.add_argument("--dataset", type=str, default=DATASET, 52 | help="dataset name pascal_voc or pascal_context") 53 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 54 | help="Path to the directory containing the PASCAL VOC dataset.") 55 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 56 | help="Path to the file listing the images in the dataset.") 57 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 58 | help="The index of the label to ignore during the training.") 59 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 60 | help="Number of classes to predict (including background).") 61 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 62 | help="Where restore model parameters from.") 63 | parser.add_argument("--mlmt-file", type=str, default=MLMT_FILE, 64 | help="Where MLMT output") 65 | parser.add_argument("--save-dir", type=str, default=SAVE_DIRECTORY, 66 | help="Directory to store results") 67 | parser.add_argument("--gpu", type=int, default=0, 68 | help="choose gpu device.") 69 | parser.add_argument("--with-mlmt", action="store_true", 70 | help="combine with Multi-Label Mean Teacher branch") 71 | parser.add_argument("--save-output-images", action="store_true", 72 | help="save output images") 73 | return parser.parse_args() 74 | 75 | 76 | class VOCColorize(object): 77 | def __init__(self, n=22): 78 | self.cmap = color_map(22) 79 | self.cmap = torch.from_numpy(self.cmap[:n]) 80 | 81 | def __call__(self, gray_image): 82 | size = gray_image.shape 83 | color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8) 84 | 85 | for label in range(0, len(self.cmap)): 86 | mask = (label == gray_image) 87 | color_image[0][mask] = self.cmap[label][0] 88 | color_image[1][mask] = self.cmap[label][1] 89 | color_image[2][mask] = self.cmap[label][2] 90 | 91 | # handle void 92 | mask = (255 == gray_image) 93 | color_image[0][mask] = color_image[1][mask] = color_image[2][mask] = 255 94 | 95 | return color_image 96 | 97 | def color_map(N=256, normalized=False): 98 | def bitget(byteval, idx): 99 | return ((byteval & (1 << idx)) != 0) 100 | 101 | dtype = 'float32' if normalized else 'uint8' 102 | cmap = np.zeros((N, 3), dtype=dtype) 103 | for i in range(N): 104 | r = g = b = 0 105 | c = i 106 | for j in range(8): 107 | r = r | (bitget(c, 0) << 7-j) 108 | g = g | (bitget(c, 1) << 7-j) 109 | b = b | (bitget(c, 2) << 7-j) 110 | c = c >> 3 111 | 112 | cmap[i] = np.array([r, g, b]) 113 | 114 | cmap = cmap/255 if normalized else cmap 115 | return cmap 116 | 117 | def get_label_vector(target, nclass): 118 | # target is a 3D Variable BxHxW, output is 2D BxnClass 119 | hist, _ = np.histogram(target, bins=nclass, range=(0, nclass-1)) 120 | vect = hist>0 121 | vect_out = np.zeros((21,1)) 122 | for i in range(len(vect)): 123 | if vect[i] == True: 124 | vect_out[i] = 1 125 | else: 126 | vect_out[i] = 0 127 | 128 | return vect_out 129 | 130 | def get_iou(args, data_list, class_num, save_path=None): 131 | from multiprocessing import Pool 132 | from utils.metric import ConfusionMatrix 133 | 134 | ConfM = ConfusionMatrix(class_num) 135 | f = ConfM.generateM 136 | pool = Pool() 137 | m_list = pool.map(f, data_list) 138 | pool.close() 139 | pool.join() 140 | 141 | for m in m_list: 142 | ConfM.addM(m) 143 | 144 | aveJ, j_list, M = ConfM.jaccard() 145 | 146 | if args.dataset == 'pascal_voc': 147 | classes = np.array(('background', # always index 0 148 | 'aeroplane', 'bicycle', 'bird', 'boat', 149 | 'bottle', 'bus', 'car', 'cat', 'chair', 150 | 'cow', 'diningtable', 'dog', 'horse', 151 | 'motorbike', 'person', 'pottedplant', 152 | 'sheep', 'sofa', 'train', 'tvmonitor')) 153 | elif args.dataset == 'pascal_context': 154 | classes = np.array(('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse', 'motorbike', 'person', 155 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building', 'cabinet' , 'ceiling', 'cloth', 'computer', 'cup', 156 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 'sign', 'plate', 157 | 'road', 'rock', 'shelves', 'sidewalk', 'sky', 'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water', 'window', 'wood')) 158 | elif args.dataset == 'cityscapes': 159 | classes = np.array(("road", "sidewalk", 160 | "building", "wall", "fence", "pole", 161 | "traffic_light", "traffic_sign", "vegetation", 162 | "terrain", "sky", "person", "rider", 163 | "car", "truck", "bus", 164 | "train", "motorcycle", "bicycle")) 165 | 166 | for i, iou in enumerate(j_list): 167 | if j_list[i] > 0: 168 | print('class {:2d} {:12} IU {:.2f}'.format(i, classes[i], j_list[i])) 169 | 170 | print('meanIOU: ' + str(aveJ) + '\n') 171 | if save_path: 172 | with open(save_path, 'w') as f: 173 | for i, iou in enumerate(j_list): 174 | f.write('class {:2d} {:12} IU {:.2f}'.format(i, classes[i], j_list[i]) + '\n') 175 | f.write('meanIOU: ' + str(aveJ) + '\n') 176 | 177 | def main(): 178 | """Create the model and start the evaluation process.""" 179 | 180 | args = get_arguments() 181 | gpu0 = args.gpu 182 | 183 | if not os.path.exists(args.save_dir): 184 | os.makedirs(args.save_dir) 185 | 186 | model = Res_Deeplab(num_classes=args.num_classes) 187 | model.cuda() 188 | 189 | model = torch.nn.DataParallel(model).cuda() 190 | cudnn.benchmark = True 191 | 192 | if args.restore_from[:4] == 'http' : 193 | saved_state_dict = model_zoo.load_url(args.restore_from) 194 | else: 195 | saved_state_dict = torch.load(args.restore_from) 196 | model.load_state_dict(saved_state_dict) 197 | 198 | model.eval() 199 | model.cuda(gpu0) 200 | 201 | if args.dataset == 'pascal_voc': 202 | testloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), 203 | batch_size=1, shuffle=False, pin_memory=True) 204 | interp = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) 205 | 206 | elif args.dataset == 'pascal_context': 207 | input_transform = transform.Compose([transform.ToTensor(), 208 | transform.Normalize([.485, .456, .406], [.229, .224, .225])]) 209 | data_kwargs = {'transform': input_transform, 'base_size': 512, 'crop_size': 512} 210 | data_loader = get_loader('pascal_context') 211 | data_path = get_data_path('pascal_context') 212 | test_dataset = data_loader(data_path, split='val', mode='val', **data_kwargs) 213 | testloader = data.DataLoader(test_dataset, batch_size=1, drop_last=False, shuffle=False, num_workers=1, pin_memory=True) 214 | interp = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True) 215 | 216 | elif args.dataset == 'cityscapes': 217 | data_loader = get_loader('cityscapes') 218 | data_path = get_data_path('cityscapes') 219 | test_dataset = data_loader( data_path, img_size=(512, 1024), is_transform=True, split='val') 220 | testloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True) 221 | interp = nn.Upsample(size=(512, 1024), mode='bilinear', align_corners=True) 222 | 223 | data_list = [] 224 | colorize = VOCColorize() 225 | 226 | if args.with_mlmt: 227 | mlmt_preds = np.loadtxt(args.mlmt_file, dtype = float) 228 | 229 | mlmt_preds[mlmt_preds>=0.2] = 1 230 | mlmt_preds[mlmt_preds<0.2] = 0 231 | 232 | for index, batch in enumerate(testloader): 233 | if index % 100 == 0: 234 | print('%d processd'%(index)) 235 | image, label, size, name, _ = batch 236 | size = size[0] 237 | output = model(Variable(image, volatile=True).cuda(gpu0)) 238 | output = interp(output).cpu().data[0].numpy() 239 | 240 | if args.dataset == 'pascal_voc': 241 | output = output[:,:size[0],:size[1]] 242 | gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int) 243 | elif args.dataset == 'pascal_context': 244 | gt = np.asarray(label[0].numpy(), dtype=np.int) 245 | elif args.dataset == 'cityscapes': 246 | gt = np.asarray(label[0].numpy(), dtype=np.int) 247 | 248 | if args.with_mlmt: 249 | for i in range(args.num_classes): 250 | output[i]= output[i]*mlmt_preds[index][i] 251 | 252 | output = output.transpose(1,2,0) 253 | output = np.asarray(np.argmax(output, axis=2), dtype=np.int) 254 | 255 | if args.save_output_images: 256 | if args.dataset == 'pascal_voc': 257 | filename = os.path.join(args.save_dir, '{}.png'.format(name[0])) 258 | color_file = Image.fromarray(colorize(output).transpose(1, 2, 0), 'RGB') 259 | color_file.save(filename) 260 | elif args.dataset == 'pascal_context': 261 | filename = os.path.join(args.save_dir, filename[0]) 262 | scipy.misc.imsave(filename, gt) 263 | 264 | data_list.append([gt.flatten(), output.flatten()]) 265 | 266 | filename = os.path.join(args.save_dir, 'result.txt') 267 | get_iou(args, data_list, args.num_classes, filename) 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /model/deeplabv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the implementation of DeepLabv2 without multi-scale inputs. This implementation uses ResNet-101 by default. 3 | """ 4 | 5 | import torch.nn as nn 6 | import math 7 | import torch.utils.model_zoo as model_zoo 8 | import torch 9 | import numpy as np 10 | affine_par = True 11 | 12 | 13 | def outS(i): 14 | i = int(i) 15 | i = (i+1)/2 16 | i = int(np.ceil((i+1)/2.0)) 17 | i = (i+1)/2 18 | return i 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 64 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par) 65 | for i in self.bn1.parameters(): 66 | i.requires_grad = False 67 | 68 | padding = dilation 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 70 | padding=padding, bias=False, dilation = dilation) 71 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par) 72 | for i in self.bn2.parameters(): 73 | i.requires_grad = False 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par) 76 | for i in self.bn3.parameters(): 77 | i.requires_grad = False 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | class Classifier_Module(nn.Module): 106 | 107 | def __init__(self, dilation_series, padding_series, num_classes): 108 | super(Classifier_Module, self).__init__() 109 | self.conv2d_list = nn.ModuleList() 110 | for dilation, padding in zip(dilation_series, padding_series): 111 | self.conv2d_list.append(nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias = True)) 112 | 113 | for m in self.conv2d_list: 114 | m.weight.data.normal_(0, 0.01) 115 | 116 | def forward(self, x): 117 | out = self.conv2d_list[0](x) 118 | for i in range(len(self.conv2d_list)-1): 119 | out += self.conv2d_list[i+1](x) 120 | return out 121 | 122 | 123 | 124 | class ResNet(nn.Module): 125 | def __init__(self, block, layers, num_classes): 126 | self.inplanes = 64 127 | super(ResNet, self).__init__() 128 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 129 | bias=False) 130 | self.bn1 = nn.BatchNorm2d(64, affine = affine_par) 131 | for i in self.bn1.parameters(): 132 | i.requires_grad = False 133 | self.relu = nn.ReLU(inplace=True) 134 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 135 | self.layer1 = self._make_layer(block, 64, layers[0]) 136 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 139 | self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, 0.01) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | # for i in m.parameters(): 149 | # i.requires_grad = False 150 | 151 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 154 | downsample = nn.Sequential( 155 | nn.Conv2d(self.inplanes, planes * block.expansion, 156 | kernel_size=1, stride=stride, bias=False), 157 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par)) 158 | for i in downsample._modules['1'].parameters(): 159 | i.requires_grad = False 160 | layers = [] 161 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample)) 162 | self.inplanes = planes * block.expansion 163 | for i in range(1, blocks): 164 | layers.append(block(self.inplanes, planes, dilation=dilation)) 165 | 166 | return nn.Sequential(*layers) 167 | def _make_pred_layer(self,block, dilation_series, padding_series,num_classes): 168 | return block(dilation_series,padding_series,num_classes) 169 | 170 | def forward(self, x): 171 | x = self.conv1(x) 172 | x = self.bn1(x) 173 | x = self.relu(x) 174 | x = self.maxpool(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.layer5(x) 180 | 181 | return x 182 | 183 | def get_1x_lr_params_NOscale(self): 184 | """ 185 | This generator returns all the parameters of the net except for 186 | the last classification layer. Note that for each batchnorm layer, 187 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 188 | any batchnorm parameter 189 | """ 190 | b = [] 191 | 192 | b.append(self.conv1) 193 | b.append(self.bn1) 194 | b.append(self.layer1) 195 | b.append(self.layer2) 196 | b.append(self.layer3) 197 | b.append(self.layer4) 198 | 199 | 200 | for i in range(len(b)): 201 | for j in b[i].modules(): 202 | jj = 0 203 | for k in j.parameters(): 204 | jj+=1 205 | if k.requires_grad: 206 | yield k 207 | 208 | def get_10x_lr_params(self): 209 | """ 210 | This generator returns all the parameters for the last layer of the net, 211 | which does the classification of pixel into classes 212 | """ 213 | b = [] 214 | b.append(self.layer5.parameters()) 215 | 216 | for j in range(len(b)): 217 | for i in b[j]: 218 | yield i 219 | 220 | 221 | 222 | def optim_parameters(self, args): 223 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate}, 224 | {'params': self.get_10x_lr_params(), 'lr': 10*args.learning_rate}] 225 | 226 | 227 | def Res_Deeplab(num_classes=21): 228 | model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes) 229 | return model 230 | 231 | -------------------------------------------------------------------------------- /model/deeplabv3p.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from: 3 | https://github.com/sthalles/deeplab_v3 4 | Copyright 2020 Nvidia Corporation 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 3. Neither the name of the copyright holder nor the names of its contributors 13 | may be used to endorse or promote products derived from this software 14 | without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | POSSIBILITY OF SUCH DAMAGE. 26 | """ 27 | import torch 28 | from torch import nn 29 | 30 | #from model.mynn import initialize_weights, Norm2d, Upsample 31 | from model.utils import get_aspp, get_trunk 32 | 33 | def initialize_weights(*models): 34 | """ 35 | Initialize Model Weights 36 | """ 37 | for model in models: 38 | for module in model.modules(): 39 | if isinstance(module, (nn.Conv2d, nn.Linear)): 40 | nn.init.kaiming_normal_(module.weight) 41 | if module.bias is not None: 42 | module.bias.data.zero_() 43 | elif isinstance(module, nn.BatchNorm2d): 44 | module.weight.data.fill_(1) 45 | module.bias.data.zero_() 46 | 47 | class DeepV3Plus(nn.Module): 48 | """ 49 | DeepLabV3+ with various trunks supported 50 | Always stride8 51 | """ 52 | def __init__(self, num_classes, trunk='wrn38', 53 | use_dpc=False, init_all=False): 54 | super(DeepV3Plus, self).__init__() 55 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) 56 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 57 | bottleneck_ch=256, 58 | output_stride=8, 59 | dpc=use_dpc) 60 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) 61 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 62 | self.final = nn.Sequential( 63 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 64 | nn.BatchNorm2d(256), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 67 | nn.BatchNorm2d(256), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 70 | 71 | if init_all: 72 | initialize_weights(self.aspp) 73 | initialize_weights(self.bot_aspp) 74 | initialize_weights(self.bot_fine) 75 | initialize_weights(self.final) 76 | else: 77 | initialize_weights(self.final) 78 | 79 | def forward(self, x): 80 | x_size = x.size() 81 | s2_features, _, final_features = self.backbone(x) 82 | 83 | aspp = self.aspp(final_features) 84 | conv_aspp = self.bot_aspp(aspp) 85 | conv_s2 = self.bot_fine(s2_features) 86 | conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) 87 | #print(conv_s2.size(), conv_aspp.size()) 88 | cat_s4 = [conv_s2, conv_aspp] 89 | cat_s4 = torch.cat(cat_s4, 1) 90 | final = self.final(cat_s4) 91 | #print(final.size()) 92 | #out = Upsample(final, x_size[2:]) 93 | return final 94 | 95 | def get_1x_lr_params_NOscale(self): 96 | """ 97 | This generator returns all the parameters of the net except for 98 | the last classification layer. Note that for each batchnorm layer, 99 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 100 | any batchnorm parameter 101 | """ 102 | b = [] 103 | 104 | b.append(self.backbone) 105 | 106 | for i in range(len(b)): 107 | for j in b[i].modules(): 108 | jj = 0 109 | for k in j.parameters(): 110 | jj+=1 111 | if k.requires_grad: 112 | yield k 113 | 114 | def get_10x_lr_params(self): 115 | """ 116 | This generator returns all the parameters of the net except for 117 | the last classification layer. Note that for each batchnorm layer, 118 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 119 | any batchnorm parameter 120 | """ 121 | b = [] 122 | 123 | b.append(self.aspp) 124 | b.append(self.bot_aspp) 125 | b.append(self.bot_fine) 126 | b.append(self.final) 127 | 128 | for i in range(len(b)): 129 | for j in b[i].modules(): 130 | jj = 0 131 | for k in j.parameters(): 132 | jj+=1 133 | if k.requires_grad: 134 | yield k 135 | 136 | def optim_parameters(self, args): 137 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate}, 138 | {'params': self.get_10x_lr_params(), 'lr': 10*args.learning_rate}] 139 | 140 | 141 | def Upsample(x, size): 142 | """ 143 | Wrapper Around the Upsample Call 144 | """ 145 | return nn.functional.interpolate(x, size=size, mode='bilinear', 146 | align_corners=False) 147 | 148 | def DeepV3PlusW38(num_classes): 149 | return DeepV3Plus(num_classes, trunk='wrn38') 150 | 151 | 152 | def DeepV3PlusW38I(num_classes): 153 | return DeepV3Plus(num_classes, trunk='wrn38', init_all=True) 154 | 155 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch.nn as nn 3 | 4 | class s4GAN_discriminator(nn.Module): 5 | 6 | def __init__(self, num_classes, dataset, ndf = 64): 7 | super(s4GAN_discriminator, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(num_classes+3, ndf, kernel_size=4, stride=2, padding=1) # 160 x 160 10 | self.conv2 = nn.Conv2d( ndf, ndf*2, kernel_size=4, stride=2, padding=1) # 80 x 80 11 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) # 40 x 40 12 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) # 20 x 20 13 | if dataset == 'pascal_voc' or dataset == 'pascal_context': 14 | self.avgpool = nn.AvgPool2d((20, 20)) 15 | elif dataset == 'cityscapes': 16 | self.avgpool = nn.AvgPool2d((16, 32)) 17 | self.fc = nn.Linear(ndf*8, 1) 18 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 19 | self.drop = nn.Dropout2d(0.5) 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | 23 | def forward(self, x): 24 | 25 | x = self.conv1(x) 26 | x = self.leaky_relu(x) 27 | x = self.drop(x) 28 | 29 | x = self.conv2(x) 30 | x = self.leaky_relu(x) 31 | x = self.drop(x) 32 | 33 | x = self.conv3(x) 34 | x = self.leaky_relu(x) 35 | x = self.drop(x) 36 | 37 | x = self.conv4(x) 38 | x = self.leaky_relu(x) 39 | 40 | maps = self.avgpool(x) 41 | conv4_maps = maps 42 | out = maps.view(maps.size(0), -1) 43 | out = self.sigmoid(self.fc(out)) 44 | 45 | return out, conv4_maps 46 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | from collections import OrderedDict 31 | 32 | import torch 33 | import torch.nn.functional as F 34 | 35 | from torch import nn 36 | 37 | #from model.mynn import Norm2d, Upsample 38 | from model.wider_resnet import wrn38 39 | 40 | def Upsample(x, size): 41 | """ 42 | Wrapper Around the Upsample Call 43 | """ 44 | return nn.functional.interpolate(x, size=size, mode='bilinear', 45 | align_corners=False) 46 | 47 | def get_trunk(trunk_name, output_stride=8): 48 | """ 49 | Retrieve the network trunk and channel counts. 50 | """ 51 | assert output_stride == 8, 'Only stride8 supported right now' 52 | 53 | if trunk_name == 'wrn38': 54 | # 55 | # FIXME: pass in output_stride once we support stride 16 56 | # 57 | backbone = wrn38(pretrained=True) 58 | s2_ch = 128 59 | s4_ch = 256 60 | high_level_ch = 4096 61 | else: 62 | raise 'unknown backbone {}'.format(trunk_name) 63 | 64 | return backbone, s2_ch, s4_ch, high_level_ch 65 | 66 | class AtrousSpatialPyramidPoolingModule(nn.Module): 67 | """ 68 | operations performed: 69 | 1x1 x depth 70 | 3x3 x depth dilation 6 71 | 3x3 x depth dilation 12 72 | 3x3 x depth dilation 18 73 | image pooling 74 | concatenate all together 75 | Final 1x1 conv 76 | """ 77 | 78 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, 79 | rates=(6, 12, 18)): 80 | super(AtrousSpatialPyramidPoolingModule, self).__init__() 81 | 82 | if output_stride == 8: 83 | rates = [2 * r for r in rates] 84 | elif output_stride == 16: 85 | pass 86 | else: 87 | raise 'output stride of {} not supported'.format(output_stride) 88 | 89 | self.features = [] 90 | # 1x1 91 | self.features.append( 92 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, 93 | bias=False), 94 | nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True))) 95 | # other rates 96 | for r in rates: 97 | self.features.append(nn.Sequential( 98 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 99 | dilation=r, padding=r, bias=False), 100 | nn.BatchNorm2d(reduction_dim), 101 | nn.ReLU(inplace=True) 102 | )) 103 | self.features = nn.ModuleList(self.features) 104 | 105 | # img level features 106 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 107 | self.img_conv = nn.Sequential( 108 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 109 | nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True)) 110 | 111 | def forward(self, x): 112 | x_size = x.size() 113 | 114 | img_features = self.img_pooling(x) 115 | img_features = self.img_conv(img_features) 116 | img_features = Upsample(img_features, x_size[2:]) 117 | out = img_features 118 | 119 | for f in self.features: 120 | y = f(x) 121 | out = torch.cat((out, y), 1) 122 | return out 123 | 124 | def dpc_conv(in_dim, reduction_dim, dil, separable): 125 | if separable: 126 | groups = reduction_dim 127 | else: 128 | groups = 1 129 | 130 | return nn.Sequential( 131 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, dilation=dil, 132 | padding=dil, bias=False, groups=groups), 133 | nn.BatchNorm2d(reduction_dim), 134 | nn.ReLU(inplace=True) 135 | ) 136 | 137 | 138 | class DPC(nn.Module): 139 | ''' 140 | From: Searching for Efficient Multi-scale architectures for dense 141 | prediction 142 | ''' 143 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, 144 | rates=[(1, 6), (18, 15), (6, 21), (1, 1), (6, 3)], 145 | dropout=False, separable=False): 146 | super(DPC, self).__init__() 147 | 148 | self.dropout = dropout 149 | if output_stride == 8: 150 | rates = [(2 * r[0], 2 * r[1]) for r in rates] 151 | elif output_stride == 16: 152 | pass 153 | else: 154 | raise 'output stride of {} not supported'.format(output_stride) 155 | 156 | self.a = dpc_conv(in_dim, reduction_dim, rates[0], separable) 157 | self.b = dpc_conv(reduction_dim, reduction_dim, rates[1], separable) 158 | self.c = dpc_conv(reduction_dim, reduction_dim, rates[2], separable) 159 | self.d = dpc_conv(reduction_dim, reduction_dim, rates[3], separable) 160 | self.e = dpc_conv(reduction_dim, reduction_dim, rates[4], separable) 161 | 162 | self.drop = nn.Dropout(p=0.1) 163 | 164 | def forward(self, x): 165 | a = self.a(x) 166 | b = self.b(a) 167 | c = self.c(a) 168 | d = self.d(a) 169 | e = self.e(b) 170 | out = torch.cat((a, b, c, d, e), 1) 171 | if self.dropout: 172 | out = self.drop(out) 173 | return out 174 | 175 | 176 | def get_aspp(high_level_ch, bottleneck_ch, output_stride, dpc=False): 177 | """ 178 | Create aspp block 179 | """ 180 | if dpc: 181 | aspp = DPC(high_level_ch, bottleneck_ch, output_stride=output_stride) 182 | else: 183 | aspp = AtrousSpatialPyramidPoolingModule(high_level_ch, bottleneck_ch, 184 | output_stride=output_stride) 185 | aspp_out_ch = 5 * bottleneck_ch 186 | return aspp, aspp_out_ch 187 | 188 | -------------------------------------------------------------------------------- /model/wider_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/mapillary/inplace_abn/ 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, mapillary 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | import logging 36 | import sys 37 | from collections import OrderedDict 38 | from functools import partial 39 | import torch.nn as nn 40 | import torch 41 | #import model.mynn as mynn 42 | #from config import cfg 43 | 44 | 45 | def bnrelu(channels): 46 | """ 47 | Single Layer BN and Relui 48 | """ 49 | return nn.Sequential(nn.BatchNorm2d(channels), 50 | nn.ReLU(inplace=True)) 51 | 52 | 53 | class GlobalAvgPool2d(nn.Module): 54 | """ 55 | Global average pooling over the input's spatial dimensions 56 | """ 57 | 58 | def __init__(self): 59 | super(GlobalAvgPool2d, self).__init__() 60 | logging.info("Global Average Pooling Initialized") 61 | 62 | def forward(self, inputs): 63 | in_size = inputs.size() 64 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 65 | 66 | 67 | class IdentityResidualBlock(nn.Module): 68 | """ 69 | Identity Residual Block for WideResnet 70 | """ 71 | def __init__(self, 72 | in_channels, 73 | channels, 74 | stride=1, 75 | dilation=1, 76 | groups=1, 77 | norm_act=bnrelu, 78 | dropout=None, 79 | dist_bn=False 80 | ): 81 | """Configurable identity-mapping residual block 82 | 83 | Parameters 84 | ---------- 85 | in_channels : int 86 | Number of input channels. 87 | channels : list of int 88 | Number of channels in the internal feature maps. 89 | Can either have two or three elements: if three construct 90 | a residual block with two `3 x 3` convolutions, 91 | otherwise construct a bottleneck block with `1 x 1`, then 92 | `3 x 3` then `1 x 1` convolutions. 93 | stride : int 94 | Stride of the first `3 x 3` convolution 95 | dilation : int 96 | Dilation to apply to the `3 x 3` convolutions. 97 | groups : int 98 | Number of convolution groups. 99 | This is used to create ResNeXt-style blocks and is only compatible with 100 | bottleneck blocks. 101 | norm_act : callable 102 | Function to create normalization / activation Module. 103 | dropout: callable 104 | Function to create Dropout Module. 105 | dist_bn: Boolean 106 | A variable to enable or disable use of distributed BN 107 | """ 108 | super(IdentityResidualBlock, self).__init__() 109 | self.dist_bn = dist_bn 110 | 111 | # Check if we are using distributed BN and use the nn from encoding.nn 112 | # library rather than using standard pytorch.nn 113 | 114 | 115 | # Check parameters for inconsistencies 116 | if len(channels) != 2 and len(channels) != 3: 117 | raise ValueError("channels must contain either two or three values") 118 | if len(channels) == 2 and groups != 1: 119 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 120 | 121 | is_bottleneck = len(channels) == 3 122 | need_proj_conv = stride != 1 or in_channels != channels[-1] 123 | 124 | self.bn1 = norm_act(in_channels) 125 | if not is_bottleneck: 126 | layers = [ 127 | ("conv1", nn.Conv2d(in_channels, 128 | channels[0], 129 | 3, 130 | stride=stride, 131 | padding=dilation, 132 | bias=False, 133 | dilation=dilation)), 134 | ("bn2", norm_act(channels[0])), 135 | ("conv2", nn.Conv2d(channels[0], channels[1], 136 | 3, 137 | stride=1, 138 | padding=dilation, 139 | bias=False, 140 | dilation=dilation)) 141 | ] 142 | if dropout is not None: 143 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 144 | else: 145 | layers = [ 146 | ("conv1", 147 | nn.Conv2d(in_channels, 148 | channels[0], 149 | 1, 150 | stride=stride, 151 | padding=0, 152 | bias=False)), 153 | ("bn2", norm_act(channels[0])), 154 | ("conv2", nn.Conv2d(channels[0], 155 | channels[1], 156 | 3, stride=1, 157 | padding=dilation, bias=False, 158 | groups=groups, 159 | dilation=dilation)), 160 | ("bn3", norm_act(channels[1])), 161 | ("conv3", nn.Conv2d(channels[1], channels[2], 162 | 1, stride=1, padding=0, bias=False)) 163 | ] 164 | if dropout is not None: 165 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 166 | self.convs = nn.Sequential(OrderedDict(layers)) 167 | 168 | if need_proj_conv: 169 | self.proj_conv = nn.Conv2d( 170 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 171 | 172 | def forward(self, x): 173 | """ 174 | This is the standard forward function for non-distributed batch norm 175 | """ 176 | if hasattr(self, "proj_conv"): 177 | bn1 = self.bn1(x) 178 | shortcut = self.proj_conv(bn1) 179 | else: 180 | shortcut = x.clone() 181 | bn1 = self.bn1(x) 182 | 183 | out = self.convs(bn1) 184 | out.add_(shortcut) 185 | return out 186 | 187 | 188 | 189 | 190 | class WiderResNet(nn.Module): 191 | """ 192 | WideResnet Global Module for Initialization 193 | """ 194 | def __init__(self, 195 | structure, 196 | norm_act=bnrelu, 197 | classes=0 198 | ): 199 | """Wider ResNet with pre-activation (identity mapping) blocks 200 | 201 | Parameters 202 | ---------- 203 | structure : list of int 204 | Number of residual blocks in each of the six modules of the network. 205 | norm_act : callable 206 | Function to create normalization / activation Module. 207 | classes : int 208 | If not `0` also include global average pooling and \ 209 | a fully-connected layer with `classes` outputs at the end 210 | of the network. 211 | """ 212 | super(WiderResNet, self).__init__() 213 | self.structure = structure 214 | 215 | if len(structure) != 6: 216 | raise ValueError("Expected a structure with six values") 217 | 218 | # Initial layers 219 | self.mod1 = nn.Sequential(OrderedDict([ 220 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 221 | ])) 222 | 223 | # Groups of residual blocks 224 | in_channels = 64 225 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), 226 | (512, 1024, 2048), (1024, 2048, 4096)] 227 | for mod_id, num in enumerate(structure): 228 | # Create blocks for module 229 | blocks = [] 230 | for block_id in range(num): 231 | blocks.append(( 232 | "block%d" % (block_id + 1), 233 | IdentityResidualBlock(in_channels, channels[mod_id], 234 | norm_act=norm_act) 235 | )) 236 | 237 | # Update channels and p_keep 238 | in_channels = channels[mod_id][-1] 239 | 240 | # Create module 241 | if mod_id <= 4: 242 | self.add_module("pool%d" % 243 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 244 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 245 | 246 | # Pooling and predictor 247 | self.bn_out = norm_act(in_channels) 248 | if classes != 0: 249 | self.classifier = nn.Sequential(OrderedDict([ 250 | ("avg_pool", GlobalAvgPool2d()), 251 | ("fc", nn.Linear(in_channels, classes)) 252 | ])) 253 | 254 | def forward(self, img): 255 | out = self.mod1(img) 256 | out = self.mod2(self.pool2(out)) 257 | out = self.mod3(self.pool3(out)) 258 | out = self.mod4(self.pool4(out)) 259 | out = self.mod5(self.pool5(out)) 260 | out = self.mod6(self.pool6(out)) 261 | out = self.mod7(out) 262 | out = self.bn_out(out) 263 | 264 | if hasattr(self, "classifier"): 265 | out = self.classifier(out) 266 | 267 | return out 268 | 269 | 270 | class WiderResNetA2(nn.Module): 271 | """ 272 | Wider ResNet with pre-activation (identity mapping) blocks 273 | 274 | This variant uses down-sampling by max-pooling in the first two blocks and 275 | by strided convolution in the others. 276 | 277 | Parameters 278 | ---------- 279 | structure : list of int 280 | Number of residual blocks in each of the six modules of the network. 281 | norm_act : callable 282 | Function to create normalization / activation Module. 283 | classes : int 284 | If not `0` also include global average pooling and a fully-connected layer 285 | with `classes` outputs at the end 286 | of the network. 287 | dilation : bool 288 | If `True` apply dilation to the last three modules and change the 289 | down-sampling factor from 32 to 8. 290 | """ 291 | def __init__(self, 292 | structure, 293 | norm_act=bnrelu, 294 | classes=0, 295 | dilation=False, 296 | dist_bn=False 297 | ): 298 | super(WiderResNetA2, self).__init__() 299 | self.dist_bn = dist_bn 300 | 301 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn 302 | nn.Dropout = nn.Dropout2d 303 | norm_act = bnrelu 304 | self.structure = structure 305 | self.dilation = dilation 306 | 307 | if len(structure) != 6: 308 | raise ValueError("Expected a structure with six values") 309 | 310 | # Initial layers 311 | self.mod1 = torch.nn.Sequential(OrderedDict([ 312 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 313 | ])) 314 | 315 | # Groups of residual blocks 316 | in_channels = 64 317 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 318 | (1024, 2048, 4096)] 319 | for mod_id, num in enumerate(structure): 320 | # Create blocks for module 321 | blocks = [] 322 | for block_id in range(num): 323 | if not dilation: 324 | dil = 1 325 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 326 | else: 327 | if mod_id == 3: 328 | dil = 2 329 | elif mod_id > 3: 330 | dil = 4 331 | else: 332 | dil = 1 333 | stride = 2 if block_id == 0 and mod_id == 2 else 1 334 | 335 | if mod_id == 4: 336 | drop = partial(nn.Dropout, p=0.3) 337 | elif mod_id == 5: 338 | drop = partial(nn.Dropout, p=0.5) 339 | else: 340 | drop = None 341 | 342 | blocks.append(( 343 | "block%d" % (block_id + 1), 344 | IdentityResidualBlock(in_channels, 345 | channels[mod_id], norm_act=norm_act, 346 | stride=stride, dilation=dil, 347 | dropout=drop, dist_bn=self.dist_bn) 348 | )) 349 | 350 | # Update channels and p_keep 351 | in_channels = channels[mod_id][-1] 352 | 353 | # Create module 354 | if mod_id < 2: 355 | self.add_module("pool%d" % 356 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 357 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 358 | 359 | # Pooling and predictor 360 | self.bn_out = norm_act(in_channels) 361 | if classes != 0: 362 | self.classifier = nn.Sequential(OrderedDict([ 363 | ("avg_pool", GlobalAvgPool2d()), 364 | ("fc", nn.Linear(in_channels, classes)) 365 | ])) 366 | 367 | def forward(self, img): 368 | out = self.mod1(img) 369 | out = self.mod2(self.pool2(out)) # s2 370 | out = self.mod3(self.pool3(out)) # s4 371 | out = self.mod4(out) # s8 372 | out = self.mod5(out) 373 | out = self.mod6(out) 374 | out = self.mod7(out) 375 | out = self.bn_out(out) 376 | 377 | if hasattr(self, "classifier"): 378 | return self.classifier(out) 379 | return out 380 | 381 | 382 | _NETS = { 383 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 384 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 385 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 386 | } 387 | 388 | __all__ = [] 389 | for name, params in _NETS.items(): 390 | net_name = "wider_resnet" + name 391 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 392 | __all__.append(net_name) 393 | for name, params in _NETS.items(): 394 | net_name = "wider_resnet" + name + "_a2" 395 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 396 | __all__.append(net_name) 397 | 398 | 399 | class wrn38(nn.Module): 400 | """ 401 | This is wider resnet 38, output_stride=8 402 | """ 403 | def __init__(self, pretrained=True): 404 | super(wrn38, self).__init__() 405 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) 406 | wide_resnet = torch.nn.DataParallel(wide_resnet) 407 | if pretrained: 408 | pretrained_model = './pretrained_models/wider_resnet38.pth.tar' 409 | checkpoint = torch.load(pretrained_model, map_location='cpu') 410 | wide_resnet.load_state_dict(checkpoint['state_dict']) 411 | del checkpoint 412 | wide_resnet = wide_resnet.module 413 | # print(wide_resnet) 414 | self.mod1 = wide_resnet.mod1 415 | self.mod2 = wide_resnet.mod2 416 | self.mod3 = wide_resnet.mod3 417 | self.mod4 = wide_resnet.mod4 418 | self.mod5 = wide_resnet.mod5 419 | self.mod6 = wide_resnet.mod6 420 | self.mod7 = wide_resnet.mod7 421 | self.pool2 = wide_resnet.pool2 422 | self.pool3 = wide_resnet.pool3 423 | del wide_resnet 424 | 425 | def forward(self, x): 426 | x = self.mod1(x) 427 | x = self.mod2(self.pool2(x)) # s2 428 | s2_features = x 429 | x = self.mod3(self.pool3(x)) # s4 430 | s4_features = x 431 | x = self.mod4(x) 432 | x = self.mod5(x) 433 | x = self.mod6(x) 434 | x = self.mod7(x) 435 | return s2_features, s4_features, x 436 | 437 | 438 | class wrn38_gscnn(wrn38): 439 | def __init__(self, pretrained=True): 440 | super(wrn38_gscnn, self).__init__(pretrained=pretrained) 441 | 442 | def forward(self, x): 443 | m1 = self.mod1(x) 444 | m2 = self.mod2(self.pool2(m1)) 445 | m3 = self.mod3(self.pool3(m2)) 446 | m4 = self.mod4(m3) 447 | m5 = self.mod5(m4) 448 | m6 = self.mod6(m5) 449 | m7 = self.mod7(m6) 450 | return m1, m2, m3, m4, m5, m6, m7 451 | -------------------------------------------------------------------------------- /splits/city/split_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sud0301/semisup-semseg/586609b7bd4bf851c4b1f2691584ce8b11ba6d50/splits/city/split_0.pkl -------------------------------------------------------------------------------- /splits/pc/split_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sud0301/semisup-semseg/586609b7bd4bf851c4b1f2691584ce8b11ba6d50/splits/pc/split_0.pkl -------------------------------------------------------------------------------- /splits/voc/split_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sud0301/semisup-semseg/586609b7bd4bf851c4b1f2691584ce8b11ba6d50/splits/voc/split_0.pkl -------------------------------------------------------------------------------- /train_full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import sys 4 | import os 5 | import os.path as osp 6 | import scipy.misc 7 | import random 8 | import timeit 9 | import pickle 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils import data, model_zoo 14 | from torch.autograd import Variable 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | import torch.backends.cudnn as cudnn 18 | import torchvision.transforms as transform 19 | 20 | from model.deeplabv2 import Res_Deeplab 21 | #from model.deeplabv3p import Res_Deeplab 22 | 23 | from utils.loss import CrossEntropy2d 24 | from data.voc_dataset import VOCDataSet, VOCGTDataSet 25 | from data import get_loader, get_data_path 26 | from data.augmentations import * 27 | 28 | start = timeit.default_timer() 29 | 30 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 31 | 32 | # dataset params 33 | NUM_CLASSES = 21 # 21 for PASCAL-VOC / 60 for PASCAL-Context 34 | 35 | DATASET = 'pascal_voc' #pascal_voc or pascal_context 36 | 37 | DATA_DIRECTORY = './data/voc_dataset/' 38 | DATA_LIST_PATH = './data/voc_list/train_aug.txt' 39 | CHECKPOINT_DIR = './checkpoints/voc_full/' 40 | 41 | MODEL = 'DeepLab' 42 | BATCH_SIZE = 10 43 | NUM_STEPS = 40000 44 | SAVE_PRED_EVERY = 5000 45 | 46 | INPUT_SIZE = '321, 321' 47 | IGNORE_LABEL = 255 # 255 for PASCAL-VOC / -1 for PASCAL-Context 48 | 49 | LEARNING_RATE = 2.5e-4 50 | WEIGHT_DECAY = 0.0005 51 | POWER = 0.9 52 | MOMENTUM = 0.9 53 | NUM_WORKERS = 4 54 | RANDOM_SEED = 1234 55 | 56 | RESTORE_FROM = './pretrained_models/resnet101-5d3b4d8f.pth' # ImageNet pretrained encoder 57 | 58 | SPLIT_ID = './splits/voc/split_0.pkl' 59 | LABELED_RATIO= None # use 100% labeled data 60 | 61 | def get_arguments(): 62 | """Parse all the arguments provided from the CLI. 63 | 64 | Returns: 65 | A list of parsed arguments. 66 | """ 67 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 68 | parser.add_argument("--model", type=str, default=MODEL, 69 | help="available options : DeepLab/DRN") 70 | parser.add_argument("--dataset", type=str, default=DATASET, 71 | help="dataset to be used") 72 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 73 | help="Number of images sent to the network in one step.") 74 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 75 | help="number of workers for multithread dataloading.") 76 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 77 | help="Path to the directory containing the PASCAL VOC dataset.") 78 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 79 | help="Path to the file listing the images in the dataset.") 80 | parser.add_argument("--split-id", type=str, default=SPLIT_ID, 81 | help="name of split pickle file") 82 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 83 | help="Comma-separated string with height and width of images.") 84 | parser.add_argument("--ignore-label", type=float, default=IGNORE_LABEL, 85 | help="label value to ignored for loss calculation") 86 | parser.add_argument("--labeled-ratio", type=float, default=LABELED_RATIO, 87 | help="ratio of labeled samples/total samples") 88 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 89 | help="Base learning rate for training with polynomial decay.") 90 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 91 | help="Momentum component of the optimiser.") 92 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 93 | help="Number of classes to predict (including background).") 94 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 95 | help="Number of iterations") 96 | parser.add_argument("--power", type=float, default=POWER, 97 | help="Decay parameter to compute the learning rate.") 98 | parser.add_argument("--random-mirror", action="store_true", 99 | help="Whether to randomly mirror the inputs during the training.") 100 | parser.add_argument("--random-scale", action="store_true", 101 | help="Whether to randomly scale the inputs during the training.") 102 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 103 | help="Random seed to have reproducible results.") 104 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 105 | help="Where restore model parameters from.") 106 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 107 | help="Save summaries and checkpoint every often.") 108 | parser.add_argument("--checkpoint-dir", type=str, default=CHECKPOINT_DIR, 109 | help="Where to save checkpoints of the model.") 110 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 111 | help="Regularisation parameter for L2-loss.") 112 | parser.add_argument("--gpu", type=int, default=0, 113 | help="choose gpu device.") 114 | return parser.parse_args() 115 | 116 | args = get_arguments() 117 | 118 | def loss_calc(pred, label, gpu): 119 | label = Variable(label.long()).cuda(gpu) 120 | criterion = CrossEntropy2d(ignore_label=args.ignore_label).cuda(gpu) 121 | return criterion(pred, label) 122 | 123 | def lr_poly(base_lr, iter, max_iter, power): 124 | return base_lr*((1-float(iter)/max_iter)**(power)) 125 | 126 | def adjust_learning_rate(optimizer, i_iter): 127 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 128 | optimizer.param_groups[0]['lr'] = lr 129 | if len(optimizer.param_groups) > 1 : 130 | optimizer.param_groups[1]['lr'] = lr * 10 131 | 132 | def main(): 133 | 134 | h, w = map(int, args.input_size.split(',')) 135 | input_size = (h, w) 136 | 137 | cudnn.enabled = True 138 | gpu = args.gpu 139 | 140 | # create network 141 | model = Res_Deeplab(num_classes= args.num_classes) 142 | model.cuda() 143 | 144 | # load pretrained parameters 145 | saved_state_dict = torch.load(args.restore_from) 146 | 147 | # only copy the params that exist in current model (caffe-like) 148 | new_params = model.state_dict().copy() 149 | for name, param in new_params.items(): 150 | if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 151 | new_params[name].copy_(saved_state_dict[name]) 152 | model.load_state_dict(new_params) 153 | 154 | model.train() 155 | model.cuda(args.gpu) 156 | 157 | cudnn.benchmark = True 158 | 159 | if not os.path.exists(args.checkpoint_dir): 160 | os.makedirs(args.checkpoint_dir) 161 | 162 | if args.dataset == 'pascal_voc': 163 | train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, 164 | scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) 165 | elif args.dataset == 'pascal_context': 166 | input_transform = transform.Compose([transform.ToTensor(), 167 | transform.Normalize([.485, .456, .406], [.229, .224, .225])]) 168 | data_kwargs = {'transform': input_transform, 'base_size': 505, 'crop_size': 321} 169 | #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs) 170 | data_loader = get_loader('pascal_context') 171 | data_path = get_data_path('pascal_context') 172 | train_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) 173 | 174 | elif args.dataset == 'cityscapes': 175 | data_loader = get_loader('cityscapes') 176 | data_path = get_data_path('cityscapes') 177 | data_aug = Compose([RandomCrop_city((256, 512)), RandomHorizontallyFlip()]) 178 | train_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug) 179 | 180 | train_dataset_size = len(train_dataset) 181 | print ('dataset size: ', train_dataset_size) 182 | 183 | if args.labeled_ratio is None: 184 | trainloader = data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, 185 | num_workers=4, pin_memory=True) 186 | else: 187 | partial_size = int(args.labeled_ratio * train_dataset_size) 188 | 189 | if args.split_id is not None: 190 | train_ids = pickle.load(open(args.split_id, 'rb')) 191 | print('loading train ids from {}'.format(args.split_id)) 192 | else: 193 | train_ids = np.arange(train_dataset_size) 194 | np.random.shuffle(train_ids) 195 | 196 | pickle.dump(train_ids, open(os.path.join(args.checkpoint_dir, 'split.pkl'), 'wb')) 197 | 198 | train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) 199 | 200 | trainloader = data.DataLoader(train_dataset, 201 | batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) 202 | 203 | trainloader_iter = iter(trainloader) 204 | 205 | # optimizer for segmentation network 206 | optimizer = optim.SGD(model.optim_parameters(args), 207 | lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) 208 | optimizer.zero_grad() 209 | 210 | # loss/ bilinear upsampling 211 | interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) 212 | 213 | for i_iter in range(args.num_steps): 214 | 215 | loss_value = 0 216 | optimizer.zero_grad() 217 | adjust_learning_rate(optimizer, i_iter) 218 | 219 | try: 220 | batch_lab = next(trainloader_iter) 221 | except: 222 | trainloader_iter = iter(trainloader) 223 | batch_lab = next(trainloader_iter) 224 | 225 | images, labels, _, _, index = batch_lab 226 | images = Variable(images).cuda(args.gpu) 227 | 228 | pred = interp(model(images)) 229 | loss = loss_calc(pred, labels, args.gpu) 230 | 231 | loss.backward() 232 | loss_value += loss.item() 233 | 234 | optimizer.step() 235 | 236 | print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(i_iter, args.num_steps, loss_value)) 237 | 238 | if i_iter >= args.num_steps-1: 239 | print ('save model ...') 240 | torch.save(model.state_dict(),osp.join(args.checkpoint_dir, 'VOC_'+str(args.num_steps)+'.pth')) 241 | break 242 | 243 | if i_iter % args.save_pred_every == 0 and i_iter!=0: 244 | print ('saving checkpoint ...') 245 | torch.save(model.state_dict(),osp.join(args.checkpoint_dir, 'VOC_'+str(i_iter)+'.pth')) 246 | 247 | end = timeit.default_timer() 248 | print (end-start,'seconds') 249 | 250 | if __name__ == '__main__': 251 | main() 252 | -------------------------------------------------------------------------------- /train_mlmt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | import numpy as np 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.sampler import SubsetRandomSampler 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | import torchvision.models as models 17 | 18 | import data.dataset_processing as data 19 | from data.dataset_processing import TransformTwice, GaussianBlur, update_ema_variables 20 | 21 | global_step = 0 22 | 23 | TRAIN_DATA = 'train' 24 | TEST_DATA = 'val' 25 | TRAIN_IMG_FILE = 'train_img.txt' 26 | TEST_IMG_FILE = 'val_img.txt' 27 | TRAIN_LABEL_FILE = 'train_label.txt' 28 | TEST_LABEL_FILE = 'val_label.txt' 29 | 30 | m = nn.Sigmoid() 31 | 32 | def get_arguments(): 33 | parser = argparse.ArgumentParser(description="MLMT Network Branch") 34 | parser.add_argument("--lr", type=float, default=3e-2, help="learning rate") 35 | parser.add_argument("--eta-min", type=float, default=1e-4, help="minimum learning rate for the scheduler") 36 | parser.add_argument("--weight-decay", type=float, default=1e-5, help="optimizer: weight decay") 37 | parser.add_argument("--workers", type=int, default=4, help="number of workers") 38 | parser.add_argument("--num-classes", type=int, default=21, help="number of classes, For eg 21 in VOC") 39 | 40 | parser.add_argument("--batch-size-lab", type=int, default=16, help="minibatch size of labeled training set") 41 | parser.add_argument("--batch-size-unlab", type=int, default=80, help="minibatch size of unlabeled training set") 42 | parser.add_argument("--batch-size-val", type=int, default=32, help="minibatch size of validation set") 43 | 44 | parser.add_argument("--num-epochs", type=int, default=100, help="number of epochs") 45 | parser.add_argument("--burn-in-epochs", type=int, default=10, help="number of burn-in epochs") 46 | parser.add_argument("--evaluation-epochs", type=int, default=5, help="evaluation epochs") 47 | 48 | parser.add_argument('--exp-name', type=str, default='default', help="experiment name") 49 | parser.add_argument('--cons-loss', type=str, default='cosine', help="consistency loss type: cosine") 50 | parser.add_argument('--data-dir', type=str, default='./data/voc_dataset/', help="dataset directory path") 51 | parser.add_argument('--pkl-file', type=str, default='./checkpoints/voc_semi_0_125/train_voc_split.pkl', help="indexes of files") 52 | 53 | parser.add_argument("--w-cons", type=float, default=1.0, help="weightage consistency loss term") 54 | parser.add_argument("--ema-decay", type=float, default=0.999, help="decay rate of exponential moving average") 55 | parser.add_argument("--labeled-ratio", type=float, default=0.125, help="percent of labeled samples") 56 | parser.add_argument('--verbose', action='store_true', help='verbose') 57 | 58 | return parser.parse_args() 59 | 60 | args = get_arguments() 61 | 62 | if args.verbose: 63 | from utils.visualize import progress_bar 64 | 65 | def main(): 66 | global global_step 67 | 68 | train_loader_lab, train_loader_unlab, valloader = create_data_loaders() 69 | print ('data loaders ready !!') 70 | 71 | def create_model(ema=False): 72 | model = models.resnet101(pretrained=True) 73 | model.fc = nn.Linear(2048, args.num_classes) 74 | model = torch.nn.DataParallel(model) 75 | model.cuda() 76 | cudnn.benchmark = True 77 | 78 | if ema: 79 | for param in model.parameters(): 80 | param.detach_() 81 | return model 82 | 83 | model = create_model() 84 | model_mt = create_model(ema=True) 85 | 86 | optimizer = torch.optim.SGD(model.parameters(), 87 | args.lr, 88 | momentum=0.9, 89 | weight_decay=args.weight_decay, 90 | nesterov=True) 91 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 92 | T_max=args.num_epochs, 93 | eta_min=args.eta_min) 94 | 95 | for epoch in range(args.num_epochs): 96 | print ('Epoch#: ', epoch) 97 | 98 | train(train_loader_lab, train_loader_unlab, model, model_mt, optimizer, epoch) 99 | scheduler.step() 100 | 101 | if args.evaluation_epochs and (epoch + 1) % args.evaluation_epochs == 0: 102 | print ("Evaluating the primary model:") 103 | validate(valloader, 'val', model, epoch + 1) 104 | print ("Evaluating the MT model:") 105 | validate(valloader, 'ema', model_mt, epoch + 1) 106 | 107 | def create_data_loaders(): 108 | channel_stats = dict(mean=[.485, .456, .406], 109 | std=[.229, .224, .225]) 110 | 111 | transform_train = transforms.Compose([ 112 | transforms.Resize(size=(320, 320), interpolation=2), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | transforms.Normalize(**channel_stats) 116 | ]) 117 | 118 | transform_aug = transforms.Compose([ 119 | transforms.Resize(size=(320, 320), interpolation=2), 120 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), 121 | transforms.RandomHorizontalFlip(), 122 | GaussianBlur(), 123 | transforms.ToTensor(), 124 | transforms.Normalize(**channel_stats) 125 | ]) 126 | 127 | transform_test = transforms.Compose([ 128 | transforms.Resize(size=(320, 320), interpolation=2), 129 | transforms.ToTensor(), 130 | transforms.Normalize(**channel_stats) 131 | ]) 132 | 133 | transform_lab = TransformTwice(transform_train, transform_train) 134 | transform_unlab = TransformTwice(transform_train, transform_aug) 135 | 136 | print ('loading data ...') 137 | dataset = data.DatasetProcessing( 138 | args.data_dir, TRAIN_DATA, TRAIN_IMG_FILE, TRAIN_LABEL_FILE, transform_lab, train=True) 139 | 140 | dataset_aug = data.DatasetProcessing( 141 | args.data_dir, TRAIN_DATA, TRAIN_IMG_FILE, TRAIN_LABEL_FILE, transform_unlab, train=True) 142 | 143 | labeled_idxs, unlabeled_idxs = data.split_idxs(args.pkl_file, args.labeled_ratio) 144 | print ('number of labeled samples: ', len(labeled_idxs)) 145 | print ('number of unlabeled samples: ', len(unlabeled_idxs)) 146 | 147 | sampler_lab = SubsetRandomSampler(labeled_idxs) 148 | sampler_unlab = SubsetRandomSampler(unlabeled_idxs) 149 | 150 | trainloader_lab = torch.utils.data.DataLoader(dataset, 151 | batch_size=args.batch_size_lab, 152 | sampler=sampler_lab, 153 | num_workers=args.workers, 154 | pin_memory=True) 155 | 156 | trainloader_unlab = torch.utils.data.DataLoader(dataset_aug, 157 | batch_size=args.batch_size_unlab, 158 | sampler=sampler_unlab, 159 | num_workers=args.workers, 160 | pin_memory=True) 161 | 162 | dataset_test = data.DatasetProcessing( 163 | args.data_dir, TEST_DATA, TEST_IMG_FILE, TEST_LABEL_FILE, transform_test, train=False) 164 | 165 | valloader = torch.utils.data.DataLoader( 166 | dataset_test, 167 | batch_size=args.batch_size_val, 168 | shuffle=False, 169 | num_workers=2 * args.workers, 170 | pin_memory=True, 171 | drop_last=False) 172 | 173 | return trainloader_lab, trainloader_unlab, valloader 174 | 175 | def cosine_loss(p_logits, q_logits): 176 | return torch.nn.CosineEmbeddingLoss()(q_logits, p_logits.detach(), torch.ones(p_logits.shape[0]).cuda()) 177 | 178 | def train(trainloader_lab, trainloader_unlab, model, model_mt, optimizer, epoch): 179 | global global_step 180 | 181 | loss_sum = 0.0 182 | class_loss_sum = 0.0 183 | cons_loss_sum = 0.0 184 | avg_acc_sum = 0.0 185 | avg_acc_sum_mt = 0.0 186 | 187 | class_criterion = nn.BCELoss().cuda() 188 | 189 | # switch to train mode 190 | model.train() 191 | model_mt.train() 192 | 193 | trainloader_unlab_iter = iter(trainloader_unlab) 194 | 195 | for batch_idx, ((inputs, _), target) in enumerate(trainloader_lab): 196 | 197 | #target = target.squeeze(2).float() 198 | inputs, target = inputs.cuda(), target.cuda() 199 | 200 | model_out = m(model(inputs)) 201 | model_mt_out = m(model_mt(inputs)) 202 | 203 | class_loss = class_criterion(model_out, target) 204 | 205 | try: 206 | batch_unlab = next(trainloader_unlab_iter) 207 | except: 208 | trainloader_unlab_iter = iter(trainloader_unlab) 209 | batch_unlab = next(trainloader_unlab_iter) 210 | 211 | (inputs_unlab, inputs_unlab_aug), _ = batch_unlab 212 | inputs_unlab, inputs_unlab_aug = inputs_unlab.cuda(), inputs_unlab_aug.cuda() 213 | 214 | model_unlab_out_aug = model(inputs_unlab_aug) 215 | with torch.no_grad(): 216 | model_mt_unlab_out = model_mt(inputs_unlab) 217 | 218 | cons_loss = cosine_loss(model_mt_unlab_out, model_unlab_out_aug) 219 | 220 | if epoch>args.burn_in_epochs: 221 | w_cons = min(args.w_cons, (epoch-args.burn_in_epochs)*2/args.num_epochs) 222 | else: 223 | w_cons = 0.0 224 | 225 | loss = class_loss + w_cons*cons_loss 226 | 227 | class_loss_sum += class_loss.item() 228 | cons_loss_sum += cons_loss.item() 229 | loss_sum += loss.item() 230 | 231 | avg_acc, acc_zeros, acc_ones, acc = accuracy(model_out, target) 232 | avg_acc_mt, acc_zeros_mt, acc_ones_mt, acc_mt = accuracy(model_mt_out, target) 233 | 234 | avg_acc_sum += avg_acc 235 | avg_acc_sum_mt += avg_acc_mt 236 | 237 | # compute gradient and do SGD step 238 | optimizer.zero_grad() 239 | loss.backward() 240 | optimizer.step() 241 | global_step += 1 242 | update_ema_variables(model, model_mt, args.ema_decay, global_step) 243 | 244 | if args.verbose: 245 | progress_bar(batch_idx, len(trainloader_lab), 'Loss: %.3f | Class Loss: %.3f | Cons Loss: %.3f | Avg Acc: %.3f | Avg Acc MT: %.3f ' 246 | % (loss_sum/(batch_idx+1), class_loss_sum/(batch_idx+1), cons_loss_sum/(batch_idx+1), avg_acc_sum/(batch_idx+1), avg_acc_sum_mt/(batch_idx+1))) 247 | if not args.verbose: 248 | print('Loss: ', loss_sum/(batch_idx+1), ' Class Loss: ', class_loss_sum/(batch_idx+1), ' Cons Loss: ', cons_loss_sum/(batch_idx+1), ' Avg Acc: ', avg_acc_sum/(batch_idx+1), ' Avg Acc MT: ', avg_acc_sum_mt/(batch_idx+1)) 249 | 250 | 251 | def validate(eval_loader, mode, model, epoch): 252 | 253 | avg_acc_sum = 0.0 254 | ones_acc_sum = 0.0 255 | zeros_acc_sum = 0.0 256 | 257 | if mode=='val': 258 | filename_raw = 'output_val_raw_' + str(epoch) + '.txt' 259 | filename_bin = 'output_val_bin_' + str(epoch) + '.txt' 260 | if mode == 'ema': 261 | filename_raw = 'output_ema_raw_' + str(epoch) + '.txt' 262 | filename_bin = 'output_ema_bin_' + str(epoch) + '.txt' 263 | 264 | mlmt_output_path = os.path.join('./mlmt_output', args.exp_name) 265 | 266 | if not os.path.exists(mlmt_output_path): 267 | os.makedirs(mlmt_output_path) 268 | 269 | f_raw = open(os.path.join(mlmt_output_path, filename_raw), 'a') 270 | f_bin = open(os.path.join(mlmt_output_path, filename_bin), 'a') 271 | 272 | model.eval() 273 | 274 | with torch.no_grad(): 275 | for batch_idx, (inputs, target) in enumerate(eval_loader): 276 | 277 | inputs, target = inputs.cuda(), target.cuda() 278 | 279 | # compute output 280 | output = m(model(inputs)) 281 | 282 | if epoch%1 == 0: 283 | output_raw = output.cpu().numpy() 284 | output_raw = np.roll(output_raw, 1) 285 | output_bin = (output_raw>0.5)*1 286 | np.savetxt(f_raw, output_raw, fmt='%f') 287 | np.savetxt(f_bin, output_bin, fmt='%d') 288 | 289 | # measure accuracy and record loss 290 | avg_acc, acc_zeros, acc_ones, acc = accuracy(output, target) 291 | 292 | ones_acc_sum += acc_ones 293 | zeros_acc_sum += acc_zeros 294 | avg_acc_sum += avg_acc 295 | if args.verbose: 296 | progress_bar(batch_idx, len(eval_loader), '| Avg Acc: %.3f | Ones Acc: %.3f | Zeros Acc: %.3f |' 297 | % (avg_acc_sum/(batch_idx+1), ones_acc_sum/(batch_idx+1), zeros_acc_sum/(batch_idx+1))) 298 | if not args.verbose: 299 | print(batch_idx, len(eval_loader), ' Avg Acc: ', avg_acc_sum/(batch_idx+1)) 300 | 301 | f_raw.close() 302 | f_bin.close() 303 | 304 | def accuracy(outputs, targets): 305 | thres = torch.ones(targets.size(0), args.num_classes)*0.5 306 | thres = thres.cuda() 307 | 308 | cond = torch.ge(outputs, thres) 309 | 310 | count_label_ones = 0 311 | count_label_zeros = 0 312 | correct_ones = 0 313 | correct_zeros = 0 314 | correct = 0 315 | total = 0 316 | 317 | for i in range(targets.size(0)): 318 | for j in range(args.num_classes): 319 | if targets[i][j]==0: 320 | count_label_zeros +=1 321 | if targets[i][j]==1: 322 | count_label_ones +=1 323 | 324 | targets = targets.type(torch.ByteTensor).cuda() 325 | 326 | for i in range(targets.size(0)): 327 | for j in range(args.num_classes): 328 | if targets[i][j]==cond[i][j]: 329 | correct +=1 330 | if targets[i][j] == 0: 331 | correct_zeros +=1 332 | elif targets[i][j] ==1: 333 | correct_ones +=1 334 | total += targets.size(0)*args.num_classes 335 | 336 | total_acc = (correct_zeros + correct_ones)*100.0/total 337 | avg_acc = (correct_ones/count_label_ones + correct_zeros/count_label_zeros)*100.0/2.0 338 | acc_zeros = (100.*correct_zeros/count_label_zeros) 339 | acc_ones = (100.*correct_ones/count_label_ones) 340 | 341 | return avg_acc, acc_zeros, acc_ones, total_acc 342 | 343 | if __name__ == '__main__': 344 | main() 345 | -------------------------------------------------------------------------------- /train_s4GAN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | import timeit 6 | 7 | import cv2 8 | import numpy as np 9 | import pickle 10 | import scipy.misc 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils import data, model_zoo 18 | from torch.autograd import Variable 19 | import torchvision.transforms as transform 20 | 21 | from model.deeplabv2 import Res_Deeplab 22 | #from model.deeplabv3p import Res_Deeplab 23 | 24 | from model.discriminator import s4GAN_discriminator 25 | from utils.loss import CrossEntropy2d 26 | from data.voc_dataset import VOCDataSet, VOCGTDataSet 27 | from data import get_loader, get_data_path 28 | from data.augmentations import * 29 | 30 | start = timeit.default_timer() 31 | 32 | DATA_DIRECTORY = './data/voc_dataset/' 33 | DATA_LIST_PATH = './data/voc_list/train_aug.txt' 34 | CHECKPOINT_DIR = './checkpoints/voc_semi_0_125/' 35 | 36 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 37 | NUM_CLASSES = 21 # 21 for PASCAL-VOC / 60 for PASCAL-Context / 19 Cityscapes 38 | DATASET = 'pascal_voc' #pascal_voc or pascal_context 39 | 40 | SPLIT_ID = None 41 | 42 | MODEL = 'DeepLab' 43 | BATCH_SIZE = 8 44 | NUM_STEPS = 40000 45 | SAVE_PRED_EVERY = 5000 46 | 47 | INPUT_SIZE = '321,321' 48 | IGNORE_LABEL = 255 # 255 for PASCAL-VOC / -1 for PASCAL-Context / 250 for Cityscapes 49 | 50 | RESTORE_FROM = './pretrained_models/resnet101-5d3b4d8f.pth' 51 | 52 | LEARNING_RATE = 2.5e-4 53 | LEARNING_RATE_D = 1e-4 54 | POWER = 0.9 55 | WEIGHT_DECAY = 0.0005 56 | MOMENTUM = 0.9 57 | NUM_WORKERS = 4 58 | RANDOM_SEED = 1234 59 | 60 | LAMBDA_FM = 0.1 61 | LAMBDA_ST = 1.0 62 | THRESHOLD_ST = 0.6 # 0.6 for PASCAL-VOC/Context / 0.7 for Cityscapes 63 | 64 | LABELED_RATIO = None #0.02 # 1/8 labeled data by default 65 | 66 | def get_arguments(): 67 | """Parse all the arguments provided from the CLI. 68 | 69 | Returns: 70 | A list of parsed arguments. 71 | """ 72 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 73 | parser.add_argument("--model", type=str, default=MODEL, 74 | help="available options : DeepLab/DRN") 75 | parser.add_argument("--dataset", type=str, default=DATASET, 76 | help="dataset to be used") 77 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 78 | help="Number of images sent to the network in one step.") 79 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 80 | help="number of workers for multithread dataloading.") 81 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 82 | help="Path to the directory containing the PASCAL VOC dataset.") 83 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 84 | help="Path to the file listing the images in the dataset.") 85 | parser.add_argument("--labeled-ratio", type=float, default=LABELED_RATIO, 86 | help="ratio of the labeled data to full dataset") 87 | parser.add_argument("--split-id", type=str, default=SPLIT_ID, 88 | help="split order id") 89 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 90 | help="Comma-separated string with height and width of images.") 91 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 92 | help="Base learning rate for training with polynomial decay.") 93 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D, 94 | help="Base learning rate for discriminator.") 95 | parser.add_argument("--lambda-fm", type=float, default=LAMBDA_FM, 96 | help="lambda_fm for feature-matching loss.") 97 | parser.add_argument("--lambda-st", type=float, default=LAMBDA_ST, 98 | help="lambda_st for self-training.") 99 | parser.add_argument("--threshold-st", type=float, default=THRESHOLD_ST, 100 | help="threshold_st for the self-training threshold.") 101 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 102 | help="Momentum component of the optimiser.") 103 | parser.add_argument("--ignore-label", type=float, default=IGNORE_LABEL, 104 | help="label value to ignored for loss calculation") 105 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 106 | help="Number of classes to predict (including background).") 107 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 108 | help="Number of iterations.") 109 | parser.add_argument("--power", type=float, default=POWER, 110 | help="Decay parameter to compute the learning rate.") 111 | parser.add_argument("--random-mirror", action="store_true", 112 | help="Whether to randomly mirror the inputs during the training.") 113 | parser.add_argument("--random-scale", action="store_true", 114 | help="Whether to randomly scale the inputs during the training.") 115 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 116 | help="Random seed to have reproducible results.") 117 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 118 | help="Where restore model parameters from.") 119 | parser.add_argument("--restore-from-D", type=str, default=None, 120 | help="Where restore model parameters from.") 121 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 122 | help="Save summaries and checkpoint every often.") 123 | parser.add_argument("--checkpoint-dir", type=str, default=CHECKPOINT_DIR, 124 | help="Where to save checkpoints of the model.") 125 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 126 | help="Regularisation parameter for L2-loss.") 127 | parser.add_argument("--gpu", type=int, default=0, 128 | help="choose gpu device.") 129 | return parser.parse_args() 130 | 131 | args = get_arguments() 132 | 133 | def loss_calc(pred, label, gpu): 134 | label = Variable(label.long()).cuda(gpu) 135 | criterion = CrossEntropy2d(ignore_label=args.ignore_label).cuda(gpu) # Ignore label ?? 136 | return criterion(pred, label) 137 | 138 | def lr_poly(base_lr, iter, max_iter, power): 139 | return base_lr*((1-float(iter)/max_iter)**(power)) 140 | 141 | def adjust_learning_rate(optimizer, i_iter): 142 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 143 | optimizer.param_groups[0]['lr'] = lr 144 | if len(optimizer.param_groups) > 1 : 145 | optimizer.param_groups[1]['lr'] = lr * 10 146 | 147 | def adjust_learning_rate_D(optimizer, i_iter): 148 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 149 | optimizer.param_groups[0]['lr'] = lr 150 | if len(optimizer.param_groups) > 1 : 151 | optimizer.param_groups[1]['lr'] = lr * 10 152 | 153 | def one_hot(label): 154 | label = label.numpy() 155 | one_hot = np.zeros((label.shape[0], args.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) 156 | for i in range(args.num_classes): 157 | one_hot[:,i,...] = (label==i) 158 | #handle ignore labels 159 | return torch.FloatTensor(one_hot) 160 | 161 | def compute_argmax_map(output): 162 | output = output.detach().cpu().numpy() 163 | output = output.transpose((1,2,0)) 164 | output = np.asarray(np.argmax(output, axis=2), dtype=np.int) 165 | output = torch.from_numpy(output).float() 166 | return output 167 | 168 | def find_good_maps(D_outs, pred_all): 169 | count = 0 170 | for i in range(D_outs.size(0)): 171 | if D_outs[i] > args.threshold_st: 172 | count +=1 173 | 174 | if count > 0: 175 | print ('Above ST-Threshold : ', count, '/', args.batch_size) 176 | pred_sel = torch.Tensor(count, pred_all.size(1), pred_all.size(2), pred_all.size(3)) 177 | label_sel = torch.Tensor(count, pred_sel.size(2), pred_sel.size(3)) 178 | num_sel = 0 179 | for j in range(D_outs.size(0)): 180 | if D_outs[j] > args.threshold_st: 181 | pred_sel[num_sel] = pred_all[j] 182 | label_sel[num_sel] = compute_argmax_map(pred_all[j]) 183 | num_sel +=1 184 | return pred_sel.cuda(), label_sel.cuda(), count 185 | else: 186 | return 0, 0, count 187 | 188 | criterion = nn.BCELoss() 189 | 190 | def main(): 191 | print (args) 192 | 193 | h, w = map(int, args.input_size.split(',')) 194 | input_size = (h, w) 195 | 196 | cudnn.enabled = True 197 | gpu = args.gpu 198 | 199 | # create network 200 | model = Res_Deeplab(num_classes=args.num_classes) 201 | 202 | # load pretrained parameters 203 | saved_state_dict = torch.load(args.restore_from) 204 | 205 | new_params = model.state_dict().copy() 206 | for name, param in new_params.items(): 207 | if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 208 | new_params[name].copy_(saved_state_dict[name]) 209 | model.load_state_dict(new_params) 210 | 211 | model.train() 212 | model.cuda(args.gpu) 213 | 214 | model = torch.nn.DataParallel(model).cuda() 215 | cudnn.benchmark = True 216 | 217 | # init D 218 | model_D = s4GAN_discriminator(num_classes=args.num_classes, dataset=args.dataset) 219 | 220 | if args.restore_from_D is not None: 221 | model_D.load_state_dict(torch.load(args.restore_from_D)) 222 | 223 | model_D = torch.nn.DataParallel(model_D).cuda() 224 | cudnn.benchmark = True 225 | 226 | model_D.train() 227 | model_D.cuda(args.gpu) 228 | 229 | if not os.path.exists(args.checkpoint_dir): 230 | os.makedirs(args.checkpoint_dir) 231 | 232 | if args.dataset == 'pascal_voc': 233 | train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, 234 | scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) 235 | #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, 236 | #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) 237 | 238 | elif args.dataset == 'pascal_context': 239 | input_transform = transform.Compose([transform.ToTensor(), 240 | transform.Normalize([.406, .456, .485], [.229, .224, .225])]) 241 | data_kwargs = {'transform': input_transform, 'base_size': 505, 'crop_size': 321} 242 | #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs) 243 | data_loader = get_loader('pascal_context') 244 | data_path = get_data_path('pascal_context') 245 | train_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) 246 | #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) 247 | 248 | elif args.dataset == 'cityscapes': 249 | data_loader = get_loader('cityscapes') 250 | data_path = get_data_path('cityscapes') 251 | data_aug = Compose([RandomCrop_city((256, 512)), RandomHorizontallyFlip()]) 252 | train_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug) 253 | #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug) 254 | 255 | train_dataset_size = len(train_dataset) 256 | print ('dataset size: ', train_dataset_size) 257 | 258 | if args.labeled_ratio is None: 259 | trainloader = data.DataLoader(train_dataset, 260 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 261 | 262 | trainloader_gt = data.DataLoader(train_dataset, 263 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 264 | 265 | trainloader_remain = data.DataLoader(train_dataset, 266 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 267 | trainloader_remain_iter = iter(trainloader_remain) 268 | 269 | else: 270 | partial_size = int(args.labeled_ratio * train_dataset_size) 271 | 272 | if args.split_id is not None: 273 | train_ids = pickle.load(open(args.split_id, 'rb')) 274 | print('loading train ids from {}'.format(args.split_id)) 275 | else: 276 | train_ids = np.arange(train_dataset_size) 277 | np.random.shuffle(train_ids) 278 | 279 | pickle.dump(train_ids, open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'), 'wb')) 280 | 281 | train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) 282 | train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) 283 | train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) 284 | 285 | trainloader = data.DataLoader(train_dataset, 286 | batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) 287 | trainloader_remain = data.DataLoader(train_dataset, 288 | batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=4, pin_memory=True) 289 | trainloader_gt = data.DataLoader(train_dataset, 290 | batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=4, pin_memory=True) 291 | 292 | trainloader_remain_iter = iter(trainloader_remain) 293 | 294 | trainloader_iter = iter(trainloader) 295 | trainloader_gt_iter = iter(trainloader_gt) 296 | 297 | # optimizer for segmentation network 298 | optimizer = optim.SGD(model.module.optim_parameters(args), 299 | lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) 300 | optimizer.zero_grad() 301 | 302 | # optimizer for discriminator network 303 | optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) 304 | optimizer_D.zero_grad() 305 | 306 | interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) 307 | 308 | # labels for adversarial training 309 | pred_label = 0 310 | gt_label = 1 311 | 312 | y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda()) 313 | 314 | 315 | for i_iter in range(args.num_steps): 316 | 317 | loss_ce_value = 0 318 | loss_D_value = 0 319 | loss_fm_value = 0 320 | loss_S_value = 0 321 | 322 | optimizer.zero_grad() 323 | adjust_learning_rate(optimizer, i_iter) 324 | optimizer_D.zero_grad() 325 | adjust_learning_rate_D(optimizer_D, i_iter) 326 | 327 | # train Segmentation Network 328 | # don't accumulate grads in D 329 | for param in model_D.parameters(): 330 | param.requires_grad = False 331 | 332 | # training loss for labeled data only 333 | try: 334 | batch = next(trainloader_iter) 335 | except: 336 | trainloader_iter = iter(trainloader) 337 | batch = next(trainloader_iter) 338 | 339 | images, labels, _, _, _ = batch 340 | images = Variable(images).cuda(args.gpu) 341 | pred = interp(model(images)) 342 | loss_ce = loss_calc(pred, labels, args.gpu) # Cross entropy loss for labeled data 343 | 344 | #training loss for remaining unlabeled data 345 | try: 346 | batch_remain = next(trainloader_remain_iter) 347 | except: 348 | trainloader_remain_iter = iter(trainloader_remain) 349 | batch_remain = next(trainloader_remain_iter) 350 | 351 | images_remain, _, _, _, _ = batch_remain 352 | images_remain = Variable(images_remain).cuda(args.gpu) 353 | pred_remain = interp(model(images_remain)) 354 | 355 | # concatenate the prediction with the input images 356 | images_remain = (images_remain-torch.min(images_remain))/(torch.max(images_remain)- torch.min(images_remain)) 357 | #print (pred_remain.size(), images_remain.size()) 358 | pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain), dim=1) 359 | 360 | D_out_z, D_out_y_pred = model_D(pred_cat) # predicts the D ouput 0-1 and feature map for FM-loss 361 | 362 | # find predicted segmentation maps above threshold 363 | pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain) 364 | 365 | # training loss on above threshold segmentation predictions (Cross Entropy Loss) 366 | if count > 0 and i_iter > 0: 367 | loss_st = loss_calc(pred_sel, labels_sel, args.gpu) 368 | else: 369 | loss_st = 0.0 370 | 371 | # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input 372 | try: 373 | batch_gt = next(trainloader_gt_iter) 374 | except: 375 | trainloader_gt_iter = iter(trainloader_gt) 376 | batch_gt = next(trainloader_gt_iter) 377 | 378 | images_gt, labels_gt, _, _, _ = batch_gt 379 | # Converts grounth truth segmentation into 'num_classes' segmentation maps. 380 | D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) 381 | 382 | images_gt = images_gt.cuda() 383 | images_gt = (images_gt - torch.min(images_gt))/(torch.max(images)-torch.min(images)) 384 | 385 | D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1) 386 | D_out_z_gt , D_out_y_gt = model_D(D_gt_v_cat) 387 | 388 | # L1 loss for Feature Matching Loss 389 | loss_fm = torch.mean(torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0))) 390 | 391 | if count > 0 and i_iter > 0: # if any good predictions found for self-training loss 392 | loss_S = loss_ce + args.lambda_fm*loss_fm + args.lambda_st*loss_st 393 | else: 394 | loss_S = loss_ce + args.lambda_fm*loss_fm 395 | 396 | loss_S.backward() 397 | loss_fm_value+= args.lambda_fm*loss_fm 398 | 399 | loss_ce_value += loss_ce.item() 400 | loss_S_value += loss_S.item() 401 | 402 | # train D 403 | for param in model_D.parameters(): 404 | param.requires_grad = True 405 | 406 | # train with pred 407 | pred_cat = pred_cat.detach() # detach does not allow the graddients to back propagate. 408 | 409 | D_out_z, _ = model_D(pred_cat) 410 | y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda()) 411 | loss_D_fake = criterion(D_out_z, y_fake_) 412 | 413 | # train with gt 414 | D_out_z_gt , _ = model_D(D_gt_v_cat) 415 | y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda()) 416 | loss_D_real = criterion(D_out_z_gt, y_real_) 417 | 418 | loss_D = (loss_D_fake + loss_D_real)/2.0 419 | loss_D.backward() 420 | loss_D_value += loss_D.item() 421 | 422 | optimizer.step() 423 | optimizer_D.step() 424 | 425 | print('iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'.format(i_iter, args.num_steps, loss_ce_value, loss_fm_value, loss_S_value, loss_D_value)) 426 | 427 | if i_iter >= args.num_steps-1: 428 | print ('save model ...') 429 | torch.save(model.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(args.num_steps)+'.pth')) 430 | torch.save(model_D.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(args.num_steps)+'_D.pth')) 431 | break 432 | 433 | if i_iter % args.save_pred_every == 0 and i_iter!=0: 434 | print ('saving checkpoint ...') 435 | torch.save(model.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(i_iter)+'.pth')) 436 | torch.save(model_D.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(i_iter)+'_D.pth')) 437 | 438 | end = timeit.default_timer() 439 | print (end-start,'seconds') 440 | 441 | if __name__ == '__main__': 442 | main() 443 | -------------------------------------------------------------------------------- /train_s4GAN_wrn38.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | import timeit 6 | 7 | import cv2 8 | import numpy as np 9 | import pickle 10 | import scipy.misc 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils import data, model_zoo 18 | from torch.autograd import Variable 19 | import torchvision.transforms as transform 20 | from torch.optim.lr_scheduler import CosineAnnealingLR 21 | 22 | #from model.deeplabv2 import Res_Deeplab 23 | from model.deeplabv3 import DeepV3PlusW38 24 | 25 | from model.discriminator import s4GAN_discriminator 26 | from utils.loss import CrossEntropy2d 27 | from data.voc_dataset import VOCDataSet, VOCGTDataSet 28 | from data.ade20k_dataset import ADE20K 29 | from data import get_loader, get_data_path 30 | from data.augmentations import * 31 | 32 | from torch.utils.tensorboard import SummaryWriter 33 | from utils.misc import AverageMeter, accuracy 34 | from utils.metric import get_iou 35 | 36 | start = timeit.default_timer() 37 | 38 | DATA_DIRECTORY = './data/voc_dataset/' 39 | DATA_LIST_PATH = './data/voc_list/train_aug.txt' 40 | CHECKPOINT_DIR = './checkpoints/voc_semi_0_125/' 41 | 42 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 43 | NUM_CLASSES = 21 # 21 for PASCAL-VOC / 60 for PASCAL-Context / 19 Cityscapes 44 | DATASET = 'pascal_voc' #pascal_voc or pascal_context 45 | 46 | SPLIT_ID = None 47 | 48 | MODEL = 'DeepLab' 49 | BATCH_SIZE = 8 50 | NUM_STEPS = 40000 51 | SAVE_PRED_EVERY = 1000 52 | 53 | INPUT_SIZE = '321,321' 54 | IGNORE_LABEL = 255 # 255 for PASCAL-VOC / -1 for PASCAL-Context / 250 for Cityscapes 55 | 56 | RESTORE_FROM = './pretrained_models/resnet101-5d3b4d8f.pth' 57 | 58 | LEARNING_RATE = 2.5e-4 59 | LEARNING_RATE_D = 1e-4 60 | POWER = 0.9 61 | WEIGHT_DECAY = 0.0005 62 | MOMENTUM = 0.9 63 | NUM_WORKERS = 8 64 | RANDOM_SEED = 1234 65 | 66 | LAMBDA_FM = 0.1 67 | LAMBDA_ST = 1.0 68 | THRESHOLD_ST = 0.6 # 0.6 for PASCAL-VOC/Context / 0.7 for Cityscapes 69 | 70 | LABELED_RATIO = None #0.02 # 1/8 labeled data by default 71 | 72 | def get_arguments(): 73 | """Parse all the arguments provided from the CLI. 74 | 75 | Returns: 76 | A list of parsed arguments. 77 | """ 78 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 79 | parser.add_argument("--model", type=str, default=MODEL, 80 | help="available options : DeepLab/DRN") 81 | parser.add_argument("--dataset", type=str, default=DATASET, 82 | help="dataset to be used") 83 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 84 | help="Number of images sent to the network in one step.") 85 | parser.add_argument("--batch-size-unlab", type=int, default=BATCH_SIZE, 86 | help="Number of images sent to the network in one step.") 87 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 88 | help="number of workers for multithread dataloading.") 89 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 90 | help="Path to the directory containing the PASCAL VOC dataset.") 91 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 92 | help="Path to the file listing the images in the dataset.") 93 | parser.add_argument("--labeled-ratio", type=float, default=LABELED_RATIO, 94 | help="ratio of the labeled data to full dataset") 95 | parser.add_argument("--split-id", type=str, default=SPLIT_ID, 96 | help="split order id") 97 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 98 | help="Comma-separated string with height and width of images.") 99 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 100 | help="Base learning rate for training with polynomial decay.") 101 | parser.add_argument("--eta-min-factor", type=float, default=10, 102 | help="Base learning rate for training with polynomial decay.") 103 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D, 104 | help="Base learning rate for discriminator.") 105 | parser.add_argument("--lambda-fm", type=float, default=LAMBDA_FM, 106 | help="lambda_fm for feature-matching loss.") 107 | parser.add_argument("--lambda-st", type=float, default=LAMBDA_ST, 108 | help="lambda_st for self-training.") 109 | parser.add_argument("--threshold-st", type=float, default=THRESHOLD_ST, 110 | help="threshold_st for the self-training threshold.") 111 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 112 | help="Momentum component of the optimiser.") 113 | parser.add_argument("--ignore-label", type=float, default=IGNORE_LABEL, 114 | help="label value to ignored for loss calculation") 115 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 116 | help="Number of classes to predict (including background).") 117 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 118 | help="Number of iterations.") 119 | parser.add_argument("--power", type=float, default=POWER, 120 | help="Decay parameter to compute the learning rate.") 121 | parser.add_argument("--random-mirror", action="store_true", 122 | help="Whether to randomly mirror the inputs during the training.") 123 | parser.add_argument("--random-scale", action="store_true", 124 | help="Whether to randomly scale the inputs during the training.") 125 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 126 | help="Random seed to have reproducible results.") 127 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 128 | help="Where restore model parameters from.") 129 | parser.add_argument("--restore-from-D", type=str, default=None, 130 | help="Where restore model parameters from.") 131 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 132 | help="Save summaries and checkpoint every often.") 133 | parser.add_argument("--checkpoint-dir", type=str, default=CHECKPOINT_DIR, 134 | help="Where to save checkpoints of the model.") 135 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 136 | help="Regularisation parameter for L2-loss.") 137 | parser.add_argument("--gpu", type=int, default=0, 138 | help="choose gpu device.") 139 | parser.add_argument('--out', default='result', 140 | help='directory to output the result') 141 | return parser.parse_args() 142 | 143 | args = get_arguments() 144 | 145 | def loss_calc(pred, label, gpu): 146 | label = Variable(label.long()).cuda(gpu) 147 | criterion = CrossEntropy2d(ignore_label=args.ignore_label).cuda(gpu) # Ignore label ?? 148 | return criterion(pred, label) 149 | 150 | def lr_poly(base_lr, iter, max_iter, power): 151 | return base_lr*((1-float(iter)/max_iter)**(power)) 152 | 153 | def adjust_learning_rate(optimizer, i_iter): 154 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 155 | optimizer.param_groups[0]['lr'] = lr 156 | if len(optimizer.param_groups) > 1 : 157 | optimizer.param_groups[1]['lr'] = lr * 10 158 | 159 | def adjust_learning_rate_D(optimizer, i_iter): 160 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 161 | optimizer.param_groups[0]['lr'] = lr 162 | if len(optimizer.param_groups) > 1 : 163 | optimizer.param_groups[1]['lr'] = lr * 10 164 | 165 | def adjust_threshold_st(i_iter): 166 | return 0.5 + (args.num_steps-i_iter)*0.2/args.num_steps 167 | 168 | 169 | def one_hot(label): 170 | label = label.numpy() 171 | one_hot = np.zeros((label.shape[0], args.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) 172 | for i in range(args.num_classes): 173 | one_hot[:,i,...] = (label==i) 174 | #handle ignore labels 175 | return torch.FloatTensor(one_hot) 176 | 177 | def compute_argmax_map(output): 178 | output = output.detach().cpu().numpy() 179 | output = output.transpose((1,2,0)) 180 | output = np.asarray(np.argmax(output, axis=2), dtype=np.int) 181 | output = torch.from_numpy(output).float() 182 | return output 183 | 184 | def find_good_maps(D_outs, pred_all): 185 | count = 0 186 | for i in range(D_outs.size(0)): 187 | if D_outs[i] > args.threshold_st: 188 | count +=1 189 | 190 | if count > 0: 191 | #print ('Above ST-Threshold : ', count, '/', args.batch_size) 192 | pred_sel = torch.Tensor(count, pred_all.size(1), pred_all.size(2), pred_all.size(3)) 193 | label_sel = torch.Tensor(count, pred_sel.size(2), pred_sel.size(3)) 194 | num_sel = 0 195 | for j in range(D_outs.size(0)): 196 | if D_outs[j] > args.threshold_st: 197 | pred_sel[num_sel] = pred_all[j] 198 | label_sel[num_sel] = compute_argmax_map(pred_all[j]) 199 | num_sel +=1 200 | return pred_sel.cuda(), label_sel.cuda(), count 201 | else: 202 | return 0, 0, count 203 | 204 | criterion = nn.BCELoss() 205 | 206 | def main(): 207 | print (args) 208 | 209 | os.makedirs(args.out, exist_ok=True) 210 | args.writer = SummaryWriter(args.out) 211 | 212 | h, w = map(int, args.input_size.split(',')) 213 | input_size = (h, w) 214 | 215 | cudnn.enabled = True 216 | gpu = args.gpu 217 | 218 | # create network 219 | #model = Res_Deeplab(num_classes=args.num_classes) 220 | model = DeepV3PlusW38(num_classes=args.num_classes) 221 | 222 | # load pretrained parameters 223 | saved_state_dict = torch.load(args.restore_from) 224 | 225 | new_params = model.state_dict().copy() 226 | for name, param in new_params.items(): 227 | if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 228 | new_params[name].copy_(saved_state_dict[name]) 229 | model.load_state_dict(new_params) 230 | 231 | model.train() 232 | model.cuda(args.gpu) 233 | 234 | model = torch.nn.DataParallel(model).cuda() 235 | cudnn.benchmark = True 236 | 237 | # init D 238 | model_D = s4GAN_discriminator(num_classes=args.num_classes, dataset=args.dataset) 239 | 240 | if args.restore_from_D is not None: 241 | model_D.load_state_dict(torch.load(args.restore_from_D)) 242 | 243 | model_D = torch.nn.DataParallel(model_D).cuda() 244 | cudnn.benchmark = True 245 | 246 | model_D.train() 247 | model_D.cuda(args.gpu) 248 | 249 | if not os.path.exists(args.checkpoint_dir): 250 | os.makedirs(args.checkpoint_dir) 251 | 252 | if args.dataset == 'pascal_voc': 253 | train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, 254 | scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) 255 | #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, 256 | #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) 257 | 258 | elif args.dataset == 'pascal_context': 259 | input_transform = transform.Compose([transform.ToTensor(), 260 | transform.Normalize([.406, .456, .485], [.229, .224, .225])]) 261 | data_kwargs = {'transform': input_transform, 'base_size': 505, 'crop_size': 321} 262 | #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs) 263 | data_loader = get_loader('pascal_context') 264 | data_path = get_data_path('pascal_context') 265 | train_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) 266 | #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) 267 | 268 | elif args.dataset == 'cityscapes': 269 | data_loader = get_loader('cityscapes') 270 | data_path = get_data_path('cityscapes') 271 | data_aug = Compose([RandomCrop_city((input_size[0], input_size[1])), RandomHorizontallyFlip()]) 272 | train_dataset = data_loader( data_path, is_transform=True, img_size=(input_size[0], input_size[1]), augmentations=data_aug) 273 | #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug) 274 | 275 | elif args.dataset == 'ade20k': 276 | train_dataset = ADE20K(mode='train', crop_size=input_size) 277 | 278 | train_dataset_size = len(train_dataset) 279 | print ('dataset size: ', train_dataset_size) 280 | 281 | if args.labeled_ratio is None: 282 | trainloader = data.DataLoader(train_dataset, 283 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 284 | 285 | trainloader_gt = data.DataLoader(train_dataset, 286 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 287 | 288 | trainloader_remain = data.DataLoader(train_dataset, 289 | batch_size=args.batch_size_unlab, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 290 | trainloader_remain_iter = iter(trainloader_remain) 291 | 292 | else: 293 | partial_size = int(args.labeled_ratio * train_dataset_size) 294 | print('labeled data: ', partial_size) 295 | print('unlabeled data: ', train_dataset_size - partial_size) 296 | 297 | if args.split_id is not None: 298 | train_ids = pickle.load(open(args.split_id, 'rb')) 299 | print('loading train ids from {}'.format(args.split_id)) 300 | else: 301 | train_ids = np.arange(train_dataset_size) 302 | np.random.shuffle(train_ids) 303 | 304 | pickle.dump(train_ids, open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'), 'wb')) 305 | 306 | train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) 307 | train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) 308 | train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) 309 | 310 | trainloader = data.DataLoader(train_dataset, 311 | batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True, drop_last=True) 312 | trainloader_remain = data.DataLoader(train_dataset, 313 | batch_size=args.batch_size_unlab, sampler=train_remain_sampler, num_workers=4, pin_memory=True, drop_last=True) 314 | trainloader_gt = data.DataLoader(train_dataset, 315 | batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=4, pin_memory=True, drop_last=True) 316 | 317 | trainloader_remain_iter = iter(trainloader_remain) 318 | 319 | print('train dataloader created!') 320 | 321 | trainloader_iter = iter(trainloader) 322 | trainloader_gt_iter = iter(trainloader_gt) 323 | 324 | if args.dataset == 'pascal_voc': 325 | valloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), 326 | batch_size=1, shuffle=False, pin_memory=True) 327 | interp_val = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) 328 | elif args.dataset == 'cityscapes': 329 | val_dataset = data_loader( data_path, img_size=(512, 1024), is_transform=True, split='val') 330 | valloader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) 331 | interp_val = nn.Upsample(size=(512, 1024), mode='bilinear', align_corners=True) 332 | elif args.dataset == 'ade20k': 333 | val_dataset = ADE20K(mode='val', crop_size=(505, 505)) 334 | valloader = data.DataLoader(val_dataset, 335 | batch_size=args.batch_size, 336 | shuffle=False, 337 | num_workers=4, 338 | pin_memory=True, 339 | drop_last=True) 340 | interp_val = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) 341 | print('val dataloader created!') 342 | 343 | # optimizer for segmentation network 344 | optimizer = optim.SGD(model.module.optim_parameters(args), 345 | lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) 346 | scheduler = CosineAnnealingLR(optimizer, T_max=args.num_steps, eta_min=args.learning_rate/args.eta_min_factor) 347 | optimizer.zero_grad() 348 | 349 | # optimizer for discriminator network 350 | optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) 351 | scheduler_D = CosineAnnealingLR(optimizer_D, T_max=args.num_steps, eta_min=args.learning_rate_D/args.eta_min_factor) 352 | optimizer_D.zero_grad() 353 | 354 | interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) 355 | 356 | # labels for adversarial training 357 | pred_label = 0 358 | gt_label = 1 359 | 360 | y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda()) 361 | 362 | losses_ce = AverageMeter() 363 | losses_st = AverageMeter() 364 | losses_S = AverageMeter() 365 | losses_D = AverageMeter() 366 | losses_fm = AverageMeter() 367 | counts = AverageMeter() 368 | 369 | for i_iter in range(args.num_steps): 370 | 371 | model.train() 372 | 373 | loss_ce_value = 0 374 | loss_D_value = 0 375 | loss_fm_value = 0 376 | loss_S_value = 0 377 | 378 | #args.threshold_st = adjust_threshold_st(i_iter) 379 | 380 | optimizer.zero_grad() 381 | adjust_learning_rate(optimizer, i_iter) 382 | optimizer_D.zero_grad() 383 | adjust_learning_rate_D(optimizer_D, i_iter) 384 | 385 | # train Segmentation Network 386 | # don't accumulate grads in D 387 | for param in model_D.parameters(): 388 | param.requires_grad = False 389 | 390 | # training loss for labeled data only 391 | try: 392 | batch = next(trainloader_iter) 393 | except: 394 | trainloader_iter = iter(trainloader) 395 | batch = next(trainloader_iter) 396 | 397 | images, labels, _, _, _ = batch 398 | images = images.cuda() 399 | pred = interp(model(images)) 400 | loss_ce = loss_calc(pred, labels, args.gpu) # Cross entropy loss for labeled data 401 | 402 | #training loss for remaining unlabeled data 403 | try: 404 | batch_remain = next(trainloader_remain_iter) 405 | except: 406 | trainloader_remain_iter = iter(trainloader_remain) 407 | batch_remain = next(trainloader_remain_iter) 408 | 409 | images_remain, _, _, _, _ = batch_remain 410 | images_remain = Variable(images_remain).cuda(args.gpu) 411 | pred_remain = interp(model(images_remain)) 412 | 413 | # concatenate the prediction with the input images 414 | images_remain = (images_remain-torch.min(images_remain))/(torch.max(images_remain)- torch.min(images_remain)) 415 | #print (pred_remain.size(), images_remain.size()) 416 | pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain), dim=1) 417 | 418 | D_out_z, D_out_y_pred = model_D(pred_cat) # predicts the D ouput 0-1 and feature map for FM-loss 419 | 420 | # find predicted segmentation maps above threshold 421 | pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain) 422 | 423 | # training loss on above threshold segmentation predictions (Cross Entropy Loss) 424 | if count > 0 and i_iter > 0: 425 | loss_st = loss_calc(pred_sel, labels_sel, args.gpu) 426 | losses_st.update(loss_st.item()) 427 | else: 428 | loss_st = 0.0 429 | 430 | # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input 431 | try: 432 | batch_gt = next(trainloader_gt_iter) 433 | except: 434 | trainloader_gt_iter = iter(trainloader_gt) 435 | batch_gt = next(trainloader_gt_iter) 436 | 437 | images_gt, labels_gt, _, _, _ = batch_gt 438 | # Converts grounth truth segmentation into 'num_classes' segmentation maps. 439 | D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) 440 | 441 | images_gt = images_gt.cuda() 442 | images_gt = (images_gt - torch.min(images_gt))/(torch.max(images)-torch.min(images)) 443 | 444 | D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1) 445 | D_out_z_gt , D_out_y_gt = model_D(D_gt_v_cat) 446 | 447 | # L1 loss for Feature Matching Loss 448 | loss_fm = torch.mean(torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0))) 449 | 450 | if count > 0 and i_iter > 0: # if any good predictions found for self-training loss 451 | loss_S = loss_ce + args.lambda_fm*loss_fm + args.lambda_st*loss_st 452 | else: 453 | loss_S = loss_ce + args.lambda_fm*loss_fm 454 | 455 | loss_S.backward() 456 | loss_fm_value+= args.lambda_fm*loss_fm 457 | 458 | loss_ce_value += loss_ce.item() 459 | loss_S_value += loss_S.item() 460 | 461 | # train D 462 | for param in model_D.parameters(): 463 | param.requires_grad = True 464 | 465 | # train with pred 466 | pred_cat = pred_cat.detach() # detach does not allow the graddients to back propagate. 467 | 468 | D_out_z, _ = model_D(pred_cat) 469 | y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda()) 470 | loss_D_fake = criterion(D_out_z, y_fake_) 471 | 472 | # train with gt 473 | D_out_z_gt , _ = model_D(D_gt_v_cat) 474 | y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda()) 475 | loss_D_real = criterion(D_out_z_gt, y_real_) 476 | 477 | loss_D = (loss_D_fake + loss_D_real)/2.0 478 | loss_D.backward() 479 | loss_D_value += loss_D.item() 480 | 481 | optimizer.step() 482 | #scheduler.step() 483 | optimizer_D.step() 484 | #scheduler_D.step() 485 | 486 | losses_ce.update(loss_ce.item()) 487 | losses_S.update(loss_S.item()) 488 | losses_D.update(loss_D.item()) 489 | losses_fm.update(loss_fm.item()) 490 | counts.update(count) 491 | 492 | if i_iter%10==0: 493 | log_idx = i_iter/10 494 | 495 | args.writer.add_scalar('train/1.train_loss_ce', losses_ce.avg, log_idx) 496 | args.writer.add_scalar('train/2.train_loss_st', losses_st.avg, log_idx) 497 | args.writer.add_scalar('train/3.train_loss_fm', losses_fm.avg, log_idx) 498 | args.writer.add_scalar('train/4.train_loss_S', losses_S.avg, log_idx) 499 | args.writer.add_scalar('train/5.train_loss_D', losses_D.avg, log_idx) 500 | args.writer.add_scalar('train/6.count', counts.avg, log_idx) 501 | args.writer.add_scalar('train/7.lr', optimizer.param_groups[0]['lr'], log_idx) 502 | 503 | 504 | losses_ce = AverageMeter() 505 | losses_st = AverageMeter() 506 | losses_S = AverageMeter() 507 | losses_D = AverageMeter() 508 | losses_fm = AverageMeter() 509 | counts = AverageMeter() 510 | 511 | print('iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'.format(i_iter, args.num_steps, loss_ce_value, loss_fm_value, loss_S_value, loss_D_value)) 512 | 513 | if i_iter%200==0: 514 | miou_val, loss_val = validate(valloader, interp_val, model) 515 | print('miou_val: ', miou_val, ' loss_val; ', loss_val) 516 | #mious.update(miou_val) 517 | #losses_val.update(loss_val) 518 | args.writer.add_scalar('val/1.val_miou', miou_val, i_iter/1000) 519 | args.writer.add_scalar('val/2.val_loss', loss_val, i_iter/1000) 520 | #mious = AverageMeter() 521 | #losses_val = AverageMeter() 522 | 523 | 524 | if i_iter >= args.num_steps-1: 525 | print ('save model ...') 526 | torch.save(model.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(args.num_steps)+'.pth')) 527 | torch.save(model_D.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(args.num_steps)+'_D.pth')) 528 | break 529 | 530 | if i_iter % args.save_pred_every == 0 and i_iter!=0: 531 | print ('saving checkpoint ...') 532 | torch.save(model.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(i_iter)+'.pth')) 533 | torch.save(model_D.state_dict(),os.path.join(args.checkpoint_dir, 'VOC_'+str(i_iter)+'_D.pth')) 534 | end = timeit.default_timer() 535 | print (end-start,'seconds') 536 | 537 | def validate(valloader, interp_val, model): 538 | print('validating...') 539 | loss_val = 0 540 | data_list = [] 541 | model.eval() 542 | for index, batch in enumerate(valloader): 543 | if index ==50: 544 | break 545 | image, label, size, name, _ = batch 546 | size = size[0] 547 | with torch.no_grad(): 548 | output = model(image.cuda()) 549 | output = interp_val(output) 550 | #loss_ce = loss_calc(output, label, args.gpu) 551 | output = output.cpu().data[0].numpy() 552 | 553 | #loss_val += loss_ce.item() 554 | 555 | if args.dataset == 'pascal_voc': 556 | output = output[:,:size[0],:size[1]] 557 | gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int) 558 | elif args.dataset == 'ade20k': 559 | output = output[:,:size[0],:size[1]] 560 | gt = np.asarray(label[0].numpy()[:size[0],:size[1]], dtype=np.int) 561 | elif args.dataset == 'pascal_context': 562 | gt = np.asarray(label[0].numpy(), dtype=np.int) 563 | elif args.dataset == 'cityscapes': 564 | gt = np.asarray(label[0].numpy(), dtype=np.int) 565 | 566 | output = output.transpose(1,2,0) 567 | output = np.asarray(np.argmax(output, axis=2), dtype=np.int) 568 | data_list.append([gt.flatten(), output.flatten()]) 569 | 570 | torch.cuda.empty_cache() 571 | 572 | filename = os.path.join(args.out, 'result.txt') 573 | miou_val = get_iou(args, data_list, args.num_classes, filename) 574 | #return miou_val , loss_val/50 575 | return miou_val , 0 576 | 577 | if __name__ == '__main__': 578 | main() 579 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | class CrossEntropy2d(nn.Module): 8 | 9 | def __init__(self, ignore_label=255): 10 | super(CrossEntropy2d, self).__init__() 11 | self.ignore_label = ignore_label 12 | 13 | def forward(self, predict, target, weight=None): 14 | """ 15 | Args: 16 | predict:(n, c, h, w) 17 | target:(n, h, w) 18 | weight (Tensor, optional): a manual rescaling weight given to each class. 19 | If given, has to be a Tensor of size "nclasses" 20 | """ 21 | assert not target.requires_grad 22 | assert predict.dim() == 4 23 | assert target.dim() == 3 24 | n, c, h, w = predict.size() 25 | target_mask = (target >= 0) * (target != self.ignore_label) 26 | target = target[target_mask] 27 | if not target.data.dim(): 28 | return Variable(torch.zeros(1)) 29 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 30 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 31 | loss = F.cross_entropy(predict, target, weight=weight, reduction='elementwise_mean') 32 | return loss 33 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | 4 | from multiprocessing import Pool 5 | #import copy_reg 6 | import pickle 7 | import types 8 | def _pickle_method(m): 9 | if m.im_self is None: 10 | return getattr, (m.im_class, m.im_func.func_name) 11 | else: 12 | return getattr, (m.im_self, m.im_func.func_name) 13 | 14 | #pickle.dump(types.MethodType, _pickle_method) 15 | 16 | def get_iou(data_list, class_num, save_path=None): 17 | from multiprocessing import Pool 18 | from utils.metric import ConfusionMatrix 19 | 20 | ConfM = ConfusionMatrix(class_num) 21 | f = ConfM.generateM 22 | pool = Pool() 23 | m_list = pool.map(f, data_list) 24 | pool.close() 25 | pool.join() 26 | 27 | for m in m_list: 28 | ConfM.addM(m) 29 | 30 | aveJ, j_list, M = ConfM.jaccard() 31 | 32 | classes = np.array(('background', # always index 0 33 | 'aeroplane', 'bicycle', 'bird', 'boat', 34 | 'bottle', 'bus', 'car', 'cat', 'chair', 35 | 'cow', 'diningtable', 'dog', 'horse', 36 | 'motorbike', 'person', 'pottedplant', 37 | 'sheep', 'sofa', 'train', 'tvmonitor')) 38 | 39 | for i, iou in enumerate(j_list): 40 | print('class {:2d} {:12} IU {:.2f}'.format(i, classes[i], j_list[i])) 41 | 42 | 43 | print('meanIOU: ' + str(aveJ) + '\n') 44 | if save_path: 45 | with open(save_path, 'w') as f: 46 | for i, iou in enumerate(j_list): 47 | f.write('class {:2d} {:12} IU {:.2f}'.format(i, classes[i], j_list[i]) + '\n') 48 | f.write('meanIOU: ' + str(aveJ) + '\n') 49 | 50 | class ConfusionMatrix(object): 51 | 52 | def __init__(self, nclass, classes=None): 53 | self.nclass = nclass 54 | self.classes = classes 55 | self.M = np.zeros((nclass, nclass)) 56 | 57 | def add(self, gt, pred): 58 | assert(np.max(pred) <= self.nclass) 59 | assert(len(gt) == len(pred)) 60 | for i in range(len(gt)): 61 | if not gt[i] == 255: 62 | self.M[gt[i], pred[i]] += 1.0 63 | 64 | def addM(self, matrix): 65 | assert(matrix.shape == self.M.shape) 66 | self.M += matrix 67 | 68 | def __str__(self): 69 | pass 70 | 71 | def recall(self): 72 | recall = 0.0 73 | for i in range(self.nclass): 74 | recall += self.M[i, i] / np.sum(self.M[:, i]) 75 | 76 | return recall/self.nclass 77 | 78 | def accuracy(self): 79 | accuracy = 0.0 80 | for i in range(self.nclass): 81 | accuracy += self.M[i, i] / np.sum(self.M[i, :]) 82 | 83 | return accuracy/self.nclass 84 | 85 | def jaccard(self): 86 | jaccard = 0.0 87 | jaccard_sum = [] 88 | jaccard_perclass = [] 89 | for i in range(self.nclass): 90 | jaccard_perclass.append(self.M[i, i] / (np.sum(self.M[i, :]) + np.sum(self.M[:, i]) - self.M[i, i])) 91 | if not self.M[i, i] == 0: 92 | jaccard_sum.append(self.M[i, i] / (np.sum(self.M[i, :]) + np.sum(self.M[:, i]) - self.M[i, i])) 93 | 94 | return np.sum(jaccard_sum)/self.nclass, jaccard_perclass, self.M 95 | 96 | def generateM(self, item): 97 | gt, pred = item 98 | m = np.zeros((self.nclass, self.nclass)) 99 | assert(len(gt) == len(pred)) 100 | for i in range(len(gt)): 101 | if gt[i] < self.nclass: #and pred[i] < self.nclass: 102 | m[gt[i], pred[i]] += 1.0 103 | return m 104 | 105 | 106 | if __name__ == '__main__': 107 | args = parse_args() 108 | 109 | m_list = [] 110 | data_list = [] 111 | test_ids = [i.strip() for i in open(args.test_ids) if not i.strip() == ''] 112 | for index, img_id in enumerate(test_ids): 113 | if index % 100 == 0: 114 | print('%d processd'%(index)) 115 | pred_img_path = os.path.join(args.pred_dir, img_id+'.png') 116 | gt_img_path = os.path.join(args.gt_dir, img_id+'.png') 117 | pred = cv2.imread(pred_img_path, cv2.IMREAD_GRAYSCALE) 118 | gt = cv2.imread(gt_img_path, cv2.IMREAD_GRAYSCALE) 119 | # show_all(gt, pred) 120 | data_list.append([gt.flatten(), pred.flatten()]) 121 | 122 | ConfM = ConfusionMatrix(args.class_num) 123 | f = ConfM.generateM 124 | pool = Pool() 125 | m_list = pool.map(f, data_list) 126 | pool.close() 127 | pool.join() 128 | 129 | for m in m_list: 130 | ConfM.addM(m) 131 | 132 | aveJ, j_list, M = ConfM.jaccard() 133 | with open(args.save_path, 'w') as f: 134 | f.write('meanIOU: ' + str(aveJ) + '\n') 135 | f.write(str(j_list)+'\n') 136 | f.write(str(M)+'\n') 137 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | ''' 4 | import logging 5 | 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | __all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter'] 11 | 12 | 13 | def get_mean_and_std(dataset): 14 | '''Compute the mean and std value of dataset.''' 15 | dataloader = torch.utils.data.DataLoader( 16 | dataset, batch_size=1, shuffle=False, num_workers=4) 17 | 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | logger.info('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:, i, :, :].mean() 24 | std[i] += inputs[:, i, :, :].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the precision@k for the specified values of k""" 32 | maxk = max(topk) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(maxk, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].reshape(-1).float().sum(0) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value 48 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 49 | """ 50 | 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val, n=1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | 10 | _, term_width = os.popen('stty size', 'r').read().split() 11 | term_width = int(term_width) 12 | 13 | #term_width = 200 14 | TOTAL_BAR_LENGTH = 70. 15 | last_time = time.time() 16 | begin_time = last_time 17 | 18 | 19 | def progress_bar(current, total, msg=None): 20 | global last_time, begin_time 21 | if current == 0: 22 | begin_time = time.time() # Reset for new bar. 23 | 24 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 25 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 26 | 27 | sys.stdout.write(' [') 28 | for i in range(cur_len): 29 | sys.stdout.write('=') 30 | sys.stdout.write('>') 31 | for i in range(rest_len): 32 | sys.stdout.write('.') 33 | sys.stdout.write(']') 34 | 35 | cur_time = time.time() 36 | step_time = cur_time - last_time 37 | last_time = cur_time 38 | tot_time = cur_time - begin_time 39 | 40 | L = [] 41 | L.append(' Step: %s' % format_time(step_time)) 42 | L.append(' | Tot: %s' % format_time(tot_time)) 43 | if msg: 44 | L.append(' | ' + msg) 45 | 46 | msg = ''.join(L) 47 | sys.stdout.write(msg) 48 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 49 | sys.stdout.write(' ') 50 | 51 | # Go back to the center of the bar. 52 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 53 | sys.stdout.write('\b') 54 | sys.stdout.write(' %d/%d ' % (current+1, total)) 55 | 56 | if current < total-1: 57 | sys.stdout.write('\r') 58 | else: 59 | sys.stdout.write('\n') 60 | sys.stdout.flush() 61 | 62 | def format_time(seconds): 63 | days = int(seconds / 3600/24) 64 | seconds = seconds - days*3600*24 65 | hours = int(seconds / 3600) 66 | seconds = seconds - hours*3600 67 | minutes = int(seconds / 60) 68 | seconds = seconds - minutes*60 69 | secondsf = int(seconds) 70 | seconds = seconds - secondsf 71 | millis = int(seconds*1000) 72 | 73 | f = '' 74 | i = 1 75 | if days > 0: 76 | f += str(days) + 'D' 77 | i += 1 78 | if hours > 0 and i <= 2: 79 | f += str(hours) + 'h' 80 | i += 1 81 | if minutes > 0 and i <= 2: 82 | f += str(minutes) + 'm' 83 | i += 1 84 | if secondsf > 0 and i <= 2: 85 | f += str(secondsf) + 's' 86 | i += 1 87 | if millis > 0 and i <= 2: 88 | f += str(millis) + 'ms' 89 | i += 1 90 | if f == '': 91 | f = '0ms' 92 | return f 93 | 94 | --------------------------------------------------------------------------------