├── README.md ├── dataloaders ├── __init__.py ├── combine_dbs.py ├── custom_transforms.py ├── helpers.py ├── pascal.py └── sbd.py ├── eval.py ├── evaluation ├── __init__.py ├── eval.py └── evaluation.py ├── ims ├── IOG.gif ├── cross_domain.gif ├── ims.png └── refinement.gif ├── mypath.py ├── networks ├── CoarseNet.py ├── FineNet.py ├── __init__.py ├── backbone │ ├── __init__.py │ └── resnet.py ├── loss.py ├── mainnetwork.py ├── refinementnetwork.py └── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── batchnorm.cpython-35.pyc │ ├── comm.cpython-35.pyc │ └── replicate.cpython-35.pyc │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── test.py ├── test_refine.py ├── train.py └── train_refine.py /README.md: -------------------------------------------------------------------------------- 1 | # Inside-Outside-Guidance (IOG) 2 | This project hosts the code for the IOG algorithms for interactive segmentation. 3 | > [Interactive Object Segmentation with Inside-Outside Guidance](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhang_Interactive_Object_Segmentation_With_Inside-Outside_Guidance_CVPR_2020_paper.pdf) 4 | > Shiyin Zhang, Jun Hao Liew, Yunchao Wei, Shikui Wei, Yao Zhao 5 | 6 | **Updates:** 7 | - 2021.4.6 Create the interactive refinement branch for IOG. 8 | 9 | ![img](https://github.com/shiyinzhang/Inside-Outside-Guidance/blob/master/ims/ims.png "img") 10 | 11 | ### Abstract 12 | This paper explores how to harvest precise object segmentation masks while minimizing the human interaction cost. To achieve this, we propose an Inside-Outside Guidance (IOG) approach in this work. Concretely, we leverage an inside point that is clicked near the object center and two outside points at the symmetrical corner locations (top-left and bottom-right or top-right and bottom-left) of a tight bounding box that encloses the target object. This results in a total of one foreground click and four background clicks for segmentation. Our IOG not only achieves state-of-the-art performance on several popular benchmarks, but also demonstrates strong generalization capability across different domains such as street scenes, aerial imagery and medical images, without fine-tuning. In addition, we also propose a simple two-stage solution that enables our IOG to produce high quality instance segmentation masks from existing datasets with off-the-shelf bounding boxes such as ImageNet and Open Images, demonstrating the superiority of our IOG as an annotation tool. 13 | 14 | ### Demo 15 | 16 | 17 | 18 | 21 | 24 | 27 | 28 | 29 | 32 | 35 | 38 | 39 |
19 | 20 | 22 | 23 | 25 | 26 |
30 | IOG(3 points) 31 | 33 | IOG(Refinement) 34 | 36 | IOG(Cross domain) 37 |
40 | 41 | 42 | ### Installation 43 | 1. Install requirement 44 | - PyTorch = 0.4 45 | - python >= 3.5 46 | - torchvision = 0.2 47 | - pycocotools 48 | 2. Usage 49 | You can start training with the following commands: 50 | ``` 51 | # training step 52 | python train.py 53 | python train_refinement.py 54 | 55 | # testing step 56 | python test.py 57 | python test_refinement.py 58 | 59 | # train step 60 | python eval.py 61 | python eval_refinement.py 62 | ``` 63 | We set the paths of PASCAL/SBD dataset and pretrained model in mypath.py. 64 | 65 | ### Pretrained models 66 | | Network |Dataset | Backbone | Download Link | 67 | |---------|---------|-------------|:-------------------------:| 68 | |IOG |PASCAL + SBD | ResNet-101 | [IOG_PASCAL_SBD.pth](https://drive.google.com/file/d/1Lm1hhMhhjjnNwO4Pf7SC6tXLayH2iH0l/view?usp=sharing) | 69 | |IOG |PASCAL | ResNet-101 | [IOG_PASCAL.pth](https://drive.google.com/file/d/1GLZIQlQ-3KUWaGTQ1g_InVcqesGfGcpW/view?usp=sharing) | 70 | |IOG-Refinement |PASCAL + SBD | ResNet-101 | [IOG_PASCAL_SBD_REFINEMENT.pth](https://drive.google.com/file/d/1VdOFUZZbtbYt9aIMugKhMKDA6EuqKG30/view?usp=sharing) | 71 | 72 | ### Dataset 73 | With the annotated bounding boxes (∼0.615M) of ILSVRCLOC, we apply our IOG to collect their pixel-level annotations, named Pixel-ImageNet, which are publicly available at https://github.com/shiyinzhang/Pixel-ImageNet. 74 | ### Citations 75 | Please consider citing our papers in your publications if it helps your research. The following is a BibTeX reference. 76 | 77 | @inproceedings{zhang2020interactive, 78 | title={Interactive Object Segmentation With Inside-Outside Guidance}, 79 | author={Zhang, Shiyin and Liew, Jun Hao and Wei, Yunchao and Wei, Shikui and Zhao, Yao}, 80 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 81 | pages={12234--12244}, 82 | year={2020} 83 | } 84 | 85 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | def __init__(self, dataloaders, excluded=None): 6 | self.dataloaders = dataloaders 7 | self.excluded = excluded 8 | self.im_ids = [] 9 | 10 | # Combine object lists 11 | for dl in dataloaders: 12 | for elem in dl.im_ids: 13 | if elem not in self.im_ids: 14 | self.im_ids.append(elem) 15 | 16 | # Exclude 17 | if excluded: 18 | for dl in excluded: 19 | for elem in dl.im_ids: 20 | if elem in self.im_ids: 21 | self.im_ids.remove(elem) 22 | 23 | # Get object pointers 24 | self.obj_list = [] 25 | self.im_list = [] 26 | new_im_ids = [] 27 | obj_counter = 0 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | flag = False 33 | new_im_ids.append(curr_im_id) 34 | for kk in range(len(dl.obj_dict[curr_im_id])): 35 | if dl.obj_dict[curr_im_id][kk] != -1: 36 | self.obj_list.append({'db_ii': ii, 'obj_ii': dl.obj_list.index([jj, kk])}) 37 | flag = True 38 | obj_counter += 1 39 | self.im_list.append({'db_ii': ii, 'im_ii': jj}) 40 | if flag: 41 | num_images += 1 42 | 43 | self.im_ids = new_im_ids 44 | print('Combined number of images: {:d}\nCombined number of objects: {:d}'.format(num_images, len(self.obj_list))) 45 | 46 | def __getitem__(self, index): 47 | 48 | _db_ii = self.obj_list[index]["db_ii"] 49 | _obj_ii = self.obj_list[index]['obj_ii'] 50 | sample = self.dataloaders[_db_ii].__getitem__(_obj_ii) 51 | 52 | if 'meta' in sample.keys(): 53 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 54 | 55 | return sample 56 | 57 | def __len__(self): 58 | return len(self.obj_list) 59 | 60 | def __str__(self): 61 | include_db = [str(db) for db in self.dataloaders] 62 | exclude_db = [str(db) for db in self.excluded] 63 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 64 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | import numpy.random as random 3 | import numpy as np 4 | import dataloaders.helpers as helpers 5 | import scipy.misc as sm 6 | from dataloaders.helpers import * 7 | 8 | class ScaleNRotate(object): 9 | """Scale (zoom-in, zoom-out) and Rotate the image and the ground truth. 10 | Args: 11 | two possibilities: 12 | 1. rots (tuple): (minimum, maximum) rotation angle 13 | scales (tuple): (minimum, maximum) scale 14 | 2. rots [list]: list of fixed possible rotation angles 15 | scales [list]: list of fixed possible scales 16 | """ 17 | def __init__(self, rots=(-30, 30), scales=(.75, 1.25), semseg=False): 18 | assert (isinstance(rots, type(scales))) 19 | self.rots = rots 20 | self.scales = scales 21 | self.semseg = semseg 22 | 23 | def __call__(self, sample): 24 | 25 | if type(self.rots) == tuple: 26 | # Continuous range of scales and rotations 27 | rot = (self.rots[1] - self.rots[0]) * random.random() - \ 28 | (self.rots[1] - self.rots[0])/2 29 | 30 | sc = (self.scales[1] - self.scales[0]) * random.random() - \ 31 | (self.scales[1] - self.scales[0]) / 2 + 1 32 | elif type(self.rots) == list: 33 | # Fixed range of scales and rotations 34 | rot = self.rots[random.randint(0, len(self.rots))] 35 | sc = self.scales[random.randint(0, len(self.scales))] 36 | 37 | for elem in sample.keys(): 38 | if 'meta' in elem: 39 | continue 40 | 41 | tmp = sample[elem] 42 | 43 | h, w = tmp.shape[:2] 44 | center = (w / 2, h / 2) 45 | assert(center != 0) # Strange behaviour warpAffine 46 | M = cv2.getRotationMatrix2D(center, rot, sc) 47 | 48 | if ((tmp == 0) | (tmp == 1)).all(): 49 | flagval = cv2.INTER_NEAREST 50 | elif 'gt' in elem and self.semseg: 51 | flagval = cv2.INTER_NEAREST 52 | else: 53 | flagval = cv2.INTER_CUBIC 54 | tmp = cv2.warpAffine(tmp, M, (w, h), flags=flagval) 55 | 56 | sample[elem] = tmp 57 | 58 | return sample 59 | 60 | def __str__(self): 61 | return 'ScaleNRotate:(rot='+str(self.rots)+',scale='+str(self.scales)+')' 62 | 63 | 64 | class FixedResize(object): 65 | """Resize the image and the ground truth to specified resolution. 66 | Args: 67 | resolutions (dict): the list of resolutions 68 | """ 69 | def __init__(self, resolutions=None, flagvals=None): 70 | self.resolutions = resolutions 71 | self.flagvals = flagvals 72 | if self.flagvals is not None: 73 | assert(len(self.resolutions) == len(self.flagvals)) 74 | 75 | def __call__(self, sample): 76 | 77 | # Fixed range of scales 78 | if self.resolutions is None: 79 | return sample 80 | 81 | elems = list(sample.keys()) 82 | 83 | for elem in elems: 84 | 85 | if 'meta' in elem or 'bbox' in elem or ('extreme_points_coord' in elem and elem not in self.resolutions): 86 | continue 87 | if 'extreme_points_coord' in elem and elem in self.resolutions: 88 | bbox = sample['bbox'] 89 | crop_size = np.array([bbox[3]-bbox[1]+1, bbox[4]-bbox[2]+1]) 90 | res = np.array(self.resolutions[elem]).astype(np.float32) 91 | sample[elem] = np.round(sample[elem]*res/crop_size).astype(np.int) 92 | continue 93 | if elem in self.resolutions: 94 | if self.resolutions[elem] is None: 95 | continue 96 | if isinstance(sample[elem], list): 97 | if sample[elem][0].ndim == 3: 98 | output_size = np.append(self.resolutions[elem], [3, len(sample[elem])]) 99 | else: 100 | output_size = np.append(self.resolutions[elem], len(sample[elem])) 101 | tmp = sample[elem] 102 | sample[elem] = np.zeros(output_size, dtype=np.float32) 103 | for ii, crop in enumerate(tmp): 104 | if self.flagvals is None: 105 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem]) 106 | else: 107 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem], flagval=self.flagvals[elem]) 108 | else: 109 | if self.flagvals is None: 110 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem]) 111 | else: 112 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem], flagval=self.flagvals[elem]) 113 | else: 114 | del sample[elem] 115 | 116 | return sample 117 | 118 | def __str__(self): 119 | return 'FixedResize:'+str(self.resolutions) 120 | 121 | 122 | class RandomHorizontalFlip(object): 123 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 124 | 125 | def __call__(self, sample): 126 | 127 | if random.random() < 0.5: 128 | for elem in sample.keys(): 129 | if 'meta' in elem: 130 | continue 131 | tmp = sample[elem] 132 | tmp = cv2.flip(tmp, flipCode=1) 133 | sample[elem] = tmp 134 | 135 | return sample 136 | 137 | def __str__(self): 138 | return 'RandomHorizontalFlip' 139 | 140 | 141 | class IOGPoints(object): 142 | """ 143 | Returns the IOG Points (top-left and bottom-right or top-right and bottom-left) in a given binary mask 144 | sigma: sigma of Gaussian to create a heatmap from a point 145 | pad_pixel: number of pixels fo the maximum perturbation 146 | elem: which element of the sample to choose as the binary mask 147 | """ 148 | def __init__(self, sigma=10, elem='crop_gt',pad_pixel =10): 149 | self.sigma = sigma 150 | self.elem = elem 151 | self.pad_pixel =pad_pixel 152 | 153 | def __call__(self, sample): 154 | 155 | if sample[self.elem].ndim == 3: 156 | raise ValueError('IOGPoints not implemented for multiple object per image.') 157 | _target = sample[self.elem] 158 | 159 | targetshape=_target.shape 160 | if np.max(_target) == 0: 161 | sample['IOG_points'] = np.zeros([targetshape[0],targetshape[1],2], dtype=_target.dtype) # TODO: handle one_mask_per_point case 162 | else: 163 | _points = helpers.iog_points(_target, self.pad_pixel) 164 | sample['IOG_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) 165 | 166 | return sample 167 | 168 | def __str__(self): 169 | return 'IOGPoints:(sigma='+str(self.sigma)+', pad_pixel='+str(self.pad_pixel)+', elem='+str(self.elem)+')' 170 | 171 | 172 | class ConcatInputs(object): 173 | 174 | def __init__(self, elems=('image', 'point')): 175 | self.elems = elems 176 | 177 | def __call__(self, sample): 178 | 179 | res = sample[self.elems[0]] 180 | 181 | for elem in self.elems[1:]: 182 | assert(sample[self.elems[0]].shape[:2] == sample[elem].shape[:2]) 183 | 184 | # Check if third dimension is missing 185 | tmp = sample[elem] 186 | if tmp.ndim == 2: 187 | tmp = tmp[:, :, np.newaxis] 188 | 189 | res = np.concatenate((res, tmp), axis=2) 190 | 191 | sample['concat'] = res 192 | return sample 193 | 194 | def __str__(self): 195 | return 'ExtremePoints:'+str(self.elems) 196 | 197 | 198 | class CropFromMask(object): 199 | """ 200 | Returns image cropped in bounding box from a given mask 201 | """ 202 | def __init__(self, crop_elems=('image', 'gt','void_pixels'), 203 | mask_elem='gt', 204 | relax=0, 205 | zero_pad=False): 206 | 207 | self.crop_elems = crop_elems 208 | self.mask_elem = mask_elem 209 | self.relax = relax 210 | self.zero_pad = zero_pad 211 | 212 | def __call__(self, sample): 213 | _target = sample[self.mask_elem] 214 | if _target.ndim == 2: 215 | _target = np.expand_dims(_target, axis=-1) 216 | for elem in self.crop_elems: 217 | _img = sample[elem] 218 | _crop = [] 219 | if self.mask_elem == elem: 220 | if _img.ndim == 2: 221 | _img = np.expand_dims(_img, axis=-1) 222 | for k in range(0, _target.shape[-1]): 223 | _tmp_img = _img[..., k] 224 | _tmp_target = _target[..., k] 225 | if np.max(_target[..., k]) == 0: 226 | _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype)) 227 | else: 228 | _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 229 | else: 230 | for k in range(0, _target.shape[-1]): 231 | if np.max(_target[..., k]) == 0: 232 | _crop.append(np.zeros(_img.shape, dtype=_img.dtype)) 233 | else: 234 | _tmp_target = _target[..., k] 235 | _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 236 | if len(_crop) == 1: 237 | sample['crop_' + elem] = _crop[0] 238 | else: 239 | sample['crop_' + elem] = _crop 240 | 241 | return sample 242 | 243 | def __str__(self): 244 | return 'CropFromMask:(crop_elems='+str(self.crop_elems)+', mask_elem='+str(self.mask_elem)+\ 245 | ', relax='+str(self.relax)+',zero_pad='+str(self.zero_pad)+')' 246 | 247 | 248 | class ToImage(object): 249 | """ 250 | Return the given elements between 0 and 255 251 | """ 252 | def __init__(self, norm_elem='image', custom_max=255.): 253 | self.norm_elem = norm_elem 254 | self.custom_max = custom_max 255 | 256 | def __call__(self, sample): 257 | if isinstance(self.norm_elem, tuple): 258 | for elem in self.norm_elem: 259 | tmp = sample[elem] 260 | sample[elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 261 | else: 262 | tmp = sample[self.norm_elem] 263 | sample[self.norm_elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 264 | return sample 265 | 266 | def __str__(self): 267 | return 'NormalizeImage' 268 | 269 | 270 | class ToTensor(object): 271 | """Convert ndarrays in sample to Tensors.""" 272 | 273 | def __call__(self, sample): 274 | 275 | for elem in sample.keys(): 276 | if 'meta' in elem: 277 | continue 278 | elif 'bbox' in elem: 279 | tmp = sample[elem] 280 | sample[elem] = torch.from_numpy(tmp) 281 | continue 282 | 283 | tmp = sample[elem] 284 | 285 | if tmp.ndim == 2: 286 | tmp = tmp[:, :, np.newaxis] 287 | 288 | # swap color axis because 289 | # numpy image: H x W x C 290 | # torch image: C X H X W 291 | tmp = tmp.transpose((2, 0, 1)) 292 | sample[elem] = torch.from_numpy(tmp) 293 | 294 | return sample 295 | 296 | def __str__(self): 297 | return 'ToTensor' 298 | -------------------------------------------------------------------------------- /dataloaders/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch, cv2 3 | import random 4 | import numpy as np 5 | 6 | 7 | def tens2image(im): 8 | if im.size()[0] == 1: 9 | tmp = np.squeeze(im.numpy(), axis=0) 10 | else: 11 | tmp = im.numpy() 12 | if tmp.ndim == 2: 13 | return tmp 14 | else: 15 | return tmp.transpose((1, 2, 0)) 16 | 17 | 18 | def crop2fullmask(crop_mask, bbox, im=None, im_size=None, zero_pad=False, relax=0, mask_relax=True, 19 | interpolation=cv2.INTER_CUBIC, scikit=False): 20 | if scikit: 21 | from skimage.transform import resize as sk_resize 22 | assert(not(im is None and im_size is None)), 'You have to provide an image or the image size True' 23 | if im is None: 24 | im_si = im_size 25 | else: 26 | im_si = im.shape 27 | # Borers of image 28 | bounds = (0, 0, im_si[1] - 1, im_si[0] - 1) 29 | 30 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 31 | bbox_valid = (max(bbox[0], bounds[0]), 32 | max(bbox[1], bounds[1]), 33 | min(bbox[2], bounds[2]), 34 | min(bbox[3], bounds[3])) 35 | 36 | # Bounding box of initial mask 37 | bbox_init = (bbox[0] + relax, 38 | bbox[1] + relax, 39 | bbox[2] - relax, 40 | bbox[3] - relax) 41 | 42 | if zero_pad: 43 | # Offsets for x and y 44 | offsets = (-bbox[0], -bbox[1]) 45 | else: 46 | # assert((bbox == bbox_valid).all()) 47 | offsets = (-bbox_valid[0], -bbox_valid[1]) 48 | 49 | # Simple per element addition in the tuple 50 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 51 | 52 | if scikit: 53 | crop_mask = sk_resize(crop_mask, (bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), order=0, mode='constant').astype(crop_mask.dtype) 54 | else: 55 | crop_mask = cv2.resize(crop_mask, (bbox[2] - bbox[0] + 1, bbox[3] - bbox[1] + 1), interpolation=interpolation) 56 | result_ = np.zeros(im_si) 57 | result_[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] = \ 58 | crop_mask[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] 59 | 60 | result = np.zeros(im_si) 61 | if mask_relax: 62 | result[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] = \ 63 | result_[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] 64 | else: 65 | result = result_ 66 | 67 | return result 68 | 69 | 70 | def overlay_mask(im, ma, colors=None, alpha=0.5): 71 | assert np.max(im) <= 1.0 72 | if colors is None: 73 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255. 74 | else: 75 | colors = np.append([[0.,0.,0.]], colors, axis=0); 76 | 77 | if ma.ndim == 3: 78 | assert len(colors) >= ma.shape[0], 'Not enough colors' 79 | ma = ma.astype(np.bool) 80 | im = im.astype(np.float32) 81 | 82 | if ma.ndim == 2: 83 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[1, :3] # np.array([0,0,255])/255.0 84 | else: 85 | fg = [] 86 | for n in range(ma.ndim): 87 | fg.append(im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1+n, :3]) 88 | # Whiten background 89 | bg = im.copy() 90 | if ma.ndim == 2: 91 | bg[ma == 0] = im[ma == 0] 92 | bg[ma == 1] = fg[ma == 1] 93 | total_ma = ma 94 | else: 95 | total_ma = np.zeros([ma.shape[1], ma.shape[2]]) 96 | for n in range(ma.shape[0]): 97 | tmp_ma = ma[n, :, :] 98 | total_ma = np.logical_or(tmp_ma, total_ma) 99 | tmp_fg = fg[n] 100 | bg[tmp_ma == 1] = tmp_fg[tmp_ma == 1] 101 | bg[total_ma == 0] = im[total_ma == 0] 102 | 103 | # [-2:] is s trick to be compatible both with opencv 2 and 3 104 | contours = cv2.findContours(total_ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 105 | cv2.drawContours(bg, contours[0], -1, (0.0, 0.0, 0.0), 1) 106 | 107 | return bg 108 | import PIL 109 | def overlay_masks(im, masks, alpha=0.5): 110 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255. 111 | 112 | if isinstance(masks, np.ndarray): 113 | masks = [masks] 114 | 115 | assert len(colors) >= len(masks), 'Not enough colors' 116 | 117 | ov = im.copy() 118 | ov_black = im.copy()*0 119 | 120 | imgZero = np.zeros(np.array(masks, dtype = np.uint8).shape,np.uint8) 121 | im = im.astype(np.float32) 122 | total_ma = np.zeros([im.shape[0], im.shape[1]]) 123 | i = 1 124 | for ma in masks: 125 | ma = ma.astype(np.bool) 126 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[i, :3] # np.array([0,0,255])/255.0 127 | i = i + 1 128 | ov[ma == 1] = fg[ma == 1] 129 | total_ma += ma 130 | 131 | # [-2:] is s trick to be compatible both with opencv 2 and 3 132 | contours = cv2.findContours(ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 133 | cv2.drawContours(ov, contours[0], -1, (0.0, 0.0, 0.0), 1) 134 | cv2.drawContours(ov_black, contours[0], -1, (255, 255, 255), -1)#only draw a round 135 | ov[total_ma == 0] = im[total_ma == 0] 136 | 137 | return ov_black 138 | 139 | from scipy import ndimage 140 | def getPositon(distance_transform): 141 | a = np.mat(distance_transform) 142 | raw, column = a.shape# get the matrix of a raw and column 143 | _positon = np.argmax(a)# get the index of max in the a 144 | m, n = divmod(_positon, column) 145 | raw=m 146 | column=n 147 | # print "The raw is " ,m 148 | # print "The column is ", n 149 | # print "The max of the a is ", a[m , n] 150 | # print(raw,column,a[m , n]) 151 | return raw,column 152 | 153 | def iog_points(mask, pad_pixel=10): 154 | def find_point(id_x, id_y, ids): 155 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)] 156 | return [id_x[sel_id], id_y[sel_id]] 157 | 158 | inds_y, inds_x = np.where(mask > 0.5) 159 | [h,w]=mask.shape 160 | left = find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x))) 161 | right = find_point(inds_x, inds_y, np.where(inds_x >= np.max(inds_x))) 162 | top = find_point(inds_x, inds_y, np.where(inds_y <= np.min(inds_y))) 163 | bottom = find_point(inds_x, inds_y, np.where(inds_y >= np.max(inds_y))) 164 | 165 | x_min=left[0] 166 | x_max=right[0] 167 | y_min=top[1] 168 | y_max=bottom[1] 169 | 170 | map_xor = (mask > 0.5) 171 | h,w = map_xor.shape 172 | map_xor_new = np.zeros((h+2,w+2)) 173 | map_xor_new[1:(h+1),1:(w+1)] = map_xor[:,:] 174 | distance_transform=ndimage.distance_transform_edt(map_xor_new) 175 | distance_transform_back = distance_transform[1:(h+1),1:(w+1)] 176 | raw,column=getPositon(distance_transform_back) 177 | center_point = [column,raw] 178 | 179 | left_top=[max(x_min-pad_pixel,0), max(y_min-pad_pixel,0)] 180 | left_bottom=[max(x_min-pad_pixel ,0), min(y_max+pad_pixel,h)] 181 | right_top=[min(x_max+pad_pixel,w), max(y_min-pad_pixel,0)] 182 | righr_bottom=[min(x_max+pad_pixel ,w), min(y_max+pad_pixel,h)] 183 | a=[center_point,left_top,left_bottom,right_top,righr_bottom] 184 | 185 | return np.array(a) 186 | 187 | 188 | def get_bbox(mask, points=None, pad=0, zero_pad=False): 189 | if points is not None: 190 | inds = np.flip(points.transpose(), axis=0) 191 | else: 192 | inds = np.where(mask > 0) 193 | 194 | if inds[0].shape[0] == 0: 195 | return None 196 | 197 | if zero_pad: 198 | x_min_bound = -np.inf 199 | y_min_bound = -np.inf 200 | x_max_bound = np.inf 201 | y_max_bound = np.inf 202 | else: 203 | x_min_bound = 0 204 | y_min_bound = 0 205 | x_max_bound = mask.shape[1] - 1 206 | y_max_bound = mask.shape[0] - 1 207 | 208 | x_min = max(inds[1].min() - pad, x_min_bound) 209 | y_min = max(inds[0].min() - pad, y_min_bound) 210 | x_max = min(inds[1].max() + pad, x_max_bound) 211 | y_max = min(inds[0].max() + pad, y_max_bound) 212 | 213 | return x_min, y_min, x_max, y_max 214 | 215 | 216 | def crop_from_bbox(img, bbox, zero_pad=False): 217 | # Borders of image 218 | bounds = (0, 0, img.shape[1] - 1, img.shape[0] - 1) 219 | 220 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 221 | bbox_valid = (max(bbox[0], bounds[0]), 222 | max(bbox[1], bounds[1]), 223 | min(bbox[2], bounds[2]), 224 | min(bbox[3], bounds[3])) 225 | 226 | if zero_pad: 227 | # Initialize crop size (first 2 dimensions) 228 | crop = np.zeros((bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), dtype=img.dtype) 229 | 230 | # Offsets for x and y 231 | offsets = (-bbox[0], -bbox[1]) 232 | 233 | else: 234 | assert(bbox == bbox_valid) 235 | crop = np.zeros((bbox_valid[3] - bbox_valid[1] + 1, bbox_valid[2] - bbox_valid[0] + 1), dtype=img.dtype) 236 | offsets = (-bbox_valid[0], -bbox_valid[1]) 237 | 238 | # Simple per element addition in the tuple 239 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 240 | 241 | img = np.squeeze(img) 242 | if img.ndim == 2: 243 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] = \ 244 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] 245 | else: 246 | crop = np.tile(crop[:, :, np.newaxis], [1, 1, 3]) # Add 3 RGB Channels 247 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1, :] = \ 248 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1, :] 249 | 250 | return crop 251 | 252 | 253 | def fixed_resize(sample, resolution, flagval=None): 254 | 255 | if flagval is None: 256 | if ((sample == 0) | (sample == 1)).all(): 257 | flagval = cv2.INTER_NEAREST 258 | else: 259 | flagval = cv2.INTER_CUBIC 260 | 261 | if isinstance(resolution, int): 262 | tmp = [resolution, resolution] 263 | tmp[np.argmax(sample.shape[:2])] = int(round(float(resolution)/np.min(sample.shape[:2])*np.max(sample.shape[:2]))) 264 | resolution = tuple(tmp) 265 | 266 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3): 267 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval) 268 | else: 269 | tmp = sample 270 | sample = np.zeros(np.append(resolution, tmp.shape[2]), dtype=np.float32) 271 | for ii in range(sample.shape[2]): 272 | sample[:, :, ii] = cv2.resize(tmp[:, :, ii], resolution[::-1], interpolation=flagval) 273 | return sample 274 | 275 | 276 | def crop_from_mask(img, mask, relax=0, zero_pad=False): 277 | if mask.shape[:2] != img.shape[:2]: 278 | mask = cv2.resize(mask, dsize=tuple(reversed(img.shape[:2])), interpolation=cv2.INTER_NEAREST) 279 | 280 | assert(mask.shape[:2] == img.shape[:2]) 281 | bbox = get_bbox(mask, pad=relax, zero_pad=zero_pad) 282 | 283 | if bbox is None: 284 | return None 285 | 286 | crop = crop_from_bbox(img, bbox, zero_pad) 287 | 288 | return crop 289 | 290 | 291 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64): 292 | """ Make a square gaussian kernel. 293 | size: is the dimensions of the output gaussian 294 | sigma: is full-width-half-maximum, which 295 | can be thought of as an effective radius. 296 | """ 297 | 298 | x = np.arange(0, size[1], 1, float) 299 | y = np.arange(0, size[0], 1, float) 300 | y = y[:, np.newaxis] 301 | 302 | if center is None: 303 | x0 = y0 = size[0] // 2 304 | else: 305 | x0 = center[0] 306 | y0 = center[1] 307 | 308 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type) 309 | 310 | 311 | def make_gt(img, labels, sigma=10, one_mask_per_point=False): 312 | """ Make the ground-truth for landmark. 313 | img: the original color image 314 | labels: label with the Gaussian center(s) [[x0, y0],[x1, y1],...] 315 | sigma: sigma of the Gaussian. 316 | one_mask_per_point: masks for each point in different channels? 317 | """ 318 | 319 | h, w = img.shape[:2] 320 | if labels is None: 321 | gt = make_gaussian((h, w), center=(h//2, w//2), sigma=sigma) 322 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64) 323 | gt_0 = gt 324 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64) 325 | 326 | gtout = np.zeros(shape=(h, w, 2)) 327 | gtout[:, :, 0]=gt_0 328 | gtout[:, :, 1]=gt_1 329 | gtout = gtout.astype(dtype=img.dtype) #(0~1) 330 | return gtout 331 | else: 332 | labels = np.array(labels) 333 | if labels.ndim == 1: 334 | labels = labels[np.newaxis] 335 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64) 336 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64) 337 | gt_0 = np.maximum(gt_0, make_gaussian((h, w), center=labels[0, :], sigma=sigma)) 338 | 339 | else: 340 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64) 341 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64) 342 | for ii in range(1,labels.shape[0]): 343 | gt_1 = np.maximum(gt_1, make_gaussian((h, w), center=labels[ii, :], sigma=sigma)) 344 | gt_0 = np.maximum(gt_0, make_gaussian((h, w), center=labels[0, :], sigma=sigma)) 345 | 346 | gt = np.zeros(shape=(h, w, 2)) 347 | gt[:, :, 0]=gt_0 348 | gt[:, :, 1]=gt_1 349 | gt = gt.astype(dtype=img.dtype) #(0~1) 350 | return gt 351 | 352 | def cstm_normalize(im, max_value): 353 | """ 354 | Normalize image to range 0 - max_value 355 | """ 356 | imn = max_value*(im - im.min()) / max((im.max() - im.min()), 1e-8) 357 | return imn 358 | 359 | 360 | def generate_param_report(logfile, param): 361 | log_file = open(logfile, 'w') 362 | for key, val in param.items(): 363 | log_file.write(key+':'+str(val)+'\n') 364 | log_file.close() 365 | -------------------------------------------------------------------------------- /dataloaders/pascal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import errno 3 | import hashlib 4 | import os 5 | import sys 6 | import tarfile 7 | import numpy as np 8 | 9 | import torch.utils.data as data 10 | from PIL import Image 11 | from six.moves import urllib 12 | import json 13 | from mypath import Path 14 | 15 | 16 | class VOCSegmentation(data.Dataset): 17 | 18 | URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" 19 | FILE = "VOCtrainval_11-May-2012.tar" 20 | MD5 = '6cd6e144f989b92b3379bac3b3de84fd' 21 | BASE_DIR = 'VOCdevkit/VOC2012' 22 | 23 | category_names = ['background', 24 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 25 | 'bus', 'car', 'cat', 'chair', 'cow', 26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 27 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 28 | 29 | def __init__(self, 30 | root=Path.db_root_dir('pascal'), 31 | split='val', 32 | transform=None, 33 | download=False, 34 | preprocess=False, 35 | area_thres=0, 36 | retname=True, 37 | suppress_void_pixels=True, 38 | default=False): 39 | 40 | self.root = root 41 | _voc_root = os.path.join(self.root, self.BASE_DIR) 42 | _mask_dir = os.path.join(_voc_root, 'SegmentationObject')#each object each color 43 | _cat_dir = os.path.join(_voc_root, 'SegmentationClass')#each class each color 44 | _image_dir = os.path.join(_voc_root, 'JPEGImages') 45 | self.transform = transform 46 | if isinstance(split, str): 47 | self.split = [split] 48 | else: 49 | split.sort() 50 | self.split = split 51 | self.area_thres = area_thres 52 | self.retname = retname 53 | self.suppress_void_pixels = suppress_void_pixels 54 | self.default = default 55 | 56 | # Build the ids file 57 | area_th_str = "" 58 | if self.area_thres != 0: 59 | area_th_str = '_area_thres-' + str(area_thres) 60 | 61 | self.obj_list_file = os.path.join(self.root, self.BASE_DIR, 'ImageSets', 'Segmentation', 62 | '_'.join(self.split) + '_instances' + area_th_str + '.txt') 63 | 64 | if download: 65 | self._download() 66 | 67 | if not self._check_integrity(): 68 | raise RuntimeError('Dataset not found or corrupted.' + 69 | ' You can use download=True to download it') 70 | 71 | # train/val/test splits are pre-cut 72 | _splits_dir = os.path.join(_voc_root, 'ImageSets', 'Segmentation') 73 | 74 | self.im_ids = [] 75 | self.images = [] 76 | self.categories = [] 77 | self.masks = [] 78 | 79 | for splt in self.split: 80 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 81 | lines = f.read().splitlines() 82 | 83 | for ii, line in enumerate(lines): 84 | _image = os.path.join(_image_dir, line + ".jpg") 85 | _cat = os.path.join(_cat_dir, line + ".png") 86 | _mask = os.path.join(_mask_dir, line + ".png") 87 | assert os.path.isfile(_image) 88 | assert os.path.isfile(_cat) 89 | assert os.path.isfile(_mask) 90 | self.im_ids.append(line.rstrip('\n')) 91 | self.images.append(_image) 92 | self.categories.append(_cat) 93 | self.masks.append(_mask) 94 | assert (len(self.images) == len(self.masks)) 95 | assert (len(self.images) == len(self.categories)) 96 | 97 | # Precompute the list of objects and their categories for each image 98 | if (not self._check_preprocess()) or preprocess: 99 | print('Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.') 100 | self._preprocess() 101 | 102 | # Build the list of objects 103 | self.obj_list = [] 104 | num_images = 0 105 | for ii in range(len(self.im_ids)): 106 | flag = False 107 | for jj in range(len(self.obj_dict[self.im_ids[ii]])): 108 | if self.obj_dict[self.im_ids[ii]][jj] != -1: 109 | self.obj_list.append([ii, jj]) 110 | flag = True 111 | if flag: 112 | num_images += 1 113 | 114 | # Display stats 115 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list))) 116 | 117 | def __getitem__(self, index): 118 | _img, _target, _void_pixels, _, _, _ = self._make_img_gt_point_pair(index) 119 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels} 120 | 121 | if self.retname: 122 | _im_ii = self.obj_list[index][0] 123 | _obj_ii = self.obj_list[index][1] 124 | sample['meta'] = {'image': str(self.im_ids[_im_ii]), 125 | 'object': str(_obj_ii), 126 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii], 127 | 'im_size': (_img.shape[0], _img.shape[1])} 128 | 129 | if self.transform is not None: 130 | sample = self.transform(sample) 131 | return sample 132 | 133 | def __len__(self): 134 | return len(self.obj_list) 135 | 136 | def _check_integrity(self): 137 | _fpath = os.path.join(self.root, self.FILE) 138 | if not os.path.isfile(_fpath): 139 | print("{} does not exist".format(_fpath)) 140 | return False 141 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest() 142 | if _md5c != self.MD5: 143 | print(" MD5({}) did not match MD5({}) expected for {}".format( 144 | _md5c, self.MD5, _fpath)) 145 | return False 146 | return True 147 | 148 | def _check_preprocess(self): 149 | _obj_list_file = self.obj_list_file 150 | if not os.path.isfile(_obj_list_file): 151 | return False 152 | else: 153 | self.obj_dict = json.load(open(_obj_list_file, 'r')) 154 | 155 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids)) 156 | 157 | def _preprocess(self): 158 | self.obj_dict = {} 159 | obj_counter = 0 160 | for ii in range(len(self.im_ids)): 161 | # Read object masks and get number of objects 162 | _mask = np.array(Image.open(self.masks[ii])) 163 | _mask_ids = np.unique(_mask) 164 | if _mask_ids[-1] == 255: 165 | n_obj = _mask_ids[-2] 166 | else: 167 | n_obj = _mask_ids[-1] 168 | 169 | # Get the categories from these objects 170 | _cats = np.array(Image.open(self.categories[ii])) 171 | _cat_ids = [] 172 | for jj in range(n_obj): 173 | tmp = np.where(_mask == jj + 1) 174 | obj_area = len(tmp[0]) 175 | if obj_area > self.area_thres: 176 | _cat_ids.append(int(_cats[tmp[0][0], tmp[1][0]])) 177 | else: 178 | _cat_ids.append(-1) 179 | obj_counter += 1 180 | 181 | self.obj_dict[self.im_ids[ii]] = _cat_ids 182 | 183 | with open(self.obj_list_file, 'w') as outfile: 184 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]]))) 185 | for ii in range(1, len(self.im_ids)): 186 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]]))) 187 | outfile.write('\n}\n') 188 | 189 | print('Preprocessing finished') 190 | 191 | def _download(self): 192 | _fpath = os.path.join(self.root, self.FILE) 193 | 194 | try: 195 | os.makedirs(self.root) 196 | except OSError as e: 197 | if e.errno == errno.EEXIST: 198 | pass 199 | else: 200 | raise 201 | 202 | if self._check_integrity(): 203 | print('Files already downloaded and verified') 204 | return 205 | else: 206 | print('Downloading ' + self.URL + ' to ' + _fpath) 207 | 208 | def _progress(count, block_size, total_size): 209 | sys.stdout.write('\r>> %s %.1f%%' % 210 | (_fpath, float(count * block_size) / 211 | float(total_size) * 100.0)) 212 | sys.stdout.flush() 213 | 214 | urllib.request.urlretrieve(self.URL, _fpath, _progress) 215 | 216 | # extract file 217 | cwd = os.getcwd() 218 | print('Extracting tar file') 219 | tar = tarfile.open(_fpath) 220 | os.chdir(self.root) 221 | tar.extractall() 222 | tar.close() 223 | os.chdir(cwd) 224 | print('Done!') 225 | 226 | def _make_img_gt_point_pair(self, index): 227 | _im_ii = self.obj_list[index][0] 228 | _obj_ii = self.obj_list[index][1] 229 | 230 | # Read Image 231 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32) ###zsy open image imread 232 | 233 | # Read Target object 234 | _tmp = (np.array(Image.open(self.masks[_im_ii]))).astype(np.float32) 235 | _void_pixels = (_tmp == 255) 236 | _tmp[_void_pixels] = 0 237 | 238 | _other_same_class = np.zeros(_tmp.shape) 239 | _other_classes = np.zeros(_tmp.shape) 240 | 241 | if self.default: 242 | _target = _tmp 243 | _background = np.logical_and(_tmp == 0, ~_void_pixels) 244 | else: 245 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32) 246 | _background = np.logical_and(_tmp == 0, ~_void_pixels) 247 | obj_cat = self.obj_dict[self.im_ids[_im_ii]][_obj_ii] 248 | for ii in range(1, np.max(_tmp).astype(np.int)+1): 249 | ii_cat = self.obj_dict[self.im_ids[_im_ii]][ii-1] 250 | if obj_cat == ii_cat and ii != _obj_ii+1: 251 | _other_same_class = np.logical_or(_other_same_class, _tmp == ii) 252 | elif ii != _obj_ii+1: 253 | _other_classes = np.logical_or(_other_classes, _tmp == ii) 254 | 255 | return _img, _target, _void_pixels.astype(np.float32), \ 256 | _other_classes.astype(np.float32), _other_same_class.astype(np.float32), \ 257 | _background.astype(np.float32) 258 | 259 | def __str__(self): 260 | return 'VOC2012(split=' + str(self.split) + ',area_thres=' + str(self.area_thres) + ')' 261 | 262 | 263 | if __name__ == '__main__': 264 | import matplotlib.pyplot as plt 265 | import dataloaders.helpers as helpers 266 | import torch 267 | import dataloaders.custom_transforms as tr 268 | from torchvision import transforms 269 | 270 | transform = transforms.Compose([tr.ToTensor()]) 271 | 272 | dataset = VOCSegmentation(split=['train', 'val'], transform=transform, retname=True) 273 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 274 | 275 | for i, sample in enumerate(dataloader): 276 | plt.figure() 277 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"]) / 255., 278 | np.squeeze(helpers.tens2image(sample["gt"]))) 279 | plt.imshow(overlay) 280 | plt.title(dataset.category_names[sample["meta"]["category"][0]]) 281 | if i == 3: 282 | break 283 | 284 | plt.show(block=True) 285 | -------------------------------------------------------------------------------- /dataloaders/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch, cv2 4 | import errno 5 | import hashlib 6 | import json 7 | import os 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | import scipy.io 13 | import torch.utils.data as data 14 | from PIL import Image 15 | from six.moves import urllib 16 | from mypath import Path 17 | 18 | 19 | class SBDSegmentation(data.Dataset): 20 | 21 | URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz" 22 | FILE = "benchmark.tgz" 23 | MD5 = '82b4d87ceb2ed10f6038a1cba92111cb' 24 | 25 | def __init__(self, 26 | root=Path.db_root_dir('sbd'), 27 | split='val', 28 | transform=None, 29 | download=False, 30 | preprocess=False, 31 | area_thres=0, 32 | retname=True): 33 | 34 | # Store parameters 35 | self.root = root 36 | self.transform = transform 37 | if isinstance(split, str): 38 | self.split = [split] 39 | else: 40 | split.sort() 41 | self.split = split 42 | self.area_thres = area_thres 43 | self.retname = retname 44 | 45 | # Where to find things according to the author's structure 46 | self.dataset_dir = os.path.join(self.root, 'benchmark_RELEASE', 'dataset') 47 | _mask_dir = os.path.join(self.dataset_dir, 'inst') 48 | _image_dir = os.path.join(self.dataset_dir, 'img') 49 | 50 | if self.area_thres != 0: 51 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances_area_thres-' + 52 | str(area_thres) + '.txt') 53 | else: 54 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances' + '.txt') 55 | 56 | # Download dataset? 57 | if download: 58 | self._download() 59 | if not self._check_integrity(): 60 | raise RuntimeError('Dataset file downloaded is corrupted.') 61 | 62 | # Get list of all images from the split and check that the files exist 63 | self.im_ids = [] 64 | self.images = [] 65 | self.masks = [] 66 | for splt in self.split: 67 | with open(os.path.join(self.dataset_dir, splt+'.txt'), "r") as f: 68 | lines = f.read().splitlines() 69 | 70 | for line in lines: 71 | _image = os.path.join(_image_dir, line + ".jpg") 72 | _mask = os.path.join(_mask_dir, line + ".mat") 73 | assert os.path.isfile(_image) 74 | assert os.path.isfile(_mask) 75 | self.im_ids.append(line) 76 | self.images.append(_image) 77 | self.masks.append(_mask) 78 | 79 | assert (len(self.images) == len(self.masks)) 80 | 81 | # Precompute the list of objects and their categories for each image 82 | if (not self._check_preprocess()) or preprocess: 83 | print('Preprocessing SBD dataset, this will take long, but it will be done only once.') 84 | self._preprocess() 85 | 86 | # Build the list of objects 87 | self.obj_list = [] 88 | num_images = 0 89 | for ii in range(len(self.im_ids)): 90 | if self.im_ids[ii] in self.obj_dict.keys(): 91 | flag = False 92 | for jj in range(len(self.obj_dict[self.im_ids[ii]])): 93 | if self.obj_dict[self.im_ids[ii]][jj] != -1: 94 | self.obj_list.append([ii, jj]) 95 | flag = True 96 | if flag: 97 | num_images += 1 98 | 99 | # Display stats 100 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list))) 101 | 102 | def __getitem__(self, index): 103 | 104 | _img, _target = self._make_img_gt_point_pair(index) 105 | _void_pixels = (_target == 255).astype(np.float32) 106 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels} 107 | 108 | if self.retname: 109 | _im_ii = self.obj_list[index][0] 110 | _obj_ii = self.obj_list[index][1] 111 | sample['meta'] = {'image': str(self.im_ids[_im_ii]), 112 | 'object': str(_obj_ii), 113 | 'im_size': (_img.shape[0], _img.shape[1]), 114 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii]} 115 | 116 | if self.transform is not None: 117 | sample = self.transform(sample) 118 | 119 | return sample 120 | 121 | def __len__(self): 122 | return len(self.obj_list) 123 | 124 | def _check_integrity(self): 125 | _fpath = os.path.join(self.root, self.FILE) 126 | if not os.path.isfile(_fpath): 127 | print("{} does not exist".format(_fpath)) 128 | return False 129 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest() 130 | if _md5c != self.MD5: 131 | print(" MD5({}) did not match MD5({}) expected for {}".format( 132 | _md5c, self.MD5, _fpath)) 133 | return False 134 | return True 135 | 136 | def _check_preprocess(self): 137 | # Check that the file with categories is there and with correct size 138 | _obj_list_file = self.obj_list_file 139 | if not os.path.isfile(_obj_list_file): 140 | return False 141 | else: 142 | self.obj_dict = json.load(open(_obj_list_file, 'r')) 143 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids)) 144 | 145 | def _preprocess(self): 146 | # Get all object instances and their category 147 | self.obj_dict = {} 148 | obj_counter = 0 149 | for ii in range(len(self.im_ids)): 150 | # Read object masks and get number of objects 151 | tmp = scipy.io.loadmat(self.masks[ii]) 152 | _mask = tmp["GTinst"][0]["Segmentation"][0] 153 | _cat_ids = tmp["GTinst"][0]["Categories"][0].astype(int) 154 | 155 | _mask_ids = np.unique(_mask) 156 | n_obj = _mask_ids[-1] 157 | assert(n_obj == len(_cat_ids)) 158 | 159 | for jj in range(n_obj): 160 | temp = np.where(_mask == jj + 1) 161 | obj_area = len(temp[0]) 162 | if obj_area < self.area_thres: 163 | _cat_ids[jj] = -1 164 | obj_counter += 1 165 | 166 | self.obj_dict[self.im_ids[ii]] = np.squeeze(_cat_ids, 1).tolist() 167 | 168 | # Save it to file for future reference 169 | with open(self.obj_list_file, 'w') as outfile: 170 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]]))) 171 | for ii in range(1, len(self.im_ids)): 172 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]]))) 173 | outfile.write('\n}\n') 174 | 175 | print('Pre-processing finished') 176 | 177 | def _download(self): 178 | _fpath = os.path.join(self.root, self.FILE) 179 | 180 | try: 181 | os.makedirs(self.root) 182 | except OSError as e: 183 | if e.errno == errno.EEXIST: 184 | pass 185 | else: 186 | raise 187 | 188 | if self._check_integrity(): 189 | print('Files already downloaded and verified') 190 | return 191 | else: 192 | print('Downloading ' + self.URL + ' to ' + _fpath) 193 | 194 | def _progress(count, block_size, total_size): 195 | sys.stdout.write('\r>> %s %.1f%%' % 196 | (_fpath, float(count * block_size) / 197 | float(total_size) * 100.0)) 198 | sys.stdout.flush() 199 | 200 | urllib.request.urlretrieve(self.URL, _fpath, _progress) 201 | 202 | # extract file 203 | cwd = os.getcwd() 204 | print('Extracting tar file') 205 | tar = tarfile.open(_fpath) 206 | os.chdir(self.root) 207 | tar.extractall() 208 | tar.close() 209 | os.chdir(cwd) 210 | print('Done!') 211 | 212 | def _make_img_gt_point_pair(self, index): 213 | _im_ii = self.obj_list[index][0] 214 | _obj_ii = self.obj_list[index][1] 215 | 216 | # Read Image 217 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32) 218 | 219 | # Read Taret object 220 | _tmp = scipy.io.loadmat(self.masks[_im_ii])["GTinst"][0]["Segmentation"][0] 221 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32) 222 | 223 | return _img, _target 224 | 225 | def __str__(self): 226 | return 'SBDSegmentation(split='+str(self.split)+', area_thres='+str(self.area_thres)+')' 227 | 228 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from torch.utils.data import DataLoader 4 | from evaluation.eval import eval_one_result 5 | import dataloaders.pascal as pascal 6 | 7 | exp_root_dir = './' 8 | 9 | method_names = [] 10 | method_names.append('run_0') 11 | 12 | if __name__ == '__main__': 13 | 14 | # Dataloader 15 | dataset = pascal.VOCSegmentation(transform=None, retname=True) 16 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 17 | 18 | # Iterate through all the different methods 19 | for method in method_names: 20 | for ii in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]: 21 | results_folder = os.path.join(exp_root_dir, method, 'Results') 22 | 23 | filename = os.path.join(exp_root_dir, 'eval_results', method.replace('/', '-') + '.txt') 24 | if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')): 25 | os.makedirs(os.path.join(exp_root_dir, 'eval_results')) 26 | 27 | jaccards = eval_one_result(dataloader, results_folder, mask_thres=ii) 28 | val = jaccards["all_jaccards"].mean() 29 | 30 | # Show mean and store result 31 | print(ii) 32 | print("Result for {:<80}: {}".format(method, str.format("{0:.4f}", 100*val))) 33 | with open(filename, 'w') as f: 34 | f.write(str(val)) 35 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import dataloaders.helpers as helpers 7 | import evaluation.evaluation as evaluation 8 | 9 | def eval_one_result(loader, folder, one_mask_per_image=False, mask_thres=0.5, use_void_pixels=True, custom_box=False): 10 | def mAPr(per_cat, thresholds): 11 | n_cat = len(per_cat) 12 | all_apr = np.zeros(len(thresholds)) 13 | for ii, th in enumerate(thresholds): 14 | per_cat_recall = np.zeros(n_cat) 15 | for jj, categ in enumerate(per_cat.keys()): 16 | per_cat_recall[jj] = np.sum(np.array(per_cat[categ]) > th)/len(per_cat[categ]) 17 | 18 | all_apr[ii] = per_cat_recall.mean() 19 | 20 | return all_apr.mean() 21 | 22 | # Allocate 23 | eval_result = dict() 24 | eval_result["all_jaccards"] = np.zeros(len(loader)) 25 | eval_result["all_percent"] = np.zeros(len(loader)) 26 | eval_result["meta"] = [] 27 | eval_result["per_categ_jaccard"] = dict() 28 | 29 | # Iterate 30 | for i, sample in enumerate(loader): 31 | 32 | if i % 500 == 0: 33 | print('Evaluating: {} of {} objects'.format(i, len(loader))) 34 | 35 | # Load result 36 | if not one_mask_per_image: 37 | filename = os.path.join(folder, 38 | sample["meta"]["image"][0] + '-' + sample["meta"]["object"][0] + '.png') 39 | else: 40 | filename = os.path.join(folder, 41 | sample["meta"]["image"][0] + '.png') 42 | mask = np.array(Image.open(filename)).astype(np.float32) / 255. 43 | gt = np.squeeze(helpers.tens2image(sample["gt"])) 44 | if use_void_pixels: 45 | void_pixels = np.squeeze(helpers.tens2image(sample["void_pixels"])) 46 | if mask.shape != gt.shape: 47 | mask = cv2.resize(mask, gt.shape[::-1], interpolation=cv2.INTER_CUBIC) 48 | 49 | # Threshold 50 | mask = (mask > mask_thres) 51 | if use_void_pixels: 52 | void_pixels = (void_pixels > 0.5) 53 | 54 | # Evaluate 55 | if use_void_pixels: 56 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask, void_pixels) 57 | else: 58 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask) 59 | 60 | if custom_box: 61 | box = np.squeeze(helpers.tens2image(sample["box"])) 62 | bb = helpers.get_bbox(box) 63 | else: 64 | bb = helpers.get_bbox(gt) 65 | 66 | mask_crop = helpers.crop_from_bbox(mask, bb) 67 | if use_void_pixels: 68 | non_void_pixels_crop = helpers.crop_from_bbox(np.logical_not(void_pixels), bb) 69 | gt_crop = helpers.crop_from_bbox(gt, bb) 70 | if use_void_pixels: 71 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop) & non_void_pixels_crop)/np.sum(non_void_pixels_crop) 72 | else: 73 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop))/mask_crop.size 74 | # Store in per category 75 | if "category" in sample["meta"]: 76 | cat = sample["meta"]["category"][0] 77 | else: 78 | cat = 1 79 | if cat not in eval_result["per_categ_jaccard"]: 80 | eval_result["per_categ_jaccard"][cat] = [] 81 | eval_result["per_categ_jaccard"][cat].append(eval_result["all_jaccards"][i]) 82 | 83 | # Store meta 84 | eval_result["meta"].append(sample["meta"]) 85 | 86 | # Compute some stats 87 | eval_result["mAPr0.5"] = mAPr(eval_result["per_categ_jaccard"], [0.5]) 88 | eval_result["mAPr0.7"] = mAPr(eval_result["per_categ_jaccard"], [0.7]) 89 | eval_result["mAPr-vol"] = mAPr(eval_result["per_categ_jaccard"], np.linspace(0.1, 0.9, 9)) 90 | 91 | return eval_result 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def jaccard(annotation, segmentation, void_pixels=None): 4 | 5 | assert(annotation.shape == segmentation.shape) 6 | 7 | if void_pixels is None: 8 | void_pixels = np.zeros_like(annotation) 9 | assert(void_pixels.shape == annotation.shape) 10 | 11 | annotation = annotation.astype(np.bool) 12 | segmentation = segmentation.astype(np.bool) 13 | void_pixels = void_pixels.astype(np.bool) 14 | if np.isclose(np.sum(annotation & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(segmentation & np.logical_not(void_pixels)), 0): 15 | return 1 16 | else: 17 | return np.sum(((annotation & segmentation) & np.logical_not(void_pixels))) / \ 18 | np.sum(((annotation | segmentation) & np.logical_not(void_pixels)), dtype=np.float32) 19 | 20 | -------------------------------------------------------------------------------- /ims/IOG.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/IOG.gif -------------------------------------------------------------------------------- /ims/cross_domain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/cross_domain.gif -------------------------------------------------------------------------------- /ims/ims.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/ims.png -------------------------------------------------------------------------------- /ims/refinement.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/refinement.gif -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | 2 | class Path(object): 3 | @staticmethod 4 | def db_root_dir(database): 5 | if database == 'pascal': 6 | return '/path/to/PASCAL/VOC2012' # folder that contains VOCdevkit/. 7 | 8 | elif database == 'sbd': 9 | return '/path/to/SBD/' # folder with img/, inst/, cls/, etc. 10 | else: 11 | print('Database {} not available.'.format(database)) 12 | raise NotImplementedError 13 | 14 | @staticmethod 15 | def models_dir(): 16 | return '/path/to/models/resnet101-5d3b4d8f.pth' 17 | #'resnet101-5d3b4d8f.pth' #resnet50-19c8e357.pth' 18 | -------------------------------------------------------------------------------- /networks/CoarseNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | class CoarseNet(nn.Module): 6 | def __init__(self, channel_settings, output_shape, num_class): 7 | super(CoarseNet, self).__init__() 8 | self.channel_settings = channel_settings 9 | laterals, upsamples, predict = [], [], [] 10 | for i in range(len(channel_settings)): 11 | laterals.append(self._lateral(channel_settings[i])) 12 | predict.append(self._predict(output_shape, num_class)) 13 | if i != len(channel_settings) - 1: 14 | upsamples.append(self._upsample()) 15 | self.laterals = nn.ModuleList(laterals) 16 | self.upsamples = nn.ModuleList(upsamples) 17 | self.predict = nn.ModuleList(predict) 18 | 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | 29 | def _lateral(self, input_size): 30 | layers = [] 31 | layers.append(nn.Conv2d(input_size, 256, 32 | kernel_size=1, stride=1, bias=False)) 33 | layers.append(nn.BatchNorm2d(256)) 34 | layers.append(nn.ReLU(inplace=True)) 35 | return nn.Sequential(*layers) 36 | 37 | def _upsample(self): 38 | layers = [] 39 | layers.append(torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 40 | layers.append(torch.nn.Conv2d(256, 256, 41 | kernel_size=1, stride=1, bias=False)) 42 | layers.append(nn.BatchNorm2d(256)) 43 | return nn.Sequential(*layers) 44 | 45 | def _predict(self, output_shape, num_class): 46 | layers = [] 47 | layers.append(nn.Conv2d(256, 256, 48 | kernel_size=1, stride=1, bias=False)) 49 | layers.append(nn.BatchNorm2d(256)) 50 | layers.append(nn.ReLU(inplace=True)) 51 | layers.append(nn.Conv2d(256, num_class, 52 | kernel_size=3, stride=1, padding=1, bias=False)) 53 | layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True)) 54 | layers.append(nn.BatchNorm2d(num_class)) 55 | return nn.Sequential(*layers) 56 | 57 | def forward(self, x): 58 | coarse_fms, coarse_outs = [], [] 59 | for i in range(len(self.channel_settings)): 60 | if i == 0: 61 | feature = self.laterals[i](x[i]) 62 | coarse_fms.append(feature) 63 | if i != len(self.channel_settings) - 1: 64 | up = feature 65 | feature = self.predict[i](feature) 66 | coarse_outs.append(feature) 67 | else: 68 | feature = self.laterals[i](x[i]) 69 | feature = feature+ up 70 | coarse_fms.append(feature) 71 | if i != len(self.channel_settings) - 1: 72 | up = self.upsamples[i](feature) 73 | feature = self.predict[i](feature) 74 | coarse_outs.append(feature) 75 | return coarse_fms, coarse_outs 76 | -------------------------------------------------------------------------------- /networks/FineNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Bottleneck(nn.Module): 5 | expansion = 4 6 | 7 | def __init__(self, inplanes, planes, stride=1): 8 | super(Bottleneck, self).__init__() 9 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 10 | self.bn1 = nn.BatchNorm2d(planes) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(planes) 14 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 15 | self.bn3 = nn.BatchNorm2d(planes * 2) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.downsample = nn.Sequential( 18 | nn.Conv2d(inplanes, planes * 2, 19 | kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(planes * 2)) 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class FineNet(nn.Module): 46 | def __init__(self, lateral_channel, out_shape, num_class): 47 | super(FineNet, self).__init__() 48 | cascade = [] 49 | num_cascade = 4 50 | for i in range(num_cascade): 51 | cascade.append(self._make_layer(lateral_channel, num_cascade-i-1, out_shape)) 52 | self.cascade = nn.ModuleList(cascade) 53 | self.final_predict = self._predict(4*lateral_channel, num_class) 54 | 55 | def _make_layer(self, input_channel, num, output_shape): 56 | layers = [] 57 | for i in range(num): 58 | layers.append(Bottleneck(input_channel, 128)) 59 | layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True)) 60 | return nn.Sequential(*layers) 61 | 62 | def _predict(self, input_channel, num_class): 63 | layers = [] 64 | layers.append(Bottleneck(input_channel, 128)) 65 | layers.append(nn.Conv2d(256, num_class, 66 | kernel_size=3, stride=1, padding=1, bias=False)) 67 | layers.append(nn.BatchNorm2d(num_class)) 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | fine_fms = [] 72 | for i in range(4): 73 | fine_fms.append(self.cascade[i](x[i])) 74 | out = torch.cat(fine_fms, dim=1) 75 | out = self.final_predict(out) 76 | return out 77 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/__init__.py -------------------------------------------------------------------------------- /networks/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.backbone import resnet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained): 4 | if backbone == 'resnet101': 5 | return resnet.ResNet101(output_stride, BatchNorm,nInputChannels=nInputChannels,pretrained=pretrained) 6 | elif backbone == 'resnet50': 7 | return resnet.ResNet50(output_stride, BatchNorm,nInputChannels=nInputChannels,pretrained=pretrained) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /networks/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 4 | 5 | class Bottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, stride=1, dilation=1,downsample=None, BatchNorm=None): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 11 | self.bn1 = BatchNorm(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 13 | dilation=dilation, padding=dilation, bias=False) 14 | self.bn2 = BatchNorm(planes) 15 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 16 | self.bn3 = BatchNorm(planes * 4) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.downsample = downsample 19 | self.stride = stride 20 | self.dilation = dilation 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | class ResNet(nn.Module): 45 | 46 | def __init__(self, block, layers, output_stride, BatchNorm,nInputChannels=3, pretrained=False): 47 | self.inplanes = 64 48 | super(ResNet, self).__init__() 49 | blocks = [1, 2, 4] 50 | if output_stride == 16: 51 | strides = [1, 2, 2, 1] 52 | dilations = [1, 1, 1, 2] 53 | elif output_stride == 8: 54 | strides = [1, 2, 1, 1] 55 | dilations = [1, 1, 2, 4] 56 | else: 57 | raise NotImplementedError 58 | 59 | # Modules 60 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = BatchNorm(64) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | 66 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 69 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 70 | self._init_weight() 71 | 72 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 73 | downsample = None 74 | if stride != 1 or self.inplanes != planes * block.expansion: 75 | downsample = nn.Sequential( 76 | nn.Conv2d(self.inplanes, planes * block.expansion, 77 | kernel_size=1, stride=stride, bias=False), 78 | BatchNorm(planes * block.expansion), 79 | ) 80 | 81 | layers = [] 82 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 83 | self.inplanes = planes * block.expansion 84 | for i in range(1, blocks): 85 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | nn.Conv2d(self.inplanes, planes * block.expansion, 94 | kernel_size=1, stride=stride, bias=False), 95 | BatchNorm(planes * block.expansion), 96 | ) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 100 | downsample=downsample, BatchNorm=BatchNorm)) 101 | self.inplanes = planes * block.expansion 102 | for i in range(1, len(blocks)): 103 | layers.append(block(self.inplanes, planes, stride=1, 104 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, input): 109 | x = self.conv1(input) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.maxpool(x) 113 | 114 | x = self.layer1(x);low_level_feat_1 = x 115 | x = self.layer2(x);low_level_feat_2 = x 116 | x = self.layer3(x);low_level_feat_3 = x 117 | x = self.layer4(x) 118 | 119 | return [x, low_level_feat_3,low_level_feat_2,low_level_feat_1] 120 | 121 | def _init_weight(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | elif isinstance(m, SynchronizedBatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | 134 | 135 | 136 | 137 | def ResNet101(output_stride, BatchNorm,nInputChannels, pretrained=False): 138 | """Constructs a ResNet-101 model. 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | """ 142 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, nInputChannels=nInputChannels,pretrained=pretrained) 143 | return model 144 | 145 | def ResNet50(output_stride, BatchNorm,nInputChannels, pretrained=False): 146 | """Constructs a ResNet-101 model. 147 | Args: 148 | pretrained (bool): If True, returns a model pre-trained on ImageNet 149 | """ 150 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, nInputChannels=nInputChannels,pretrained=pretrained) 151 | return model 152 | 153 | -------------------------------------------------------------------------------- /networks/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def class_cross_entropy_loss(output, label, size_average=False, batch_average=True, void_pixels=None): 5 | assert(output.size() == label.size()) 6 | labels = torch.ge(label, 0.5).float() 7 | num_labels_pos = torch.sum(labels) 8 | num_labels_neg = torch.sum(1.0 - labels) 9 | num_total = num_labels_pos + num_labels_neg 10 | output_gt_zero = torch.ge(output, 0).float() 11 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log( 12 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero))) 13 | if void_pixels is not None: 14 | w_void = torch.le(void_pixels, 0.5).float() 15 | final_loss = torch.mul(w_void, loss_val) 16 | else: 17 | final_loss=loss_val 18 | final_loss = torch.sum(-final_loss) 19 | if size_average: 20 | final_loss /= np.prod(label.size()) 21 | elif batch_average: 22 | final_loss /= label.size()[0] 23 | return final_loss 24 | 25 | -------------------------------------------------------------------------------- /networks/mainnetwork.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mypath import Path 6 | from networks.backbone import build_backbone 7 | from networks.CoarseNet import CoarseNet 8 | from networks.FineNet import FineNet 9 | 10 | affine_par = True 11 | class PSPModule(nn.Module): 12 | """ 13 | Pyramid Scene Parsing module 14 | """ 15 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1): 16 | super(PSPModule, self).__init__() 17 | self.stages = [] 18 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes]) 19 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features) 20 | self.relu = nn.ReLU() 21 | 22 | def _make_stage_1(self, in_features, size): 23 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 24 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False) 25 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par) 26 | relu = nn.ReLU(inplace=True) 27 | return nn.Sequential(prior, conv, bn, relu) 28 | 29 | def _make_stage_2(self, in_features, out_features): 30 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False) 31 | bn = nn.BatchNorm2d(out_features, affine=affine_par) 32 | relu = nn.ReLU(inplace=True) 33 | 34 | return nn.Sequential(conv, bn, relu) 35 | 36 | def forward(self, feats): 37 | h, w = feats.size(2), feats.size(3) 38 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] 39 | priors.append(feats) 40 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1))) 41 | return bottle 42 | 43 | class SegmentationNetwork(nn.Module): 44 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21,nInputChannels=3, 45 | sync_bn=True, freeze_bn=False): 46 | super(SegmentationNetwork, self).__init__() 47 | output_shape = 128 48 | channel_settings = [512, 1024, 512, 256] 49 | self.Coarse_net = CoarseNet(channel_settings, output_shape, num_classes) 50 | self.Fine_net = FineNet(channel_settings[-1], output_shape, num_classes) 51 | BatchNorm = nn.BatchNorm2d 52 | self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained=False) 53 | self.psp4 = PSPModule(in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=256) 54 | self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True) 55 | if freeze_bn: 56 | self.freeze_bn() 57 | 58 | def forward(self, input): 59 | low_level_feat_4, low_level_feat_3,low_level_feat_2,low_level_feat_1 = self.backbone(input) 60 | low_level_feat_4 = self.psp4(low_level_feat_4) 61 | res_out = [low_level_feat_4, low_level_feat_3,low_level_feat_2,low_level_feat_1] 62 | coarse_fms, coarse_outs = self.Coarse_net(res_out) 63 | fine_out = self.Fine_net(coarse_fms) 64 | coarse_outs[0] = self.upsample(coarse_outs[0]) 65 | coarse_outs[1] = self.upsample(coarse_outs[1]) 66 | coarse_outs[2] = self.upsample(coarse_outs[2]) 67 | coarse_outs[3] = self.upsample(coarse_outs[3]) 68 | fine_out = self.upsample(fine_out) 69 | return coarse_outs[0],coarse_outs[1],coarse_outs[2],coarse_outs[3],fine_out 70 | 71 | def freeze_bn(self): 72 | for m in self.modules(): 73 | if isinstance(m, nn.BatchNorm2d): 74 | m.eval() 75 | 76 | def get_1x_lr_params(self): 77 | modules = [self.backbone] 78 | for i in range(len(modules)): 79 | for m in modules[i].named_modules(): 80 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 81 | for p in m[1].parameters(): 82 | if p.requires_grad: 83 | yield p 84 | 85 | def get_10x_lr_params(self): 86 | modules = [self.Coarse_net,self.Fine_net,self.psp4,self.upsample] 87 | for i in range(len(modules)): 88 | for m in modules[i].named_modules(): 89 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 90 | for p in m[1].parameters(): 91 | if p.requires_grad: 92 | yield p 93 | 94 | def Network(nInputChannels=5,num_classes=1,backbone='resnet101',output_stride=16, 95 | sync_bn=None,freeze_bn=False,pretrained=False): 96 | model = SegmentationNetwork(nInputChannels=nInputChannels,num_classes=num_classes,backbone=backbone, 97 | output_stride=output_stride,sync_bn=sync_bn,freeze_bn=freeze_bn) 98 | if pretrained: 99 | load_pth_name= Path.models_dir() 100 | pretrain_dict = torch.load( load_pth_name,map_location=lambda storage, loc: storage) 101 | conv1_weight_new=np.zeros( (64,5,7,7) ) 102 | conv1_weight_new[:,:3,:,:]=pretrain_dict['conv1.weight'].cpu().data 103 | pretrain_dict['conv1.weight']=torch.from_numpy(conv1_weight_new ) 104 | state_dict = model.state_dict() 105 | model_dict = state_dict 106 | for k, v in pretrain_dict.items(): 107 | kk='backbone.'+k 108 | if kk in state_dict: 109 | model_dict[kk] = v 110 | state_dict.update(model_dict) 111 | model.load_state_dict(state_dict) 112 | return model 113 | -------------------------------------------------------------------------------- /networks/refinementnetwork.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import scipy.misc as sm 6 | from mypath import Path 7 | from networks.backbone import build_backbone 8 | from networks.CoarseNet import CoarseNet 9 | from networks.FineNet import FineNet 10 | from dataloaders.helpers import * 11 | affine_par = True 12 | 13 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64): 14 | x = np.arange(0, size[1], 1, float) 15 | y = np.arange(0, size[0], 1, float) 16 | y = y[:, np.newaxis] 17 | if center is None: 18 | x0 = y0 = size[0] // 2 19 | else: 20 | x0 = center[0] 21 | y0 = center[1] 22 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type) 23 | 24 | def getPositon(distance_transform): 25 | a = np.mat(distance_transform) 26 | raw, column = a.shape# get the matrix of a raw and column 27 | _positon = np.argmax(a)# get the index of max in the a 28 | m, n = divmod(_positon, column) 29 | raw=m 30 | column=n 31 | return raw,column 32 | 33 | def generate_distance_map(map_xor,points_center,points_bg,gt): 34 | distance_transform=ndimage.distance_transform_edt(map_xor) 35 | raw,column=getPositon(distance_transform) 36 | gt_0 = np.zeros(shape=(gt.shape[0],gt.shape[0]), dtype=np.float64) 37 | gt_0 [column,raw]= 1 38 | map_center=np.sum(np.logical_and(gt_0 ,gt)) 39 | map_bg=np.sum(np.logical_and(gt_0 ,1-gt)) 40 | sigma=10 41 | if map_center==1: 42 | points_center = 255*np.maximum(points_center/255, make_gaussian((gt.shape[0],gt.shape[0]), center=[column,raw], sigma=sigma)) 43 | elif map_bg==1: 44 | points_bg = 255*np.maximum(points_bg/255, make_gaussian((gt.shape[0],gt.shape[0]), center=[column,raw], sigma=sigma)) 45 | else: 46 | print('error') 47 | pointsgt_new = np.zeros(shape=(gt.shape[0], gt.shape[0], 2)) 48 | pointsgt_new[:, :, 0]=points_center 49 | pointsgt_new[:, :, 1]=points_bg 50 | pointsgt_new = pointsgt_new.astype(dtype=np.uint8) 51 | pointsgt_new = pointsgt_new.transpose((2, 0, 1)) 52 | pointsgt_new = pointsgt_new[np.newaxis,:, :, :] 53 | pointsgt_new = torch.from_numpy(pointsgt_new) 54 | return pointsgt_new 55 | 56 | 57 | def iou_cal( pre, gts,extreme_points,mask_thres=0.5): 58 | iu_ave=0 59 | distance_map_new = torch.zeros(extreme_points.shape) 60 | for jj in range(int(pre.shape[0])): 61 | pred = np.transpose(pre.cpu().data.numpy()[jj, :, :, :], (1, 2, 0)) 62 | pred = 1 / (1 + np.exp(-pred)) 63 | pred = np.squeeze(pred) 64 | gts=gts.cpu() 65 | gt = tens2image(gts[jj, :, :, :]) 66 | extreme_points=extreme_points.cpu() 67 | points_center = tens2image(extreme_points[jj, 0:1, :, :]) 68 | points_bg = tens2image(extreme_points[jj, 1:2, :, :]) 69 | gt = (gt > mask_thres) 70 | pred= (pred > mask_thres) 71 | map_and=np.logical_and(pred ,gt) 72 | map_or=np.logical_or(pred ,gt) 73 | map_xor=np.bitwise_xor(pred,gt) 74 | if np.sum(map_or)==0: 75 | iu=0 76 | else: 77 | iu=np.sum(map_and)/np.sum(map_or) 78 | iu_ave=iu_ave+iu 79 | distance_map_new[jj,:,:,:]=generate_distance_map(map_xor,points_center,points_bg,gt) 80 | iu_ave=iu_ave/pre.shape[0] 81 | distance_map_new = distance_map_new.cuda() 82 | return iu_ave, distance_map_new 83 | 84 | class PSPModule(nn.Module): 85 | """ 86 | Pyramid Scene Parsing module 87 | """ 88 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1): 89 | super(PSPModule, self).__init__() 90 | self.stages = [] 91 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes]) 92 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features) 93 | self.relu = nn.ReLU() 94 | 95 | def _make_stage_1(self, in_features, size): 96 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 97 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False) 98 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par) 99 | relu = nn.ReLU(inplace=True) 100 | return nn.Sequential(prior, conv, bn, relu) 101 | 102 | def _make_stage_2(self, in_features, out_features): 103 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False) 104 | bn = nn.BatchNorm2d(out_features, affine=affine_par) 105 | relu = nn.ReLU(inplace=True) 106 | 107 | return nn.Sequential(conv, bn, relu) 108 | 109 | def forward(self, feats): 110 | h, w = feats.size(2), feats.size(3) 111 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] 112 | priors.append(feats) 113 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1))) 114 | return bottle 115 | 116 | class SegmentationNetwork(nn.Module): 117 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21,nInputChannels=3, 118 | sync_bn=True, freeze_bn=False): 119 | super(SegmentationNetwork, self).__init__() 120 | output_shape = 128 121 | channel_settings = [512, 1024, 512, 256] 122 | self.Coarse_net = CoarseNet(channel_settings, output_shape, num_classes) 123 | self.Fine_net = FineNet(channel_settings[-1], output_shape, num_classes) 124 | BatchNorm = nn.BatchNorm2d 125 | self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained=False) 126 | self.psp4 = PSPModule(in_features=2048+64, out_features=512, sizes=(1, 2, 3, 6), n_classes=256) 127 | self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True) 128 | self.iog_points = nn.Sequential(nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=1, bias=False), 129 | nn.BatchNorm2d(64), 130 | nn.ReLU(), 131 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), 132 | nn.BatchNorm2d(128), 133 | nn.ReLU(), 134 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), 135 | nn.BatchNorm2d(256), 136 | nn.ReLU(), 137 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False), 138 | nn.BatchNorm2d(256), 139 | nn.ReLU(), 140 | nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False), 141 | nn.BatchNorm2d(64), 142 | nn.ReLU()) 143 | 144 | if freeze_bn: 145 | self.freeze_bn() 146 | 147 | def forward(self, input,IOG_points,gts,refinement_num_max): 148 | low_level_feat_4_orig, low_level_feat_3_orig,low_level_feat_2_orig,low_level_feat_1_orig = self.backbone(input) 149 | feats_orig=low_level_feat_4_orig 150 | outlist=[] 151 | distance_map=IOG_points 152 | distance_map_512=distance_map 153 | for refinement_num in range(0,refinement_num_max): 154 | distance_map = self.iog_points(distance_map) 155 | feats_concat=torch.cat((feats_orig,distance_map),dim=1)#2048+64 156 | 157 | low_level_feat_4 = self.psp4(feats_concat) 158 | res_out = [low_level_feat_4, low_level_feat_3_orig,low_level_feat_2_orig,low_level_feat_1_orig] 159 | coarse_fms, coarse_outs = self.Coarse_net(res_out) 160 | fine_out = self.Fine_net(coarse_fms) 161 | 162 | out_512 = F.upsample(fine_out,size=(512, 512), mode='bilinear', align_corners=True) 163 | iou_i,distance_map_new = iou_cal(out_512,gts,distance_map_512) 164 | distance_map=distance_map_new 165 | distance_map_512 = distance_map 166 | out = [coarse_outs[0],coarse_outs[1],coarse_outs[2],coarse_outs[3],fine_out,iou_i] 167 | outlist.append(out) 168 | return outlist 169 | 170 | def freeze_bn(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.BatchNorm2d): 173 | m.eval() 174 | 175 | def get_1x_lr_params(self): 176 | modules = [self.backbone] 177 | for i in range(len(modules)): 178 | for m in modules[i].named_modules(): 179 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 180 | for p in m[1].parameters(): 181 | if p.requires_grad: 182 | yield p 183 | 184 | def get_10x_lr_params(self): 185 | modules = [self.Coarse_net,self.Fine_net,self.psp4,self.upsample,self.iog_points] 186 | for i in range(len(modules)): 187 | for m in modules[i].named_modules(): 188 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 189 | for p in m[1].parameters(): 190 | if p.requires_grad: 191 | yield p 192 | 193 | def Network(nInputChannels=5,num_classes=1,backbone='resnet101',output_stride=16, 194 | sync_bn=None,freeze_bn=False,pretrained=False): 195 | model = SegmentationNetwork(nInputChannels=nInputChannels,num_classes=num_classes,backbone=backbone, 196 | output_stride=output_stride,sync_bn=sync_bn,freeze_bn=freeze_bn) 197 | if pretrained: 198 | load_pth_name= Path.models_dir() 199 | pretrain_dict = torch.load( load_pth_name,map_location=lambda storage, loc: storage) 200 | conv1_weight_new=np.zeros( (64,5,7,7) ) 201 | conv1_weight_new[:,:3,:,:]=pretrain_dict['conv1.weight'].cpu().data 202 | pretrain_dict['conv1.weight']=torch.from_numpy(conv1_weight_new ) 203 | state_dict = model.state_dict() 204 | model_dict = state_dict 205 | for k, v in pretrain_dict.items(): 206 | kk='backbone.'+k 207 | if kk in state_dict: 208 | model_dict[kk] = v 209 | state_dict.update(model_dict) 210 | model.load_state_dict(state_dict) 211 | return model 212 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/comm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/comm.cpython-35.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/replicate.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/replicate.cpython-35.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import scipy.misc as sm 3 | from collections import OrderedDict 4 | import glob 5 | import numpy as np 6 | import socket 7 | 8 | # PyTorch includes 9 | import torch 10 | import torch.optim as optim 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | 14 | # Custom includes 15 | from dataloaders.combine_dbs import CombineDBs as combine_dbs 16 | import dataloaders.pascal as pascal 17 | import dataloaders.sbd as sbd 18 | from dataloaders import custom_transforms as tr 19 | from networks.loss import class_cross_entropy_loss 20 | from dataloaders.helpers import * 21 | from networks.mainnetwork import * 22 | 23 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu 24 | gpu_id = 0 25 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 26 | if torch.cuda.is_available(): 27 | print('Using GPU: {} '.format(gpu_id)) 28 | 29 | # Setting parameters 30 | resume_epoch = 100 # test epoch 31 | nInputChannels = 5 # Number of input channels (RGB + heatmap of IOG points) 32 | 33 | # Results and model directories (a new directory is generated for every run) 34 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 35 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] 36 | if resume_epoch == 0: 37 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 38 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 39 | else: 40 | run_id = 0 41 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 42 | if not os.path.exists(os.path.join(save_dir, 'models')): 43 | os.makedirs(os.path.join(save_dir, 'models')) 44 | 45 | # Network definition 46 | modelName = 'IOG_pascal' 47 | net = Network(nInputChannels=nInputChannels,num_classes=1, 48 | backbone='resnet101', 49 | output_stride=16, 50 | sync_bn=None, 51 | freeze_bn=False) 52 | 53 | # load pretrain_dict 54 | pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')) 55 | print("Initializing weights from: {}".format( 56 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) 57 | net.load_state_dict(pretrain_dict) 58 | net.to(device) 59 | 60 | # Generate result of the validation images 61 | net.eval() 62 | composed_transforms_ts = transforms.Compose([ 63 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 64 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 65 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 66 | tr.ToImage(norm_elem='IOG_points'), 67 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 68 | tr.ToTensor()]) 69 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True) 70 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 71 | 72 | save_dir_res = os.path.join(save_dir, 'Results') 73 | if not os.path.exists(save_dir_res): 74 | os.makedirs(save_dir_res) 75 | save_dir_res_list=[save_dir_res] 76 | print('Testing Network') 77 | with torch.no_grad(): 78 | for ii, sample_batched in enumerate(testloader): 79 | inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta'] 80 | inputs = inputs.to(device) 81 | coarse_outs1,coarse_outs2,coarse_outs3,coarse_outs4,fine_out = net.forward(inputs) 82 | outputs = fine_out.to(torch.device('cpu')) 83 | pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0)) 84 | pred = 1 / (1 + np.exp(-pred)) 85 | pred = np.squeeze(pred) 86 | gt = tens2image(gts[0, :, :, :]) 87 | bbox = get_bbox(gt, pad=30, zero_pad=True) 88 | result = crop2fullmask(pred, bbox, gt, zero_pad=True, relax=0,mask_relax=False) 89 | sm.imsave(os.path.join(save_dir_res_list[0], metas['image'][0] + '-' + metas['object'][0] + '.png'), result) 90 | -------------------------------------------------------------------------------- /test_refine.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import scipy.misc as sm 3 | from collections import OrderedDict 4 | import glob 5 | import numpy as np 6 | import socket 7 | import timeit 8 | 9 | # PyTorch includes 10 | import torch 11 | import torch.optim as optim 12 | from torchvision import transforms 13 | from torch.utils.data import DataLoader 14 | 15 | # Custom includes 16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs 17 | import dataloaders.pascal as pascal 18 | import dataloaders.sbd as sbd 19 | from dataloaders import custom_transforms as tr 20 | from dataloaders.helpers import * 21 | from networks.loss import class_cross_entropy_loss 22 | from networks.refinementnetwork import * 23 | from torch.nn.functional import upsample 24 | 25 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu 26 | gpu_id = 0 27 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 28 | if torch.cuda.is_available(): 29 | print('Using GPU: {} '.format(gpu_id)) 30 | 31 | # Setting parameters 32 | resume_epoch = 100 # test epoch 33 | nInputChannels = 5 # Number of input channels (RGB + heatmap of IOG points) 34 | refinement_num_max = 2 # the number of new points: 35 | 36 | # Results and model directories (a new directory is generated for every run) 37 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 38 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] 39 | if resume_epoch == 0: 40 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 41 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 42 | else: 43 | run_id = 0 44 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 45 | if not os.path.exists(os.path.join(save_dir, 'models')): 46 | os.makedirs(os.path.join(save_dir, 'models')) 47 | 48 | # Network definition 49 | modelName = 'IOG_pascal_refinement' 50 | net = Network(nInputChannels=nInputChannels,num_classes=1, 51 | backbone='resnet101', 52 | output_stride=16, 53 | sync_bn=None, 54 | freeze_bn=False) 55 | 56 | # load pretrain_dict 57 | pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')) 58 | print("Initializing weights from: {}".format( 59 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) 60 | net.load_state_dict(pretrain_dict) 61 | net.to(device) 62 | 63 | # Generate result of the validation images 64 | net.eval() 65 | composed_transforms_ts = transforms.Compose([ 66 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 67 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 68 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 69 | tr.ToImage(norm_elem='IOG_points'), 70 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 71 | tr.ToTensor()]) 72 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True) 73 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 74 | 75 | save_dir_res_list=[] 76 | for add_clicks in range(0,refinement_num_max+1): 77 | save_dir_res = os.path.join(save_dir, 'Results-'+str(add_clicks)) 78 | if not os.path.exists(save_dir_res): 79 | os.makedirs(save_dir_res) 80 | save_dir_res_list.append(save_dir_res) 81 | 82 | print('Testing Network') 83 | with torch.no_grad(): 84 | # Main Testing Loop 85 | for ii, sample_batched in enumerate(testloader): 86 | metas = sample_batched['meta'] 87 | gts = sample_batched['gt'] 88 | gts_crop = sample_batched['crop_gt'] 89 | inputs = sample_batched['concat'] 90 | void_pixels = sample_batched['crop_void_pixels'] 91 | IOG_points = sample_batched['IOG_points'] 92 | inputs.requires_grad_() 93 | inputs, gts_crop ,void_pixels,IOG_points = inputs.to(device), gts_crop.to(device), void_pixels.to(device), IOG_points.to(device) 94 | out = net.forward(inputs,IOG_points,gts_crop,refinement_num_max+1) 95 | for i in range(0,refinement_num_max+1): 96 | glo1,glo2,glo3,glo4,refine,iou_i=out[i] 97 | output_refine = upsample(refine, size=(512, 512), mode='bilinear', align_corners=True) 98 | outputs = output_refine.to(torch.device('cpu')) 99 | pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0)) 100 | pred = 1 / (1 + np.exp(-pred)) 101 | pred = np.squeeze(pred) 102 | gt = tens2image(gts[0, :, :, :]) 103 | bbox = get_bbox(gt, pad=30, zero_pad=True) 104 | result = crop2fullmask(pred, bbox, gt, zero_pad=True, relax=0,mask_relax=False) 105 | 106 | # Save the result, attention to the index 107 | sm.imsave(os.path.join(save_dir_res_list[i], metas['image'][0] + '-' + metas['object'][0] + '.png'), result) 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import scipy.misc as sm 3 | from collections import OrderedDict 4 | import glob 5 | import numpy as np 6 | import socket 7 | import timeit 8 | 9 | # PyTorch includes 10 | import torch 11 | import torch.optim as optim 12 | from torchvision import transforms 13 | from torch.utils.data import DataLoader 14 | 15 | # Custom includes 16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs 17 | import dataloaders.pascal as pascal 18 | import dataloaders.sbd as sbd 19 | from dataloaders import custom_transforms as tr 20 | from dataloaders.helpers import * 21 | from networks.loss import class_cross_entropy_loss 22 | from networks.mainnetwork import * 23 | 24 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu 25 | gpu_id = 0 26 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 27 | if torch.cuda.is_available(): 28 | print('Using GPU: {} '.format(gpu_id)) 29 | 30 | # Setting parameters 31 | use_sbd = False # train with SBD 32 | nEpochs = 100 # Number of epochs for training 33 | resume_epoch = 0 # Default is 0, change if want to resume 34 | p = OrderedDict() # Parameters to include in report 35 | p['trainBatch'] = 5 # Training batch size 5 36 | snapshot = 10 # Store a model every snapshot epochs 37 | nInputChannels = 5 # Number of input channels (RGB + heatmap of extreme points) 38 | p['nAveGrad'] = 1 # Average the gradient of several iterations 39 | p['lr'] = 1e-8 # Learning rate 40 | p['wd'] = 0.0005 # Weight decay 41 | p['momentum'] = 0.9 # Momentum 42 | 43 | # Results and model directories (a new directory is generated for every run) 44 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 45 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] 46 | if resume_epoch == 0: 47 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 48 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 49 | else: 50 | run_id = 0 51 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 52 | if not os.path.exists(os.path.join(save_dir, 'models')): 53 | os.makedirs(os.path.join(save_dir, 'models')) 54 | 55 | # Network definition 56 | modelName = 'IOG_pascal' 57 | net = Network(nInputChannels=nInputChannels,num_classes=1, 58 | backbone='resnet101', 59 | output_stride=16, 60 | sync_bn=None, 61 | freeze_bn=False, 62 | pretrained=True) 63 | if resume_epoch == 0: 64 | print("Initializing from pretrained model") 65 | else: 66 | print("Initializing weights from: {}".format( 67 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) 68 | net.load_state_dict( 69 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'), 70 | map_location=lambda storage, loc: storage)) 71 | train_params = [{'params': net.get_1x_lr_params(), 'lr': p['lr']}, 72 | {'params': net.get_10x_lr_params(), 'lr': p['lr'] * 10}] 73 | net.to(device) 74 | 75 | if resume_epoch != nEpochs: 76 | # Logging into Tensorboard 77 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 78 | 79 | # Use the following optimizer 80 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) 81 | p['optimizer'] = str(optimizer) 82 | 83 | # Preparation of the data loaders 84 | composed_transforms_tr = transforms.Compose([ 85 | tr.RandomHorizontalFlip(), 86 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), 87 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 88 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 89 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 90 | tr.ToImage(norm_elem='IOG_points'), 91 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 92 | tr.ToTensor()]) 93 | 94 | composed_transforms_ts = transforms.Compose([ 95 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 96 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 97 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 98 | tr.ToImage(norm_elem='IOG_points'), 99 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 100 | tr.ToTensor()]) 101 | 102 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr) 103 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts) 104 | if use_sbd: 105 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True) 106 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val]) 107 | else: 108 | db_train = voc_train 109 | 110 | p['dataset_train'] = str(db_train) 111 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms] 112 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2) 113 | 114 | # Train variables 115 | num_img_tr = len(trainloader) 116 | running_loss_tr = 0.0 117 | aveGrad = 0 118 | print("Training Network") 119 | for epoch in range(resume_epoch, nEpochs): 120 | start_time = timeit.default_timer() 121 | epoch_loss = [] 122 | net.train() 123 | for ii, sample_batched in enumerate(trainloader): 124 | gts = sample_batched['crop_gt'] 125 | inputs = sample_batched['concat'] 126 | void_pixels = sample_batched['crop_void_pixels'] 127 | inputs.requires_grad_() 128 | inputs, gts ,void_pixels = inputs.to(device), gts.to(device), void_pixels.to(device) 129 | coarse_outs1,coarse_outs2,coarse_outs3,coarse_outs4,fine_out = net.forward(inputs) 130 | 131 | # Compute the losses 132 | loss_coarse_outs1 = class_cross_entropy_loss(coarse_outs1, gts, void_pixels=void_pixels) 133 | loss_coarse_outs2 = class_cross_entropy_loss(coarse_outs2, gts, void_pixels=void_pixels) 134 | loss_coarse_outs3 = class_cross_entropy_loss(coarse_outs3, gts, void_pixels=void_pixels) 135 | loss_coarse_outs4 = class_cross_entropy_loss(coarse_outs4, gts, void_pixels=void_pixels) 136 | loss_fine_out = class_cross_entropy_loss(fine_out, gts, void_pixels=void_pixels) 137 | loss = loss_coarse_outs1+loss_coarse_outs2+ loss_coarse_outs3+loss_coarse_outs4+loss_fine_out 138 | 139 | if ii % 10 ==0: 140 | print('Epoch',epoch,'step',ii,'loss',loss) 141 | running_loss_tr += loss.item() 142 | 143 | # Print stuff 144 | if ii % num_img_tr == num_img_tr - 1 -p['trainBatch']: 145 | running_loss_tr = running_loss_tr / num_img_tr 146 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0])) 147 | print('Loss: %f' % running_loss_tr) 148 | running_loss_tr = 0 149 | stop_time = timeit.default_timer() 150 | print("Execution time: " + str(stop_time - start_time)+"\n") 151 | 152 | # Backward the averaged gradient 153 | loss /= p['nAveGrad'] 154 | loss.backward() 155 | aveGrad += 1 156 | 157 | # Update the weights once in p['nAveGrad'] forward passes 158 | if aveGrad % p['nAveGrad'] == 0: 159 | optimizer.step() 160 | optimizer.zero_grad() 161 | aveGrad = 0 162 | 163 | # Save the model 164 | if (epoch % snapshot) == snapshot - 1 and epoch != 0: 165 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')) 166 | -------------------------------------------------------------------------------- /train_refine.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import scipy.misc as sm 3 | from collections import OrderedDict 4 | import glob 5 | import numpy as np 6 | import socket 7 | import timeit 8 | 9 | # PyTorch includes 10 | import torch 11 | import torch.optim as optim 12 | from torchvision import transforms 13 | from torch.utils.data import DataLoader 14 | 15 | # Custom includes 16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs 17 | import dataloaders.pascal as pascal 18 | import dataloaders.sbd as sbd 19 | from dataloaders import custom_transforms as tr 20 | from dataloaders.helpers import * 21 | from networks.loss import class_cross_entropy_loss 22 | from networks.refinementnetwork import * 23 | from torch.nn.functional import upsample 24 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu 25 | gpu_id = 0 26 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 27 | if torch.cuda.is_available(): 28 | print('Using GPU: {} '.format(gpu_id)) 29 | 30 | # Setting parameters 31 | use_sbd = False # train with SBD 32 | nEpochs = 100 # Number of epochs for training 33 | resume_epoch = 0 # Default is 0, change if want to resume 34 | p = OrderedDict() # Parameters to include in report 35 | p['trainBatch'] = 2 # Training batch size 5 36 | snapshot = 10 # Store a model every snapshot epochs 37 | nInputChannels = 5 # Number of input channels (RGB + heatmap of extreme points) 38 | p['nAveGrad'] = 1 # Average the gradient of several iterations 39 | p['lr'] = 1e-8 # Learning rate 40 | p['wd'] = 0.0005 # Weight decay 41 | p['momentum'] = 0.9 # Momentum 42 | threshold=0.95 # loss 43 | refinement_num_max = 1 # the number of new points: 44 | # Results and model directories (a new directory is generated for every run) 45 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 46 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] 47 | if resume_epoch == 0: 48 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 49 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 50 | else: 51 | run_id = 0 52 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 53 | if not os.path.exists(os.path.join(save_dir, 'models')): 54 | os.makedirs(os.path.join(save_dir, 'models')) 55 | 56 | # Network definition 57 | modelName = 'IOG_pascal_refinement' 58 | net = Network(nInputChannels=nInputChannels,num_classes=1, 59 | backbone='resnet101', 60 | output_stride=16, 61 | sync_bn=None, 62 | freeze_bn=False, 63 | pretrained=True) 64 | if resume_epoch == 0: 65 | print("Initializing from pretrained model") 66 | else: 67 | print("Initializing weights from: {}".format( 68 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) 69 | net.load_state_dict( 70 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'), 71 | map_location=lambda storage, loc: storage)) 72 | train_params = [{'params': net.get_1x_lr_params(), 'lr': p['lr']}, 73 | {'params': net.get_10x_lr_params(), 'lr': p['lr'] * 10}] 74 | net.to(device) 75 | 76 | if resume_epoch != nEpochs: 77 | # Logging into Tensorboard 78 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 79 | 80 | # Use the following optimizer 81 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) 82 | p['optimizer'] = str(optimizer) 83 | 84 | # Preparation of the data loaders 85 | composed_transforms_tr = transforms.Compose([ 86 | tr.RandomHorizontalFlip(), 87 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), 88 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 89 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 90 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 91 | tr.ToImage(norm_elem='IOG_points'), 92 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 93 | tr.ToTensor()]) 94 | 95 | composed_transforms_ts = transforms.Compose([ 96 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True), 97 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}), 98 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10), 99 | tr.ToImage(norm_elem='IOG_points'), 100 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')), 101 | tr.ToTensor()]) 102 | 103 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr) 104 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts) 105 | if use_sbd: 106 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True) 107 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val]) 108 | else: 109 | db_train = voc_train 110 | 111 | p['dataset_train'] = str(db_train) 112 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms] 113 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2) 114 | 115 | # Train variables 116 | num_img_tr = len(trainloader) 117 | running_loss_tr = 0.0 118 | aveGrad = 0 119 | print("Training Network") 120 | for epoch in range(resume_epoch, nEpochs): 121 | start_time = timeit.default_timer() 122 | epoch_loss = [] 123 | net.train() 124 | for ii, sample_batched in enumerate(trainloader): 125 | gts = sample_batched['crop_gt'] 126 | inputs = sample_batched['concat'] 127 | void_pixels = sample_batched['crop_void_pixels'] 128 | IOG_points = sample_batched['IOG_points'] 129 | inputs.requires_grad_() 130 | inputs, gts ,void_pixels,IOG_points = inputs.to(device), gts.to(device), void_pixels.to(device), IOG_points.to(device) 131 | out = net.forward(inputs,IOG_points,gts,refinement_num_max+1) 132 | for i in range(0,refinement_num_max+1): 133 | glo1,glo2,glo3,glo4,refine,iou_i=out[i] 134 | output_glo1 = upsample(glo1, size=(512, 512), mode='bilinear', align_corners=True) 135 | output_glo2 = upsample(glo2, size=(512, 512), mode='bilinear', align_corners=True) 136 | output_glo3 = upsample(glo3, size=(512, 512), mode='bilinear', align_corners=True) 137 | output_glo4 = upsample(glo4, size=(512, 512), mode='bilinear', align_corners=True) 138 | output_refine = upsample(refine, size=(512, 512), mode='bilinear', align_corners=True) 139 | 140 | # Compute the losses, side outputs and fuse 141 | loss_output_glo1 = class_cross_entropy_loss(output_glo1, gts, void_pixels=void_pixels,size_average=False, batch_average=True) 142 | loss_output_glo2 = class_cross_entropy_loss(output_glo2, gts, void_pixels=void_pixels,size_average=False, batch_average=True) 143 | loss_output_glo3 = class_cross_entropy_loss(output_glo3, gts, void_pixels=void_pixels,size_average=False, batch_average=True) 144 | 145 | loss_output_glo4 = class_cross_entropy_loss(output_glo4, gts, void_pixels=void_pixels,size_average=False, batch_average=True) 146 | loss_output_refine = class_cross_entropy_loss(output_refine, gts, void_pixels=void_pixels,size_average=False, batch_average=True) 147 | 148 | if i ==0: 149 | loss1 = loss_output_glo1+loss_output_glo2+ loss_output_glo3+loss_output_glo4+loss_output_glo4+loss_output_refine 150 | iou1 = iou_i 151 | if i ==1: 152 | loss2 = loss_output_glo1+loss_output_glo2+ loss_output_glo3+loss_output_glo4+loss_output_glo4+loss_output_refine 153 | iou2 = iou_i 154 | 155 | if iou1>=threshold: 156 | loss=loss1 157 | else: 158 | loss=0.5*loss1+0.5*loss2 159 | 160 | if ii % 10 ==0: 161 | print('Epoch',epoch,'step',ii,'loss',loss) 162 | running_loss_tr += loss.item() 163 | 164 | # Print stuff 165 | if ii % num_img_tr == num_img_tr - 1 -p['trainBatch']: 166 | running_loss_tr = running_loss_tr / num_img_tr 167 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0])) 168 | print('Loss: %f' % running_loss_tr) 169 | running_loss_tr = 0 170 | stop_time = timeit.default_timer() 171 | print("Execution time: " + str(stop_time - start_time)+"\n") 172 | 173 | # Backward the averaged gradient 174 | loss /= p['nAveGrad'] 175 | loss.backward() 176 | aveGrad += 1 177 | 178 | # Update the weights once in p['nAveGrad'] forward passes 179 | if aveGrad % p['nAveGrad'] == 0: 180 | optimizer.step() 181 | optimizer.zero_grad() 182 | aveGrad = 0 183 | 184 | # Save the model 185 | if (epoch % snapshot) == snapshot - 1 and epoch != 0: 186 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')) 187 | --------------------------------------------------------------------------------