├── .idea
├── misc.xml
├── modules.xml
├── segmentation.iml
├── vcs.xml
└── workspace.xml
├── README.md
├── data.py
├── data_stats.py
├── dataset
├── data_reader.py
└── seg_data.py
├── layers
├── layers_fcn_gcn.py
└── layers_unet.py
├── main.py
├── misc
├── carvana_test.png
├── carvana_test_overlay.png
├── fcn_gcn.png
└── unet.png
├── models
├── base_model.py
├── fcn_gcn_net.py
├── losses.py
└── u_net.py
└── utils
├── dataset_util.py
├── depricated.py
└── tfrecord_util.py
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/segmentation.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
131 |
132 |
133 |
134 | weights
135 | loss_op
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 | 1530817280059
232 |
233 |
234 | 1530817280059
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Segmentation
2 |
3 | This repository contains tensorflow implemenation of two models for semantic segmentations known to give high accuracy:
4 |
5 | * **U-net** (https://arxiv.org/abs/1505.04597) : The network can be found in [u_net.py](./u_net.py). Here is the architecture which I have borrowed from the paper :
6 | 
7 | There are a few minor differences in my implementation. I have used 'same' padding to simplify things. For the upsampling, I have simply used `tf.image.resize_images` function (see [layers_unet.py](./layers_unet.py)) . The full transpose convolution (deconvolution) layer is implemented for FCN described next.
8 |
9 | * **FCN** with **global convolution** (https://arxiv.org/abs/1703.02719) : The network can be found in [fcn_gcn_net.py](./fcn_gcn_net.py). Here is the architecture which I have borrowed from the paper :
10 | 
11 | Again, there are a few minor differences in my implementation. In particular, I have used VGG style encoder instead of ResNet blocks. All the layers/blocks used in the architecture (including the deconvolution layer) can be found in [layers_fcn_gcn_net.py](./layers_fcn_gcn_net.py).
12 |
13 |
14 | ### Kaggle - Carvana image masking challenge
15 | I applied these models to one of the Kaggle competetions where the background behind the object (in this case : cars) had to be removed. More details can be found here : [Kaggle : Carvana image masking challenge](https://www.kaggle.com/c/carvana-image-masking-challenge). Due to lack of time and resources, I ended up making only a *single submission* and got a score of **99.2%** (winning solution had a score of 99.7%). For this particular challenge, since there is only one class, U-net is a better model choice. Here is a sample result when U-net is applied to test image:
16 | 
17 | 
18 |
19 |
20 | **Scope for improvement** :
21 | There are several strategies that could have improved the score but I did not use due to lack of time:
22 |
23 | * Image preprocessing and data augmentation
24 | * Use higher resolution images : I was using AWS which wasn't fast enough to handle high resolution images. So, I had to scale down images considerably which leads to lower accuracy.
25 | * Tiling sub-regions of high resolution image : This strategy will ensure that each tile can fit in the GPU but is obviously more time consuming.
26 | * Apply Conditional Random Field post-processing
27 |
28 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from data_stats import prepare_data_stats
4 | from tqdm import tqdm
5 | import random
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | from PIL import Image
9 | import math
10 |
11 |
12 | class DataSet(object):
13 | def __init__(self, args, cfg, img_files, mask_files=None):
14 | self.params = args
15 | self.is_train = (args.phase == 'train')
16 | self.batch_size = args.batch_size
17 |
18 | self.cfg = {}
19 | self.cfg = cfg
20 |
21 | self.img_files = img_files
22 | self.mask_files = mask_files
23 |
24 | self.count = len(img_files)
25 | self.indices = list(range(self.count))
26 | self.current_index = 0
27 |
28 | def reset(self):
29 | """ Reset the dataset. """
30 | self.current_index = 0
31 |
32 | if self.is_train:
33 | np.random.shuffle(self.indices)
34 |
35 | def next_batch(self):
36 | """ Fetch the next batch. """
37 | start, end = self.current_index, self.current_index + self.batch_size
38 | current_indices = self.indices[start:end]
39 | img_files = np.array(self.img_files)[current_indices]
40 |
41 | if (self.params.phase == 'test'):
42 | self.current_index += self.batch_size
43 | return img_files, None
44 | else:
45 | mask_files = np.array(self.mask_files)[current_indices]
46 | self.current_index += self.batch_size
47 | return img_files, mask_files
48 |
49 |
50 | def prepare_train_data(args, cfg):
51 | """ Prepare data for training the model. """
52 | print("Preparing data for training...")
53 | image_dir, mask_dir, data_dir, set_ = (args.train_image_dir, args.train_mask_dir,
54 | args.train_data_dir, args.set)
55 |
56 | train_data_dir = os.path.join(args.train_data_dir, str(set_).zfill(2))
57 | if not os.path.exists(train_data_dir):
58 | os.makedirs(train_data_dir)
59 | data_stats_file = os.path.join(train_data_dir, 'data_stats.npz')
60 |
61 | if not os.path.exists(data_stats_file):
62 | prepare_data_stats(args)
63 |
64 | img_files, mask_files = \
65 | prepare_data(set_, image_dir, mask_dir)
66 |
67 | dataset = DataSet(args, cfg, img_files, mask_files)
68 | return dataset
69 |
70 |
71 | def prepare_test_data(args, cfg):
72 | """ Prepare data for testing the model. """
73 | print("Preparing data for testing...")
74 | image_dir, set_ = (args.test_image_dir, args.set)
75 | basedir = os.path.join(image_dir, str(set_).zfill(2))
76 | img_files = os.listdir(basedir)
77 | img_files = [os.path.join(basedir, f) for f in img_files]
78 | dataset = DataSet(args, cfg, img_files)
79 | return dataset
80 |
81 |
82 | def prepare_data(set_, image_dir, mask_dir):
83 | img_files = os.listdir(os.path.join(image_dir, str(set_).zfill(2)))
84 | mask_files = []
85 | img_files_abs = []
86 |
87 | print("Building data...")
88 | for f in tqdm(img_files):
89 | tag = f.split('.jpg')[0]
90 | mask_file = os.path.join(mask_dir, str(set_).zfill(2), tag + '_mask')
91 | if "augment" in f:
92 | mask_file += ".png"
93 | else:
94 | mask_file += ".gif"
95 | mask_files.append(mask_file)
96 | img_files_abs.append(os.path.join(image_dir, str(set_).zfill(2), f))
97 |
98 | print("Dataset built.")
99 | return img_files_abs, mask_files
100 |
101 |
102 | def augment(img, img_mask, data_stats_file, flip=False):
103 | data_stats = np.load(data_stats_file)
104 | left_min, left_max = data_stats['left_range']
105 | right_min, right_max = data_stats['right_range']
106 | top_min, top_max = data_stats['top_range']
107 | bottom_min, bottom_max = data_stats['bottom_range']
108 | height_min, height_max = data_stats['height_range']
109 | width_min, width_max = data_stats['width_range']
110 |
111 | l = random.randint(left_min, left_max)
112 | t = random.randint(top_min, top_max)
113 | max_h = min(height_max, bottom_max - t)
114 | max_w = min(width_max, right_max - l)
115 | min_h = max(height_min, bottom_min - t)
116 | min_w = max(width_min, right_min - l)
117 | t0, l0 = np.min(np.nonzero(img_mask), axis=1)
118 | b0, r0 = np.max(np.nonzero(img_mask), axis=1)
119 | h0, w0 = (b0 - t0), (r0 - l0)
120 | rw_min = min_w/w0
121 | rw_max = max_w/w0
122 | rh_min = min_h/h0
123 | rh_max = max_h/h0
124 | r_min = max(rw_min, rh_min)
125 | r_max = min(rw_max, rh_max)
126 | ratio = random.uniform(r_min, r_max)
127 | r = l + w0 * ratio
128 | b = t + h0 * ratio
129 | pts1 = np.float32([[l0, t0], [r0, t0], [r0, b0]])
130 | pts2 = np.float32([[l, t], [r, t], [r, b]])
131 | mat = cv2.getAffineTransform(pts1, pts2)
132 | if r < 1:
133 | interpolation = cv2.INTER_AREA
134 | else:
135 | interpolation = cv2.INTER_CUBIC
136 | rows, cols, _ = img.shape
137 | new_img = cv2.warpAffine(img, mat, (cols, rows), flags=interpolation)
138 | new_img_mask = cv2.warpAffine(img_mask, mat, (cols, rows), flags=interpolation)
139 |
140 | rot_angle = random.uniform(-1, 1)
141 | mat = cv2.getRotationMatrix2D((cols / 2, rows / 2), rot_angle, 1)
142 | new_img = cv2.warpAffine(new_img, mat, (cols, rows))
143 | new_img_mask = cv2.warpAffine(new_img_mask, mat, (cols, rows))
144 | if flip:
145 | if random.randint(0, 1):
146 | new_img = cv2.flip(new_img, 1)
147 | new_img_mask = cv2.flip(new_img_mask, 1)
148 | hsv = cv2.cvtColor(new_img, cv2.COLOR_RGB2HSV)
149 | hsv = np.float32(hsv)
150 | hue_shift = random.randint(-50, 50)
151 | hsv[:, :, 0][new_img_mask == 1] += hue_shift
152 | hsv[:, :, 0][hsv[:, :, 0] < 0] += 180
153 | val_scale = random.uniform(0.75, 1.25)
154 | hsv[:, :, 2] *= val_scale
155 | hsv[:, :, 2][hsv[:, :, 2] > 255] = 255
156 | val_scale = random.uniform(0.75, 1.25)
157 | hsv[:, :, 1] *= val_scale
158 | hsv[:, :, 1][hsv[:, :, 1] > 255] = 255
159 | hsv = np.uint8(hsv)
160 | new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
161 | return new_img, new_img_mask
162 |
163 |
164 | def augment_data(args):
165 | set_ = args.set
166 | image_dir = args.train_image_dir
167 | mask_dir = args.train_mask_dir
168 | data_stats_file = os.path.join(args.train_data_dir,
169 | str(set_).zfill(2), 'data_stats.npz')
170 | num_aug = args.augment_factor
171 | img_files = os.listdir(os.path.join(image_dir, str(set_).zfill(2)))
172 | print("Removing old augmentations...")
173 | for f in img_files:
174 | if "augment" in f:
175 | os.remove(os.path.join(image_dir, str(set_).zfill(2), f))
176 |
177 | flip_data = True
178 | num_aug_flipped = 0
179 | print("Creating augmented dataset...")
180 |
181 | if (set_ != 1) and (set_ != 9):
182 | num_aug_flipped = int(math.ceil(num_aug / 2))
183 | flipped_set = 18 - set_
184 | flip_data = False
185 | flipped_img_files = os.listdir(os.path.join(image_dir, str(flipped_set).zfill(2)))
186 | for f in tqdm(flipped_img_files):
187 | tag = f.split('.jpg')[0]
188 | s = flipped_set
189 | img_file = os.path.join(image_dir, str(s).zfill(2), tag + '.jpg')
190 | mask_file = os.path.join(mask_dir, str(s).zfill(2), tag + '_mask.gif')
191 | img = plt.imread(img_file)
192 | img_mask = plt.imread(mask_file)[:, :, 0] // 255
193 | img = cv2.flip(img, 1)
194 | img_mask = cv2.flip(img_mask, 1)
195 |
196 | for n_aug in range(num_aug_flipped):
197 | new_img, new_img_mask = augment(img, img_mask, data_stats_file, flip_data)
198 | new_img_file = os.path.join(image_dir, str(set_).zfill(2), tag +
199 | '_augment' + str(n_aug).zfill(2) + '.jpg')
200 | new_mask_file = os.path.join(mask_dir, str(set_).zfill(2), tag +
201 | '_augment' + str(n_aug).zfill(2) + '_mask.png')
202 | new_img = Image.fromarray(new_img)
203 | new_img.save(new_img_file)
204 | cv2.imwrite(new_mask_file, 255 * new_img_mask.astype(np.uint8))
205 |
206 | img_files = os.listdir(os.path.join(image_dir, str(set_).zfill(2)))
207 | for f in tqdm(img_files):
208 | tag = f.split('.jpg')[0]
209 | s = set_
210 | img_file = os.path.join(image_dir, str(s).zfill(2), tag + '.jpg')
211 | mask_file = os.path.join(mask_dir, str(s).zfill(2), tag + '_mask.gif')
212 | img = plt.imread(img_file)
213 | img_mask = plt.imread(mask_file)[:, :, 0] // 255
214 |
215 | for n_aug in range(num_aug_flipped, num_aug):
216 | new_img, new_img_mask = augment(img, img_mask, data_stats_file, flip_data)
217 | new_img_file = os.path.join(image_dir, str(set_).zfill(2), tag +
218 | '_augment' + str(n_aug).zfill(2) + '.jpg')
219 | new_mask_file = os.path.join(mask_dir, str(set_).zfill(2), tag +
220 | '_augment' + str(n_aug).zfill(2) + '_mask.png')
221 | new_img = Image.fromarray(new_img)
222 | new_img.save(new_img_file)
223 | cv2.imwrite(new_mask_file, 255 * new_img_mask.astype(np.uint8))
224 |
--------------------------------------------------------------------------------
/data_stats.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from utils import get_bbox
5 |
6 | cfg = {
7 | 'image_shape': [1280, 1918],
8 | 'scaled_h': 640,
9 | 'roi_h': 640,
10 | 'buffer': 50
11 | }
12 |
13 |
14 | #def extract_color_stats(img_dir, set_):
15 | # img_mean, img_std = [], []
16 |
17 | def extract_bbox(mask_dir, set_):
18 | gt_bboxes = []
19 | basedir = os.path.join(mask_dir, str(set_).zfill(2))
20 | mask_files = [os.path.join(basedir, f) for f in os.listdir(basedir) if 'augment' not in f]
21 | for mask_file in tqdm(mask_files):
22 | gt_bboxes.append(get_bbox(mask_file))
23 | gt_bboxes = np.array(gt_bboxes)
24 | return gt_bboxes
25 |
26 |
27 | def prepare_data_stats(args):
28 | """Prepare data statistics and save it in a config file"""
29 | set_ = args.set
30 | mask_dir = args.train_mask_dir
31 | data_stats_file = os.path.join(args.train_data_dir,
32 | str(set_).zfill(2), 'data_stats.npz')
33 | print("Extracting ground truth bboxes for set {}...".format(set_))
34 | gt_bboxes = extract_bbox(mask_dir, set_)
35 | if set_ != 1 and set_ != 9:
36 | flipped_set = 18 - set_
37 | print("Extracting ground truth bboxes for set {}...".format(flipped_set))
38 | gt_bboxes_flipped = extract_bbox(mask_dir, flipped_set)
39 | w = cfg['image_shape'][1]
40 | gt_bboxes_flipped[:,1] = w - gt_bboxes_flipped[:,1] - gt_bboxes_flipped[:,3]
41 | gt_bboxes = np.concatenate((gt_bboxes, gt_bboxes_flipped), axis=0)
42 |
43 | left_min, left_max = np.min(gt_bboxes[:, 1]), np.max(gt_bboxes[:, 1])
44 | top_min, top_max = np.min(gt_bboxes[:, 0]), np.max(gt_bboxes[:, 0])
45 | right_min = np.min(gt_bboxes[:, 1] + gt_bboxes[:, 3])
46 | right_max = np.max(gt_bboxes[:, 1] + gt_bboxes[:, 3])
47 | bottom_min = np.min(gt_bboxes[:, 0] + gt_bboxes[:, 2])
48 | bottom_max = np.max(gt_bboxes[:, 0] + gt_bboxes[:, 2])
49 | height_min, height_max = np.min(gt_bboxes[:, 2]), np.max(gt_bboxes[:, 2])
50 | width_min, width_max = np.min(gt_bboxes[:, 3]), np.max(gt_bboxes[:, 3])
51 | aspect = gt_bboxes[:, 3] / gt_bboxes[:, 2]
52 | aspect_min, aspect_max = np.min(aspect), np.max(aspect)
53 | aspect_median = np.median(aspect)
54 | top_std, left_std = np.std(gt_bboxes[:, :2], axis=0)
55 | bottom_std, right_std = np.std(gt_bboxes[:, :2] + gt_bboxes[:, 2:], axis=0)
56 | height_mean, width_mean = np.mean(gt_bboxes[:, 2:], axis=0)
57 | print("Saving data statistics to {}".format(data_stats_file))
58 | basedir = os.path.dirname(data_stats_file)
59 | if not os.path.exists(basedir):
60 | os.makedirs(basedir)
61 | np.savez(data_stats_file,
62 | left_range=np.array([left_min, left_max]),
63 | right_range=np.array([right_min, right_max]),
64 | top_range=np.array([top_min, top_max]),
65 | bottom_range=np.array([bottom_min, bottom_max]),
66 | height_range=np.array([height_min, height_max]),
67 | width_range=np.array([width_min, width_max]),
68 | aspect_stats=np.array([aspect_min, aspect_max, aspect_median]),
69 | mean=np.array([height_mean, width_mean]),
70 | std=np.array([left_std, right_std, top_std, bottom_std]))
71 | configure(args)
72 |
73 |
74 | def configure(args):
75 | """This function determines the ideal image preprocessing such cropping and resizing
76 | using data statistics."""
77 | set_ = args.set
78 | data_stats_file = os.path.join(args.train_data_dir,
79 | str(set_).zfill(2), 'data_stats.npz')
80 | config_file = os.path.join(args.train_data_dir,
81 | str(set_).zfill(2), 'config.npy')
82 | stats = np.load(data_stats_file)
83 |
84 | img_height, img_width = cfg['image_shape']
85 | left_crop = stats['left_range'][0] - 0.5 * stats['std'][0]
86 | left_crop = max(0., left_crop)
87 | right_crop = img_width - (stats['right_range'][1] + 0.5 * stats['std'][1])
88 | right_crop = max(0., right_crop)
89 | top_crop = stats['top_range'][0] - 0.7 * stats['std'][2]
90 | top_crop = max(0., top_crop)
91 | bottom_crop = img_height - (stats['bottom_range'][1] + 0.7 * stats['std'][3])
92 | bottom_crop = max(0., bottom_crop)
93 |
94 | if set_ == 1 or set == 9:
95 | left_crop = min(left_crop, right_crop)
96 | right_crop = left_crop
97 |
98 | h_cropped = img_height - top_crop - bottom_crop
99 | w_cropped = img_width - left_crop - right_crop
100 | aspect = np.round(w_cropped / h_cropped, decimals=1)
101 | print("Aspect ratio : ", aspect)
102 | delta_h = (w_cropped / aspect - h_cropped)/2
103 | top_crop = int(top_crop - delta_h)
104 | bottom_crop = int(bottom_crop - delta_h)
105 | left_crop = int(left_crop)
106 | right_crop = int(right_crop)
107 | crops = [left_crop, right_crop, top_crop, bottom_crop]
108 | print("Crop margins (l,r,t.b) : ", crops)
109 | cfg['crops'] = crops
110 |
111 | img_h = cfg['scaled_h']
112 | img_w = int(aspect * img_h)
113 | scale = img_w / w_cropped
114 | print("Scale factor : ", scale)
115 | print("Final image shape : ", [img_h, img_w])
116 | cfg['scale'] = scale
117 | cfg['scaled_img_shape'] = [img_h, img_w]
118 |
119 | roi_aspect = stats['aspect_stats'][2] # median
120 | roi_aspect = np.round(roi_aspect, decimals=1)
121 | roi_h = cfg['roi_h']
122 | roi_w = int(roi_aspect * roi_h)
123 | roi_shape = [roi_h, roi_w]
124 | print("Input roi size for segmentation : ", roi_shape)
125 | cfg['roi_shape'] = roi_shape
126 | print("Saving config data to {}".format(config_file))
127 |
128 | print("Extracting image mean...")
129 |
130 | np.save(config_file, cfg)
131 |
132 |
133 | def load_config(train_data_dir, set_):
134 | cfg_file = os.path.join(train_data_dir,
135 | str(set_).zfill(2), 'config.npy')
136 | return np.load(cfg_file).item()
137 |
138 | # img_h = 320
139 | # img_w = int(aspect * img_h)
140 | # if args.basic_model != 'vgg16':
141 | # img_h += 1
142 | # img_w += 1
143 | # scale = img_w/w_cropped
144 | # print("Scale factor : ", scale)
145 | # print("Final image shape : ", [img_h, img_w])
146 | # cfg['crops'] = crops
147 | # cfg['scale'] = scale
148 | # cfg['img_shape'] = [img_h, img_w]
149 | #
150 | # feat_stride = 16
151 | # print("Using feature stride of ", feat_stride)
152 | #
153 | # r_median = np.round(stats['aspect_stats'][2], decimals=1)
154 | # roi_h = 320
155 | # roi_w = int(r_median * roi_h)
156 | # roi_shape = [roi_h, roi_w]
157 | # print("Input roi size for mask-RCN : ", roi_shape)
158 | # cfg['roi_shape'] = roi_shape
159 | #
160 | # if args.basic_model == 'vgg16':
161 | # feat_shape = [img_h//feat_stride, img_w//feat_stride, 512]
162 | # roi_feat_shape = [roi_h // feat_stride, roi_w // feat_stride, 512]
163 | # else:
164 | # feat_shape = [(img_h-1)//feat_stride+1, (img_w-1)//feat_stride+1, 2048]
165 | # roi_feat_shape = [(roi_h-1)//feat_stride+1, (roi_w-1)//feat_stride+1, 2048]
166 | # print("Feature shape from RPN : ", feat_shape)
167 | # print("Feature shape from mask-RCN : ", roi_feat_shape)
168 | # cfg['feat_shape'] = feat_shape
169 | # cfg['roi_feat_shape'] = roi_feat_shape
170 |
171 |
172 | # class ImageLoader(object):
173 | # def __init__(self):
174 | # self.img_scale = cfg['scale']
175 | # self.crop_margin = cfg['crops']
176 | #
177 | # def load_img(self, img_file):
178 | # """ Load and preprocess an image. """
179 | # img = cv2.imread(img_file)
180 | #
181 | # if self.bgr:
182 | # cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
183 | #
184 | # left, right, top, bottom = self.crop_margin
185 | # img = img[top:img_height - bottom, left:img_width - right, :]
186 | #
187 | # h = img_height - top - bottom
188 | # w = img_width - left - right
189 | # img = cv2.resize(img, None, fx=self.img_scale, fy=self.img_scale, interpolation=cv2.INTER_AREA)
190 | # img = np.float32(img)
191 | #
192 | # img[:, :, 0] -= 123.68
193 | # img[:, :, 1] -= 116.78
194 | # img[:, :, 2] -= 103.94
195 | #
196 | # return img
197 | #
198 | # def load_imgs(self, img_files):
199 | # """ Load and preprocess a list of images. """
200 | # imgs = []
201 | # for img_file in img_files:
202 | # imgs.append(self.load_img(img_file))
203 | # imgs = np.array(imgs, np.float32)
204 | # return imgs
--------------------------------------------------------------------------------
/dataset/data_reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import tensorflow as tf
5 | import functools
6 | from utils.dataset_util import (
7 | rotate, random_hue, random_contrast, random_brightness)
8 | from dataset.seg_data import SegData
9 |
10 |
11 | slim_example_decoder = tf.contrib.slim.tfexample_decoder
12 |
13 |
14 | class SegDataReader(object):
15 | def __init__(self, data_cfg):
16 | self.data_cfg = data_cfg
17 | self.datasets = []
18 |
19 | for dataset in data_cfg.datasets:
20 | data_dir = dataset['data_dir']
21 | name = dataset['name']
22 | weight = dataset['weight']
23 | img_files = os.path.join(
24 | data_dir, 'images', '*')
25 | mask_files = [os.path.join(
26 | data_dir, 'masks', os.path.basename(f).split('.')[0] + '.png')
27 | for f in img_files]
28 | tfrecord_file = dataset['tfrecord_file']
29 | if (tfrecord_file is None) or dataset['overwrite_tfrecord']:
30 | tfrecord_name = os.path.basename(data_dir) + '.records'
31 | sub_dir = os.path.dirname(dataset['tfrecord_files'])
32 | tfrecord_path = os.path.join(data_dir, sub_dir, tfrecord_name)
33 | tfrecord_dir = os.path.dirname(tfrecord_path)
34 | if not os.path.exists(tfrecord_dir):
35 | os.makedirs(tfrecord_dir)
36 | ds = self.add_dataset(name, img_files, mask_files)
37 | ds.create_tf_record(tfrecord_path)
38 | self.datasets.append({'name': name,
39 | 'tfrecord_path': tfrecord_path,
40 | 'weight': weight})
41 |
42 | def add_dataset(self, name, img_files, mask_files):
43 | if name == 'png_objects':
44 | ds = SegData(self.data_cfg, img_files, mask_files)
45 | else:
46 | raise RuntimeError('Dataset not supported')
47 | return ds
48 |
49 | def _get_probs(self):
50 | probs = [ds['weight'] for ds in self.datasets]
51 | probs = np.array(probs)
52 | return probs / np.sum(probs)
53 |
54 | @staticmethod
55 | def _get_tensor(tensor):
56 | if isinstance(tensor, tf.SparseTensor):
57 | return tf.sparse_tensor_to_dense(tensor)
58 | return tensor
59 |
60 | @staticmethod
61 | def _image_decoder(keys_to_tensors):
62 | filename = keys_to_tensors['image/filename']
63 | image_string = tf.read_file(filename)
64 | image_decoded = tf.image.decode_jpeg(image_string)
65 | return image_decoded
66 |
67 | @staticmethod
68 | def _mask_decoder(keys_to_tensors):
69 | filename = keys_to_tensors['mask/filename']
70 | mask_string = tf.read_file(filename)
71 | mask_decoded = tf.image.decode_png(mask_string)
72 | return mask_decoded
73 |
74 | def _decoder(self):
75 | keys_to_features = {
76 | 'image/filename':
77 | tf.FixedLenFeature((), tf.string, default_value=''),
78 | 'mask/filename':
79 | tf.FixedLenFeature((), tf.string, default_value='')
80 | }
81 | items_to_handlers = {
82 | 'image': slim_example_decoder.ItemHandlerCallback(
83 | 'image/filename', self._image_decoder),
84 | 'mask': slim_example_decoder.ItemHandlerCallback(
85 | 'mask/filename', self._mask_decoder)
86 | }
87 | decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,
88 | items_to_handlers)
89 | return decoder
90 |
91 | def augment_data(self, dataset, train_cfg):
92 | aug_cfg = train_cfg.augmentation
93 | preprocess_cfg = train_cfg.preprocess
94 | img_size = preprocess_cfg['image_resize']
95 | if aug_cfg['flip_left_right']:
96 | random_flip_left_right_fn = functools.partial(
97 | random_flip_left_right,
98 | flipped_keypoint_indices=flipped_kp_indices)
99 | dataset = dataset.map(
100 | random_flip_left_right_fn,
101 | num_parallel_calls=train_cfg.num_parallel_map_calls
102 | )
103 | dataset = dataset.prefetch(train_cfg.prefetch_size)
104 | random_crop_fn = functools.partial(
105 | random_crop,
106 | crop_size=img_size,
107 | scale_range=aug_cfg['scale_range']
108 | )
109 | if aug_cfg['random_crop']:
110 | dataset = dataset.map(
111 | random_crop_fn,
112 | num_parallel_calls=train_cfg.num_parallel_map_calls
113 | )
114 | dataset = dataset.prefetch(train_cfg.prefetch_size)
115 | if aug_cfg['random_brightness']:
116 | dataset = dataset.map(
117 | random_brightness,
118 | num_parallel_calls=train_cfg.num_parallel_map_calls
119 | )
120 | dataset = dataset.prefetch(train_cfg.prefetch_size)
121 | if aug_cfg['random_contrast']:
122 | dataset = dataset.map(
123 | random_contrast,
124 | num_parallel_calls=train_cfg.num_parallel_map_calls
125 | )
126 | dataset = dataset.prefetch(train_cfg.prefetch_size)
127 | return dataset
128 |
129 | def preprocess_data(self, dataset, train_cfg):
130 | preprocess_cfg = train_cfg.preprocess
131 | img_size = preprocess_cfg['image_resize']
132 | resize_fn = functools.partial(
133 | resize,
134 | target_image_size=img_size)
135 | dataset = dataset.map(
136 | resize_fn,
137 | num_parallel_calls=train_cfg.num_parallel_map_calls
138 | )
139 | dataset.prefetch(train_cfg.prefetch_size)
140 | return dataset
141 |
142 | def read_data(self, train_config):
143 | probs = self._get_probs()
144 | probs = tf.cast(probs, tf.float32)
145 | decoder = self._decoder()
146 | filenames = [ds['tfrecord_path'] for ds in self.datasets]
147 | file_ids = list(range(len(filenames)))
148 | dataset = tf.data.Dataset.from_tensor_slices((file_ids, filenames))
149 | dataset = dataset.apply(tf.contrib.data.rejection_resample(
150 | class_func=lambda c, _: c,
151 | target_dist=probs,
152 | seed=42))
153 | dataset = dataset.map(lambda _, a: a[1])
154 | if train_config.shuffle:
155 | dataset = dataset.shuffle(
156 | train_config.filenames_shuffle_buffer_size)
157 |
158 | dataset = dataset.repeat(train_config.num_epochs or None)
159 |
160 | file_read_func = functools.partial(tf.data.TFRecordDataset,
161 | buffer_size=8 * 1000 * 1000)
162 | dataset = dataset.apply(
163 | tf.contrib.data.parallel_interleave(
164 | file_read_func, cycle_length=train_config.num_readers,
165 | block_length=train_config.read_block_length, sloppy=True))
166 | if train_config.shuffle:
167 | dataset = dataset.shuffle(train_config.shuffle_buffer_size)
168 |
169 | decode_fn = functools.partial(
170 | decoder.decode, items=['image', 'mask'])
171 | dataset = dataset.map(
172 | decode_fn, num_parallel_calls=train_config.num_parallel_map_calls)
173 | dataset = dataset.prefetch(train_config.prefetch_size)
174 |
175 | dataset = self.augment_data(dataset, train_config)
176 |
177 | dataset = self.preprocess_data(dataset, train_config)
178 | return dataset
179 |
--------------------------------------------------------------------------------
/dataset/seg_data.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import os
4 | from abc import abstractmethod
5 | from tqdm import tqdm
6 | from utils import tfrecord_util
7 | import tensorflow as tf
8 |
9 |
10 | class SegData(object):
11 | def __init__(self, cfg, img_files, mask_files):
12 | self.cfg = cfg
13 | self.img_files = img_files
14 | self.mask_files = mask_files
15 |
16 | def _create_tf_example(self, img_file, mask_file):
17 | feature_dict = {
18 | 'image/filename':
19 | tfrecord_util.bytes_feature(img_file.encode('utf8')),
20 | 'mask/filename':
21 | tfrecord_util.bytes_feature(mask_file.encode('utf8'))
22 | }
23 | return tf.train.Example(features=tf.train.Features(feature=feature_dict))
24 |
25 | def create_tf_record(self, out_path, shuffle=True):
26 | print("Creating tf records : ", out_path)
27 | writer = tf.python_io.TFRecordWriter(out_path)
28 | if shuffle:
29 | np.random.shuffle(self.img_files)
30 | for img_file, mask_file in tqdm(zip(self.img_files, self.mask_files)):
31 | tf_example = self._create_tf_example(img_file, mask_file)
32 | writer.write(tf_example.SerializeToString())
33 | writer.close()
34 |
--------------------------------------------------------------------------------
/layers/layers_fcn_gcn.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import math
4 |
5 |
6 | def conv_module(input_, n_filters, training, name, pool=True, activation=tf.nn.relu,
7 | padding='same', batch_norm=True):
8 | """{Conv -> BN -> RELU} x 2 -> {Pool, optional}
9 | reference : https://github.com/kkweon/UNet-in-Tensorflow
10 | Args:
11 | input_ (4-D Tensor): (batch_size, H, W, C)
12 | n_filters (int): depth of output tensor
13 | training (bool): If True, run in training mode
14 | name (str): name postfix
15 | pool (bool): If True, MaxPool2D after last conv layer
16 | activation: Activaion functions
17 | padding (str): 'same' or 'valid'
18 | batch_norm (bool) : If True, use batch-norm
19 | Returns:
20 | u_net: output of the Convolution operations
21 | pool (optional): output of the max pooling operations
22 | """
23 | kernel_sizes = [3,3]
24 | net = input_
25 | with tf.variable_scope("conv_module_{}".format(name)):
26 | for i, k_size in enumerate(kernel_sizes):
27 | net = tf.layers.conv2d(net, n_filters, (k_size, k_size), activation=None, padding=padding,
28 | name="conv_{}".format(i + 1))
29 |
30 | if batch_norm:
31 | net = tf.layers.batch_normalization(net, training=training, renorm=True,
32 | name="bn_{}".format(i + 1))
33 | net = activation(net, name="relu_{}".format(i + 1))
34 |
35 | if pool is False:
36 | return net
37 |
38 | pool = tf.layers.max_pooling2d(net, (2, 2), strides=(2, 2), name="pool")
39 |
40 | return net, pool
41 |
42 |
43 | def global_conv_module(input_, num_classes, training, name, k=13, padding='same'):
44 | """Global convolution network [https://arxiv.org/abs/1703.02719]
45 | Args:
46 | input_ (4-D Tensor): (batch_size, H, W, C)
47 | num_classes (integer) : Number of classes to classify
48 | name (str): name postfix
49 | k (integer): filter size for 1 x k + k x 1 convolutions
50 | padding (str) : 'same' or 'valid
51 | Returns:
52 | net (4-D Tensor): (batch_size, H, W, num_classes)
53 | """
54 | net = input_
55 | n_filters = num_classes
56 |
57 | with tf.variable_scope("global_conv_module_{}".format(name)):
58 | branch_a = tf.layers.conv2d(net, n_filters, (k, 1), activation=None,
59 | padding=padding, name='conv_1a')
60 | branch_a = tf.layers.conv2d(branch_a, n_filters, (1, k), activation=None,
61 | padding=padding, name='conv_2a')
62 |
63 | branch_b = tf.layers.conv2d(net, n_filters, (1, k), activation=None,
64 | padding=padding, name='conv_1b')
65 | branch_b = tf.layers.conv2d(branch_b, n_filters, (k, 1), activation=None,
66 | padding=padding, name='conv_2b')
67 |
68 | net = tf.add(branch_a, branch_b, name='sum')
69 |
70 | return net
71 |
72 |
73 | def boundary_refine(input_, training, name, activation=tf.nn.relu, batch_norm=True):
74 | """Boundary refinement network [https://arxiv.org/abs/1703.02719]
75 | Args:
76 | input_ (4-D Tensor): (batch_size, H, W, C)
77 | training (bool): If True, run in training mode
78 | name (str): name postfix
79 | activation: Activaion functions
80 | batch_norm (bool) : Whether to use batch norm
81 | Returns:
82 | net (4-D Tensor): output tensor of same shape as input_
83 | """
84 | net = input_
85 | n_filters = input_.get_shape()[3].value
86 |
87 | with tf.variable_scope("boundary_refine_module_{}".format(name)):
88 |
89 | net = tf.layers.conv2d(net, n_filters, (3, 3), activation=None,
90 | padding='SAME', name='conv_1')
91 | if batch_norm:
92 | net = tf.layers.batch_normalization(net, training=training,
93 | name='bn_1', renorm=True)
94 | net = activation(net, name='relu_1')
95 |
96 | net = tf.layers.conv2d(net, n_filters, (3, 3), activation=None,
97 | padding='SAME', name='conv_2')
98 | net = tf.add(net, input_, name='sum')
99 |
100 | return net
101 |
102 |
103 | def get_deconv_filter(name, n_channels, k_size):
104 | """Creates weight kernel initialization for deconvolution layer
105 | reference: https://github.com/MarvinTeichmann/tensorflow-fcn
106 | Args:
107 | name (str): name postfix
108 | n_channels (int): number of input and output channels are same
109 | k_size (int): kernel-size (~ 2 x stride for FCN case)
110 | Returns:
111 | weight kernels (4-D Tensor): (k_size , k_size, n_channels, n_channels)
112 | """
113 | k = k_size
114 | filter_shape = [k, k, n_channels, n_channels]
115 | f = math.ceil(k / 2.0)
116 | c = (2 * f - 1 - f % 2) / (2.0 * f)
117 | bilinear = np.zeros((k, k))
118 | for x, y in zip(range(k), range(k)):
119 | bilinear[x, y] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
120 | weights = np.zeros(filter_shape)
121 | for i in range(n_channels):
122 | weights[:, :, i, i] = bilinear
123 |
124 | init = tf.constant_initializer(value=weights, dtype=tf.float32)
125 | var = tf.get_variable(name=name, initializer=init, shape=weights.shape)
126 | return var
127 |
128 |
129 | def deconv_module(input_, name, stride=2, kernel_size=4, padding='SAME'):
130 | """ Convolutional transpose layer for upsampling score layer
131 | reference: https://github.com/MarvinTeichmann/tensorflow-fcn
132 | Args:
133 | input_ (4-D Tensor): (batch_size, H, W, C)
134 | name (str): name postfix
135 | stride (int): the upscaling factor (default is 2)
136 | kernel_size (int): (~ 2 x stride for FCN case)
137 | padding (str): 'same' or 'valid'
138 | Returns:
139 | net: output of transpose convolution operations
140 | """
141 | n_channels = input_.get_shape()[3].value
142 | strides = [1, stride, stride, 1]
143 | in_shape = input_.get_shape()
144 | h = in_shape[1].value * stride
145 | w = in_shape[2].value * stride
146 | out_shape = tf.stack([in_shape[0].value, h, w, n_channels])
147 | with tf.variable_scope('deconv_{}'.format(name)):
148 | weights = get_deconv_filter('up_filter_kernel', n_channels, k_size=kernel_size)
149 | deconv = tf.nn.conv2d_transpose(input_, weights, output_shape=out_shape,
150 | strides=strides, padding=padding)
151 | return deconv
152 |
--------------------------------------------------------------------------------
/layers/layers_unet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import math
4 |
5 |
6 | def conv_module(input_, n_filters, training, name, pool=True, activation=tf.nn.relu,
7 | padding='same', batch_norm=True):
8 | """{Conv -> BN -> RELU} x 2 -> {Pool, optional}
9 | reference : https://github.com/kkweon/UNet-in-Tensorflow
10 | Args:
11 | input_ (4-D Tensor): (batch_size, H, W, C)
12 | n_filters (int): depth of output tensor
13 | training (bool): If True, run in training mode
14 | name (str): name postfix
15 | pool (bool): If True, MaxPool2D after last conv layer
16 | activation: Activaion functions
17 | padding (str): 'same' or 'valid'
18 | batch_norm (bool) : If True, use batch-norm
19 | Returns:
20 | u_net: output of the Convolution operations
21 | pool (optional): output of the max pooling operations
22 | """
23 | kernel_sizes = [3,3]
24 | net = input_
25 | with tf.variable_scope("conv_module_{}".format(name)):
26 | for i, k_size in enumerate(kernel_sizes):
27 | net = tf.layers.conv2d(net, n_filters, (k_size, k_size), activation=None, padding=padding,
28 | name="conv_{}".format(i + 1))
29 |
30 | if batch_norm:
31 | net = tf.layers.batch_normalization(net, training=training, renorm=True,
32 | name="bn_{}".format(i + 1))
33 | net = activation(net, name="relu_{}".format(i + 1))
34 |
35 | if pool is False:
36 | return net
37 |
38 | pool = tf.layers.max_pooling2d(net, (2, 2), strides=(2, 2), name="pool")
39 |
40 | return net, pool
41 |
42 |
43 | def upsample(input_, name, upscale_factor=(2,2)):
44 | H, W, _ = input_.get_shape().as_list()[1:]
45 |
46 | target_H = H * upscale_factor[0]
47 | target_W = W * upscale_factor[1]
48 |
49 | return tf.image.resize_nearest_neighbor(input_, (target_H, target_W), name="upsample_{}".format(name))
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import sys
5 | from model import Model
6 | from data_stats import prepare_data_stats, load_config
7 | from data import prepare_train_data, prepare_test_data, augment_data
8 |
9 |
10 | def main(argv):
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--net', default='fcn_gcn', help='NN to use: can be unet or fcn_gcn')
13 | parser.add_argument('--phase', default='train', help='Phase: Can be train, val or test')
14 | parser.add_argument('--stage', type=int, default=1, help='Training stage')
15 | parser.add_argument('--load', action='store_true', default=False,
16 | help='Turn on to load the pretrained model')
17 |
18 | parser.add_argument('--prepare_data_stats', action='store_true', default=False,
19 | help='Turn on to prepare data statistics. Must do this for the first time of training.')
20 |
21 | parser.add_argument('--set', type=int, default=1,
22 | help='set for one of the zones/angles: Can be integer from 1 to 16')
23 |
24 | parser.add_argument('--train_image_dir', default='../data/train/images/',
25 | help='Directory containing training images')
26 | parser.add_argument('--train_mask_dir', default='../data/train/masks/',
27 | help='Directory containing masks for training images')
28 | parser.add_argument('--train_data_dir', default='../data/train/misc/',
29 | help='Directory to store temporary training data')
30 | parser.add_argument('--test_image_dir', default='../data/test/images/',
31 | help='Directory containing test images')
32 | parser.add_argument('--test_results_dir', default='../data/test/results/',
33 | help='Directory containing results for test set')
34 |
35 | parser.add_argument('--save_dir', default='./models/', help='Directory to contain the trained model')
36 |
37 | parser.add_argument('--save_period', type=int, default=100, help='Period to save the trained model')
38 | parser.add_argument('--display_period', type=int, default=20,
39 | help='Period over which to display loss.')
40 | parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
41 | parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
42 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
43 | parser.add_argument('--batch_norm', action='store_true', default=True,
44 | help='Turn on to use batch normalization')
45 | parser.add_argument('--sce_weight', type=float, default=1.,
46 | help='Adds softmax cross-entropy (SCE) loss when weight is non-zero')
47 | parser.add_argument('--edge_factor', type=int, default=0,
48 | help='Gives additional weight to edges when using SCE')
49 |
50 | parser.add_argument('--augment_data', action='store_true', default=False,
51 | help='Turn on to generate augmented data for the first time')
52 | parser.add_argument('--augment_factor', type=int, default=1,
53 | help='Factor by which to augment original data')
54 |
55 | args = parser.parse_args()
56 |
57 | if args.prepare_data_stats:
58 | prepare_data_stats(args)
59 |
60 | if args.augment_data:
61 | augment_data(args)
62 |
63 | cfg = load_config(args.train_data_dir, args.set)
64 | model = Model(args, cfg)
65 |
66 | if args.phase == 'train':
67 | train_data = prepare_train_data(args, cfg)
68 | model.train(train_data)
69 | elif args.phase == 'val':
70 | assert args.batch_size == 1
71 | train_data = prepare_train_data(args, cfg)
72 | model.validate(train_data)
73 | elif args.phase == 'test':
74 | assert args.batch_size == 1
75 | test_data = prepare_test_data(args, cfg)
76 | model.test(test_data)
77 | else:
78 | return
79 |
80 | if __name__ == "__main__":
81 | main(sys.argv)
82 |
83 |
--------------------------------------------------------------------------------
/misc/carvana_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/preritj/segmentation/fa0d8c3ac29cdc12e21983923d6364996b574532/misc/carvana_test.png
--------------------------------------------------------------------------------
/misc/carvana_test_overlay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/preritj/segmentation/fa0d8c3ac29cdc12e21983923d6364996b574532/misc/carvana_test_overlay.png
--------------------------------------------------------------------------------
/misc/fcn_gcn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/preritj/segmentation/fa0d8c3ac29cdc12e21983923d6364996b574532/misc/fcn_gcn.png
--------------------------------------------------------------------------------
/misc/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/preritj/segmentation/fa0d8c3ac29cdc12e21983923d6364996b574532/misc/unet.png
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | from models.losses import pixel_wise_loss
2 | from abc import abstractmethod
3 |
4 |
5 | class SegModel(object):
6 | def __init__(self, model_cfg):
7 | self.cfg = model_cfg
8 |
9 | @abstractmethod
10 | def preprocess(self, inputs):
11 | """Image preprocessing"""
12 | raise NotImplementedError("Not yet implemented")
13 |
14 | @abstractmethod
15 | def build_net(self, preprocessed_inputs, is_training=False):
16 | """Builds network and returns heatmaps and fpn features"""
17 | raise NotImplementedError("Not yet implemented")
18 |
19 | def predict(self, inputs, is_training=False):
20 | images = inputs['images']
21 | preprocessed_inputs = self.preprocess(images)
22 | mask_logits = self.build_net(
23 | preprocessed_inputs, is_training=is_training)
24 | prediction = {'mask_logits': mask_logits}
25 | return prediction
26 |
27 | def losses(self, prediction, ground_truth):
28 | mask_logits = prediction['mask_logits']
29 | masks_gt = ground_truth['masks']
30 | weights_gt = None
31 | if self.cfg.use_weights:
32 | weights_gt = ground_truth['weights']
33 | loss = pixel_wise_loss(mask_logits, masks_gt, pixel_weights=weights_gt)
34 | losses = {'CE_loss': loss}
35 | return losses
36 |
37 |
38 | # def create_tf_placeholders(self):
39 | # roi_h, roi_w = self.cfg['scaled_img_shape'] #self.cfg['roi_shape']
40 | # roi_images = tf.placeholder(tf.float32, [self.batch_size, roi_h, roi_w, 3])
41 | # roi_masks = tf.placeholder(tf.float32, [self.batch_size, roi_h, roi_w, 2])
42 | # roi_weights = tf.placeholder(tf.float32, [self.batch_size, roi_h, roi_w])
43 | # self.tf_placeholders = {'images': roi_images,
44 | # 'masks': roi_masks,
45 | # 'weights': roi_weights}
46 |
47 |
48 | # def make_train_op(self):
49 | # learning_rate = self.params.learning_rate
50 | # roi_masks = self.tf_placeholders["masks"]
51 | # roi_masks_pos = tf.slice(roi_masks, [0, 0, 0, 1], [-1, -1, -1, 1])
52 | # roi_masks_pos = tf.squeeze(roi_masks_pos, [-1])
53 | # roi_weights = self.tf_placeholders["weights"]
54 | # _, tf_mask = mask_prediction(self.mask_logits)
55 | # loss0 = dice_coef_loss(roi_masks_pos, tf_mask)
56 | # loss1 = pixel_wise_loss(self.mask_logits, roi_masks, pixel_weights=roi_weights)
57 | # loss = loss0 + self.params.sce_weight * loss1
58 | # solver = tf.train.AdamOptimizer(learning_rate, epsilon=1e-8)
59 | #
60 | # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
61 | # with tf.control_dependencies(update_ops):
62 | # self.train_op = solver.minimize(loss, global_step=self.global_step)
63 | # self.loss_op = [loss0, loss1]
64 | #
65 | # def make_eval_op(self):
66 | # pred_probs, pred_masks = mask_prediction(self.mask_logits)
67 | # self.eval_op = [pred_probs, pred_masks]
68 | #
69 | # def get_feed_dict(self, batch, perturb=True):
70 | # if self.stage == 1:
71 | # roi_images, roi_masks, roi_weights = \
72 | # self.image_loader.load_img_batch(batch, edge_factor=self.params.edge_factor)
73 | # else:
74 | # roi_images, roi_masks, roi_weights = \
75 | # self.image_loader.load_roi_batch(batch, perturb=perturb,
76 | # edge_factor=self.params.edge_factor)
77 | # tf_roi_images = self.tf_placeholders["images"]
78 | # if roi_masks is None:
79 | # return {tf_roi_images: roi_images}
80 | # tf_roi_masks = self.tf_placeholders["masks"]
81 | # tf_roi_weights = self.tf_placeholders["weights"]
82 | # return {tf_roi_images: roi_images,
83 | # tf_roi_masks: roi_masks,
84 | # tf_roi_weights: roi_weights}
85 | #
86 | # def train(self, data):
87 | # """ Train the model. """
88 | # params = self.params
89 | # save_dir = os.path.join(params.save_dir, str(params.set).zfill(2), 'stage_'+str(self.stage))
90 | # if not os.path.exists(save_dir):
91 | # os.makedirs(save_dir)
92 | # save_dir = os.path.join(save_dir, 'model')
93 | # self.make_train_op()
94 | #
95 | # with tf.Session() as sess:
96 | # sess.run(tf.global_variables_initializer())
97 | # saver = tf.train.Saver()
98 | # if params.load:
99 | # self.load(sess, saver)
100 | #
101 | # n_display = params.display_period
102 | # for i_epoch in tqdm(list(range(params.num_epochs)), desc='epoch'):
103 | # dice_loss, sce_loss, n_steps = 0, 0, 0
104 | # for _ in tqdm(list(range(0, data.count, self.batch_size)), desc='batch'):
105 | # batch = data.next_batch()
106 | # if len(batch[0]) < self.batch_size:
107 | # continue
108 | # ops = [self.train_op, self.global_step] + self.loss_op
109 | # feed_dict = self.get_feed_dict(batch, perturb=True)
110 | # _, global_step, loss0, loss1 = sess.run(ops, feed_dict=feed_dict)
111 | # if n_steps + 1 == n_display:
112 | # print("Dice coeff : {}, Cross entropy loss : {}"
113 | # .format(-dice_loss/n_steps, sce_loss/n_steps))
114 | # dice_loss, sce_loss, n_steps = 0, 0, 0
115 | # else:
116 | # dice_loss += loss0
117 | # sce_loss += loss1
118 | # n_steps += 1
119 | #
120 | # if (global_step + 1) % params.save_period == 0:
121 | # print("Saving model in {}".format(save_dir))
122 | # saver.save(sess, save_dir, global_step)
123 | # data.reset()
124 | # print("{} epochs finished.".format(i_epoch))
125 | #
126 | # def validate(self, data):
127 | # """ Test the model. """
128 | # # params = self.params
129 | # self.make_eval_op()
130 | #
131 | # with tf.Session() as sess:
132 | # sess.run(tf.global_variables_initializer())
133 | # saver = tf.train.Saver()
134 | # self.load(sess, saver)
135 | # for _ in tqdm(list(range(data.count)), desc='batch'):
136 | # batch = data.next_batch()
137 | # img_file, mask_file = batch[0][0], batch[1][0]
138 | #
139 | # gt_bbox = self.image_loader.generate_rois([mask_file], perturb=False)[0]
140 | # feed_dict = self.get_feed_dict(batch, perturb=False)
141 | # pred_probs, _ = sess.run(self.eval_op, feed_dict=feed_dict)
142 | # # pred_mask = np.zeros_like(pred_probs, dtype=np.uint8)
143 | # # pred_mask[np.where(pred_mask > 0.5)] = 1
144 | # # print(np.where(pred_mask > 0.5))
145 | # mask_pred = pred_probs[0, :, :, 1]
146 | # mask_pred[mask_pred > 0.5] = 1
147 | # mask_pred[mask_pred <= 0.5] = 0
148 | #
149 | # if True:
150 | # img = cv2.imread(img_file)
151 | # real_mask = np.zeros_like(img, dtype=np.uint8)
152 | # if self.stage == 1:
153 | # img_h, img_w = self.cfg['image_shape']
154 | # l, r, t, b = self.cfg['crops']
155 | # pred_mask = imresize(mask_pred, (img_h - t - b, img_w - l - r)) / 255
156 | # real_mask[t: img_h - b, l: img_w - r, 0] = np.uint8(np.round(pred_mask))
157 | # else:
158 | # y, x, h, w = gt_bbox
159 | # pred_mask = cv2.resize(mask_pred, (w, h))
160 | # real_mask[y:y + h, x:x + w, 0] = np.uint8(pred_mask)
161 | #
162 | #
163 | # winname = 'Image %s' % (img_file)
164 | # img = cv2.resize(img, (1438, 960))
165 | # img_mask = cv2.resize(real_mask * 255, (1438, 960), interpolation=cv2.INTER_CUBIC)
166 | # display_img = cv2.addWeighted(img, 0.2, img_mask, 0.8, 0)
167 | # cv2.imshow(winname, display_img)
168 | # cv2.moveWindow(winname, 100, 100)
169 | # cv2.waitKey(1000)
170 | #
171 | # gt_mask = self.image_loader.load_mask(mask_file)
172 | # print("Dice coefficient : ", dice_coef(gt_mask, real_mask[:,:,0]))
173 | #
174 | # def test(self, data):
175 | # """ Test the model. """
176 | # params = self.params
177 | # self.make_eval_op()
178 | #
179 | # res_dir = params.test_results_dir
180 | # res_dir = os.path.join(res_dir, str(params.set).zfill(2),
181 | # 'stage_' + str(params.stage))
182 | # if not os.path.exists(res_dir):
183 | # os.makedirs(res_dir)
184 | # img_names = []
185 | # rle_strings = []
186 | #
187 | # with tf.Session() as sess:
188 | # sess.run(tf.global_variables_initializer())
189 | # saver = tf.train.Saver()
190 | # self.load(sess, saver)
191 | # for _ in tqdm(list(range(data.count)), desc='batch'):
192 | # batch = data.next_batch()
193 | # img_file = batch[0][0]
194 | # feed_dict = self.get_feed_dict(batch, perturb=False)
195 | # pred_probs, _ = sess.run(self.eval_op, feed_dict=feed_dict)
196 | # # pred_mask = np.zeros_like(pred_probs, dtype=np.uint8)
197 | # # pred_mask[np.where(pred_mask > 0.5)] = 1
198 | # # print(np.where(pred_mask > 0.5))
199 | # mask_pred = pred_probs[0, :, :]
200 | # #mask_pred[mask_pred > 0.5] = 1
201 | # #mask_pred[mask_pred <= 0.5] = 0
202 | # real_mask = self.image_loader.postprocess(mask_pred)
203 | # rle = rle_encode(real_mask)
204 | # rle_strings.append(rle_to_string(rle))
205 | #
206 | # if 1:
207 | # img = cv2.imread(img_file)
208 | # img_mask = np.zeros_like(img)
209 | # img_mask[:, :, 0] = real_mask * 255
210 | # # y, x, h, w = gt_bbox
211 | # # print(gt_bbox)
212 | #
213 | # winname = 'Image %s' % (img_file)
214 | # img = cv2.resize(img, (1438, 960))
215 | #
216 | #
217 | # img_mask = cv2.resize(img_mask, (1438, 960))
218 | # display_img = cv2.addWeighted(img, 0.4, img_mask, 0.6, 0)
219 | # cv2.imshow(winname, display_img)
220 | # cv2.moveWindow(winname, 100, 100)
221 | # cv2.waitKey(1000)
222 | #
223 | # img_name = os.path.basename(img_file)
224 | # img_names.append(img_name)
225 | # #outfile = os.path.join(res_dir, str(img_name) + '.npy')
226 | # #np.save(outfile, mask_pred)
227 | # df = {'img' : img_names, 'rle_mask' : rle_strings}
228 | # df = pd.DataFrame(df)
229 | # outfile = os.path.join(res_dir, 'results.csv')
230 | # df.to_csv(outfile)
231 | #
232 | #
233 | #
234 | # def load(self, sess, saver):
235 | # """ Load the trained model. """
236 | # params = self.params
237 | # print("Loading model...")
238 | # load_dir = os.path.join(params.save_dir, str(params.set).zfill(2),
239 | # 'stage_'+str(params.stage), 'model')
240 | # checkpoint = tf.train.get_checkpoint_state(os.path.dirname(load_dir))
241 | # if checkpoint is None:
242 | # print("Error: No saved model found. Please train first.")
243 | # sys.exit(0)
244 | # saver.restore(sess, checkpoint.model_checkpoint_path)
--------------------------------------------------------------------------------
/models/fcn_gcn_net.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from layers.layers_fcn_gcn import (
3 | conv_module, global_conv_module, boundary_refine, deconv_module)
4 | from models.base_model import SegModel
5 |
6 |
7 | class FCNGCNnet(SegModel):
8 | def __init__(self, cfg):
9 | super().__init__(cfg)
10 |
11 | def preprocess(self, inputs):
12 | """Image preprocessing"""
13 | h, w = self.cfg.input_shape
14 | inputs = tf.reshape(inputs, [-1, h, w, 3])
15 | return 2.0 * tf.to_float(inputs) / 255. - 1.0
16 |
17 | def build_net(self, input_, is_training=False):
18 | """Based on https://arxiv.org/abs/1703.02719 but using VGG style base
19 | Args:
20 | input_ (4-D Tensor): (N, H, W, C)
21 | is_training (bool): If True, run in training mode
22 | Returns:
23 | output (4-D Tensor): (N, H, W, n)
24 | Logits classifying each pixel as either 'car' (1) or 'not car' (0)
25 | """
26 | num_classes = self.cfg.num_classes # Number of classes
27 | k_gcn = self.cfg.k_gcn # Kernel size for global conv layer
28 | init_channels = self.cfg.init_channels # Number of channels in the first conv layer
29 | n_layers = self.cfg.n_layers # Number of times to downsample/upsample
30 | batch_norm = self.cfg.batch_norm # if True, use batch-norm
31 |
32 | # color-space adjustment
33 | net = tf.layers.conv2d(input_, 3, (1, 1), name="color_space_adjust")
34 | n = n_layers
35 |
36 | # encoder
37 | feed = net
38 | ch = init_channels
39 | conv_blocks = []
40 | for i in range(n-1):
41 | conv, feed = conv_module(feed, ch, is_training, name=str(i + 1),
42 | batch_norm=batch_norm)
43 | conv_blocks.append(conv)
44 | ch *= 2
45 | last_conv = conv_module(feed, ch, is_training, name=str(n), pool=False,
46 | batch_norm=batch_norm)
47 | conv_blocks.append(last_conv)
48 |
49 | # global convolution network
50 | global_conv_blocks = []
51 | for i in range(n):
52 | global_conv_blocks.append(
53 | global_conv_module(conv_blocks[i], num_classes, is_training,
54 | k=k_gcn, name=str(i + 1)))
55 |
56 | # boundary refinement
57 | br_blocks = []
58 | for i in range(n):
59 | br_blocks.append(boundary_refine(global_conv_blocks[i], is_training,
60 | name=str(i + 1), batch_norm=batch_norm))
61 |
62 | # decoder / upsampling
63 | up_blocks = []
64 | last_br = br_blocks[-1]
65 | for i in range(n-1, 0, -1):
66 | deconv = deconv_module(last_br, name=str(i+1), stride=2, kernel_size=4)
67 | up = tf.add(deconv, br_blocks[i - 1])
68 | last_br = boundary_refine(up, is_training, name='up_' + str(i))
69 | up_blocks.append(up)
70 |
71 | logits = last_br
72 | return logits
73 |
74 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def dice_coef(y_true, y_pred, axis=None, smooth=0.001):
5 | if axis is None:
6 | axis=[1,2]
7 | y_true_f = tf.cast(y_true, dtype=tf.float32)
8 | y_pred_f = tf.cast(y_pred, dtype=tf.float32)
9 | intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=axis)
10 | dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f, axis=axis)
11 | + tf.reduce_sum(y_pred_f, axis=axis) + smooth)
12 | return tf.reduce_mean(dice)
13 |
14 |
15 | def dice_coef_loss(y_true, y_pred):
16 | return -dice_coef(y_true, y_pred)
17 |
18 |
19 | def pixel_wise_loss(pixel_logits, gt_pixels, pixel_weights=None):
20 | """Calculates pixel-wise softmax cross entropy loss
21 | Args:
22 | pixel_logits (4-D Tensor): (N, H, W, 2)
23 | gt_pixels (3-D Tensor): Image masks of shape (N, H, W, 2)
24 | pixel_weights (3-D Tensor) : (N, H, W) Weights for each pixel
25 | Returns:
26 | scalar loss : softmax cross-entropy
27 | """
28 | logits = tf.reshape(pixel_logits, [-1, 2])
29 | labels = tf.reshape(gt_pixels, [-1, 2])
30 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
31 | if pixel_weights is None:
32 | return tf.reduce_mean(loss)
33 | else:
34 | weights = tf.reshape(pixel_weights, [-1])
35 | return tf.reduce_sum(loss * weights) / tf.reduce_sum(weights)
36 |
37 |
38 | def mask_prediction(pixel_logits):
39 | """
40 | Args:
41 | pixel_logits (4-D Tensor): (N, H, W, 2)
42 | Returns:
43 | Predicted pixel-wise probabilities (3-D Tensor): (N, H, W)
44 | Predicted mask (3-D Tensor): (N, H, W)
45 | """
46 | probs = tf.nn.softmax(pixel_logits)
47 | n, h, w, _ = probs.get_shape()
48 | masks = tf.reshape(probs, [-1, 2])
49 | masks = tf.argmax(masks, axis=1)
50 | masks = tf.reshape(masks, [n.value, h.value, w.value])
51 | probs = tf.slice(probs, [0, 0, 0, 1], [-1, -1, -1, 1])
52 | probs = tf.squeeze(probs, axis=-1)
53 | return probs, masks
54 |
--------------------------------------------------------------------------------
/models/u_net.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from layers.layers_unet import conv_module, upsample
3 | from models.base_model import SegModel
4 |
5 |
6 | class UNet(SegModel):
7 | def __init__(self, cfg):
8 | super().__init__(cfg)
9 |
10 | def preprocess(self, inputs):
11 | """Image preprocessing"""
12 | h, w = self.cfg.input_shape
13 | inputs = tf.reshape(inputs, [-1, h, w, 3])
14 | return 2.0 * tf.to_float(inputs) / 255. - 1.0
15 |
16 | def build_net(self, input_, is_training=False):
17 | """Based on https://arxiv.org/abs/1505.04597
18 | Args:
19 | input_ (4-D Tensor): (N, H, W, C)
20 | is_training (bool): If True, run in training mode
21 | Returns:
22 | output (4-D Tensor): (N, H, W, n)
23 | Logits classifying each pixel as either 'car' (1) or 'not car' (0)
24 | """
25 | num_classes = self.cfg.num_classes # Number of classes
26 | n_layers = self.cfg.n_layers # Number of times to downsample/upsample
27 | init_channels = self.cfg.init_channels # Number of channels in the first conv layer
28 | batch_norm = self.cfg.batch_norm # if True, use batch-norm
29 |
30 | # color-space adjustment
31 | net = tf.layers.conv2d(input_, 3, (1, 1), name="color_space_adjust")
32 |
33 | # encoder
34 | feed = net
35 | ch = init_channels
36 | conv_blocks = []
37 | for i in range(n_layers):
38 | conv, feed = conv_module(feed, ch, is_training, name='down_{}'.format(i + 1),
39 | batch_norm=batch_norm)
40 | conv_blocks.append(conv)
41 | ch *= 2
42 | last_conv = conv_module(feed, ch, is_training, name='down_{}'.format(n_layers+1),
43 | pool=False, batch_norm=batch_norm)
44 | conv_blocks.append(last_conv)
45 |
46 | # decoder / upsampling
47 | feed = conv_blocks[-1]
48 | for i in range(n_layers, 0, -1):
49 | ch /= 2
50 | up = upsample(feed, name=str(i+1))
51 | concat = tf.concat([up, conv_blocks[i-1]], axis=-1, name="concat_{}".format(i))
52 | feed = conv_module(concat, ch, is_training, name='up_{}'.format(i), batch_norm=batch_norm,
53 | pool=False)
54 |
55 | logits = tf.layers.conv2d(feed, num_classes, (1, 1), name='logits', activation=None, padding='same')
56 | return logits
57 |
58 |
--------------------------------------------------------------------------------
/utils/dataset_util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def random_int(maxval, minval=0):
8 | return tf.random_uniform(
9 | shape=[], minval=minval, maxval=maxval, dtype=tf.int32)
10 |
11 |
12 | def random_rotate(image, mask):
13 | """random rotations in multiples of 90 degrees"""
14 | k = random_int(4)
15 | image = tf.image.rot90(image, k)
16 | mask = tf.image.rot90(mask, k)
17 | return image, mask
18 |
19 |
20 | def random_flip_left_right(img, mask):
21 | random_var = random_int(2)
22 | random_var = tf.cast(random_var, tf.bool)
23 | flipped_img = tf.cond(random_var,
24 | true_fn=lambda: tf.image.flip_left_right(img),
25 | false_fn=lambda: tf.identity(img))
26 | mask = tf.expand_dims(mask, axis=2)
27 | flipped_mask = tf.cond(random_var,
28 | true_fn=lambda: tf.image.flip_left_right(mask),
29 | false_fn=lambda: tf.identity(mask))
30 | return flipped_img, flipped_mask
31 |
32 |
33 | def random_brightness(image, mask):
34 | image = tf.image.random_brightness(
35 | image,
36 | max_delta=0.1)
37 | return image, mask
38 |
39 |
40 | def random_contrast(image, mask):
41 | image = tf.image.random_contrast(
42 | image,
43 | lower=0.9,
44 | upper=1.1)
45 | return image, mask
46 |
47 |
48 | def random_hue(image, mask):
49 | image = tf.image.random_hue(
50 | image,
51 | max_delta=0.1)
52 | return image, mask
53 |
54 |
55 | # def resize(image, keypoints, bbox, mask,
56 | # target_image_size=(224, 224),
57 | # target_mask_size=None):
58 | # img_size = list(target_image_size)
59 | # if target_mask_size is None:
60 | # target_mask_size = img_size
61 | # mask_size = list(target_mask_size)
62 | # new_image = tf.image.resize_images(image, size=img_size)
63 | # new_mask = tf.expand_dims(mask, axis=2)
64 | # new_mask.set_shape([None, None, 1])
65 | # new_mask = tf.image.resize_images(new_mask, size=mask_size)
66 | # new_mask = tf.squeeze(new_mask)
67 | # return new_image, keypoints, bbox, new_mask
68 |
69 |
70 | ###################################################
71 | # Some other potentially useful functions
72 | ###################################################
73 | def rle_encode(mask_image):
74 | pixels = mask_image.flatten()
75 | # We avoid issues with '1' at the start or end (at the corners of
76 | # the original image) by setting those pixels to '0' explicitly.
77 | # We do not expect these to be non-zero for an accurate mask,
78 | # so this should not harm the score.
79 | pixels[0] = 0
80 | pixels[-1] = 0
81 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
82 | runs[1::2] = runs[1::2] - runs[:-1:2]
83 | return runs
84 |
85 |
86 | def rle_to_string(runs):
87 | return ' '.join(str(x) for x in runs)
88 |
89 |
90 | def dice_coef(y_true, y_pred, smooth=0.001):
91 | y_true = np.array(y_true).flatten()
92 | y_pred = np.array(y_pred).flatten()
93 | intersection = np.sum(y_true * y_pred)
94 | dice = (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)
95 | return np.mean(dice)
96 |
97 |
98 | def get_bbox(mask_file):
99 | binary_img = plt.imread(mask_file)
100 | if binary_img.ndim > 2:
101 | binary_img = binary_img[:, :, 0] // 255
102 | ymin, xmin = np.min(np.nonzero(binary_img), axis=1)
103 | ymax, xmax = np.max(np.nonzero(binary_img), axis=1)
104 | return [ymin - 1, xmin - 1, ymax - ymin + 2, xmax - xmin + 2]
105 |
106 |
107 | def fix_aspect_ratio(cfg, rois):
108 | h_roi, w_roi = cfg['roi_shape']
109 | roi_aspect = w_roi / h_roi
110 |
111 | aspect_rois = rois[:, 3] / rois[:, 2]
112 | idx_a = np.where(aspect_rois > roi_aspect)[0]
113 | idx_b = np.where(aspect_rois < roi_aspect)[0]
114 |
115 | rois_a = rois[idx_a, :]
116 | desired_h = rois_a[:, 3] / roi_aspect
117 | delta_h = (desired_h - rois_a[:, 2]) / 2
118 | rois_a[:, 0] = rois_a[:, 0] - delta_h
119 | rois_a[:, 2] = desired_h
120 |
121 | rois_b = rois[idx_b, :]
122 | desired_w = rois_b[:, 2] * roi_aspect
123 | delta_w = (desired_w - rois_b[:, 3]) / 2
124 | rois_b[:, 1] = rois_b[:, 1] - delta_w
125 | rois_b[:, 3] = desired_w
126 |
127 | rois[idx_a, :] = rois_a
128 | rois[idx_b, :] = rois_b
129 |
130 | return rois
131 |
132 |
133 | def filter_mask(mask):
134 | kernel = np.ones((2, 2))
135 | mask_smooth = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
136 | mask_smooth = np.uint8(np.round(mask_smooth))
137 | blobs = cv2.connectedComponentsWithStats(mask_smooth, 4, cv2.CV_32S)
138 | stats = blobs[2]
139 | obj_label = None
140 | for i, stat in enumerate(stats):
141 | if stat[4] < 10000:
142 | continue
143 | elif (stat[0] < 2) and (stat[1] < 2):
144 | continue
145 | else:
146 | obj_label = i
147 | break
148 | blobs = blobs[1]
149 | blobs[blobs != obj_label] = 0
150 | return np.uint8(blobs)
151 |
152 |
153 |
154 |
155 |
--------------------------------------------------------------------------------
/utils/depricated.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import cv2
4 | from scipy.misc import imresize
5 |
6 |
7 | class ImageLoader(object):
8 | def __init__(self, cfg):
9 |
10 | self.cfg = {}
11 | self.cfg = cfg
12 |
13 | def postprocess(self, mask_pred):
14 | img_h, img_w = self.cfg['image_shape']
15 | l, r, t, b = self.cfg['crops']
16 | pred_mask = imresize(mask_pred, (img_h - t - b, img_w - l - r)) / 255
17 | real_mask = np.zeros((img_h, img_w))
18 | real_mask[t: img_h - b, l: img_w - r] = pred_mask
19 | real_mask = filter_mask(real_mask)
20 | return real_mask
21 |
22 | @staticmethod
23 | def load_img(img_file):
24 | img = cv2.imread(img_file)
25 | img = np.array(img, dtype=np.float32)
26 | return img / 127.5 - 1.
27 |
28 | @staticmethod
29 | def load_mask(mask_file):
30 | img_mask = plt.imread(mask_file)
31 | if img_mask.ndim > 2:
32 | img_mask = img_mask[:, :, 0] // 255
33 | img_mask = np.float32(img_mask)
34 | return np.float32(img_mask)
35 |
36 | def generate_rois(self, mask_files, perturb=True):
37 | rois = []
38 | for mask_file in mask_files:
39 | gt_bbox = get_bbox(mask_file)
40 | if perturb:
41 | pad = self.cfg['buffer']
42 | padded_bbox = gt_bbox + np.array([- pad, - pad, 2 * pad, 2 * pad])
43 | dy_c, dx_c = np.int32(0.1 * pad * (np.random.rand(2) - 0.5))
44 | dh, dw = np.int32(2 * pad * (np.random.rand(2) - 0.5))
45 | roi_bbox = (np.array(padded_bbox) +
46 | np.int32([dy_c - dh/2, dx_c - dw/2, dh, dw]))
47 | else:
48 | roi_bbox = gt_bbox
49 | rois.append(roi_bbox)
50 | rois = np.array(rois)
51 | rois = fix_aspect_ratio(self.cfg, rois)
52 | return rois
53 |
54 | def build_roi(self, img, mask, roi, edge_factor):
55 | img_h, img_w = self.cfg['image_shape']
56 | final_roi_h, final_roi_w = self.cfg['roi_shape']
57 | y, x, h, w = np.array(roi, dtype=np.int32)
58 | y_min = max(0, y)
59 | y_max = min(img_h, y + h)
60 | x_min = max(0, x)
61 | x_max = min(img_w, x + w)
62 | roi_img = np.zeros((h, w, 3))
63 | roi_mask = np.zeros((h, w))
64 | roi_h = y_max - y_min
65 | roi_w = x_max - x_min
66 | roi_y = (h - roi_h) // 2
67 | roi_x = (w - roi_w) // 2
68 | roi_img[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w, :] = img[y_min:y_max, x_min: x_max, :]
69 | roi_img = cv2.resize(roi_img, (final_roi_w, final_roi_h))
70 | roi_mask[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w] = mask[y_min:y_max, x_min: x_max]
71 | roi_mask = cv2.resize(roi_mask, (final_roi_w, final_roi_h))
72 | roi_mask = np.array(np.round(roi_mask), dtype=np.uint8)
73 |
74 | if edge_factor > 0:
75 | roi_mask_weight = self.apply_edge_weighting(roi_mask, edge_factor)
76 | else:
77 | roi_mask_weight = np.ones_like(roi_mask, dtype=np.float32)
78 |
79 | roi_mask_weight = self.apply_class_reweighting(roi_mask, roi_mask_weight)
80 | return roi_img, roi_mask, roi_mask_weight
81 |
82 | @staticmethod
83 | def apply_class_reweighting(mask, mask_weight):
84 | n_tot = mask.size
85 | n_pos = np.sum(mask == 1)
86 | n_neg = n_tot - n_pos
87 | pos_wt = n_pos / n_tot
88 | neg_wt = n_neg / n_tot
89 | mask_weight[mask > 0.5] *= pos_wt
90 | mask_weight[mask < 0.5] *= neg_wt
91 | return mask_weight
92 |
93 | @staticmethod
94 | def apply_edge_weighting(mask, edge_factor):
95 | mask_weight = np.ones_like(mask, dtype=np.float32)
96 | kernel = np.ones((33, 33), np.uint8)
97 | erosion = cv2.erode(mask, kernel, iterations=5)
98 | dilation = cv2.dilate(mask, kernel, iterations=2)
99 | dilation[erosion > 0] = 0
100 | n = edge_factor
101 | mask_weight += n * 1.0 * dilation
102 | mask_weight /= (1. + n)
103 | return mask_weight
104 |
105 | def load_roi_batch(self, batch, perturb=True, edge_factor=4):
106 | img_files, mask_files = batch
107 | rois = self.generate_rois(mask_files, perturb=perturb)
108 | roi_imgs, roi_masks, roi_mask_weights = [], [], []
109 | for img_file, mask_file, roi in zip(img_files, mask_files, rois):
110 | img = self.load_img(img_file)
111 | mask = self.load_mask(mask_file)
112 | roi_img, roi_mask, roi_weight = self.build_roi(img, mask, roi, edge_factor)
113 | roi_imgs.append(roi_img)
114 | roi_masks.append(roi_mask)
115 | roi_mask_weights.append(roi_weight)
116 | return np.array(roi_imgs), np.array(roi_masks), np.array(roi_mask_weights)
117 |
118 | def preprocess_image(self, img, mask, edge_factor):
119 | img_h, img_w = self.cfg['image_shape']
120 | l, r, t, b = self.cfg['crops']
121 | new_img_h, new_img_w = self.cfg['scaled_img_shape']
122 | new_img = img[t : img_h - b, l: img_w - r, :]
123 | new_img = cv2.resize(new_img, (new_img_w, new_img_h), interpolation=cv2.INTER_AREA)
124 | if mask is None:
125 | return new_img
126 | new_mask = mask[t : img_h - b, l: img_w - r]
127 | new_mask = imresize(new_mask, (new_img_h, new_img_w)) / 255
128 | #new_mask = cv2.resize(new_mask, (new_img_w, new_img_h), interpolation=cv2.INTER_AREA)
129 | if edge_factor > 0:
130 | mask_weight = self.apply_edge_weighting(new_mask, edge_factor)
131 | else:
132 | mask_weight = np.ones_like(new_mask, dtype=np.float32)
133 | new_mask = np.stack((1. - new_mask, new_mask), axis=-1)
134 | return new_img, new_mask, mask_weight
135 |
136 | def load_img_batch(self, batch, edge_factor=4):
137 | img_files, mask_files = batch
138 | if mask_files is None:
139 | imgs = [self.load_img(f) for f in img_files]
140 | imgs = [self.preprocess_image(img, None, 0) for img in imgs]
141 | return np.array(imgs), None, None
142 | imgs, masks, mask_weights = [], [], []
143 | for img_file, mask_file in zip(img_files, mask_files):
144 | img = self.load_img(img_file)
145 | mask = self.load_mask(mask_file)
146 | img, mask, mask_weight = self.preprocess_image(img, mask, edge_factor)
147 | imgs.append(img)
148 | masks.append(mask)
149 | mask_weights.append(mask_weight)
150 | return np.array(imgs), np.array(masks), np.array(mask_weights)
151 |
--------------------------------------------------------------------------------
/utils/tfrecord_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for creating TFRecord data sets."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def int64_feature(value):
22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
23 |
24 |
25 | def int64_list_feature(value):
26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
27 |
28 |
29 | def bytes_feature(value):
30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
31 |
32 |
33 | def bytes_list_feature(value):
34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
35 |
36 |
37 | def float_list_feature(value):
38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
39 |
40 |
41 | def read_examples_list(path):
42 | """Read list of training or validation examples.
43 | The file is assumed to contain a single example per line where the first
44 | token in the line is an identifier that allows us to find the image and
45 | annotation xml for that example.
46 | For example, the line:
47 | xyz 3
48 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
49 | Args:
50 | path: absolute path to examples list file.
51 | Returns:
52 | list of example identifiers (strings).
53 | """
54 | with tf.gfile.GFile(path) as fid:
55 | lines = fid.readlines()
56 | return [line.strip().split(' ')[0] for line in lines]
57 |
58 |
59 | def recursive_parse_xml_to_dict(xml):
60 | """Recursively parses XML contents to python dict.
61 | We assume that `object` tags are the only ones that can appear
62 | multiple times at the same level of a tree.
63 | Args:
64 | xml: xml tree obtained by parsing XML file contents using lxml.etree
65 | Returns:
66 | Python dictionary holding XML contents.
67 | """
68 | if not xml:
69 | return {xml.tag: xml.text}
70 | result = {}
71 | for child in xml:
72 | child_result = recursive_parse_xml_to_dict(child)
73 | if child.tag != 'object':
74 | result[child.tag] = child_result[child.tag]
75 | else:
76 | if child.tag not in result:
77 | result[child.tag] = []
78 | result[child.tag].append(child_result[child.tag])
79 | return {xml.tag: result}
80 |
81 |
82 | def make_initializable_iterator(dataset):
83 | """Creates an iterator, and initializes tables.
84 | This is useful in cases where make_one_shot_iterator wouldn't work because
85 | the graph contains a hash table that needs to be initialized.
86 | Args:
87 | dataset: A `tf.data.Dataset` object.
88 | Returns:
89 | A `tf.data.Iterator`.
90 | """
91 | iterator = dataset.make_initializable_iterator()
92 | tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
93 | return iterator
94 |
95 |
96 | def read_dataset(file_read_func, decode_func, input_files, config):
97 | """Reads a dataset, and handles repetition and shuffling.
98 | Args:
99 | file_read_func: Function to use in tf.data.Dataset.interleave, to read
100 | every individual file into a tf.data.Dataset.
101 | decode_func: Function to apply to all records.
102 | input_files: A list of file paths to read.
103 | config: A input_reader_builder.InputReader object.
104 | Returns:
105 | A tf.data.Dataset based on config.
106 | """
107 | # Shard, shuffle, and read files.
108 | filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
109 | 0)
110 | filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
111 | if config.shuffle:
112 | filename_dataset = filename_dataset.shuffle(
113 | config.filenames_shuffle_buffer_size)
114 | elif config.num_readers > 1:
115 | tf.logging.warning('`shuffle` is false, but the input data stream is '
116 | 'still slightly shuffled since `num_readers` > 1.')
117 |
118 | filename_dataset = filename_dataset.repeat(config.num_epochs or None)
119 |
120 | records_dataset = filename_dataset.apply(
121 | tf.contrib.data.parallel_interleave(
122 | file_read_func, cycle_length=config.num_readers,
123 | block_length=config.read_block_length, sloppy=True))
124 | if config.shuffle:
125 | records_dataset.shuffle(config.shuffle_buffer_size)
126 | tensor_dataset = records_dataset.map(
127 | decode_func, num_parallel_calls=config.num_parallel_map_calls)
128 | return tensor_dataset.prefetch(config.prefetch_size)
129 |
--------------------------------------------------------------------------------