├── README.md
├── dataloaders
├── __init__.py
├── combine_dbs.py
├── custom_transforms.py
├── helpers.py
├── pascal.py
└── sbd.py
├── eval.py
├── evaluation
├── __init__.py
├── eval.py
└── evaluation.py
├── ims
├── IOG.gif
├── cross_domain.gif
├── ims.png
└── refinement.gif
├── mypath.py
├── networks
├── CoarseNet.py
├── FineNet.py
├── __init__.py
├── backbone
│ ├── __init__.py
│ └── resnet.py
├── loss.py
├── mainnetwork.py
├── refinementnetwork.py
└── sync_batchnorm
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── batchnorm.cpython-35.pyc
│ ├── comm.cpython-35.pyc
│ └── replicate.cpython-35.pyc
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── test.py
├── test_refine.py
├── train.py
└── train_refine.py
/README.md:
--------------------------------------------------------------------------------
1 | # Inside-Outside-Guidance (IOG)
2 | This project hosts the code for the IOG algorithms for interactive segmentation.
3 | > [Interactive Object Segmentation with Inside-Outside Guidance](http://openaccess.thecvf.com/content_CVPR_2020/papers/Zhang_Interactive_Object_Segmentation_With_Inside-Outside_Guidance_CVPR_2020_paper.pdf)
4 | > Shiyin Zhang, Jun Hao Liew, Yunchao Wei, Shikui Wei, Yao Zhao
5 |
6 | **Updates:**
7 | - 2021.4.6 Create the interactive refinement branch for IOG.
8 |
9 | 
10 |
11 | ### Abstract
12 | This paper explores how to harvest precise object segmentation masks while minimizing the human interaction cost. To achieve this, we propose an Inside-Outside Guidance (IOG) approach in this work. Concretely, we leverage an inside point that is clicked near the object center and two outside points at the symmetrical corner locations (top-left and bottom-right or top-right and bottom-left) of a tight bounding box that encloses the target object. This results in a total of one foreground click and four background clicks for segmentation. Our IOG not only achieves state-of-the-art performance on several popular benchmarks, but also demonstrates strong generalization capability across different domains such as street scenes, aerial imagery and medical images, without fine-tuning. In addition, we also propose a simple two-stage solution that enables our IOG to produce high quality instance segmentation masks from existing datasets with off-the-shelf bounding boxes such as ImageNet and Open Images, demonstrating the superiority of our IOG as an annotation tool.
13 |
14 | ### Demo
15 |
16 |
17 |
18 |
19 |
20 | |
21 |
22 |
23 | |
24 |
25 |
26 | |
27 |
28 |
29 |
30 | IOG(3 points)
31 | |
32 |
33 | IOG(Refinement)
34 | |
35 |
36 | IOG(Cross domain)
37 | |
38 |
39 |
40 |
41 |
42 | ### Installation
43 | 1. Install requirement
44 | - PyTorch = 0.4
45 | - python >= 3.5
46 | - torchvision = 0.2
47 | - pycocotools
48 | 2. Usage
49 | You can start training with the following commands:
50 | ```
51 | # training step
52 | python train.py
53 | python train_refinement.py
54 |
55 | # testing step
56 | python test.py
57 | python test_refinement.py
58 |
59 | # train step
60 | python eval.py
61 | python eval_refinement.py
62 | ```
63 | We set the paths of PASCAL/SBD dataset and pretrained model in mypath.py.
64 |
65 | ### Pretrained models
66 | | Network |Dataset | Backbone | Download Link |
67 | |---------|---------|-------------|:-------------------------:|
68 | |IOG |PASCAL + SBD | ResNet-101 | [IOG_PASCAL_SBD.pth](https://drive.google.com/file/d/1Lm1hhMhhjjnNwO4Pf7SC6tXLayH2iH0l/view?usp=sharing) |
69 | |IOG |PASCAL | ResNet-101 | [IOG_PASCAL.pth](https://drive.google.com/file/d/1GLZIQlQ-3KUWaGTQ1g_InVcqesGfGcpW/view?usp=sharing) |
70 | |IOG-Refinement |PASCAL + SBD | ResNet-101 | [IOG_PASCAL_SBD_REFINEMENT.pth](https://drive.google.com/file/d/1VdOFUZZbtbYt9aIMugKhMKDA6EuqKG30/view?usp=sharing) |
71 |
72 | ### Dataset
73 | With the annotated bounding boxes (∼0.615M) of ILSVRCLOC, we apply our IOG to collect their pixel-level annotations, named Pixel-ImageNet, which are publicly available at https://github.com/shiyinzhang/Pixel-ImageNet.
74 | ### Citations
75 | Please consider citing our papers in your publications if it helps your research. The following is a BibTeX reference.
76 |
77 | @inproceedings{zhang2020interactive,
78 | title={Interactive Object Segmentation With Inside-Outside Guidance},
79 | author={Zhang, Shiyin and Liew, Jun Hao and Wei, Yunchao and Wei, Shikui and Zhao, Yao},
80 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
81 | pages={12234--12244},
82 | year={2020}
83 | }
84 |
85 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/dataloaders/__init__.py
--------------------------------------------------------------------------------
/dataloaders/combine_dbs.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 |
4 | class CombineDBs(data.Dataset):
5 | def __init__(self, dataloaders, excluded=None):
6 | self.dataloaders = dataloaders
7 | self.excluded = excluded
8 | self.im_ids = []
9 |
10 | # Combine object lists
11 | for dl in dataloaders:
12 | for elem in dl.im_ids:
13 | if elem not in self.im_ids:
14 | self.im_ids.append(elem)
15 |
16 | # Exclude
17 | if excluded:
18 | for dl in excluded:
19 | for elem in dl.im_ids:
20 | if elem in self.im_ids:
21 | self.im_ids.remove(elem)
22 |
23 | # Get object pointers
24 | self.obj_list = []
25 | self.im_list = []
26 | new_im_ids = []
27 | obj_counter = 0
28 | num_images = 0
29 | for ii, dl in enumerate(dataloaders):
30 | for jj, curr_im_id in enumerate(dl.im_ids):
31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids):
32 | flag = False
33 | new_im_ids.append(curr_im_id)
34 | for kk in range(len(dl.obj_dict[curr_im_id])):
35 | if dl.obj_dict[curr_im_id][kk] != -1:
36 | self.obj_list.append({'db_ii': ii, 'obj_ii': dl.obj_list.index([jj, kk])})
37 | flag = True
38 | obj_counter += 1
39 | self.im_list.append({'db_ii': ii, 'im_ii': jj})
40 | if flag:
41 | num_images += 1
42 |
43 | self.im_ids = new_im_ids
44 | print('Combined number of images: {:d}\nCombined number of objects: {:d}'.format(num_images, len(self.obj_list)))
45 |
46 | def __getitem__(self, index):
47 |
48 | _db_ii = self.obj_list[index]["db_ii"]
49 | _obj_ii = self.obj_list[index]['obj_ii']
50 | sample = self.dataloaders[_db_ii].__getitem__(_obj_ii)
51 |
52 | if 'meta' in sample.keys():
53 | sample['meta']['db'] = str(self.dataloaders[_db_ii])
54 |
55 | return sample
56 |
57 | def __len__(self):
58 | return len(self.obj_list)
59 |
60 | def __str__(self):
61 | include_db = [str(db) for db in self.dataloaders]
62 | exclude_db = [str(db) for db in self.excluded]
63 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db)
64 |
--------------------------------------------------------------------------------
/dataloaders/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | import numpy.random as random
3 | import numpy as np
4 | import dataloaders.helpers as helpers
5 | import scipy.misc as sm
6 | from dataloaders.helpers import *
7 |
8 | class ScaleNRotate(object):
9 | """Scale (zoom-in, zoom-out) and Rotate the image and the ground truth.
10 | Args:
11 | two possibilities:
12 | 1. rots (tuple): (minimum, maximum) rotation angle
13 | scales (tuple): (minimum, maximum) scale
14 | 2. rots [list]: list of fixed possible rotation angles
15 | scales [list]: list of fixed possible scales
16 | """
17 | def __init__(self, rots=(-30, 30), scales=(.75, 1.25), semseg=False):
18 | assert (isinstance(rots, type(scales)))
19 | self.rots = rots
20 | self.scales = scales
21 | self.semseg = semseg
22 |
23 | def __call__(self, sample):
24 |
25 | if type(self.rots) == tuple:
26 | # Continuous range of scales and rotations
27 | rot = (self.rots[1] - self.rots[0]) * random.random() - \
28 | (self.rots[1] - self.rots[0])/2
29 |
30 | sc = (self.scales[1] - self.scales[0]) * random.random() - \
31 | (self.scales[1] - self.scales[0]) / 2 + 1
32 | elif type(self.rots) == list:
33 | # Fixed range of scales and rotations
34 | rot = self.rots[random.randint(0, len(self.rots))]
35 | sc = self.scales[random.randint(0, len(self.scales))]
36 |
37 | for elem in sample.keys():
38 | if 'meta' in elem:
39 | continue
40 |
41 | tmp = sample[elem]
42 |
43 | h, w = tmp.shape[:2]
44 | center = (w / 2, h / 2)
45 | assert(center != 0) # Strange behaviour warpAffine
46 | M = cv2.getRotationMatrix2D(center, rot, sc)
47 |
48 | if ((tmp == 0) | (tmp == 1)).all():
49 | flagval = cv2.INTER_NEAREST
50 | elif 'gt' in elem and self.semseg:
51 | flagval = cv2.INTER_NEAREST
52 | else:
53 | flagval = cv2.INTER_CUBIC
54 | tmp = cv2.warpAffine(tmp, M, (w, h), flags=flagval)
55 |
56 | sample[elem] = tmp
57 |
58 | return sample
59 |
60 | def __str__(self):
61 | return 'ScaleNRotate:(rot='+str(self.rots)+',scale='+str(self.scales)+')'
62 |
63 |
64 | class FixedResize(object):
65 | """Resize the image and the ground truth to specified resolution.
66 | Args:
67 | resolutions (dict): the list of resolutions
68 | """
69 | def __init__(self, resolutions=None, flagvals=None):
70 | self.resolutions = resolutions
71 | self.flagvals = flagvals
72 | if self.flagvals is not None:
73 | assert(len(self.resolutions) == len(self.flagvals))
74 |
75 | def __call__(self, sample):
76 |
77 | # Fixed range of scales
78 | if self.resolutions is None:
79 | return sample
80 |
81 | elems = list(sample.keys())
82 |
83 | for elem in elems:
84 |
85 | if 'meta' in elem or 'bbox' in elem or ('extreme_points_coord' in elem and elem not in self.resolutions):
86 | continue
87 | if 'extreme_points_coord' in elem and elem in self.resolutions:
88 | bbox = sample['bbox']
89 | crop_size = np.array([bbox[3]-bbox[1]+1, bbox[4]-bbox[2]+1])
90 | res = np.array(self.resolutions[elem]).astype(np.float32)
91 | sample[elem] = np.round(sample[elem]*res/crop_size).astype(np.int)
92 | continue
93 | if elem in self.resolutions:
94 | if self.resolutions[elem] is None:
95 | continue
96 | if isinstance(sample[elem], list):
97 | if sample[elem][0].ndim == 3:
98 | output_size = np.append(self.resolutions[elem], [3, len(sample[elem])])
99 | else:
100 | output_size = np.append(self.resolutions[elem], len(sample[elem]))
101 | tmp = sample[elem]
102 | sample[elem] = np.zeros(output_size, dtype=np.float32)
103 | for ii, crop in enumerate(tmp):
104 | if self.flagvals is None:
105 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem])
106 | else:
107 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem], flagval=self.flagvals[elem])
108 | else:
109 | if self.flagvals is None:
110 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem])
111 | else:
112 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem], flagval=self.flagvals[elem])
113 | else:
114 | del sample[elem]
115 |
116 | return sample
117 |
118 | def __str__(self):
119 | return 'FixedResize:'+str(self.resolutions)
120 |
121 |
122 | class RandomHorizontalFlip(object):
123 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
124 |
125 | def __call__(self, sample):
126 |
127 | if random.random() < 0.5:
128 | for elem in sample.keys():
129 | if 'meta' in elem:
130 | continue
131 | tmp = sample[elem]
132 | tmp = cv2.flip(tmp, flipCode=1)
133 | sample[elem] = tmp
134 |
135 | return sample
136 |
137 | def __str__(self):
138 | return 'RandomHorizontalFlip'
139 |
140 |
141 | class IOGPoints(object):
142 | """
143 | Returns the IOG Points (top-left and bottom-right or top-right and bottom-left) in a given binary mask
144 | sigma: sigma of Gaussian to create a heatmap from a point
145 | pad_pixel: number of pixels fo the maximum perturbation
146 | elem: which element of the sample to choose as the binary mask
147 | """
148 | def __init__(self, sigma=10, elem='crop_gt',pad_pixel =10):
149 | self.sigma = sigma
150 | self.elem = elem
151 | self.pad_pixel =pad_pixel
152 |
153 | def __call__(self, sample):
154 |
155 | if sample[self.elem].ndim == 3:
156 | raise ValueError('IOGPoints not implemented for multiple object per image.')
157 | _target = sample[self.elem]
158 |
159 | targetshape=_target.shape
160 | if np.max(_target) == 0:
161 | sample['IOG_points'] = np.zeros([targetshape[0],targetshape[1],2], dtype=_target.dtype) # TODO: handle one_mask_per_point case
162 | else:
163 | _points = helpers.iog_points(_target, self.pad_pixel)
164 | sample['IOG_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False)
165 |
166 | return sample
167 |
168 | def __str__(self):
169 | return 'IOGPoints:(sigma='+str(self.sigma)+', pad_pixel='+str(self.pad_pixel)+', elem='+str(self.elem)+')'
170 |
171 |
172 | class ConcatInputs(object):
173 |
174 | def __init__(self, elems=('image', 'point')):
175 | self.elems = elems
176 |
177 | def __call__(self, sample):
178 |
179 | res = sample[self.elems[0]]
180 |
181 | for elem in self.elems[1:]:
182 | assert(sample[self.elems[0]].shape[:2] == sample[elem].shape[:2])
183 |
184 | # Check if third dimension is missing
185 | tmp = sample[elem]
186 | if tmp.ndim == 2:
187 | tmp = tmp[:, :, np.newaxis]
188 |
189 | res = np.concatenate((res, tmp), axis=2)
190 |
191 | sample['concat'] = res
192 | return sample
193 |
194 | def __str__(self):
195 | return 'ExtremePoints:'+str(self.elems)
196 |
197 |
198 | class CropFromMask(object):
199 | """
200 | Returns image cropped in bounding box from a given mask
201 | """
202 | def __init__(self, crop_elems=('image', 'gt','void_pixels'),
203 | mask_elem='gt',
204 | relax=0,
205 | zero_pad=False):
206 |
207 | self.crop_elems = crop_elems
208 | self.mask_elem = mask_elem
209 | self.relax = relax
210 | self.zero_pad = zero_pad
211 |
212 | def __call__(self, sample):
213 | _target = sample[self.mask_elem]
214 | if _target.ndim == 2:
215 | _target = np.expand_dims(_target, axis=-1)
216 | for elem in self.crop_elems:
217 | _img = sample[elem]
218 | _crop = []
219 | if self.mask_elem == elem:
220 | if _img.ndim == 2:
221 | _img = np.expand_dims(_img, axis=-1)
222 | for k in range(0, _target.shape[-1]):
223 | _tmp_img = _img[..., k]
224 | _tmp_target = _target[..., k]
225 | if np.max(_target[..., k]) == 0:
226 | _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype))
227 | else:
228 | _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad))
229 | else:
230 | for k in range(0, _target.shape[-1]):
231 | if np.max(_target[..., k]) == 0:
232 | _crop.append(np.zeros(_img.shape, dtype=_img.dtype))
233 | else:
234 | _tmp_target = _target[..., k]
235 | _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad))
236 | if len(_crop) == 1:
237 | sample['crop_' + elem] = _crop[0]
238 | else:
239 | sample['crop_' + elem] = _crop
240 |
241 | return sample
242 |
243 | def __str__(self):
244 | return 'CropFromMask:(crop_elems='+str(self.crop_elems)+', mask_elem='+str(self.mask_elem)+\
245 | ', relax='+str(self.relax)+',zero_pad='+str(self.zero_pad)+')'
246 |
247 |
248 | class ToImage(object):
249 | """
250 | Return the given elements between 0 and 255
251 | """
252 | def __init__(self, norm_elem='image', custom_max=255.):
253 | self.norm_elem = norm_elem
254 | self.custom_max = custom_max
255 |
256 | def __call__(self, sample):
257 | if isinstance(self.norm_elem, tuple):
258 | for elem in self.norm_elem:
259 | tmp = sample[elem]
260 | sample[elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10)
261 | else:
262 | tmp = sample[self.norm_elem]
263 | sample[self.norm_elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10)
264 | return sample
265 |
266 | def __str__(self):
267 | return 'NormalizeImage'
268 |
269 |
270 | class ToTensor(object):
271 | """Convert ndarrays in sample to Tensors."""
272 |
273 | def __call__(self, sample):
274 |
275 | for elem in sample.keys():
276 | if 'meta' in elem:
277 | continue
278 | elif 'bbox' in elem:
279 | tmp = sample[elem]
280 | sample[elem] = torch.from_numpy(tmp)
281 | continue
282 |
283 | tmp = sample[elem]
284 |
285 | if tmp.ndim == 2:
286 | tmp = tmp[:, :, np.newaxis]
287 |
288 | # swap color axis because
289 | # numpy image: H x W x C
290 | # torch image: C X H X W
291 | tmp = tmp.transpose((2, 0, 1))
292 | sample[elem] = torch.from_numpy(tmp)
293 |
294 | return sample
295 |
296 | def __str__(self):
297 | return 'ToTensor'
298 |
--------------------------------------------------------------------------------
/dataloaders/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch, cv2
3 | import random
4 | import numpy as np
5 |
6 |
7 | def tens2image(im):
8 | if im.size()[0] == 1:
9 | tmp = np.squeeze(im.numpy(), axis=0)
10 | else:
11 | tmp = im.numpy()
12 | if tmp.ndim == 2:
13 | return tmp
14 | else:
15 | return tmp.transpose((1, 2, 0))
16 |
17 |
18 | def crop2fullmask(crop_mask, bbox, im=None, im_size=None, zero_pad=False, relax=0, mask_relax=True,
19 | interpolation=cv2.INTER_CUBIC, scikit=False):
20 | if scikit:
21 | from skimage.transform import resize as sk_resize
22 | assert(not(im is None and im_size is None)), 'You have to provide an image or the image size True'
23 | if im is None:
24 | im_si = im_size
25 | else:
26 | im_si = im.shape
27 | # Borers of image
28 | bounds = (0, 0, im_si[1] - 1, im_si[0] - 1)
29 |
30 | # Valid bounding box locations as (x_min, y_min, x_max, y_max)
31 | bbox_valid = (max(bbox[0], bounds[0]),
32 | max(bbox[1], bounds[1]),
33 | min(bbox[2], bounds[2]),
34 | min(bbox[3], bounds[3]))
35 |
36 | # Bounding box of initial mask
37 | bbox_init = (bbox[0] + relax,
38 | bbox[1] + relax,
39 | bbox[2] - relax,
40 | bbox[3] - relax)
41 |
42 | if zero_pad:
43 | # Offsets for x and y
44 | offsets = (-bbox[0], -bbox[1])
45 | else:
46 | # assert((bbox == bbox_valid).all())
47 | offsets = (-bbox_valid[0], -bbox_valid[1])
48 |
49 | # Simple per element addition in the tuple
50 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets)))
51 |
52 | if scikit:
53 | crop_mask = sk_resize(crop_mask, (bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), order=0, mode='constant').astype(crop_mask.dtype)
54 | else:
55 | crop_mask = cv2.resize(crop_mask, (bbox[2] - bbox[0] + 1, bbox[3] - bbox[1] + 1), interpolation=interpolation)
56 | result_ = np.zeros(im_si)
57 | result_[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] = \
58 | crop_mask[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1]
59 |
60 | result = np.zeros(im_si)
61 | if mask_relax:
62 | result[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] = \
63 | result_[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1]
64 | else:
65 | result = result_
66 |
67 | return result
68 |
69 |
70 | def overlay_mask(im, ma, colors=None, alpha=0.5):
71 | assert np.max(im) <= 1.0
72 | if colors is None:
73 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255.
74 | else:
75 | colors = np.append([[0.,0.,0.]], colors, axis=0);
76 |
77 | if ma.ndim == 3:
78 | assert len(colors) >= ma.shape[0], 'Not enough colors'
79 | ma = ma.astype(np.bool)
80 | im = im.astype(np.float32)
81 |
82 | if ma.ndim == 2:
83 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[1, :3] # np.array([0,0,255])/255.0
84 | else:
85 | fg = []
86 | for n in range(ma.ndim):
87 | fg.append(im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1+n, :3])
88 | # Whiten background
89 | bg = im.copy()
90 | if ma.ndim == 2:
91 | bg[ma == 0] = im[ma == 0]
92 | bg[ma == 1] = fg[ma == 1]
93 | total_ma = ma
94 | else:
95 | total_ma = np.zeros([ma.shape[1], ma.shape[2]])
96 | for n in range(ma.shape[0]):
97 | tmp_ma = ma[n, :, :]
98 | total_ma = np.logical_or(tmp_ma, total_ma)
99 | tmp_fg = fg[n]
100 | bg[tmp_ma == 1] = tmp_fg[tmp_ma == 1]
101 | bg[total_ma == 0] = im[total_ma == 0]
102 |
103 | # [-2:] is s trick to be compatible both with opencv 2 and 3
104 | contours = cv2.findContours(total_ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
105 | cv2.drawContours(bg, contours[0], -1, (0.0, 0.0, 0.0), 1)
106 |
107 | return bg
108 | import PIL
109 | def overlay_masks(im, masks, alpha=0.5):
110 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255.
111 |
112 | if isinstance(masks, np.ndarray):
113 | masks = [masks]
114 |
115 | assert len(colors) >= len(masks), 'Not enough colors'
116 |
117 | ov = im.copy()
118 | ov_black = im.copy()*0
119 |
120 | imgZero = np.zeros(np.array(masks, dtype = np.uint8).shape,np.uint8)
121 | im = im.astype(np.float32)
122 | total_ma = np.zeros([im.shape[0], im.shape[1]])
123 | i = 1
124 | for ma in masks:
125 | ma = ma.astype(np.bool)
126 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[i, :3] # np.array([0,0,255])/255.0
127 | i = i + 1
128 | ov[ma == 1] = fg[ma == 1]
129 | total_ma += ma
130 |
131 | # [-2:] is s trick to be compatible both with opencv 2 and 3
132 | contours = cv2.findContours(ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
133 | cv2.drawContours(ov, contours[0], -1, (0.0, 0.0, 0.0), 1)
134 | cv2.drawContours(ov_black, contours[0], -1, (255, 255, 255), -1)#only draw a round
135 | ov[total_ma == 0] = im[total_ma == 0]
136 |
137 | return ov_black
138 |
139 | from scipy import ndimage
140 | def getPositon(distance_transform):
141 | a = np.mat(distance_transform)
142 | raw, column = a.shape# get the matrix of a raw and column
143 | _positon = np.argmax(a)# get the index of max in the a
144 | m, n = divmod(_positon, column)
145 | raw=m
146 | column=n
147 | # print "The raw is " ,m
148 | # print "The column is ", n
149 | # print "The max of the a is ", a[m , n]
150 | # print(raw,column,a[m , n])
151 | return raw,column
152 |
153 | def iog_points(mask, pad_pixel=10):
154 | def find_point(id_x, id_y, ids):
155 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)]
156 | return [id_x[sel_id], id_y[sel_id]]
157 |
158 | inds_y, inds_x = np.where(mask > 0.5)
159 | [h,w]=mask.shape
160 | left = find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x)))
161 | right = find_point(inds_x, inds_y, np.where(inds_x >= np.max(inds_x)))
162 | top = find_point(inds_x, inds_y, np.where(inds_y <= np.min(inds_y)))
163 | bottom = find_point(inds_x, inds_y, np.where(inds_y >= np.max(inds_y)))
164 |
165 | x_min=left[0]
166 | x_max=right[0]
167 | y_min=top[1]
168 | y_max=bottom[1]
169 |
170 | map_xor = (mask > 0.5)
171 | h,w = map_xor.shape
172 | map_xor_new = np.zeros((h+2,w+2))
173 | map_xor_new[1:(h+1),1:(w+1)] = map_xor[:,:]
174 | distance_transform=ndimage.distance_transform_edt(map_xor_new)
175 | distance_transform_back = distance_transform[1:(h+1),1:(w+1)]
176 | raw,column=getPositon(distance_transform_back)
177 | center_point = [column,raw]
178 |
179 | left_top=[max(x_min-pad_pixel,0), max(y_min-pad_pixel,0)]
180 | left_bottom=[max(x_min-pad_pixel ,0), min(y_max+pad_pixel,h)]
181 | right_top=[min(x_max+pad_pixel,w), max(y_min-pad_pixel,0)]
182 | righr_bottom=[min(x_max+pad_pixel ,w), min(y_max+pad_pixel,h)]
183 | a=[center_point,left_top,left_bottom,right_top,righr_bottom]
184 |
185 | return np.array(a)
186 |
187 |
188 | def get_bbox(mask, points=None, pad=0, zero_pad=False):
189 | if points is not None:
190 | inds = np.flip(points.transpose(), axis=0)
191 | else:
192 | inds = np.where(mask > 0)
193 |
194 | if inds[0].shape[0] == 0:
195 | return None
196 |
197 | if zero_pad:
198 | x_min_bound = -np.inf
199 | y_min_bound = -np.inf
200 | x_max_bound = np.inf
201 | y_max_bound = np.inf
202 | else:
203 | x_min_bound = 0
204 | y_min_bound = 0
205 | x_max_bound = mask.shape[1] - 1
206 | y_max_bound = mask.shape[0] - 1
207 |
208 | x_min = max(inds[1].min() - pad, x_min_bound)
209 | y_min = max(inds[0].min() - pad, y_min_bound)
210 | x_max = min(inds[1].max() + pad, x_max_bound)
211 | y_max = min(inds[0].max() + pad, y_max_bound)
212 |
213 | return x_min, y_min, x_max, y_max
214 |
215 |
216 | def crop_from_bbox(img, bbox, zero_pad=False):
217 | # Borders of image
218 | bounds = (0, 0, img.shape[1] - 1, img.shape[0] - 1)
219 |
220 | # Valid bounding box locations as (x_min, y_min, x_max, y_max)
221 | bbox_valid = (max(bbox[0], bounds[0]),
222 | max(bbox[1], bounds[1]),
223 | min(bbox[2], bounds[2]),
224 | min(bbox[3], bounds[3]))
225 |
226 | if zero_pad:
227 | # Initialize crop size (first 2 dimensions)
228 | crop = np.zeros((bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), dtype=img.dtype)
229 |
230 | # Offsets for x and y
231 | offsets = (-bbox[0], -bbox[1])
232 |
233 | else:
234 | assert(bbox == bbox_valid)
235 | crop = np.zeros((bbox_valid[3] - bbox_valid[1] + 1, bbox_valid[2] - bbox_valid[0] + 1), dtype=img.dtype)
236 | offsets = (-bbox_valid[0], -bbox_valid[1])
237 |
238 | # Simple per element addition in the tuple
239 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets)))
240 |
241 | img = np.squeeze(img)
242 | if img.ndim == 2:
243 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] = \
244 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1]
245 | else:
246 | crop = np.tile(crop[:, :, np.newaxis], [1, 1, 3]) # Add 3 RGB Channels
247 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1, :] = \
248 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1, :]
249 |
250 | return crop
251 |
252 |
253 | def fixed_resize(sample, resolution, flagval=None):
254 |
255 | if flagval is None:
256 | if ((sample == 0) | (sample == 1)).all():
257 | flagval = cv2.INTER_NEAREST
258 | else:
259 | flagval = cv2.INTER_CUBIC
260 |
261 | if isinstance(resolution, int):
262 | tmp = [resolution, resolution]
263 | tmp[np.argmax(sample.shape[:2])] = int(round(float(resolution)/np.min(sample.shape[:2])*np.max(sample.shape[:2])))
264 | resolution = tuple(tmp)
265 |
266 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3):
267 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval)
268 | else:
269 | tmp = sample
270 | sample = np.zeros(np.append(resolution, tmp.shape[2]), dtype=np.float32)
271 | for ii in range(sample.shape[2]):
272 | sample[:, :, ii] = cv2.resize(tmp[:, :, ii], resolution[::-1], interpolation=flagval)
273 | return sample
274 |
275 |
276 | def crop_from_mask(img, mask, relax=0, zero_pad=False):
277 | if mask.shape[:2] != img.shape[:2]:
278 | mask = cv2.resize(mask, dsize=tuple(reversed(img.shape[:2])), interpolation=cv2.INTER_NEAREST)
279 |
280 | assert(mask.shape[:2] == img.shape[:2])
281 | bbox = get_bbox(mask, pad=relax, zero_pad=zero_pad)
282 |
283 | if bbox is None:
284 | return None
285 |
286 | crop = crop_from_bbox(img, bbox, zero_pad)
287 |
288 | return crop
289 |
290 |
291 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64):
292 | """ Make a square gaussian kernel.
293 | size: is the dimensions of the output gaussian
294 | sigma: is full-width-half-maximum, which
295 | can be thought of as an effective radius.
296 | """
297 |
298 | x = np.arange(0, size[1], 1, float)
299 | y = np.arange(0, size[0], 1, float)
300 | y = y[:, np.newaxis]
301 |
302 | if center is None:
303 | x0 = y0 = size[0] // 2
304 | else:
305 | x0 = center[0]
306 | y0 = center[1]
307 |
308 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type)
309 |
310 |
311 | def make_gt(img, labels, sigma=10, one_mask_per_point=False):
312 | """ Make the ground-truth for landmark.
313 | img: the original color image
314 | labels: label with the Gaussian center(s) [[x0, y0],[x1, y1],...]
315 | sigma: sigma of the Gaussian.
316 | one_mask_per_point: masks for each point in different channels?
317 | """
318 |
319 | h, w = img.shape[:2]
320 | if labels is None:
321 | gt = make_gaussian((h, w), center=(h//2, w//2), sigma=sigma)
322 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64)
323 | gt_0 = gt
324 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64)
325 |
326 | gtout = np.zeros(shape=(h, w, 2))
327 | gtout[:, :, 0]=gt_0
328 | gtout[:, :, 1]=gt_1
329 | gtout = gtout.astype(dtype=img.dtype) #(0~1)
330 | return gtout
331 | else:
332 | labels = np.array(labels)
333 | if labels.ndim == 1:
334 | labels = labels[np.newaxis]
335 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64)
336 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64)
337 | gt_0 = np.maximum(gt_0, make_gaussian((h, w), center=labels[0, :], sigma=sigma))
338 |
339 | else:
340 | gt_0 = np.zeros(shape=(h, w), dtype=np.float64)
341 | gt_1 = np.zeros(shape=(h, w), dtype=np.float64)
342 | for ii in range(1,labels.shape[0]):
343 | gt_1 = np.maximum(gt_1, make_gaussian((h, w), center=labels[ii, :], sigma=sigma))
344 | gt_0 = np.maximum(gt_0, make_gaussian((h, w), center=labels[0, :], sigma=sigma))
345 |
346 | gt = np.zeros(shape=(h, w, 2))
347 | gt[:, :, 0]=gt_0
348 | gt[:, :, 1]=gt_1
349 | gt = gt.astype(dtype=img.dtype) #(0~1)
350 | return gt
351 |
352 | def cstm_normalize(im, max_value):
353 | """
354 | Normalize image to range 0 - max_value
355 | """
356 | imn = max_value*(im - im.min()) / max((im.max() - im.min()), 1e-8)
357 | return imn
358 |
359 |
360 | def generate_param_report(logfile, param):
361 | log_file = open(logfile, 'w')
362 | for key, val in param.items():
363 | log_file.write(key+':'+str(val)+'\n')
364 | log_file.close()
365 |
--------------------------------------------------------------------------------
/dataloaders/pascal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import errno
3 | import hashlib
4 | import os
5 | import sys
6 | import tarfile
7 | import numpy as np
8 |
9 | import torch.utils.data as data
10 | from PIL import Image
11 | from six.moves import urllib
12 | import json
13 | from mypath import Path
14 |
15 |
16 | class VOCSegmentation(data.Dataset):
17 |
18 | URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
19 | FILE = "VOCtrainval_11-May-2012.tar"
20 | MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
21 | BASE_DIR = 'VOCdevkit/VOC2012'
22 |
23 | category_names = ['background',
24 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
25 | 'bus', 'car', 'cat', 'chair', 'cow',
26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person',
27 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
28 |
29 | def __init__(self,
30 | root=Path.db_root_dir('pascal'),
31 | split='val',
32 | transform=None,
33 | download=False,
34 | preprocess=False,
35 | area_thres=0,
36 | retname=True,
37 | suppress_void_pixels=True,
38 | default=False):
39 |
40 | self.root = root
41 | _voc_root = os.path.join(self.root, self.BASE_DIR)
42 | _mask_dir = os.path.join(_voc_root, 'SegmentationObject')#each object each color
43 | _cat_dir = os.path.join(_voc_root, 'SegmentationClass')#each class each color
44 | _image_dir = os.path.join(_voc_root, 'JPEGImages')
45 | self.transform = transform
46 | if isinstance(split, str):
47 | self.split = [split]
48 | else:
49 | split.sort()
50 | self.split = split
51 | self.area_thres = area_thres
52 | self.retname = retname
53 | self.suppress_void_pixels = suppress_void_pixels
54 | self.default = default
55 |
56 | # Build the ids file
57 | area_th_str = ""
58 | if self.area_thres != 0:
59 | area_th_str = '_area_thres-' + str(area_thres)
60 |
61 | self.obj_list_file = os.path.join(self.root, self.BASE_DIR, 'ImageSets', 'Segmentation',
62 | '_'.join(self.split) + '_instances' + area_th_str + '.txt')
63 |
64 | if download:
65 | self._download()
66 |
67 | if not self._check_integrity():
68 | raise RuntimeError('Dataset not found or corrupted.' +
69 | ' You can use download=True to download it')
70 |
71 | # train/val/test splits are pre-cut
72 | _splits_dir = os.path.join(_voc_root, 'ImageSets', 'Segmentation')
73 |
74 | self.im_ids = []
75 | self.images = []
76 | self.categories = []
77 | self.masks = []
78 |
79 | for splt in self.split:
80 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
81 | lines = f.read().splitlines()
82 |
83 | for ii, line in enumerate(lines):
84 | _image = os.path.join(_image_dir, line + ".jpg")
85 | _cat = os.path.join(_cat_dir, line + ".png")
86 | _mask = os.path.join(_mask_dir, line + ".png")
87 | assert os.path.isfile(_image)
88 | assert os.path.isfile(_cat)
89 | assert os.path.isfile(_mask)
90 | self.im_ids.append(line.rstrip('\n'))
91 | self.images.append(_image)
92 | self.categories.append(_cat)
93 | self.masks.append(_mask)
94 | assert (len(self.images) == len(self.masks))
95 | assert (len(self.images) == len(self.categories))
96 |
97 | # Precompute the list of objects and their categories for each image
98 | if (not self._check_preprocess()) or preprocess:
99 | print('Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.')
100 | self._preprocess()
101 |
102 | # Build the list of objects
103 | self.obj_list = []
104 | num_images = 0
105 | for ii in range(len(self.im_ids)):
106 | flag = False
107 | for jj in range(len(self.obj_dict[self.im_ids[ii]])):
108 | if self.obj_dict[self.im_ids[ii]][jj] != -1:
109 | self.obj_list.append([ii, jj])
110 | flag = True
111 | if flag:
112 | num_images += 1
113 |
114 | # Display stats
115 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list)))
116 |
117 | def __getitem__(self, index):
118 | _img, _target, _void_pixels, _, _, _ = self._make_img_gt_point_pair(index)
119 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels}
120 |
121 | if self.retname:
122 | _im_ii = self.obj_list[index][0]
123 | _obj_ii = self.obj_list[index][1]
124 | sample['meta'] = {'image': str(self.im_ids[_im_ii]),
125 | 'object': str(_obj_ii),
126 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii],
127 | 'im_size': (_img.shape[0], _img.shape[1])}
128 |
129 | if self.transform is not None:
130 | sample = self.transform(sample)
131 | return sample
132 |
133 | def __len__(self):
134 | return len(self.obj_list)
135 |
136 | def _check_integrity(self):
137 | _fpath = os.path.join(self.root, self.FILE)
138 | if not os.path.isfile(_fpath):
139 | print("{} does not exist".format(_fpath))
140 | return False
141 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest()
142 | if _md5c != self.MD5:
143 | print(" MD5({}) did not match MD5({}) expected for {}".format(
144 | _md5c, self.MD5, _fpath))
145 | return False
146 | return True
147 |
148 | def _check_preprocess(self):
149 | _obj_list_file = self.obj_list_file
150 | if not os.path.isfile(_obj_list_file):
151 | return False
152 | else:
153 | self.obj_dict = json.load(open(_obj_list_file, 'r'))
154 |
155 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids))
156 |
157 | def _preprocess(self):
158 | self.obj_dict = {}
159 | obj_counter = 0
160 | for ii in range(len(self.im_ids)):
161 | # Read object masks and get number of objects
162 | _mask = np.array(Image.open(self.masks[ii]))
163 | _mask_ids = np.unique(_mask)
164 | if _mask_ids[-1] == 255:
165 | n_obj = _mask_ids[-2]
166 | else:
167 | n_obj = _mask_ids[-1]
168 |
169 | # Get the categories from these objects
170 | _cats = np.array(Image.open(self.categories[ii]))
171 | _cat_ids = []
172 | for jj in range(n_obj):
173 | tmp = np.where(_mask == jj + 1)
174 | obj_area = len(tmp[0])
175 | if obj_area > self.area_thres:
176 | _cat_ids.append(int(_cats[tmp[0][0], tmp[1][0]]))
177 | else:
178 | _cat_ids.append(-1)
179 | obj_counter += 1
180 |
181 | self.obj_dict[self.im_ids[ii]] = _cat_ids
182 |
183 | with open(self.obj_list_file, 'w') as outfile:
184 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]])))
185 | for ii in range(1, len(self.im_ids)):
186 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]])))
187 | outfile.write('\n}\n')
188 |
189 | print('Preprocessing finished')
190 |
191 | def _download(self):
192 | _fpath = os.path.join(self.root, self.FILE)
193 |
194 | try:
195 | os.makedirs(self.root)
196 | except OSError as e:
197 | if e.errno == errno.EEXIST:
198 | pass
199 | else:
200 | raise
201 |
202 | if self._check_integrity():
203 | print('Files already downloaded and verified')
204 | return
205 | else:
206 | print('Downloading ' + self.URL + ' to ' + _fpath)
207 |
208 | def _progress(count, block_size, total_size):
209 | sys.stdout.write('\r>> %s %.1f%%' %
210 | (_fpath, float(count * block_size) /
211 | float(total_size) * 100.0))
212 | sys.stdout.flush()
213 |
214 | urllib.request.urlretrieve(self.URL, _fpath, _progress)
215 |
216 | # extract file
217 | cwd = os.getcwd()
218 | print('Extracting tar file')
219 | tar = tarfile.open(_fpath)
220 | os.chdir(self.root)
221 | tar.extractall()
222 | tar.close()
223 | os.chdir(cwd)
224 | print('Done!')
225 |
226 | def _make_img_gt_point_pair(self, index):
227 | _im_ii = self.obj_list[index][0]
228 | _obj_ii = self.obj_list[index][1]
229 |
230 | # Read Image
231 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32) ###zsy open image imread
232 |
233 | # Read Target object
234 | _tmp = (np.array(Image.open(self.masks[_im_ii]))).astype(np.float32)
235 | _void_pixels = (_tmp == 255)
236 | _tmp[_void_pixels] = 0
237 |
238 | _other_same_class = np.zeros(_tmp.shape)
239 | _other_classes = np.zeros(_tmp.shape)
240 |
241 | if self.default:
242 | _target = _tmp
243 | _background = np.logical_and(_tmp == 0, ~_void_pixels)
244 | else:
245 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32)
246 | _background = np.logical_and(_tmp == 0, ~_void_pixels)
247 | obj_cat = self.obj_dict[self.im_ids[_im_ii]][_obj_ii]
248 | for ii in range(1, np.max(_tmp).astype(np.int)+1):
249 | ii_cat = self.obj_dict[self.im_ids[_im_ii]][ii-1]
250 | if obj_cat == ii_cat and ii != _obj_ii+1:
251 | _other_same_class = np.logical_or(_other_same_class, _tmp == ii)
252 | elif ii != _obj_ii+1:
253 | _other_classes = np.logical_or(_other_classes, _tmp == ii)
254 |
255 | return _img, _target, _void_pixels.astype(np.float32), \
256 | _other_classes.astype(np.float32), _other_same_class.astype(np.float32), \
257 | _background.astype(np.float32)
258 |
259 | def __str__(self):
260 | return 'VOC2012(split=' + str(self.split) + ',area_thres=' + str(self.area_thres) + ')'
261 |
262 |
263 | if __name__ == '__main__':
264 | import matplotlib.pyplot as plt
265 | import dataloaders.helpers as helpers
266 | import torch
267 | import dataloaders.custom_transforms as tr
268 | from torchvision import transforms
269 |
270 | transform = transforms.Compose([tr.ToTensor()])
271 |
272 | dataset = VOCSegmentation(split=['train', 'val'], transform=transform, retname=True)
273 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
274 |
275 | for i, sample in enumerate(dataloader):
276 | plt.figure()
277 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"]) / 255.,
278 | np.squeeze(helpers.tens2image(sample["gt"])))
279 | plt.imshow(overlay)
280 | plt.title(dataset.category_names[sample["meta"]["category"][0]])
281 | if i == 3:
282 | break
283 |
284 | plt.show(block=True)
285 |
--------------------------------------------------------------------------------
/dataloaders/sbd.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch, cv2
4 | import errno
5 | import hashlib
6 | import json
7 | import os
8 | import sys
9 | import tarfile
10 |
11 | import numpy as np
12 | import scipy.io
13 | import torch.utils.data as data
14 | from PIL import Image
15 | from six.moves import urllib
16 | from mypath import Path
17 |
18 |
19 | class SBDSegmentation(data.Dataset):
20 |
21 | URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
22 | FILE = "benchmark.tgz"
23 | MD5 = '82b4d87ceb2ed10f6038a1cba92111cb'
24 |
25 | def __init__(self,
26 | root=Path.db_root_dir('sbd'),
27 | split='val',
28 | transform=None,
29 | download=False,
30 | preprocess=False,
31 | area_thres=0,
32 | retname=True):
33 |
34 | # Store parameters
35 | self.root = root
36 | self.transform = transform
37 | if isinstance(split, str):
38 | self.split = [split]
39 | else:
40 | split.sort()
41 | self.split = split
42 | self.area_thres = area_thres
43 | self.retname = retname
44 |
45 | # Where to find things according to the author's structure
46 | self.dataset_dir = os.path.join(self.root, 'benchmark_RELEASE', 'dataset')
47 | _mask_dir = os.path.join(self.dataset_dir, 'inst')
48 | _image_dir = os.path.join(self.dataset_dir, 'img')
49 |
50 | if self.area_thres != 0:
51 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances_area_thres-' +
52 | str(area_thres) + '.txt')
53 | else:
54 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances' + '.txt')
55 |
56 | # Download dataset?
57 | if download:
58 | self._download()
59 | if not self._check_integrity():
60 | raise RuntimeError('Dataset file downloaded is corrupted.')
61 |
62 | # Get list of all images from the split and check that the files exist
63 | self.im_ids = []
64 | self.images = []
65 | self.masks = []
66 | for splt in self.split:
67 | with open(os.path.join(self.dataset_dir, splt+'.txt'), "r") as f:
68 | lines = f.read().splitlines()
69 |
70 | for line in lines:
71 | _image = os.path.join(_image_dir, line + ".jpg")
72 | _mask = os.path.join(_mask_dir, line + ".mat")
73 | assert os.path.isfile(_image)
74 | assert os.path.isfile(_mask)
75 | self.im_ids.append(line)
76 | self.images.append(_image)
77 | self.masks.append(_mask)
78 |
79 | assert (len(self.images) == len(self.masks))
80 |
81 | # Precompute the list of objects and their categories for each image
82 | if (not self._check_preprocess()) or preprocess:
83 | print('Preprocessing SBD dataset, this will take long, but it will be done only once.')
84 | self._preprocess()
85 |
86 | # Build the list of objects
87 | self.obj_list = []
88 | num_images = 0
89 | for ii in range(len(self.im_ids)):
90 | if self.im_ids[ii] in self.obj_dict.keys():
91 | flag = False
92 | for jj in range(len(self.obj_dict[self.im_ids[ii]])):
93 | if self.obj_dict[self.im_ids[ii]][jj] != -1:
94 | self.obj_list.append([ii, jj])
95 | flag = True
96 | if flag:
97 | num_images += 1
98 |
99 | # Display stats
100 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list)))
101 |
102 | def __getitem__(self, index):
103 |
104 | _img, _target = self._make_img_gt_point_pair(index)
105 | _void_pixels = (_target == 255).astype(np.float32)
106 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels}
107 |
108 | if self.retname:
109 | _im_ii = self.obj_list[index][0]
110 | _obj_ii = self.obj_list[index][1]
111 | sample['meta'] = {'image': str(self.im_ids[_im_ii]),
112 | 'object': str(_obj_ii),
113 | 'im_size': (_img.shape[0], _img.shape[1]),
114 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii]}
115 |
116 | if self.transform is not None:
117 | sample = self.transform(sample)
118 |
119 | return sample
120 |
121 | def __len__(self):
122 | return len(self.obj_list)
123 |
124 | def _check_integrity(self):
125 | _fpath = os.path.join(self.root, self.FILE)
126 | if not os.path.isfile(_fpath):
127 | print("{} does not exist".format(_fpath))
128 | return False
129 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest()
130 | if _md5c != self.MD5:
131 | print(" MD5({}) did not match MD5({}) expected for {}".format(
132 | _md5c, self.MD5, _fpath))
133 | return False
134 | return True
135 |
136 | def _check_preprocess(self):
137 | # Check that the file with categories is there and with correct size
138 | _obj_list_file = self.obj_list_file
139 | if not os.path.isfile(_obj_list_file):
140 | return False
141 | else:
142 | self.obj_dict = json.load(open(_obj_list_file, 'r'))
143 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids))
144 |
145 | def _preprocess(self):
146 | # Get all object instances and their category
147 | self.obj_dict = {}
148 | obj_counter = 0
149 | for ii in range(len(self.im_ids)):
150 | # Read object masks and get number of objects
151 | tmp = scipy.io.loadmat(self.masks[ii])
152 | _mask = tmp["GTinst"][0]["Segmentation"][0]
153 | _cat_ids = tmp["GTinst"][0]["Categories"][0].astype(int)
154 |
155 | _mask_ids = np.unique(_mask)
156 | n_obj = _mask_ids[-1]
157 | assert(n_obj == len(_cat_ids))
158 |
159 | for jj in range(n_obj):
160 | temp = np.where(_mask == jj + 1)
161 | obj_area = len(temp[0])
162 | if obj_area < self.area_thres:
163 | _cat_ids[jj] = -1
164 | obj_counter += 1
165 |
166 | self.obj_dict[self.im_ids[ii]] = np.squeeze(_cat_ids, 1).tolist()
167 |
168 | # Save it to file for future reference
169 | with open(self.obj_list_file, 'w') as outfile:
170 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]])))
171 | for ii in range(1, len(self.im_ids)):
172 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]])))
173 | outfile.write('\n}\n')
174 |
175 | print('Pre-processing finished')
176 |
177 | def _download(self):
178 | _fpath = os.path.join(self.root, self.FILE)
179 |
180 | try:
181 | os.makedirs(self.root)
182 | except OSError as e:
183 | if e.errno == errno.EEXIST:
184 | pass
185 | else:
186 | raise
187 |
188 | if self._check_integrity():
189 | print('Files already downloaded and verified')
190 | return
191 | else:
192 | print('Downloading ' + self.URL + ' to ' + _fpath)
193 |
194 | def _progress(count, block_size, total_size):
195 | sys.stdout.write('\r>> %s %.1f%%' %
196 | (_fpath, float(count * block_size) /
197 | float(total_size) * 100.0))
198 | sys.stdout.flush()
199 |
200 | urllib.request.urlretrieve(self.URL, _fpath, _progress)
201 |
202 | # extract file
203 | cwd = os.getcwd()
204 | print('Extracting tar file')
205 | tar = tarfile.open(_fpath)
206 | os.chdir(self.root)
207 | tar.extractall()
208 | tar.close()
209 | os.chdir(cwd)
210 | print('Done!')
211 |
212 | def _make_img_gt_point_pair(self, index):
213 | _im_ii = self.obj_list[index][0]
214 | _obj_ii = self.obj_list[index][1]
215 |
216 | # Read Image
217 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32)
218 |
219 | # Read Taret object
220 | _tmp = scipy.io.loadmat(self.masks[_im_ii])["GTinst"][0]["Segmentation"][0]
221 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32)
222 |
223 | return _img, _target
224 |
225 | def __str__(self):
226 | return 'SBDSegmentation(split='+str(self.split)+', area_thres='+str(self.area_thres)+')'
227 |
228 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | from torch.utils.data import DataLoader
4 | from evaluation.eval import eval_one_result
5 | import dataloaders.pascal as pascal
6 |
7 | exp_root_dir = './'
8 |
9 | method_names = []
10 | method_names.append('run_0')
11 |
12 | if __name__ == '__main__':
13 |
14 | # Dataloader
15 | dataset = pascal.VOCSegmentation(transform=None, retname=True)
16 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
17 |
18 | # Iterate through all the different methods
19 | for method in method_names:
20 | for ii in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]:
21 | results_folder = os.path.join(exp_root_dir, method, 'Results')
22 |
23 | filename = os.path.join(exp_root_dir, 'eval_results', method.replace('/', '-') + '.txt')
24 | if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')):
25 | os.makedirs(os.path.join(exp_root_dir, 'eval_results'))
26 |
27 | jaccards = eval_one_result(dataloader, results_folder, mask_thres=ii)
28 | val = jaccards["all_jaccards"].mean()
29 |
30 | # Show mean and store result
31 | print(ii)
32 | print("Result for {:<80}: {}".format(method, str.format("{0:.4f}", 100*val)))
33 | with open(filename, 'w') as f:
34 | f.write(str(val))
35 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import dataloaders.helpers as helpers
7 | import evaluation.evaluation as evaluation
8 |
9 | def eval_one_result(loader, folder, one_mask_per_image=False, mask_thres=0.5, use_void_pixels=True, custom_box=False):
10 | def mAPr(per_cat, thresholds):
11 | n_cat = len(per_cat)
12 | all_apr = np.zeros(len(thresholds))
13 | for ii, th in enumerate(thresholds):
14 | per_cat_recall = np.zeros(n_cat)
15 | for jj, categ in enumerate(per_cat.keys()):
16 | per_cat_recall[jj] = np.sum(np.array(per_cat[categ]) > th)/len(per_cat[categ])
17 |
18 | all_apr[ii] = per_cat_recall.mean()
19 |
20 | return all_apr.mean()
21 |
22 | # Allocate
23 | eval_result = dict()
24 | eval_result["all_jaccards"] = np.zeros(len(loader))
25 | eval_result["all_percent"] = np.zeros(len(loader))
26 | eval_result["meta"] = []
27 | eval_result["per_categ_jaccard"] = dict()
28 |
29 | # Iterate
30 | for i, sample in enumerate(loader):
31 |
32 | if i % 500 == 0:
33 | print('Evaluating: {} of {} objects'.format(i, len(loader)))
34 |
35 | # Load result
36 | if not one_mask_per_image:
37 | filename = os.path.join(folder,
38 | sample["meta"]["image"][0] + '-' + sample["meta"]["object"][0] + '.png')
39 | else:
40 | filename = os.path.join(folder,
41 | sample["meta"]["image"][0] + '.png')
42 | mask = np.array(Image.open(filename)).astype(np.float32) / 255.
43 | gt = np.squeeze(helpers.tens2image(sample["gt"]))
44 | if use_void_pixels:
45 | void_pixels = np.squeeze(helpers.tens2image(sample["void_pixels"]))
46 | if mask.shape != gt.shape:
47 | mask = cv2.resize(mask, gt.shape[::-1], interpolation=cv2.INTER_CUBIC)
48 |
49 | # Threshold
50 | mask = (mask > mask_thres)
51 | if use_void_pixels:
52 | void_pixels = (void_pixels > 0.5)
53 |
54 | # Evaluate
55 | if use_void_pixels:
56 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask, void_pixels)
57 | else:
58 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask)
59 |
60 | if custom_box:
61 | box = np.squeeze(helpers.tens2image(sample["box"]))
62 | bb = helpers.get_bbox(box)
63 | else:
64 | bb = helpers.get_bbox(gt)
65 |
66 | mask_crop = helpers.crop_from_bbox(mask, bb)
67 | if use_void_pixels:
68 | non_void_pixels_crop = helpers.crop_from_bbox(np.logical_not(void_pixels), bb)
69 | gt_crop = helpers.crop_from_bbox(gt, bb)
70 | if use_void_pixels:
71 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop) & non_void_pixels_crop)/np.sum(non_void_pixels_crop)
72 | else:
73 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop))/mask_crop.size
74 | # Store in per category
75 | if "category" in sample["meta"]:
76 | cat = sample["meta"]["category"][0]
77 | else:
78 | cat = 1
79 | if cat not in eval_result["per_categ_jaccard"]:
80 | eval_result["per_categ_jaccard"][cat] = []
81 | eval_result["per_categ_jaccard"][cat].append(eval_result["all_jaccards"][i])
82 |
83 | # Store meta
84 | eval_result["meta"].append(sample["meta"])
85 |
86 | # Compute some stats
87 | eval_result["mAPr0.5"] = mAPr(eval_result["per_categ_jaccard"], [0.5])
88 | eval_result["mAPr0.7"] = mAPr(eval_result["per_categ_jaccard"], [0.7])
89 | eval_result["mAPr-vol"] = mAPr(eval_result["per_categ_jaccard"], np.linspace(0.1, 0.9, 9))
90 |
91 | return eval_result
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def jaccard(annotation, segmentation, void_pixels=None):
4 |
5 | assert(annotation.shape == segmentation.shape)
6 |
7 | if void_pixels is None:
8 | void_pixels = np.zeros_like(annotation)
9 | assert(void_pixels.shape == annotation.shape)
10 |
11 | annotation = annotation.astype(np.bool)
12 | segmentation = segmentation.astype(np.bool)
13 | void_pixels = void_pixels.astype(np.bool)
14 | if np.isclose(np.sum(annotation & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(segmentation & np.logical_not(void_pixels)), 0):
15 | return 1
16 | else:
17 | return np.sum(((annotation & segmentation) & np.logical_not(void_pixels))) / \
18 | np.sum(((annotation | segmentation) & np.logical_not(void_pixels)), dtype=np.float32)
19 |
20 |
--------------------------------------------------------------------------------
/ims/IOG.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/IOG.gif
--------------------------------------------------------------------------------
/ims/cross_domain.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/cross_domain.gif
--------------------------------------------------------------------------------
/ims/ims.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/ims.png
--------------------------------------------------------------------------------
/ims/refinement.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/ims/refinement.gif
--------------------------------------------------------------------------------
/mypath.py:
--------------------------------------------------------------------------------
1 |
2 | class Path(object):
3 | @staticmethod
4 | def db_root_dir(database):
5 | if database == 'pascal':
6 | return '/path/to/PASCAL/VOC2012' # folder that contains VOCdevkit/.
7 |
8 | elif database == 'sbd':
9 | return '/path/to/SBD/' # folder with img/, inst/, cls/, etc.
10 | else:
11 | print('Database {} not available.'.format(database))
12 | raise NotImplementedError
13 |
14 | @staticmethod
15 | def models_dir():
16 | return '/path/to/models/resnet101-5d3b4d8f.pth'
17 | #'resnet101-5d3b4d8f.pth' #resnet50-19c8e357.pth'
18 |
--------------------------------------------------------------------------------
/networks/CoarseNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 |
5 | class CoarseNet(nn.Module):
6 | def __init__(self, channel_settings, output_shape, num_class):
7 | super(CoarseNet, self).__init__()
8 | self.channel_settings = channel_settings
9 | laterals, upsamples, predict = [], [], []
10 | for i in range(len(channel_settings)):
11 | laterals.append(self._lateral(channel_settings[i]))
12 | predict.append(self._predict(output_shape, num_class))
13 | if i != len(channel_settings) - 1:
14 | upsamples.append(self._upsample())
15 | self.laterals = nn.ModuleList(laterals)
16 | self.upsamples = nn.ModuleList(upsamples)
17 | self.predict = nn.ModuleList(predict)
18 |
19 | for m in self.modules():
20 | if isinstance(m, nn.Conv2d):
21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
22 | m.weight.data.normal_(0, math.sqrt(2. / n))
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | m.weight.data.fill_(1)
27 | m.bias.data.zero_()
28 |
29 | def _lateral(self, input_size):
30 | layers = []
31 | layers.append(nn.Conv2d(input_size, 256,
32 | kernel_size=1, stride=1, bias=False))
33 | layers.append(nn.BatchNorm2d(256))
34 | layers.append(nn.ReLU(inplace=True))
35 | return nn.Sequential(*layers)
36 |
37 | def _upsample(self):
38 | layers = []
39 | layers.append(torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
40 | layers.append(torch.nn.Conv2d(256, 256,
41 | kernel_size=1, stride=1, bias=False))
42 | layers.append(nn.BatchNorm2d(256))
43 | return nn.Sequential(*layers)
44 |
45 | def _predict(self, output_shape, num_class):
46 | layers = []
47 | layers.append(nn.Conv2d(256, 256,
48 | kernel_size=1, stride=1, bias=False))
49 | layers.append(nn.BatchNorm2d(256))
50 | layers.append(nn.ReLU(inplace=True))
51 | layers.append(nn.Conv2d(256, num_class,
52 | kernel_size=3, stride=1, padding=1, bias=False))
53 | layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True))
54 | layers.append(nn.BatchNorm2d(num_class))
55 | return nn.Sequential(*layers)
56 |
57 | def forward(self, x):
58 | coarse_fms, coarse_outs = [], []
59 | for i in range(len(self.channel_settings)):
60 | if i == 0:
61 | feature = self.laterals[i](x[i])
62 | coarse_fms.append(feature)
63 | if i != len(self.channel_settings) - 1:
64 | up = feature
65 | feature = self.predict[i](feature)
66 | coarse_outs.append(feature)
67 | else:
68 | feature = self.laterals[i](x[i])
69 | feature = feature+ up
70 | coarse_fms.append(feature)
71 | if i != len(self.channel_settings) - 1:
72 | up = self.upsamples[i](feature)
73 | feature = self.predict[i](feature)
74 | coarse_outs.append(feature)
75 | return coarse_fms, coarse_outs
76 |
--------------------------------------------------------------------------------
/networks/FineNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class Bottleneck(nn.Module):
5 | expansion = 4
6 |
7 | def __init__(self, inplanes, planes, stride=1):
8 | super(Bottleneck, self).__init__()
9 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
10 | self.bn1 = nn.BatchNorm2d(planes)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 | self.bn2 = nn.BatchNorm2d(planes)
14 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
15 | self.bn3 = nn.BatchNorm2d(planes * 2)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.downsample = nn.Sequential(
18 | nn.Conv2d(inplanes, planes * 2,
19 | kernel_size=1, stride=stride, bias=False),
20 | nn.BatchNorm2d(planes * 2))
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv3(out)
35 | out = self.bn3(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 | class FineNet(nn.Module):
46 | def __init__(self, lateral_channel, out_shape, num_class):
47 | super(FineNet, self).__init__()
48 | cascade = []
49 | num_cascade = 4
50 | for i in range(num_cascade):
51 | cascade.append(self._make_layer(lateral_channel, num_cascade-i-1, out_shape))
52 | self.cascade = nn.ModuleList(cascade)
53 | self.final_predict = self._predict(4*lateral_channel, num_class)
54 |
55 | def _make_layer(self, input_channel, num, output_shape):
56 | layers = []
57 | for i in range(num):
58 | layers.append(Bottleneck(input_channel, 128))
59 | layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True))
60 | return nn.Sequential(*layers)
61 |
62 | def _predict(self, input_channel, num_class):
63 | layers = []
64 | layers.append(Bottleneck(input_channel, 128))
65 | layers.append(nn.Conv2d(256, num_class,
66 | kernel_size=3, stride=1, padding=1, bias=False))
67 | layers.append(nn.BatchNorm2d(num_class))
68 | return nn.Sequential(*layers)
69 |
70 | def forward(self, x):
71 | fine_fms = []
72 | for i in range(4):
73 | fine_fms.append(self.cascade[i](x[i]))
74 | out = torch.cat(fine_fms, dim=1)
75 | out = self.final_predict(out)
76 | return out
77 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/__init__.py
--------------------------------------------------------------------------------
/networks/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.backbone import resnet
2 |
3 | def build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained):
4 | if backbone == 'resnet101':
5 | return resnet.ResNet101(output_stride, BatchNorm,nInputChannels=nInputChannels,pretrained=pretrained)
6 | elif backbone == 'resnet50':
7 | return resnet.ResNet50(output_stride, BatchNorm,nInputChannels=nInputChannels,pretrained=pretrained)
8 | else:
9 | raise NotImplementedError
10 |
--------------------------------------------------------------------------------
/networks/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
4 |
5 | class Bottleneck(nn.Module):
6 | expansion = 4
7 |
8 | def __init__(self, inplanes, planes, stride=1, dilation=1,downsample=None, BatchNorm=None):
9 | super(Bottleneck, self).__init__()
10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
11 | self.bn1 = BatchNorm(planes)
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
13 | dilation=dilation, padding=dilation, bias=False)
14 | self.bn2 = BatchNorm(planes)
15 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
16 | self.bn3 = BatchNorm(planes * 4)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.downsample = downsample
19 | self.stride = stride
20 | self.dilation = dilation
21 |
22 | def forward(self, x):
23 | residual = x
24 |
25 | out = self.conv1(x)
26 | out = self.bn1(out)
27 | out = self.relu(out)
28 |
29 | out = self.conv2(out)
30 | out = self.bn2(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv3(out)
34 | out = self.bn3(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 | class ResNet(nn.Module):
45 |
46 | def __init__(self, block, layers, output_stride, BatchNorm,nInputChannels=3, pretrained=False):
47 | self.inplanes = 64
48 | super(ResNet, self).__init__()
49 | blocks = [1, 2, 4]
50 | if output_stride == 16:
51 | strides = [1, 2, 2, 1]
52 | dilations = [1, 1, 1, 2]
53 | elif output_stride == 8:
54 | strides = [1, 2, 1, 1]
55 | dilations = [1, 1, 2, 4]
56 | else:
57 | raise NotImplementedError
58 |
59 | # Modules
60 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3,
61 | bias=False)
62 | self.bn1 = BatchNorm(64)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 |
66 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
69 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
70 | self._init_weight()
71 |
72 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
73 | downsample = None
74 | if stride != 1 or self.inplanes != planes * block.expansion:
75 | downsample = nn.Sequential(
76 | nn.Conv2d(self.inplanes, planes * block.expansion,
77 | kernel_size=1, stride=stride, bias=False),
78 | BatchNorm(planes * block.expansion),
79 | )
80 |
81 | layers = []
82 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
83 | self.inplanes = planes * block.expansion
84 | for i in range(1, blocks):
85 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
86 |
87 | return nn.Sequential(*layers)
88 |
89 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
90 | downsample = None
91 | if stride != 1 or self.inplanes != planes * block.expansion:
92 | downsample = nn.Sequential(
93 | nn.Conv2d(self.inplanes, planes * block.expansion,
94 | kernel_size=1, stride=stride, bias=False),
95 | BatchNorm(planes * block.expansion),
96 | )
97 |
98 | layers = []
99 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
100 | downsample=downsample, BatchNorm=BatchNorm))
101 | self.inplanes = planes * block.expansion
102 | for i in range(1, len(blocks)):
103 | layers.append(block(self.inplanes, planes, stride=1,
104 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, input):
109 | x = self.conv1(input)
110 | x = self.bn1(x)
111 | x = self.relu(x)
112 | x = self.maxpool(x)
113 |
114 | x = self.layer1(x);low_level_feat_1 = x
115 | x = self.layer2(x);low_level_feat_2 = x
116 | x = self.layer3(x);low_level_feat_3 = x
117 | x = self.layer4(x)
118 |
119 | return [x, low_level_feat_3,low_level_feat_2,low_level_feat_1]
120 |
121 | def _init_weight(self):
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv2d):
124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
125 | m.weight.data.normal_(0, math.sqrt(2. / n))
126 | elif isinstance(m, SynchronizedBatchNorm2d):
127 | m.weight.data.fill_(1)
128 | m.bias.data.zero_()
129 | elif isinstance(m, nn.BatchNorm2d):
130 | m.weight.data.fill_(1)
131 | m.bias.data.zero_()
132 |
133 |
134 |
135 |
136 |
137 | def ResNet101(output_stride, BatchNorm,nInputChannels, pretrained=False):
138 | """Constructs a ResNet-101 model.
139 | Args:
140 | pretrained (bool): If True, returns a model pre-trained on ImageNet
141 | """
142 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, nInputChannels=nInputChannels,pretrained=pretrained)
143 | return model
144 |
145 | def ResNet50(output_stride, BatchNorm,nInputChannels, pretrained=False):
146 | """Constructs a ResNet-101 model.
147 | Args:
148 | pretrained (bool): If True, returns a model pre-trained on ImageNet
149 | """
150 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, nInputChannels=nInputChannels,pretrained=pretrained)
151 | return model
152 |
153 |
--------------------------------------------------------------------------------
/networks/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def class_cross_entropy_loss(output, label, size_average=False, batch_average=True, void_pixels=None):
5 | assert(output.size() == label.size())
6 | labels = torch.ge(label, 0.5).float()
7 | num_labels_pos = torch.sum(labels)
8 | num_labels_neg = torch.sum(1.0 - labels)
9 | num_total = num_labels_pos + num_labels_neg
10 | output_gt_zero = torch.ge(output, 0).float()
11 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log(
12 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero)))
13 | if void_pixels is not None:
14 | w_void = torch.le(void_pixels, 0.5).float()
15 | final_loss = torch.mul(w_void, loss_val)
16 | else:
17 | final_loss=loss_val
18 | final_loss = torch.sum(-final_loss)
19 | if size_average:
20 | final_loss /= np.prod(label.size())
21 | elif batch_average:
22 | final_loss /= label.size()[0]
23 | return final_loss
24 |
25 |
--------------------------------------------------------------------------------
/networks/mainnetwork.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from mypath import Path
6 | from networks.backbone import build_backbone
7 | from networks.CoarseNet import CoarseNet
8 | from networks.FineNet import FineNet
9 |
10 | affine_par = True
11 | class PSPModule(nn.Module):
12 | """
13 | Pyramid Scene Parsing module
14 | """
15 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1):
16 | super(PSPModule, self).__init__()
17 | self.stages = []
18 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes])
19 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features)
20 | self.relu = nn.ReLU()
21 |
22 | def _make_stage_1(self, in_features, size):
23 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
24 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False)
25 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par)
26 | relu = nn.ReLU(inplace=True)
27 | return nn.Sequential(prior, conv, bn, relu)
28 |
29 | def _make_stage_2(self, in_features, out_features):
30 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False)
31 | bn = nn.BatchNorm2d(out_features, affine=affine_par)
32 | relu = nn.ReLU(inplace=True)
33 |
34 | return nn.Sequential(conv, bn, relu)
35 |
36 | def forward(self, feats):
37 | h, w = feats.size(2), feats.size(3)
38 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages]
39 | priors.append(feats)
40 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1)))
41 | return bottle
42 |
43 | class SegmentationNetwork(nn.Module):
44 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21,nInputChannels=3,
45 | sync_bn=True, freeze_bn=False):
46 | super(SegmentationNetwork, self).__init__()
47 | output_shape = 128
48 | channel_settings = [512, 1024, 512, 256]
49 | self.Coarse_net = CoarseNet(channel_settings, output_shape, num_classes)
50 | self.Fine_net = FineNet(channel_settings[-1], output_shape, num_classes)
51 | BatchNorm = nn.BatchNorm2d
52 | self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained=False)
53 | self.psp4 = PSPModule(in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=256)
54 | self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True)
55 | if freeze_bn:
56 | self.freeze_bn()
57 |
58 | def forward(self, input):
59 | low_level_feat_4, low_level_feat_3,low_level_feat_2,low_level_feat_1 = self.backbone(input)
60 | low_level_feat_4 = self.psp4(low_level_feat_4)
61 | res_out = [low_level_feat_4, low_level_feat_3,low_level_feat_2,low_level_feat_1]
62 | coarse_fms, coarse_outs = self.Coarse_net(res_out)
63 | fine_out = self.Fine_net(coarse_fms)
64 | coarse_outs[0] = self.upsample(coarse_outs[0])
65 | coarse_outs[1] = self.upsample(coarse_outs[1])
66 | coarse_outs[2] = self.upsample(coarse_outs[2])
67 | coarse_outs[3] = self.upsample(coarse_outs[3])
68 | fine_out = self.upsample(fine_out)
69 | return coarse_outs[0],coarse_outs[1],coarse_outs[2],coarse_outs[3],fine_out
70 |
71 | def freeze_bn(self):
72 | for m in self.modules():
73 | if isinstance(m, nn.BatchNorm2d):
74 | m.eval()
75 |
76 | def get_1x_lr_params(self):
77 | modules = [self.backbone]
78 | for i in range(len(modules)):
79 | for m in modules[i].named_modules():
80 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
81 | for p in m[1].parameters():
82 | if p.requires_grad:
83 | yield p
84 |
85 | def get_10x_lr_params(self):
86 | modules = [self.Coarse_net,self.Fine_net,self.psp4,self.upsample]
87 | for i in range(len(modules)):
88 | for m in modules[i].named_modules():
89 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
90 | for p in m[1].parameters():
91 | if p.requires_grad:
92 | yield p
93 |
94 | def Network(nInputChannels=5,num_classes=1,backbone='resnet101',output_stride=16,
95 | sync_bn=None,freeze_bn=False,pretrained=False):
96 | model = SegmentationNetwork(nInputChannels=nInputChannels,num_classes=num_classes,backbone=backbone,
97 | output_stride=output_stride,sync_bn=sync_bn,freeze_bn=freeze_bn)
98 | if pretrained:
99 | load_pth_name= Path.models_dir()
100 | pretrain_dict = torch.load( load_pth_name,map_location=lambda storage, loc: storage)
101 | conv1_weight_new=np.zeros( (64,5,7,7) )
102 | conv1_weight_new[:,:3,:,:]=pretrain_dict['conv1.weight'].cpu().data
103 | pretrain_dict['conv1.weight']=torch.from_numpy(conv1_weight_new )
104 | state_dict = model.state_dict()
105 | model_dict = state_dict
106 | for k, v in pretrain_dict.items():
107 | kk='backbone.'+k
108 | if kk in state_dict:
109 | model_dict[kk] = v
110 | state_dict.update(model_dict)
111 | model.load_state_dict(state_dict)
112 | return model
113 |
--------------------------------------------------------------------------------
/networks/refinementnetwork.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import scipy.misc as sm
6 | from mypath import Path
7 | from networks.backbone import build_backbone
8 | from networks.CoarseNet import CoarseNet
9 | from networks.FineNet import FineNet
10 | from dataloaders.helpers import *
11 | affine_par = True
12 |
13 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64):
14 | x = np.arange(0, size[1], 1, float)
15 | y = np.arange(0, size[0], 1, float)
16 | y = y[:, np.newaxis]
17 | if center is None:
18 | x0 = y0 = size[0] // 2
19 | else:
20 | x0 = center[0]
21 | y0 = center[1]
22 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type)
23 |
24 | def getPositon(distance_transform):
25 | a = np.mat(distance_transform)
26 | raw, column = a.shape# get the matrix of a raw and column
27 | _positon = np.argmax(a)# get the index of max in the a
28 | m, n = divmod(_positon, column)
29 | raw=m
30 | column=n
31 | return raw,column
32 |
33 | def generate_distance_map(map_xor,points_center,points_bg,gt):
34 | distance_transform=ndimage.distance_transform_edt(map_xor)
35 | raw,column=getPositon(distance_transform)
36 | gt_0 = np.zeros(shape=(gt.shape[0],gt.shape[0]), dtype=np.float64)
37 | gt_0 [column,raw]= 1
38 | map_center=np.sum(np.logical_and(gt_0 ,gt))
39 | map_bg=np.sum(np.logical_and(gt_0 ,1-gt))
40 | sigma=10
41 | if map_center==1:
42 | points_center = 255*np.maximum(points_center/255, make_gaussian((gt.shape[0],gt.shape[0]), center=[column,raw], sigma=sigma))
43 | elif map_bg==1:
44 | points_bg = 255*np.maximum(points_bg/255, make_gaussian((gt.shape[0],gt.shape[0]), center=[column,raw], sigma=sigma))
45 | else:
46 | print('error')
47 | pointsgt_new = np.zeros(shape=(gt.shape[0], gt.shape[0], 2))
48 | pointsgt_new[:, :, 0]=points_center
49 | pointsgt_new[:, :, 1]=points_bg
50 | pointsgt_new = pointsgt_new.astype(dtype=np.uint8)
51 | pointsgt_new = pointsgt_new.transpose((2, 0, 1))
52 | pointsgt_new = pointsgt_new[np.newaxis,:, :, :]
53 | pointsgt_new = torch.from_numpy(pointsgt_new)
54 | return pointsgt_new
55 |
56 |
57 | def iou_cal( pre, gts,extreme_points,mask_thres=0.5):
58 | iu_ave=0
59 | distance_map_new = torch.zeros(extreme_points.shape)
60 | for jj in range(int(pre.shape[0])):
61 | pred = np.transpose(pre.cpu().data.numpy()[jj, :, :, :], (1, 2, 0))
62 | pred = 1 / (1 + np.exp(-pred))
63 | pred = np.squeeze(pred)
64 | gts=gts.cpu()
65 | gt = tens2image(gts[jj, :, :, :])
66 | extreme_points=extreme_points.cpu()
67 | points_center = tens2image(extreme_points[jj, 0:1, :, :])
68 | points_bg = tens2image(extreme_points[jj, 1:2, :, :])
69 | gt = (gt > mask_thres)
70 | pred= (pred > mask_thres)
71 | map_and=np.logical_and(pred ,gt)
72 | map_or=np.logical_or(pred ,gt)
73 | map_xor=np.bitwise_xor(pred,gt)
74 | if np.sum(map_or)==0:
75 | iu=0
76 | else:
77 | iu=np.sum(map_and)/np.sum(map_or)
78 | iu_ave=iu_ave+iu
79 | distance_map_new[jj,:,:,:]=generate_distance_map(map_xor,points_center,points_bg,gt)
80 | iu_ave=iu_ave/pre.shape[0]
81 | distance_map_new = distance_map_new.cuda()
82 | return iu_ave, distance_map_new
83 |
84 | class PSPModule(nn.Module):
85 | """
86 | Pyramid Scene Parsing module
87 | """
88 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1):
89 | super(PSPModule, self).__init__()
90 | self.stages = []
91 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes])
92 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features)
93 | self.relu = nn.ReLU()
94 |
95 | def _make_stage_1(self, in_features, size):
96 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
97 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False)
98 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par)
99 | relu = nn.ReLU(inplace=True)
100 | return nn.Sequential(prior, conv, bn, relu)
101 |
102 | def _make_stage_2(self, in_features, out_features):
103 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False)
104 | bn = nn.BatchNorm2d(out_features, affine=affine_par)
105 | relu = nn.ReLU(inplace=True)
106 |
107 | return nn.Sequential(conv, bn, relu)
108 |
109 | def forward(self, feats):
110 | h, w = feats.size(2), feats.size(3)
111 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages]
112 | priors.append(feats)
113 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1)))
114 | return bottle
115 |
116 | class SegmentationNetwork(nn.Module):
117 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21,nInputChannels=3,
118 | sync_bn=True, freeze_bn=False):
119 | super(SegmentationNetwork, self).__init__()
120 | output_shape = 128
121 | channel_settings = [512, 1024, 512, 256]
122 | self.Coarse_net = CoarseNet(channel_settings, output_shape, num_classes)
123 | self.Fine_net = FineNet(channel_settings[-1], output_shape, num_classes)
124 | BatchNorm = nn.BatchNorm2d
125 | self.backbone = build_backbone(backbone, output_stride, BatchNorm,nInputChannels,pretrained=False)
126 | self.psp4 = PSPModule(in_features=2048+64, out_features=512, sizes=(1, 2, 3, 6), n_classes=256)
127 | self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True)
128 | self.iog_points = nn.Sequential(nn.Conv2d(2, 64, kernel_size=3, stride=2, padding=1, bias=False),
129 | nn.BatchNorm2d(64),
130 | nn.ReLU(),
131 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
132 | nn.BatchNorm2d(128),
133 | nn.ReLU(),
134 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
135 | nn.BatchNorm2d(256),
136 | nn.ReLU(),
137 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),
138 | nn.BatchNorm2d(256),
139 | nn.ReLU(),
140 | nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False),
141 | nn.BatchNorm2d(64),
142 | nn.ReLU())
143 |
144 | if freeze_bn:
145 | self.freeze_bn()
146 |
147 | def forward(self, input,IOG_points,gts,refinement_num_max):
148 | low_level_feat_4_orig, low_level_feat_3_orig,low_level_feat_2_orig,low_level_feat_1_orig = self.backbone(input)
149 | feats_orig=low_level_feat_4_orig
150 | outlist=[]
151 | distance_map=IOG_points
152 | distance_map_512=distance_map
153 | for refinement_num in range(0,refinement_num_max):
154 | distance_map = self.iog_points(distance_map)
155 | feats_concat=torch.cat((feats_orig,distance_map),dim=1)#2048+64
156 |
157 | low_level_feat_4 = self.psp4(feats_concat)
158 | res_out = [low_level_feat_4, low_level_feat_3_orig,low_level_feat_2_orig,low_level_feat_1_orig]
159 | coarse_fms, coarse_outs = self.Coarse_net(res_out)
160 | fine_out = self.Fine_net(coarse_fms)
161 |
162 | out_512 = F.upsample(fine_out,size=(512, 512), mode='bilinear', align_corners=True)
163 | iou_i,distance_map_new = iou_cal(out_512,gts,distance_map_512)
164 | distance_map=distance_map_new
165 | distance_map_512 = distance_map
166 | out = [coarse_outs[0],coarse_outs[1],coarse_outs[2],coarse_outs[3],fine_out,iou_i]
167 | outlist.append(out)
168 | return outlist
169 |
170 | def freeze_bn(self):
171 | for m in self.modules():
172 | if isinstance(m, nn.BatchNorm2d):
173 | m.eval()
174 |
175 | def get_1x_lr_params(self):
176 | modules = [self.backbone]
177 | for i in range(len(modules)):
178 | for m in modules[i].named_modules():
179 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
180 | for p in m[1].parameters():
181 | if p.requires_grad:
182 | yield p
183 |
184 | def get_10x_lr_params(self):
185 | modules = [self.Coarse_net,self.Fine_net,self.psp4,self.upsample,self.iog_points]
186 | for i in range(len(modules)):
187 | for m in modules[i].named_modules():
188 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
189 | for p in m[1].parameters():
190 | if p.requires_grad:
191 | yield p
192 |
193 | def Network(nInputChannels=5,num_classes=1,backbone='resnet101',output_stride=16,
194 | sync_bn=None,freeze_bn=False,pretrained=False):
195 | model = SegmentationNetwork(nInputChannels=nInputChannels,num_classes=num_classes,backbone=backbone,
196 | output_stride=output_stride,sync_bn=sync_bn,freeze_bn=freeze_bn)
197 | if pretrained:
198 | load_pth_name= Path.models_dir()
199 | pretrain_dict = torch.load( load_pth_name,map_location=lambda storage, loc: storage)
200 | conv1_weight_new=np.zeros( (64,5,7,7) )
201 | conv1_weight_new[:,:3,:,:]=pretrain_dict['conv1.weight'].cpu().data
202 | pretrain_dict['conv1.weight']=torch.from_numpy(conv1_weight_new )
203 | state_dict = model.state_dict()
204 | model_dict = state_dict
205 | for k, v in pretrain_dict.items():
206 | kk='backbone.'+k
207 | if kk in state_dict:
208 | model_dict[kk] = v
209 | state_dict.update(model_dict)
210 | model.load_state_dict(state_dict)
211 | return model
212 |
--------------------------------------------------------------------------------
/networks/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/networks/sync_batchnorm/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/networks/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc
--------------------------------------------------------------------------------
/networks/sync_batchnorm/__pycache__/comm.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/comm.cpython-35.pyc
--------------------------------------------------------------------------------
/networks/sync_batchnorm/__pycache__/replicate.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyinzhang/Inside-Outside-Guidance/696ec2ddae2e994541cc9d81bb3c41984e233c64/networks/sync_batchnorm/__pycache__/replicate.cpython-35.pyc
--------------------------------------------------------------------------------
/networks/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 | .. math::
132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
134 | standard-deviation are reduced across all devices during training.
135 | For example, when one uses `nn.DataParallel` to wrap the network during
136 | training, PyTorch's implementation normalize the tensor on each device using
137 | the statistics only on that device, which accelerated the computation and
138 | is also easy to implement, but the statistics might be inaccurate.
139 | Instead, in this synchronized version, the statistics will be computed
140 | over all training samples distributed on multiple devices.
141 |
142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
143 | as the built-in PyTorch implementation.
144 | The mean and standard-deviation are calculated per-dimension over
145 | the mini-batches and gamma and beta are learnable parameter vectors
146 | of size C (where C is the input size).
147 | During training, this layer keeps a running estimate of its computed mean
148 | and variance. The running sum is kept with a default momentum of 0.1.
149 | During evaluation, this running mean/variance is used for normalization.
150 | Because the BatchNorm is done over the `C` dimension, computing statistics
151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152 | Args:
153 | num_features: num_features from an expected input of size
154 | `batch_size x num_features [x width]`
155 | eps: a value added to the denominator for numerical stability.
156 | Default: 1e-5
157 | momentum: the value used for the running_mean and running_var
158 | computation. Default: 0.1
159 | affine: a boolean value that when set to ``True``, gives the layer learnable
160 | affine parameters. Default: ``True``
161 | Shape:
162 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164 | Examples:
165 | >>> # With Learnable Parameters
166 | >>> m = SynchronizedBatchNorm1d(100)
167 | >>> # Without Learnable Parameters
168 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
169 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
170 | >>> output = m(input)
171 | """
172 |
173 | def _check_input_dim(self, input):
174 | if input.dim() != 2 and input.dim() != 3:
175 | raise ValueError('expected 2D or 3D input (got {}D input)'
176 | .format(input.dim()))
177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178 |
179 |
180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182 | of 3d inputs
183 | .. math::
184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
186 | standard-deviation are reduced across all devices during training.
187 | For example, when one uses `nn.DataParallel` to wrap the network during
188 | training, PyTorch's implementation normalize the tensor on each device using
189 | the statistics only on that device, which accelerated the computation and
190 | is also easy to implement, but the statistics might be inaccurate.
191 | Instead, in this synchronized version, the statistics will be computed
192 | over all training samples distributed on multiple devices.
193 |
194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
195 | as the built-in PyTorch implementation.
196 | The mean and standard-deviation are calculated per-dimension over
197 | the mini-batches and gamma and beta are learnable parameter vectors
198 | of size C (where C is the input size).
199 | During training, this layer keeps a running estimate of its computed mean
200 | and variance. The running sum is kept with a default momentum of 0.1.
201 | During evaluation, this running mean/variance is used for normalization.
202 | Because the BatchNorm is done over the `C` dimension, computing statistics
203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204 | Args:
205 | num_features: num_features from an expected input of
206 | size batch_size x num_features x height x width
207 | eps: a value added to the denominator for numerical stability.
208 | Default: 1e-5
209 | momentum: the value used for the running_mean and running_var
210 | computation. Default: 0.1
211 | affine: a boolean value that when set to ``True``, gives the layer learnable
212 | affine parameters. Default: ``True``
213 | Shape:
214 | - Input: :math:`(N, C, H, W)`
215 | - Output: :math:`(N, C, H, W)` (same shape as input)
216 | Examples:
217 | >>> # With Learnable Parameters
218 | >>> m = SynchronizedBatchNorm2d(100)
219 | >>> # Without Learnable Parameters
220 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222 | >>> output = m(input)
223 | """
224 |
225 | def _check_input_dim(self, input):
226 | if input.dim() != 4:
227 | raise ValueError('expected 4D input (got {}D input)'
228 | .format(input.dim()))
229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230 |
231 |
232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234 | of 4d inputs
235 | .. math::
236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
238 | standard-deviation are reduced across all devices during training.
239 | For example, when one uses `nn.DataParallel` to wrap the network during
240 | training, PyTorch's implementation normalize the tensor on each device using
241 | the statistics only on that device, which accelerated the computation and
242 | is also easy to implement, but the statistics might be inaccurate.
243 | Instead, in this synchronized version, the statistics will be computed
244 | over all training samples distributed on multiple devices.
245 |
246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
247 | as the built-in PyTorch implementation.
248 | The mean and standard-deviation are calculated per-dimension over
249 | the mini-batches and gamma and beta are learnable parameter vectors
250 | of size C (where C is the input size).
251 | During training, this layer keeps a running estimate of its computed mean
252 | and variance. The running sum is kept with a default momentum of 0.1.
253 | During evaluation, this running mean/variance is used for normalization.
254 | Because the BatchNorm is done over the `C` dimension, computing statistics
255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256 | or Spatio-temporal BatchNorm
257 | Args:
258 | num_features: num_features from an expected input of
259 | size batch_size x num_features x depth x height x width
260 | eps: a value added to the denominator for numerical stability.
261 | Default: 1e-5
262 | momentum: the value used for the running_mean and running_var
263 | computation. Default: 0.1
264 | affine: a boolean value that when set to ``True``, gives the layer learnable
265 | affine parameters. Default: ``True``
266 | Shape:
267 | - Input: :math:`(N, C, D, H, W)`
268 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
269 | Examples:
270 | >>> # With Learnable Parameters
271 | >>> m = SynchronizedBatchNorm3d(100)
272 | >>> # Without Learnable Parameters
273 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275 | >>> output = m(input)
276 | """
277 |
278 | def _check_input_dim(self, input):
279 | if input.dim() != 5:
280 | raise ValueError('expected 5D input (got {}D input)'
281 | .format(input.dim()))
282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/networks/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/networks/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/networks/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import scipy.misc as sm
3 | from collections import OrderedDict
4 | import glob
5 | import numpy as np
6 | import socket
7 |
8 | # PyTorch includes
9 | import torch
10 | import torch.optim as optim
11 | from torchvision import transforms
12 | from torch.utils.data import DataLoader
13 |
14 | # Custom includes
15 | from dataloaders.combine_dbs import CombineDBs as combine_dbs
16 | import dataloaders.pascal as pascal
17 | import dataloaders.sbd as sbd
18 | from dataloaders import custom_transforms as tr
19 | from networks.loss import class_cross_entropy_loss
20 | from dataloaders.helpers import *
21 | from networks.mainnetwork import *
22 |
23 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
24 | gpu_id = 0
25 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
26 | if torch.cuda.is_available():
27 | print('Using GPU: {} '.format(gpu_id))
28 |
29 | # Setting parameters
30 | resume_epoch = 100 # test epoch
31 | nInputChannels = 5 # Number of input channels (RGB + heatmap of IOG points)
32 |
33 | # Results and model directories (a new directory is generated for every run)
34 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
35 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
36 | if resume_epoch == 0:
37 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
38 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
39 | else:
40 | run_id = 0
41 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
42 | if not os.path.exists(os.path.join(save_dir, 'models')):
43 | os.makedirs(os.path.join(save_dir, 'models'))
44 |
45 | # Network definition
46 | modelName = 'IOG_pascal'
47 | net = Network(nInputChannels=nInputChannels,num_classes=1,
48 | backbone='resnet101',
49 | output_stride=16,
50 | sync_bn=None,
51 | freeze_bn=False)
52 |
53 | # load pretrain_dict
54 | pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))
55 | print("Initializing weights from: {}".format(
56 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
57 | net.load_state_dict(pretrain_dict)
58 | net.to(device)
59 |
60 | # Generate result of the validation images
61 | net.eval()
62 | composed_transforms_ts = transforms.Compose([
63 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
64 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
65 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
66 | tr.ToImage(norm_elem='IOG_points'),
67 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
68 | tr.ToTensor()])
69 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
70 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
71 |
72 | save_dir_res = os.path.join(save_dir, 'Results')
73 | if not os.path.exists(save_dir_res):
74 | os.makedirs(save_dir_res)
75 | save_dir_res_list=[save_dir_res]
76 | print('Testing Network')
77 | with torch.no_grad():
78 | for ii, sample_batched in enumerate(testloader):
79 | inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta']
80 | inputs = inputs.to(device)
81 | coarse_outs1,coarse_outs2,coarse_outs3,coarse_outs4,fine_out = net.forward(inputs)
82 | outputs = fine_out.to(torch.device('cpu'))
83 | pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0))
84 | pred = 1 / (1 + np.exp(-pred))
85 | pred = np.squeeze(pred)
86 | gt = tens2image(gts[0, :, :, :])
87 | bbox = get_bbox(gt, pad=30, zero_pad=True)
88 | result = crop2fullmask(pred, bbox, gt, zero_pad=True, relax=0,mask_relax=False)
89 | sm.imsave(os.path.join(save_dir_res_list[0], metas['image'][0] + '-' + metas['object'][0] + '.png'), result)
90 |
--------------------------------------------------------------------------------
/test_refine.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import scipy.misc as sm
3 | from collections import OrderedDict
4 | import glob
5 | import numpy as np
6 | import socket
7 | import timeit
8 |
9 | # PyTorch includes
10 | import torch
11 | import torch.optim as optim
12 | from torchvision import transforms
13 | from torch.utils.data import DataLoader
14 |
15 | # Custom includes
16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs
17 | import dataloaders.pascal as pascal
18 | import dataloaders.sbd as sbd
19 | from dataloaders import custom_transforms as tr
20 | from dataloaders.helpers import *
21 | from networks.loss import class_cross_entropy_loss
22 | from networks.refinementnetwork import *
23 | from torch.nn.functional import upsample
24 |
25 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
26 | gpu_id = 0
27 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
28 | if torch.cuda.is_available():
29 | print('Using GPU: {} '.format(gpu_id))
30 |
31 | # Setting parameters
32 | resume_epoch = 100 # test epoch
33 | nInputChannels = 5 # Number of input channels (RGB + heatmap of IOG points)
34 | refinement_num_max = 2 # the number of new points:
35 |
36 | # Results and model directories (a new directory is generated for every run)
37 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
38 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
39 | if resume_epoch == 0:
40 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
41 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
42 | else:
43 | run_id = 0
44 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
45 | if not os.path.exists(os.path.join(save_dir, 'models')):
46 | os.makedirs(os.path.join(save_dir, 'models'))
47 |
48 | # Network definition
49 | modelName = 'IOG_pascal_refinement'
50 | net = Network(nInputChannels=nInputChannels,num_classes=1,
51 | backbone='resnet101',
52 | output_stride=16,
53 | sync_bn=None,
54 | freeze_bn=False)
55 |
56 | # load pretrain_dict
57 | pretrain_dict = torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))
58 | print("Initializing weights from: {}".format(
59 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
60 | net.load_state_dict(pretrain_dict)
61 | net.to(device)
62 |
63 | # Generate result of the validation images
64 | net.eval()
65 | composed_transforms_ts = transforms.Compose([
66 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
67 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'gt':cv2.INTER_LINEAR,'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
68 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
69 | tr.ToImage(norm_elem='IOG_points'),
70 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
71 | tr.ToTensor()])
72 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
73 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
74 |
75 | save_dir_res_list=[]
76 | for add_clicks in range(0,refinement_num_max+1):
77 | save_dir_res = os.path.join(save_dir, 'Results-'+str(add_clicks))
78 | if not os.path.exists(save_dir_res):
79 | os.makedirs(save_dir_res)
80 | save_dir_res_list.append(save_dir_res)
81 |
82 | print('Testing Network')
83 | with torch.no_grad():
84 | # Main Testing Loop
85 | for ii, sample_batched in enumerate(testloader):
86 | metas = sample_batched['meta']
87 | gts = sample_batched['gt']
88 | gts_crop = sample_batched['crop_gt']
89 | inputs = sample_batched['concat']
90 | void_pixels = sample_batched['crop_void_pixels']
91 | IOG_points = sample_batched['IOG_points']
92 | inputs.requires_grad_()
93 | inputs, gts_crop ,void_pixels,IOG_points = inputs.to(device), gts_crop.to(device), void_pixels.to(device), IOG_points.to(device)
94 | out = net.forward(inputs,IOG_points,gts_crop,refinement_num_max+1)
95 | for i in range(0,refinement_num_max+1):
96 | glo1,glo2,glo3,glo4,refine,iou_i=out[i]
97 | output_refine = upsample(refine, size=(512, 512), mode='bilinear', align_corners=True)
98 | outputs = output_refine.to(torch.device('cpu'))
99 | pred = np.transpose(outputs.data.numpy()[0, :, :, :], (1, 2, 0))
100 | pred = 1 / (1 + np.exp(-pred))
101 | pred = np.squeeze(pred)
102 | gt = tens2image(gts[0, :, :, :])
103 | bbox = get_bbox(gt, pad=30, zero_pad=True)
104 | result = crop2fullmask(pred, bbox, gt, zero_pad=True, relax=0,mask_relax=False)
105 |
106 | # Save the result, attention to the index
107 | sm.imsave(os.path.join(save_dir_res_list[i], metas['image'][0] + '-' + metas['object'][0] + '.png'), result)
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import scipy.misc as sm
3 | from collections import OrderedDict
4 | import glob
5 | import numpy as np
6 | import socket
7 | import timeit
8 |
9 | # PyTorch includes
10 | import torch
11 | import torch.optim as optim
12 | from torchvision import transforms
13 | from torch.utils.data import DataLoader
14 |
15 | # Custom includes
16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs
17 | import dataloaders.pascal as pascal
18 | import dataloaders.sbd as sbd
19 | from dataloaders import custom_transforms as tr
20 | from dataloaders.helpers import *
21 | from networks.loss import class_cross_entropy_loss
22 | from networks.mainnetwork import *
23 |
24 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
25 | gpu_id = 0
26 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
27 | if torch.cuda.is_available():
28 | print('Using GPU: {} '.format(gpu_id))
29 |
30 | # Setting parameters
31 | use_sbd = False # train with SBD
32 | nEpochs = 100 # Number of epochs for training
33 | resume_epoch = 0 # Default is 0, change if want to resume
34 | p = OrderedDict() # Parameters to include in report
35 | p['trainBatch'] = 5 # Training batch size 5
36 | snapshot = 10 # Store a model every snapshot epochs
37 | nInputChannels = 5 # Number of input channels (RGB + heatmap of extreme points)
38 | p['nAveGrad'] = 1 # Average the gradient of several iterations
39 | p['lr'] = 1e-8 # Learning rate
40 | p['wd'] = 0.0005 # Weight decay
41 | p['momentum'] = 0.9 # Momentum
42 |
43 | # Results and model directories (a new directory is generated for every run)
44 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
45 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
46 | if resume_epoch == 0:
47 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
48 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
49 | else:
50 | run_id = 0
51 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
52 | if not os.path.exists(os.path.join(save_dir, 'models')):
53 | os.makedirs(os.path.join(save_dir, 'models'))
54 |
55 | # Network definition
56 | modelName = 'IOG_pascal'
57 | net = Network(nInputChannels=nInputChannels,num_classes=1,
58 | backbone='resnet101',
59 | output_stride=16,
60 | sync_bn=None,
61 | freeze_bn=False,
62 | pretrained=True)
63 | if resume_epoch == 0:
64 | print("Initializing from pretrained model")
65 | else:
66 | print("Initializing weights from: {}".format(
67 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
68 | net.load_state_dict(
69 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'),
70 | map_location=lambda storage, loc: storage))
71 | train_params = [{'params': net.get_1x_lr_params(), 'lr': p['lr']},
72 | {'params': net.get_10x_lr_params(), 'lr': p['lr'] * 10}]
73 | net.to(device)
74 |
75 | if resume_epoch != nEpochs:
76 | # Logging into Tensorboard
77 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
78 |
79 | # Use the following optimizer
80 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
81 | p['optimizer'] = str(optimizer)
82 |
83 | # Preparation of the data loaders
84 | composed_transforms_tr = transforms.Compose([
85 | tr.RandomHorizontalFlip(),
86 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
87 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
88 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
89 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
90 | tr.ToImage(norm_elem='IOG_points'),
91 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
92 | tr.ToTensor()])
93 |
94 | composed_transforms_ts = transforms.Compose([
95 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
96 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
97 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
98 | tr.ToImage(norm_elem='IOG_points'),
99 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
100 | tr.ToTensor()])
101 |
102 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
103 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
104 | if use_sbd:
105 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True)
106 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
107 | else:
108 | db_train = voc_train
109 |
110 | p['dataset_train'] = str(db_train)
111 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms]
112 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2)
113 |
114 | # Train variables
115 | num_img_tr = len(trainloader)
116 | running_loss_tr = 0.0
117 | aveGrad = 0
118 | print("Training Network")
119 | for epoch in range(resume_epoch, nEpochs):
120 | start_time = timeit.default_timer()
121 | epoch_loss = []
122 | net.train()
123 | for ii, sample_batched in enumerate(trainloader):
124 | gts = sample_batched['crop_gt']
125 | inputs = sample_batched['concat']
126 | void_pixels = sample_batched['crop_void_pixels']
127 | inputs.requires_grad_()
128 | inputs, gts ,void_pixels = inputs.to(device), gts.to(device), void_pixels.to(device)
129 | coarse_outs1,coarse_outs2,coarse_outs3,coarse_outs4,fine_out = net.forward(inputs)
130 |
131 | # Compute the losses
132 | loss_coarse_outs1 = class_cross_entropy_loss(coarse_outs1, gts, void_pixels=void_pixels)
133 | loss_coarse_outs2 = class_cross_entropy_loss(coarse_outs2, gts, void_pixels=void_pixels)
134 | loss_coarse_outs3 = class_cross_entropy_loss(coarse_outs3, gts, void_pixels=void_pixels)
135 | loss_coarse_outs4 = class_cross_entropy_loss(coarse_outs4, gts, void_pixels=void_pixels)
136 | loss_fine_out = class_cross_entropy_loss(fine_out, gts, void_pixels=void_pixels)
137 | loss = loss_coarse_outs1+loss_coarse_outs2+ loss_coarse_outs3+loss_coarse_outs4+loss_fine_out
138 |
139 | if ii % 10 ==0:
140 | print('Epoch',epoch,'step',ii,'loss',loss)
141 | running_loss_tr += loss.item()
142 |
143 | # Print stuff
144 | if ii % num_img_tr == num_img_tr - 1 -p['trainBatch']:
145 | running_loss_tr = running_loss_tr / num_img_tr
146 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0]))
147 | print('Loss: %f' % running_loss_tr)
148 | running_loss_tr = 0
149 | stop_time = timeit.default_timer()
150 | print("Execution time: " + str(stop_time - start_time)+"\n")
151 |
152 | # Backward the averaged gradient
153 | loss /= p['nAveGrad']
154 | loss.backward()
155 | aveGrad += 1
156 |
157 | # Update the weights once in p['nAveGrad'] forward passes
158 | if aveGrad % p['nAveGrad'] == 0:
159 | optimizer.step()
160 | optimizer.zero_grad()
161 | aveGrad = 0
162 |
163 | # Save the model
164 | if (epoch % snapshot) == snapshot - 1 and epoch != 0:
165 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
166 |
--------------------------------------------------------------------------------
/train_refine.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import scipy.misc as sm
3 | from collections import OrderedDict
4 | import glob
5 | import numpy as np
6 | import socket
7 | import timeit
8 |
9 | # PyTorch includes
10 | import torch
11 | import torch.optim as optim
12 | from torchvision import transforms
13 | from torch.utils.data import DataLoader
14 |
15 | # Custom includes
16 | from dataloaders.combine_dbs import CombineDBs as combine_dbs
17 | import dataloaders.pascal as pascal
18 | import dataloaders.sbd as sbd
19 | from dataloaders import custom_transforms as tr
20 | from dataloaders.helpers import *
21 | from networks.loss import class_cross_entropy_loss
22 | from networks.refinementnetwork import *
23 | from torch.nn.functional import upsample
24 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
25 | gpu_id = 0
26 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
27 | if torch.cuda.is_available():
28 | print('Using GPU: {} '.format(gpu_id))
29 |
30 | # Setting parameters
31 | use_sbd = False # train with SBD
32 | nEpochs = 100 # Number of epochs for training
33 | resume_epoch = 0 # Default is 0, change if want to resume
34 | p = OrderedDict() # Parameters to include in report
35 | p['trainBatch'] = 2 # Training batch size 5
36 | snapshot = 10 # Store a model every snapshot epochs
37 | nInputChannels = 5 # Number of input channels (RGB + heatmap of extreme points)
38 | p['nAveGrad'] = 1 # Average the gradient of several iterations
39 | p['lr'] = 1e-8 # Learning rate
40 | p['wd'] = 0.0005 # Weight decay
41 | p['momentum'] = 0.9 # Momentum
42 | threshold=0.95 # loss
43 | refinement_num_max = 1 # the number of new points:
44 | # Results and model directories (a new directory is generated for every run)
45 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
46 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
47 | if resume_epoch == 0:
48 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
49 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
50 | else:
51 | run_id = 0
52 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
53 | if not os.path.exists(os.path.join(save_dir, 'models')):
54 | os.makedirs(os.path.join(save_dir, 'models'))
55 |
56 | # Network definition
57 | modelName = 'IOG_pascal_refinement'
58 | net = Network(nInputChannels=nInputChannels,num_classes=1,
59 | backbone='resnet101',
60 | output_stride=16,
61 | sync_bn=None,
62 | freeze_bn=False,
63 | pretrained=True)
64 | if resume_epoch == 0:
65 | print("Initializing from pretrained model")
66 | else:
67 | print("Initializing weights from: {}".format(
68 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
69 | net.load_state_dict(
70 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'),
71 | map_location=lambda storage, loc: storage))
72 | train_params = [{'params': net.get_1x_lr_params(), 'lr': p['lr']},
73 | {'params': net.get_10x_lr_params(), 'lr': p['lr'] * 10}]
74 | net.to(device)
75 |
76 | if resume_epoch != nEpochs:
77 | # Logging into Tensorboard
78 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
79 |
80 | # Use the following optimizer
81 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
82 | p['optimizer'] = str(optimizer)
83 |
84 | # Preparation of the data loaders
85 | composed_transforms_tr = transforms.Compose([
86 | tr.RandomHorizontalFlip(),
87 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
88 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
89 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
90 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
91 | tr.ToImage(norm_elem='IOG_points'),
92 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
93 | tr.ToTensor()])
94 |
95 | composed_transforms_ts = transforms.Compose([
96 | tr.CropFromMask(crop_elems=('image', 'gt','void_pixels'), relax=30, zero_pad=True),
97 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512), 'crop_void_pixels': (512, 512)},flagvals={'crop_image':cv2.INTER_LINEAR,'crop_gt':cv2.INTER_LINEAR,'crop_void_pixels': cv2.INTER_LINEAR}),
98 | tr.IOGPoints(sigma=10, elem='crop_gt',pad_pixel=10),
99 | tr.ToImage(norm_elem='IOG_points'),
100 | tr.ConcatInputs(elems=('crop_image', 'IOG_points')),
101 | tr.ToTensor()])
102 |
103 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
104 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
105 | if use_sbd:
106 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True)
107 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
108 | else:
109 | db_train = voc_train
110 |
111 | p['dataset_train'] = str(db_train)
112 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms]
113 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2)
114 |
115 | # Train variables
116 | num_img_tr = len(trainloader)
117 | running_loss_tr = 0.0
118 | aveGrad = 0
119 | print("Training Network")
120 | for epoch in range(resume_epoch, nEpochs):
121 | start_time = timeit.default_timer()
122 | epoch_loss = []
123 | net.train()
124 | for ii, sample_batched in enumerate(trainloader):
125 | gts = sample_batched['crop_gt']
126 | inputs = sample_batched['concat']
127 | void_pixels = sample_batched['crop_void_pixels']
128 | IOG_points = sample_batched['IOG_points']
129 | inputs.requires_grad_()
130 | inputs, gts ,void_pixels,IOG_points = inputs.to(device), gts.to(device), void_pixels.to(device), IOG_points.to(device)
131 | out = net.forward(inputs,IOG_points,gts,refinement_num_max+1)
132 | for i in range(0,refinement_num_max+1):
133 | glo1,glo2,glo3,glo4,refine,iou_i=out[i]
134 | output_glo1 = upsample(glo1, size=(512, 512), mode='bilinear', align_corners=True)
135 | output_glo2 = upsample(glo2, size=(512, 512), mode='bilinear', align_corners=True)
136 | output_glo3 = upsample(glo3, size=(512, 512), mode='bilinear', align_corners=True)
137 | output_glo4 = upsample(glo4, size=(512, 512), mode='bilinear', align_corners=True)
138 | output_refine = upsample(refine, size=(512, 512), mode='bilinear', align_corners=True)
139 |
140 | # Compute the losses, side outputs and fuse
141 | loss_output_glo1 = class_cross_entropy_loss(output_glo1, gts, void_pixels=void_pixels,size_average=False, batch_average=True)
142 | loss_output_glo2 = class_cross_entropy_loss(output_glo2, gts, void_pixels=void_pixels,size_average=False, batch_average=True)
143 | loss_output_glo3 = class_cross_entropy_loss(output_glo3, gts, void_pixels=void_pixels,size_average=False, batch_average=True)
144 |
145 | loss_output_glo4 = class_cross_entropy_loss(output_glo4, gts, void_pixels=void_pixels,size_average=False, batch_average=True)
146 | loss_output_refine = class_cross_entropy_loss(output_refine, gts, void_pixels=void_pixels,size_average=False, batch_average=True)
147 |
148 | if i ==0:
149 | loss1 = loss_output_glo1+loss_output_glo2+ loss_output_glo3+loss_output_glo4+loss_output_glo4+loss_output_refine
150 | iou1 = iou_i
151 | if i ==1:
152 | loss2 = loss_output_glo1+loss_output_glo2+ loss_output_glo3+loss_output_glo4+loss_output_glo4+loss_output_refine
153 | iou2 = iou_i
154 |
155 | if iou1>=threshold:
156 | loss=loss1
157 | else:
158 | loss=0.5*loss1+0.5*loss2
159 |
160 | if ii % 10 ==0:
161 | print('Epoch',epoch,'step',ii,'loss',loss)
162 | running_loss_tr += loss.item()
163 |
164 | # Print stuff
165 | if ii % num_img_tr == num_img_tr - 1 -p['trainBatch']:
166 | running_loss_tr = running_loss_tr / num_img_tr
167 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0]))
168 | print('Loss: %f' % running_loss_tr)
169 | running_loss_tr = 0
170 | stop_time = timeit.default_timer()
171 | print("Execution time: " + str(stop_time - start_time)+"\n")
172 |
173 | # Backward the averaged gradient
174 | loss /= p['nAveGrad']
175 | loss.backward()
176 | aveGrad += 1
177 |
178 | # Update the weights once in p['nAveGrad'] forward passes
179 | if aveGrad % p['nAveGrad'] == 0:
180 | optimizer.step()
181 | optimizer.zero_grad()
182 | aveGrad = 0
183 |
184 | # Save the model
185 | if (epoch % snapshot) == snapshot - 1 and epoch != 0:
186 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
187 |
--------------------------------------------------------------------------------