├── .gitignore ├── LICENSE ├── README.md ├── create_tf_records_citypersons.py ├── detect.py ├── inference_aleatoric.py ├── inference_epistemic.py ├── inference_standard_yolov3.py ├── lib_yolo ├── __init__.py ├── darknet.py ├── data.py ├── data_augmentation.py ├── dataset_utils.py ├── layers.py ├── model.py ├── tfdata.py ├── train.py ├── utils.py └── yolov3.py ├── pretraining.py ├── uncertainty_training.py ├── vis_uncertainty.py └── yolov3_training.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | inference/ 3 | log/ 4 | tensorboard/ 5 | uncertainty_visualization/ 6 | test_images/ 7 | __pycache__/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Florian Kraus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bayesian-yolov3 2 | YOLOv3 object detection architecture with uncertainty estimation in TensorFlow. 3 | 4 | Accompanying code for: https://arxiv.org/abs/1905.10296 (IEEE Xplore: https://ieeexplore.ieee.org/document/8917494/) 5 | 6 | ### Citation 7 | If you find this work useful in your research, please consider citing: 8 | 9 | @inproceedings{kraus_uncertainty_2019, 10 | address = {Auckland, New Zealand}, 11 | title = {Uncertainty {Estimation} in {One}-{Stage} {Object} {Detection}}, 12 | url = {https://ieeexplore.ieee.org/document/8917494/}, 13 | doi = {10.1109/ITSC.2019.8917494}, 14 | booktitle = {2019 {IEEE} {Intelligent} {Transportation} {Systems} {Conference} ({ITSC})}, 15 | publisher = {IEEE}, 16 | author = {Kraus, Florian and Dietmayer, Klaus}, 17 | month = oct, 18 | year = {2019}, 19 | pages = {53--60} 20 | } 21 | 22 | ### Notes 23 | - Training examples with documentation: 24 | - pretraining.py - pretraining for models with uncertainty estimation 25 | - uncertainty_training.py - training for models with uncertainty estimation, 26 | use checkpoints produced by pretraining.py as a starting point. 27 | - yolov3_training.py - standard yolov3 without any uncertainty estimation 28 | - look for "edit" comments 29 | - Forward passes: 30 | - detect.py (processes a list of images) 31 | - inference_*.py scripts. They process tfrecord files and produce ECP (euro city persons) formated json files. 32 | - Note NMS yields up to 1000 boxes, might be slow. Change the "nms" functions if you want better performance. 33 | - Current NMS implementation ignores classes. 34 | Example code used in the paper is given as comments (only works for two classes). 35 | - look for "edit" comments 36 | - tfrecords format: 37 | - same as for the tensorflow object detection API, 38 | however we also support tfrecords files where the label ids start at 0 instead of 1. 39 | This is controlled by setting "implicit_background_class" to True (start at 1) or False (start at 0). 40 | - example script to create tfrecordsfile is provided (create_tf_records_citypersons.py) 41 | - Pretrained yolov3 weights: 42 | - you need to download the "darknet53.conv.74" from the original yolov3 site (pjreddie). 43 | - Most things you can change should be marked with an "edit" comment. 44 | -------------------------------------------------------------------------------- /create_tf_records_citypersons.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | import numpy as np 8 | import scipy.io 9 | import tensorflow as tf 10 | 11 | 12 | def int64_feature(value): 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 14 | 15 | 16 | def int64_list_feature(value): 17 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 18 | 19 | 20 | def bytes_feature(value): 21 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 22 | 23 | 24 | def bytes_list_feature(value): 25 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 26 | 27 | 28 | def float_list_feature(value): 29 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 30 | 31 | 32 | class ExampleCreator: 33 | def __init__(self, out_dir, dataset_name, label_to_text=None): 34 | self._out_dir = out_dir 35 | self._dataset_name = dataset_name 36 | 37 | # Create a single Session to run all image coding calls. 38 | self._sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) 39 | 40 | # Initializes function that decodes RGB PNG data. 41 | self._decode_data = tf.placeholder(dtype=tf.string) 42 | self._decoded = tf.image.decode_png(self._decode_data, channels=3) 43 | 44 | self._encode_data = tf.placeholder(dtype=tf.uint8) 45 | self._encoded = tf.image.encode_png(self._encode_data) 46 | 47 | self.label_to_text = label_to_text or [ 48 | 'ignore', 49 | 'pedestrian', 50 | 'rider', 51 | 'sitting', 52 | 'unusual', 53 | 'group', 54 | ] 55 | 56 | def get_shard_filename(self, shard, num_shards, split): 57 | shard_name = '{}-{}-{:05d}-of-{:05d}'.format(self._dataset_name, split, shard, num_shards) 58 | return os.path.join(self._out_dir, shard_name) 59 | 60 | def decode_png(self, img_data): 61 | img = self._sess.run(self._decoded, feed_dict={self._decode_data: img_data}) 62 | assert len(img.shape) == 3 63 | assert img.shape[2] == 3 64 | return img 65 | 66 | def encode_png(self, img): 67 | assert len(img.shape) == 3 68 | assert img.shape[2] == 3 69 | return self._sess.run(self._encoded, feed_dict={self._encode_data: img}) 70 | 71 | def load_img(self, path): 72 | ext = os.path.splitext(path)[1] 73 | if path.endswith('.pgm'): 74 | raise NotImplementedError('pgm not supported') 75 | if path.endswith('.png'): 76 | with tf.gfile.FastGFile(path, 'rb') as f: 77 | img_data = f.read() 78 | # seems a little bit stupid to first decode and then encode the image, but so what... 79 | return self.decode_png(img_data), ext[1:] 80 | else: 81 | raise NotImplementedError('unknown file format: {}'.format(ext)) 82 | 83 | def create_example(self, img_path, annotations): 84 | img, format = self.load_img(img_path) 85 | img_height, img_width = img.shape[:2] 86 | assert img_height == 1024 87 | assert img_width == 2048 88 | encoded = self.encode_png(img) 89 | 90 | ymin, xmin, ymax, xmax, label, text, inst_id = [], [], [], [], [], [], [] 91 | 92 | skipped_annotations = 0 93 | box_cnt = 0 94 | box_sizes = [] 95 | for anno in annotations: 96 | anno = anno.astype(np.int64) # this is important, otherwise overflows can occur 97 | 98 | class_label, x1, y1, w, h, instance_id, x1_vis, y1_vis, w_vis, h_vis = anno 99 | 100 | # we conform to the tf object detection API where 0 is reserved for the implicit background class 101 | # this ensures that tfrecord files which work with the object detection API also work with this framework 102 | if class_label == 2: 103 | # rider 104 | class_label = 2 105 | elif class_label in [0, 5]: 106 | # skip: ignore and group 107 | skipped_annotations += 1 108 | continue 109 | else: 110 | # pedestrian, sitting, unusual 111 | class_label = 1 112 | 113 | box_cnt += 1 114 | 115 | label_text = self.label_to_text[class_label] 116 | ymin.append(float(y1) / img_height) 117 | xmin.append(float(x1) / img_width) 118 | ymax.append(float(y1 + h) / img_height) 119 | xmax.append(float(x1 + w) / img_width) 120 | label.append(class_label) 121 | text.append(label_text.encode('utf8')) 122 | inst_id.append(instance_id) 123 | 124 | if 'group' not in label_text and 'ignore' not in label_text: 125 | # do not add group ore ignore boxes, we do not want these to affect the prior box calculation 126 | box_sizes.append((h, w)) 127 | 128 | if skipped_annotations > 0: 129 | logging.debug( 130 | 'Skipped {}/{} annotations for img {}'.format(skipped_annotations, len(annotations), img_path)) 131 | 132 | feature_dict = { 133 | 'image/height': int64_feature(img_height), 134 | 'image/width': int64_feature(img_width), 135 | 'image/filename': bytes_feature(img_path.encode('utf8')), 136 | 'image/source_id': bytes_feature(img_path.encode('utf8')), 137 | 'image/encoded': bytes_feature(encoded), 138 | 'image/format': bytes_feature('png'.encode('utf8')), 139 | 'image/object/bbox/xmin': float_list_feature(xmin), 140 | 'image/object/bbox/xmax': float_list_feature(xmax), 141 | 'image/object/bbox/ymin': float_list_feature(ymin), 142 | 'image/object/bbox/ymax': float_list_feature(ymax), 143 | 'image/object/class/text': bytes_list_feature(text), 144 | 'image/object/class/label': int64_list_feature(label), 145 | 'image/object/instance/id': int64_list_feature(inst_id), 146 | 'image/object/cnt': int64_feature(box_cnt), 147 | } 148 | 149 | example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) 150 | 151 | return example, skipped_annotations, box_sizes, (img_height, img_width) 152 | 153 | 154 | def write_shard(args): 155 | shard, num_shards, split, data, img_dir, example_creator = args 156 | out_file = example_creator.get_shard_filename(shard, num_shards, split) 157 | 158 | writer = tf.python_io.TFRecordWriter(out_file) 159 | logging.info('Creating shard {}-{}/{}'.format(split, shard, num_shards)) 160 | 161 | skipped_annotations = 0 162 | box_sizes = [] 163 | img_sizes = set() 164 | cnt = 0 165 | for cnt, datum in enumerate(data, start=1): 166 | datum = datum[0][0] # strange matlab file format 167 | city = str(datum[0][0]) 168 | img_name = str(datum[1][0]) 169 | annotations = datum[2] 170 | 171 | img_path = os.path.join(img_dir, city, img_name) 172 | 173 | example, skipped, sizes, img_size = example_creator.create_example(img_path, annotations) 174 | skipped_annotations += skipped 175 | box_sizes.extend(sizes) 176 | img_sizes.add(img_size) 177 | 178 | writer.write(example.SerializeToString()) 179 | if cnt % 10 == 0: 180 | logging.info('Written {} examples for shard {}-{}/{}'.format(cnt, split, shard, num_shards)) 181 | 182 | if skipped_annotations > 0: 183 | logging.info('Written {} examples for shard {}-{}/{}'.format(cnt, split, shard, num_shards)) 184 | 185 | logging.info( 186 | 'Finished shard {}-{}/{}: {} examples written and {} annotations skipped'.format(split, shard, num_shards, cnt, 187 | skipped_annotations)) 188 | return box_sizes, split, img_sizes 189 | 190 | 191 | def create_jobs(split, shuffle, annotations, img_dir, num_shards, example_creator): 192 | if shuffle: 193 | np.random.shuffle(annotations) 194 | 195 | # split into roughly even sized pieces 196 | k, m = divmod(len(annotations), num_shards) 197 | shards = [annotations[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_shards)] 198 | 199 | # check if we didn't f@#! it up 200 | total_length = 0 201 | for shard in shards: 202 | total_length += shard.shape[0] 203 | assert total_length == len(annotations) 204 | 205 | # create and run jobs 206 | jobs = [(shard_id + 1, num_shards, split, data, img_dir, example_creator) for shard_id, data in enumerate(shards)] 207 | return jobs 208 | 209 | 210 | def create_dirs(dirs): 211 | for path in dirs: 212 | try: 213 | os.makedirs(path) 214 | except OSError: 215 | assert os.path.isdir(path), '{} exists but is not a directory'.format(path) 216 | 217 | 218 | def process_dataset(out_dir, dataset_name, anno_dir, img_dir, train_shards, val_shards, shuffle): 219 | out_dir = os.path.expandvars(out_dir) 220 | img_dir = os.path.expandvars(img_dir) 221 | anno_dir = os.path.expandvars(anno_dir) 222 | 223 | create_dirs([out_dir]) 224 | 225 | if shuffle: 226 | with open(os.path.join(out_dir, '{}-np_random_state'.format(dataset_name)), 'wb') as f: 227 | pickle.dump(np.random.get_state(), f) 228 | 229 | # prepare train and val splits 230 | train_anno_path = os.path.join(anno_dir, 'annotations', 'anno_train.mat') 231 | val_anno_path = os.path.join(anno_dir, 'annotations', 'anno_val.mat') 232 | 233 | train_img_dir_ = os.path.join(img_dir, 'leftImg8bit_trainvaltest', 'leftImg8bit', 'train') 234 | val_img_dir = os.path.join(img_dir, 'leftImg8bit_trainvaltest', 'leftImg8bit', 'val') 235 | 236 | train_anno = scipy.io.loadmat(train_anno_path)['anno_train_aligned'][0] # citypersons data format 237 | val_anno = scipy.io.loadmat(val_anno_path)['anno_val_aligned'][0] # citypersons data format 238 | 239 | # object which does all the hard work 240 | example_creator = ExampleCreator(out_dir, dataset_name) 241 | 242 | # Process each split in a different thread 243 | train_jobs = create_jobs('train', shuffle, train_anno, train_img_dir_, train_shards, example_creator) 244 | val_jobs = create_jobs('val', shuffle, val_anno, val_img_dir, val_shards, example_creator) 245 | 246 | jobs = train_jobs + val_jobs 247 | 248 | with ThreadPoolExecutor() as executor: 249 | result = executor.map(write_shard, jobs, 250 | chunksize=1) # chunksize=1 is important, since our jobs are long running 251 | 252 | box_sizes = [] 253 | img_sizes = set() 254 | for sizes, split, img_sizes_ in result: 255 | img_sizes.update(img_sizes_) 256 | if split == 'train': 257 | box_sizes.extend(sizes) 258 | 259 | if len(img_sizes) > 1: 260 | logging.error('Different image sizes detected: {}'.format(img_sizes)) 261 | 262 | box_sizes = np.array(box_sizes, np.float64) 263 | np.save(os.path.join(out_dir, '{}-train-box_sizes'.format(dataset_name)), box_sizes) 264 | np.save(os.path.join(out_dir, '{}-img_size_height_width'.format(dataset_name)), list(img_sizes)[0]) 265 | 266 | 267 | def main(): 268 | config = { 269 | # Place to search for the created files. 270 | 'out_dir': '$HOME/data/citypersons/tfrecords_test', 271 | 272 | # Name of the dataset, used to create the tfrecord files. 273 | 'dataset_name': 'citypersons', 274 | 275 | # Base directory which contains the citypersons annotations. 276 | 'anno_dir': '$HOME/data/citypersons', # edit 277 | 278 | # Base directory which contains the cityscapes images. 279 | 'img_dir': '$HOME/data/cityscapes', 280 | 281 | # Number of training and validation shards. 282 | 'train_shards': 3, 283 | 'val_shards': 1, 284 | 285 | # Shuffle the data before writing it to tfrecord files. 286 | 'shuffle': True, 287 | } 288 | 289 | logging.info('Saving results to {}'.format(config['out_dir'])) 290 | logging.info('----- START -----') 291 | start = time.time() 292 | 293 | process_dataset(**config) 294 | 295 | end = time.time() 296 | elapsed = int(end - start) 297 | logging.info('----- FINISHED in {:02d}:{:02d}:{:02d} -----'.format(elapsed // 3600, 298 | (elapsed // 60) % 60, 299 | elapsed % 60)) 300 | 301 | 302 | if __name__ == '__main__': 303 | logging.basicConfig(level=logging.INFO, # edit change to DEBUG for more detailed output 304 | format='%(asctime)s, %(levelname)-8s %(message)s', 305 | datefmt='%a, %d %b %Y %H:%M:%S', 306 | ) 307 | 308 | main() 309 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import inference_aleatoric 11 | import inference_epistemic 12 | import inference_standard_yolov3 13 | from lib_yolo import yolov3 14 | 15 | 16 | def box_op_standard(model): 17 | bbox = inference_standard_yolov3.concat_bbox([det_layer.bbox for det_layer in model.det_layers]) 18 | nms = inference_standard_yolov3.nms(bbox, model) 19 | nms = nms[0, ...] 20 | return nms 21 | 22 | 23 | def box_op_aleatoric(model): 24 | bbox = inference_aleatoric.concat_bbox([det_layer.bbox for det_layer in model.det_layers]) 25 | nms = inference_aleatoric.nms(bbox, model) 26 | nms = nms[0, ...] 27 | return nms 28 | 29 | 30 | def box_op_bayes(model): 31 | bbox = inference_epistemic.concat_bbox([det_layer.bbox for det_layer in model.det_layers]) 32 | nms = inference_epistemic.nms(bbox, model) 33 | return nms 34 | 35 | 36 | def filter_boxes(boxes, obj_idx, thresh): 37 | return [box for box in boxes if box[obj_idx] > thresh] 38 | 39 | 40 | def preproces_boxes(img_size, boxes, obj_idx, cls_start_idx, cls_cnt, config, cls_mapping=None): 41 | out = [] 42 | for box in boxes: 43 | cls_idx = np.argmax(box[cls_start_idx:cls_start_idx + cls_cnt]) 44 | if config['implicit_background_class']: 45 | cls_idx += 1 46 | if cls_mapping: 47 | cls = cls_mapping[cls_idx] 48 | else: 49 | cls = cls_idx 50 | 51 | cls_score = box[cls_idx + cls_start_idx] 52 | out.append({ 53 | 'cls': cls, 54 | 'score': box[obj_idx] * cls_score, 55 | 'obj_score': box[obj_idx], 56 | 'cls_score': cls_score, 57 | 'y0': np.clip(box[0], 0, 1) * img_size[0], 58 | 'x0': np.clip(box[1], 0, 1) * img_size[1], 59 | 'y1': np.clip(box[2], 0, 1) * img_size[0], 60 | 'x1': np.clip(box[3], 0, 1) * img_size[1], 61 | }) 62 | 63 | return out 64 | 65 | 66 | def draw_boxes(img, boxes, color=(43, 219, 216), thickness=1): 67 | color = np.array(color) / 255. 68 | for box in boxes: 69 | text = '{} {:4.3f}'.format(box['cls'], box['score']) 70 | size = 0.5 71 | cv2.putText(img, text, (int(box['x0']), int(box['y0'])), cv2.FONT_HERSHEY_SIMPLEX, size, color, thickness) 72 | 73 | cv2.rectangle(img, (int(box['x0']), int(box['y0'])), (int(box['x1']), int(box['y1'])), color, thickness) 74 | 75 | 76 | def load_img(config, img_size, filename): 77 | img = plt.imread(filename) # loads image as np.float32 array 78 | 79 | if config['crop']: 80 | y = (img.shape[0] - img_size[0]) // 2 81 | x = (img.shape[1] - img_size[1]) // 2 82 | img = img[y:y + img_size[0], x:x + img_size[1], :] 83 | 84 | img = np.expand_dims(img, axis=0) 85 | return img 86 | 87 | 88 | def load_model(sess, config, model_cls): 89 | if model_cls == yolov3.bayesian_yolov3_aleatoric: 90 | config['inference_mode'] = True 91 | 92 | yolo = model_cls(config) 93 | img_tensor = tf.placeholder(tf.float32, shape=(1, *yolo.img_size)) 94 | model = yolo.init_model(inputs=img_tensor, training=False).get_model() 95 | 96 | checkpoints = os.path.join(config['checkpoint_path'], config['run_id']) 97 | if config['step'] == 'last': 98 | checkpoint = tf.train.latest_checkpoint(checkpoints) 99 | else: 100 | checkpoint = None 101 | for cp in os.listdir(checkpoints): 102 | if cp.endswith('-{}.meta'.format(config['step'])): 103 | checkpoint = os.path.join(checkpoints, os.path.splitext(cp)[0]) 104 | break 105 | assert checkpoint is not None, 'could not find checkpoint' 106 | 107 | tf.train.Saver().restore(sess, checkpoint) 108 | 109 | return model, img_tensor 110 | 111 | 112 | def do_it(files, thresh, config, model_cls, cls_mapping): 113 | box_op = { 114 | yolov3.yolov3: box_op_standard, 115 | yolov3.yolov3_aleatoric: box_op_aleatoric, 116 | yolov3.bayesian_yolov3_aleatoric: box_op_bayes, 117 | }[model_cls] 118 | 119 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 120 | model, img_tensor = load_model(sess, config, model_cls) 121 | img_size = img_tensor.shape.as_list()[1:] 122 | for file in files: 123 | img = load_img(config, img_size, file) 124 | boxes, = sess.run([box_op(model)], feed_dict={img_tensor: img}) 125 | boxes = filter_boxes(boxes, model.obj_idx, thresh) 126 | boxes = preproces_boxes(img_size, boxes, model.obj_idx, model.cls_start_idx, model.cls_cnt, 127 | config, cls_mapping=cls_mapping) 128 | 129 | img = img[0, ...] 130 | draw_boxes(img, boxes) 131 | logging.info('{}: {}'.format(os.path.basename(file), boxes)) 132 | 133 | plt.imshow(img) 134 | plt.show() 135 | # plt.imsave(filename, img) # 136 | 137 | 138 | def main(): 139 | config = { 140 | 'checkpoint_path': './checkpoints/', 141 | 'run_id': 'epi_ale', # edit 142 | 'step': 'last', # edit: int or 'last' 143 | 'crop_img_size': [768, 1440, 3], 144 | 'full_img_size': [1024, 1920, 3], # edit if not ecp 145 | 'cls_cnt': 2, # edit if not ecp 146 | 'T': 35, # edit if OOM error, only relevant for bayesian model 147 | 'cpu_thread_cnt': 10, 148 | 'freeze_darknet53': False, # actual value irrelevant 149 | 'crop': False, # edit: less memory consumption if True 150 | 'training': False, 151 | 'aleatoric_loss': True, # actual value irrelevant 152 | 'priors': yolov3.ECP_9_PRIORS, # actual value irrelevant 153 | 'out_path': './uncertainty_visualization', # edit 154 | 'implicit_background_class': True, # whether the label ids start at 1 or 0. True = 1, False = 0 155 | } 156 | 157 | class_name_mapping_implicit_background_cls = { # edit: change if you have more or different classes, or set to None 158 | 1: 'ped', # 159 | 2: 'rider', 160 | } 161 | 162 | # class_name_mapping_no_implicit_background_cls = { # if your labels start at 0 instead of 1 163 | # 0: 'ped', # 164 | # 1: 'rider', 165 | # } 166 | 167 | thresh = 0.1 # edit 168 | 169 | files = glob.glob('./test_images/*') # edit 170 | 171 | # EDIT: chose appropriate model class 172 | # model_cls = yolov3.yolov3 173 | # model_cls = yolov3.yolov3_aleatoric 174 | model_cls = yolov3.bayesian_yolov3_aleatoric 175 | 176 | do_it(files, thresh, config, model_cls, class_name_mapping_implicit_background_cls) 177 | 178 | 179 | if __name__ == '__main__': 180 | logging.basicConfig(level=logging.DEBUG, 181 | format='%(asctime)s, pid: %(process)d, %(levelname)-8s %(message)s', 182 | datefmt='%a, %d %b %Y %H:%M:%S', 183 | ) 184 | main() 185 | -------------------------------------------------------------------------------- /inference_aleatoric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference script for the yolov3.yolov3_aleatoric class. 3 | 4 | Produces detection files for each input image conforming to the ECP .json format. 5 | The output of this script can be directly used by the ECP evaluation code. 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | import threading 12 | import time 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from lib_yolo import dataset_utils, yolov3 18 | 19 | 20 | class Inference: 21 | def __init__(self, yolo, config): 22 | self.batch_size = config['batch_size'] 23 | 24 | dataset = dataset_utils.TestingDataset(config) 25 | self.img_tensor, self.filename_tensor = dataset.iterator.get_next() 26 | 27 | checkpoints = os.path.join(config['checkpoint_path'], config['run_id']) 28 | if config['step'] == 'last': 29 | self.checkpoint = tf.train.latest_checkpoint(checkpoints) 30 | else: 31 | self.checkpoint = None 32 | for cp in os.listdir(checkpoints): 33 | if cp.endswith('-{}.meta'.format(config['step'])): 34 | self.checkpoint = os.path.join(checkpoints, os.path.splitext(cp)[0]) 35 | break 36 | assert self.checkpoint is not None 37 | 38 | step = self.checkpoint.split('-')[-1] 39 | 40 | self.img_size = config['full_img_size'] 41 | assert not config['crop'] 42 | self.out_path = '{}_{}'.format(config['out_path'], step) 43 | os.makedirs(self.out_path) 44 | 45 | self.config = config 46 | self.worker_thread = None 47 | 48 | self.model = yolo.init_model(inputs=self.img_tensor, training=False).get_model() 49 | 50 | bbox = concat_bbox([self.model.det_layers[0].bbox, 51 | self.model.det_layers[1].bbox, 52 | self.model.det_layers[2].bbox]) 53 | self.nms = nms(bbox, self.model) 54 | 55 | def run(self): 56 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 57 | tf.train.Saver().restore(sess, self.checkpoint) 58 | 59 | step = 0 60 | while True: 61 | try: 62 | step += 1 63 | processed = self.process_batch(sess) 64 | 65 | logging.info('Processed {} images.'.format((step - 1) * self.batch_size + processed)) 66 | 67 | except tf.errors.OutOfRangeError: 68 | break 69 | 70 | if self.worker_thread: 71 | self.worker_thread.join() 72 | return self 73 | 74 | def process_batch(self, sess): 75 | boxes, files = sess.run([self.nms, self.filename_tensor]) 76 | 77 | if self.worker_thread: 78 | self.worker_thread.join() 79 | 80 | self.worker_thread = threading.Thread(target=self.write_to_disc, args=(boxes, files)) 81 | self.worker_thread.start() 82 | return len(files) 83 | 84 | def write_to_disc(self, all_boxes, files): 85 | for batch, filename in enumerate(files): 86 | filename = filename[0].decode('utf-8') 87 | boxes = all_boxes[batch] 88 | self.write_ecp_json(boxes, filename) 89 | 90 | def write_ecp_json(self, boxes, img_name): 91 | out_name = '{}.json'.format(os.path.splitext(os.path.basename(img_name))[0]) 92 | out_file = os.path.join(self.out_path, out_name) 93 | 94 | with open(out_file, 'w') as f: 95 | json.dump({ 96 | 'children': [bbox_to_ecp_format(bbox, self.img_size, self.model, self.config) for bbox in boxes], 97 | }, f, default=lambda x: x.tolist()) 98 | 99 | 100 | # -----------------------------------------------------------------# 101 | # helpers # 102 | # -----------------------------------------------------------------# 103 | 104 | def nms(all_boxes, model): 105 | def nms_op(boxes): 106 | # nms ignoring classes 107 | nms_indices = tf.image.non_max_suppression(boxes[:, :4], boxes[:, model.obj_idx], 1000) 108 | all_boxes = tf.gather(boxes, nms_indices, axis=0) 109 | all_boxes = tf.expand_dims(all_boxes, axis=0) 110 | 111 | # # nms for each class individually, works only for data with 2 classes (e.g. ECP dataset) 112 | # # this was used to produce the results for the paper 113 | # nms_boxes = None 114 | # for cls in ['ped', 'rider']: 115 | # if cls == 'ped': 116 | # tmp = tf.greater(b[:, model.cls_start_idx], b[:, model.cls_start_idx + 1]) 117 | # elif cls == 'rider': 118 | # tmp = tf.greater(b[:, model.cls_start_idx + 1], b[:, model.cls_start_idx]) 119 | # else: 120 | # raise ValueError('invalid class: {}'.format(cls)) 121 | # 122 | # cls_indices = tf.cast(tf.reshape(tf.where(tmp), [-1]), tf.int32) 123 | # 124 | # cls_boxes = tf.gather(b, cls_indices) 125 | # ind = tf.image.non_max_suppression(cls_boxes[:, :4], cls_boxes[:, model.obj_idx], 1000) 126 | # cls_boxes = tf.gather(cls_boxes, ind, axis=0) 127 | # 128 | # if nms_boxes is None: 129 | # nms_boxes = cls_boxes 130 | # else: 131 | # nms_boxes = tf.concat([nms_boxes, cls_boxes], axis=0) 132 | # 133 | # return nms_boxes 134 | 135 | return all_boxes 136 | 137 | body = lambda i, r: [i + 1, tf.concat([r, nms_op(all_boxes[i, ...])], axis=0)] 138 | 139 | r0 = nms_op(all_boxes[0, ...]) # do while 140 | i0 = tf.constant(1) # start with 1!!! 141 | cond = lambda i, m: i < tf.shape(all_boxes)[0] 142 | ilast, result = tf.while_loop(cond, body, loop_vars=[i0, r0], 143 | shape_invariants=[i0.get_shape(), tf.TensorShape([None, None, all_boxes.shape[2]])]) 144 | 145 | return result 146 | 147 | 148 | def bbox_to_ecp_format(bbox, img_size, model, config): 149 | img_height, img_width = img_size[:2] 150 | label_to_cls_name = { # edit if not ECP dataset 151 | 1: 'pedestrian', # starts at 0 if no implicit background class 152 | 2: 'rider', 153 | } 154 | cls_scores = bbox[model.cls_start_idx:model.cls_start_idx + model.cls_cnt] 155 | cls = np.argmax(cls_scores) 156 | 157 | cls_idx = cls 158 | if config['implicit_background_class']: 159 | cls += 1 160 | 161 | return { 162 | 'y0': float(bbox[0] * img_height), 163 | 'x0': float(bbox[1] * img_width), 164 | 'y1': float(bbox[2] * img_height), 165 | 'x1': float(bbox[3] * img_width), 166 | 'x_var': float(bbox[4]), # random value for models trained without aleatoric loss. 167 | 'y_var': float(bbox[5]), # random value for models trained without aleatoric loss. 168 | 'w_var': float(bbox[6]), # random value for models trained without aleatoric loss. 169 | 'h_var': float(bbox[7]), # random value for models trained without aleatoric loss. 170 | 'total_var': float(bbox[8]), # random value for models trained without aleatoric loss. 171 | 'score': float(bbox[model.obj_idx]) * float(bbox[model.cls_start_idx + cls_idx]), 172 | 'obj_entropy': float(bbox[model.obj_idx + 1]), 173 | 'cls_scores': cls_scores, 174 | 'cls_entropy': float(bbox[model.cls_start_idx + model.cls_cnt]), 175 | 'layer_id': float(bbox[model.cls_start_idx + model.cls_cnt]), 176 | 'prior_id': float(bbox[model.cls_start_idx + model.cls_cnt]), 177 | 'identity': label_to_cls_name.get(cls, cls), 178 | } 179 | 180 | 181 | def concat_bbox(net_out): 182 | bbox = None 183 | for det_layer in net_out: 184 | for prior in det_layer: 185 | batches, lw, lh, det_size = prior.shape.as_list() 186 | tmp = tf.reshape(prior, shape=[-1, lw * lh, det_size]) 187 | if bbox is None: 188 | bbox = tmp 189 | else: 190 | bbox = tf.concat([bbox, tmp], axis=1) 191 | 192 | return bbox 193 | 194 | 195 | # -----------------------------------------------------------------# 196 | # main # 197 | # -----------------------------------------------------------------# 198 | 199 | 200 | def inference(config): 201 | assert not config['crop'] 202 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 203 | 204 | logging.info('----- START -----') 205 | 206 | start = time.time() 207 | 208 | yolo = yolov3.yolov3_aleatoric(config) 209 | 210 | Inference(yolo, config).run() 211 | 212 | end = time.time() 213 | elapsed = int(end - start) 214 | logging.info('----- FINISHED in {:02d}:{:02d}:{:02d} -----'.format(elapsed // 3600, 215 | (elapsed // 60) % 60, 216 | elapsed % 60)) 217 | 218 | 219 | def main(): 220 | config = { 221 | 'checkpoint_path': './checkpoints', # edit 222 | 'run_id': 'pretraining', # edit 223 | # 'step': 500000, # edit 224 | 'step': 'last', 225 | 'full_img_size': [1024, 1920, 3], 226 | 'cls_cnt': 2, # edit 227 | 'batch_size': 11, # edit 228 | 'cpu_thread_cnt': 24, # edit 229 | 'crop': False, 230 | 'training': False, 231 | 'aleatoric_loss': True, 232 | 'priors': yolov3.ECP_9_PRIORS, # edit 233 | 'implicit_background_class': True, 234 | 'data': { 235 | 'path': '$HOME/data/ecp/tfrecords', # edit 236 | 'file_pattern': 'ecp-day-val-*-of-*', # edit 237 | } 238 | } 239 | 240 | config['data']['file_pattern'] = os.path.join(os.path.expandvars(config['data']['path']), 241 | config['data']['file_pattern']) 242 | 243 | config['out_path'] = os.path.join('./inference', config['run_id']) # edit 244 | 245 | inference(config) 246 | 247 | 248 | if __name__ == '__main__': 249 | np.set_printoptions(suppress=True, formatter={'float_kind': '{:5.3}'.format}) 250 | logging.basicConfig(level=logging.DEBUG, 251 | format='%(asctime)s, pid: %(process)d, %(levelname)-8s %(message)s', 252 | datefmt='%a, %d %b %Y %H:%M:%S', 253 | ) 254 | 255 | main() 256 | -------------------------------------------------------------------------------- /inference_epistemic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference script for the yolov3.bayesian_yolov3_aleatoric class. 3 | 4 | Produces detection files for each input image conforming to the ECP .json format. 5 | The output of this script can be directly used by the ECP evaluation code. 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | import threading 12 | import time 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from lib_yolo import dataset_utils, yolov3 18 | 19 | 20 | class Inference: 21 | def __init__(self, yolo, config): 22 | self.batch_size = config['batch_size'] 23 | 24 | dataset = dataset_utils.TestingDataset(config) 25 | self.img_tensor, self.filename_tensor = dataset.iterator.get_next() 26 | 27 | checkpoints = os.path.join(config['checkpoint_path'], config['run_id']) 28 | if config['step'] == 'last': 29 | self.checkpoint = tf.train.latest_checkpoint(checkpoints) 30 | else: 31 | self.checkpoint = None 32 | for cp in os.listdir(checkpoints): 33 | if cp.endswith('-{}.meta'.format(config['step'])): 34 | self.checkpoint = os.path.join(checkpoints, os.path.splitext(cp)[0]) 35 | break 36 | assert self.checkpoint is not None 37 | 38 | step = self.checkpoint.split('-')[-1] 39 | 40 | self.img_size = config['full_img_size'] 41 | assert not config['crop'] 42 | self.out_path = '{}_{}'.format(config['out_path'], step) 43 | os.makedirs(self.out_path) 44 | 45 | self.config = config 46 | self.worker_thread = None 47 | 48 | assert config['inference_mode'] 49 | self.model = yolo.init_model(inputs=self.img_tensor, training=False).get_model() 50 | 51 | bbox = concat_bbox([self.model.det_layers[0].bbox, 52 | self.model.det_layers[1].bbox, 53 | self.model.det_layers[2].bbox]) 54 | self.nms = nms(bbox, self.model) 55 | 56 | def run(self): 57 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 58 | tf.train.Saver().restore(sess, self.checkpoint) 59 | 60 | step = 0 61 | while True: 62 | try: 63 | step += 1 64 | self.epistemic_forward_pass(sess) 65 | 66 | if step % 15 == 0: 67 | logging.info('Processed {} images.'.format(step)) 68 | 69 | except tf.errors.OutOfRangeError: 70 | break 71 | logging.info('Processed {} images.'.format(step)) 72 | 73 | self.worker_thread.join() 74 | 75 | def epistemic_forward_pass(self, sess): 76 | boxes, files = sess.run([self.nms, self.filename_tensor]) 77 | 78 | if self.worker_thread: 79 | self.worker_thread.join() 80 | 81 | img_name = files[0][0].decode('utf-8') 82 | self.worker_thread = threading.Thread(target=self.write_ecp_json, args=(boxes, img_name)) 83 | self.worker_thread.start() 84 | 85 | def write_ecp_json(self, boxes, img_name): 86 | out_name = '{}.json'.format(os.path.splitext(os.path.basename(img_name))[0]) 87 | out_file = os.path.join(self.out_path, out_name) 88 | 89 | with open(out_file, 'w') as f: 90 | json.dump({ 91 | 'children': [bbox_to_ecp_format(bbox, self.img_size, self.model, self.config) for bbox in boxes], 92 | }, f, default=lambda x: x.tolist()) 93 | 94 | 95 | # -----------------------------------------------------------------# 96 | # helpers # 97 | # -----------------------------------------------------------------# 98 | 99 | def nms(boxes, model): 100 | # nms ignoring classes 101 | nms_indices = tf.image.non_max_suppression(boxes[:, :4], boxes[:, model.obj_idx], 1000) 102 | all_boxes = tf.gather(boxes, nms_indices, axis=0) 103 | 104 | # # nms for each class individually, works only for data with 2 classes (e.g. ECP dataset) 105 | # # this was used to produce the results for the paper 106 | # nms_boxes = None 107 | # for cls in ['ped', 'rider']: 108 | # if cls == 'ped': 109 | # tmp = tf.greater(b[:, model.cls_start_idx], b[:, model.cls_start_idx + 1]) 110 | # elif cls == 'rider': 111 | # tmp = tf.greater(b[:, model.cls_start_idx + 1], b[:, model.cls_start_idx]) 112 | # else: 113 | # raise ValueError('invalid class: {}'.format(cls)) 114 | # 115 | # cls_indices = tf.cast(tf.reshape(tf.where(tmp), [-1]), tf.int32) 116 | # 117 | # cls_boxes = tf.gather(b, cls_indices) 118 | # ind = tf.image.non_max_suppression(cls_boxes[:, :4], cls_boxes[:, model.obj_idx], 1000) 119 | # cls_boxes = tf.gather(cls_boxes, ind, axis=0) 120 | # 121 | # if nms_boxes is None: 122 | # nms_boxes = cls_boxes 123 | # else: 124 | # nms_boxes = tf.concat([nms_boxes, cls_boxes], axis=0) 125 | # 126 | # return nms_boxes 127 | 128 | return all_boxes 129 | 130 | 131 | def bbox_to_ecp_format(bbox, img_size, model, config): 132 | img_height, img_width = img_size[:2] 133 | label_to_cls_name = { # edit if not ECP dataset 134 | 1: 'pedestrian', # starts at 0 if no implicit background class 135 | 2: 'rider', 136 | } 137 | cls_scores = bbox[model.cls_start_idx:model.cls_start_idx + model.cls_cnt] 138 | cls = np.argmax(cls_scores) 139 | 140 | cls_idx = cls 141 | if config['implicit_background_class']: 142 | cls += 1 143 | 144 | return { 145 | 'y0': float(bbox[0] * img_height), 146 | 'x0': float(bbox[1] * img_width), 147 | 'y1': float(bbox[2] * img_height), 148 | 'x1': float(bbox[3] * img_width), 149 | 'x_var_epi': float(bbox[4]), 150 | 'y_var_epi': float(bbox[5]), 151 | 'w_var_epi': float(bbox[6]), 152 | 'h_var_epi': float(bbox[7]), 153 | 'x_var_ale': float(bbox[8]), # random value for models trained without aleatoric loss. 154 | 'y_var_ale': float(bbox[9]), # random value for models trained without aleatoric loss. 155 | 'w_var_ale': float(bbox[10]), # random value for models trained without aleatoric loss. 156 | 'h_var_ale': float(bbox[11]), # random value for models trained without aleatoric loss. 157 | 'total_var_epi': float(bbox[12]), # not useful 158 | 'total_var_ale': float(bbox[13]), # not useful and random value for models trained without aleatoric loss. 159 | 'score': float(bbox[model.obj_idx]) * float(bbox[model.cls_start_idx + cls_idx]), 160 | 'obj_mutual_info': float(bbox[model.obj_idx + 1]), 161 | 'obj_entropy': float(bbox[model.obj_idx + 2]), 162 | 'cls_scores': cls_scores, 163 | 'ped_score': float(bbox[17]), 164 | 'rider_score': float(bbox[18]), 165 | 'cls_mutual_info': float(bbox[model.cls_start_idx + model.cls_cnt]), 166 | 'cls_entropy': float(bbox[model.cls_start_idx + model.cls_cnt + 1]), 167 | 'layer_id': float(bbox[model.cls_start_idx + model.cls_cnt + 2]), 168 | 'prior_id': float(bbox[model.cls_start_idx + model.cls_cnt + 3]), 169 | 'identity': label_to_cls_name.get(cls, cls), 170 | } 171 | 172 | 173 | def concat_bbox(net_out): 174 | bbox = None 175 | for det_layer in net_out: 176 | for prior in det_layer: 177 | lw, lh, det_size = prior.shape.as_list() 178 | tmp = tf.reshape(prior, shape=[lw * lh, det_size]) 179 | if bbox is None: 180 | bbox = tmp 181 | else: 182 | bbox = tf.concat([bbox, tmp], axis=0) 183 | 184 | return bbox 185 | 186 | 187 | # -----------------------------------------------------------------# 188 | # main # 189 | # -----------------------------------------------------------------# 190 | 191 | 192 | def inference(config): 193 | assert config['batch_size'] == 1 194 | assert not config['crop'] 195 | 196 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 197 | 198 | assert not config['crop'] 199 | logging.info('----- START -----') 200 | start = time.time() 201 | 202 | yolo = yolov3.bayesian_yolov3_aleatoric(config) 203 | Inference(yolo, config).run() 204 | 205 | end = time.time() 206 | elapsed = int(end - start) 207 | logging.info('----- FINISHED in {:02d}:{:02d}:{:02d} -----'.format(elapsed // 3600, 208 | (elapsed // 60) % 60, 209 | elapsed % 60)) 210 | 211 | 212 | def main(): 213 | config = { 214 | 'checkpoint_path': './checkpoints', # edit 215 | 'run_id': 'epi_ale', # edit 216 | # 'step': 500000, # edit 217 | 'step': 'last', # edit 218 | 'full_img_size': [1024, 1920, 3], # edit if not ECP dataset 219 | 'cls_cnt': 2, # edit if not ECP dataset 220 | 'batch_size': 1, 221 | 'T': 50, # edit if OOM errors 222 | 'inference_mode': True, 223 | 'cpu_thread_cnt': 24, # edit 224 | 'crop': False, 225 | 'training': False, 226 | 'aleatoric_loss': False, 227 | 'priors': yolov3.ECP_9_PRIORS, # edit 228 | 'implicit_background_class': True, 229 | 'data': { 230 | 'path': '$HOME/data/ecp/tfrecords', # edit 231 | 'file_pattern': 'ecp-day-val-*-of-*', # edit 232 | } 233 | } 234 | 235 | config['data']['file_pattern'] = os.path.join(os.path.expandvars(config['data']['path']), 236 | config['data']['file_pattern']) 237 | 238 | config['out_path'] = os.path.join('./inference', config['run_id']) # edit 239 | 240 | inference(config) 241 | 242 | 243 | if __name__ == '__main__': 244 | np.set_printoptions(suppress=True, formatter={'float_kind': '{:5.3}'.format}) 245 | logging.basicConfig(level=logging.DEBUG, 246 | format='%(asctime)s, pid: %(process)d, %(levelname)-8s %(message)s', 247 | datefmt='%a, %d %b %Y %H:%M:%S', 248 | ) 249 | main() 250 | -------------------------------------------------------------------------------- /inference_standard_yolov3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference script for the yolov3.yolov3_aleatoric class. 3 | 4 | Produces detection files for each input image conforming to the ECP .json format. 5 | The output of this script can be directly used by the ECP evaluation code. 6 | """ 7 | 8 | import json 9 | import logging 10 | import os 11 | import threading 12 | import time 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from lib_yolo import dataset_utils, yolov3 18 | 19 | 20 | class Inference: 21 | def __init__(self, yolo, config): 22 | self.batch_size = config['batch_size'] 23 | 24 | dataset = dataset_utils.TestingDataset(config) 25 | self.img_tensor, self.filename_tensor = dataset.iterator.get_next() 26 | 27 | checkpoints = os.path.join(config['checkpoint_path'], config['run_id']) 28 | if config['step'] == 'last': 29 | self.checkpoint = tf.train.latest_checkpoint(checkpoints) 30 | else: 31 | self.checkpoint = None 32 | for cp in os.listdir(checkpoints): 33 | if cp.endswith('-{}.meta'.format(config['step'])): 34 | self.checkpoint = os.path.join(checkpoints, os.path.splitext(cp)[0]) 35 | break 36 | assert self.checkpoint is not None 37 | 38 | step = self.checkpoint.split('-')[-1] 39 | 40 | self.img_size = config['full_img_size'] 41 | assert not config['crop'] 42 | self.out_path = '{}_{}'.format(config['out_path'], step) 43 | os.makedirs(self.out_path) 44 | 45 | self.config = config 46 | self.worker_thread = None 47 | 48 | self.model = yolo.init_model(inputs=self.img_tensor, training=False).get_model() 49 | 50 | bbox = concat_bbox([self.model.det_layers[0].bbox, 51 | self.model.det_layers[1].bbox, 52 | self.model.det_layers[2].bbox]) 53 | self.nms = nms(bbox, self.model) 54 | 55 | def run(self): 56 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 57 | tf.train.Saver().restore(sess, self.checkpoint) 58 | 59 | step = 0 60 | while True: 61 | try: 62 | step += 1 63 | processed = self.process_batch(sess) 64 | 65 | logging.info('Processed {} images.'.format((step - 1) * self.batch_size + processed)) 66 | 67 | except tf.errors.OutOfRangeError: 68 | break 69 | 70 | if self.worker_thread: 71 | self.worker_thread.join() 72 | return self 73 | 74 | def process_batch(self, sess): 75 | boxes, files = sess.run([self.nms, self.filename_tensor]) 76 | 77 | if self.worker_thread: 78 | self.worker_thread.join() 79 | 80 | self.worker_thread = threading.Thread(target=self.write_to_disc, args=(boxes, files)) 81 | self.worker_thread.start() 82 | return len(files) 83 | 84 | def write_to_disc(self, all_boxes, files): 85 | for batch, filename in enumerate(files): 86 | filename = filename[0].decode('utf-8') 87 | boxes = all_boxes[batch] 88 | self.write_ecp_json(boxes, filename) 89 | 90 | def write_ecp_json(self, boxes, img_name): 91 | out_name = '{}.json'.format(os.path.splitext(os.path.basename(img_name))[0]) 92 | out_file = os.path.join(self.out_path, out_name) 93 | 94 | with open(out_file, 'w') as f: 95 | json.dump({ 96 | 'children': [bbox_to_ecp_format(bbox, self.img_size, self.model, self.config) for bbox in boxes], 97 | }, f, default=lambda x: x.tolist()) 98 | 99 | 100 | # -----------------------------------------------------------------# 101 | # helpers # 102 | # -----------------------------------------------------------------# 103 | 104 | def nms(all_boxes, model): 105 | def nms_op(boxes): 106 | # nms ignoring classes 107 | nms_indices = tf.image.non_max_suppression(boxes[:, :4], boxes[:, model.obj_idx], 1000) 108 | all_boxes = tf.gather(boxes, nms_indices, axis=0) 109 | all_boxes = tf.expand_dims(all_boxes, axis=0) 110 | 111 | # # nms for each class individually, works only for data with 2 classes (e.g. ECP dataset) 112 | # # this was used to produce the results for the paper 113 | # nms_boxes = None 114 | # for cls in ['ped', 'rider']: 115 | # if cls == 'ped': 116 | # tmp = tf.greater(b[:, model.cls_start_idx], b[:, model.cls_start_idx + 1]) 117 | # elif cls == 'rider': 118 | # tmp = tf.greater(b[:, model.cls_start_idx + 1], b[:, model.cls_start_idx]) 119 | # else: 120 | # raise ValueError('invalid class: {}'.format(cls)) 121 | # 122 | # cls_indices = tf.cast(tf.reshape(tf.where(tmp), [-1]), tf.int32) 123 | # 124 | # cls_boxes = tf.gather(b, cls_indices) 125 | # ind = tf.image.non_max_suppression(cls_boxes[:, :4], cls_boxes[:, model.obj_idx], 1000) 126 | # cls_boxes = tf.gather(cls_boxes, ind, axis=0) 127 | # 128 | # if nms_boxes is None: 129 | # nms_boxes = cls_boxes 130 | # else: 131 | # nms_boxes = tf.concat([nms_boxes, cls_boxes], axis=0) 132 | # 133 | # return nms_boxes 134 | 135 | return all_boxes 136 | 137 | body = lambda i, r: [i + 1, tf.concat([r, nms_op(all_boxes[i, ...])], axis=0)] 138 | 139 | r0 = nms_op(all_boxes[0, ...]) # do while 140 | i0 = tf.constant(1) # start with 1!!! 141 | cond = lambda i, m: i < tf.shape(all_boxes)[0] 142 | ilast, result = tf.while_loop(cond, body, loop_vars=[i0, r0], 143 | shape_invariants=[i0.get_shape(), tf.TensorShape([None, None, all_boxes.shape[2]])]) 144 | 145 | return result 146 | 147 | 148 | def bbox_to_ecp_format(bbox, img_size, model, config): 149 | img_height, img_width = img_size[:2] 150 | label_to_cls_name = { # edit if not ECP dataset 151 | 1: 'pedestrian', # starts at 0 if no implicit background class 152 | 2: 'rider', 153 | } 154 | cls_scores = bbox[model.cls_start_idx:model.cls_start_idx + model.cls_cnt] 155 | cls = np.argmax(cls_scores) 156 | 157 | cls_idx = cls 158 | if config['implicit_background_class']: 159 | cls += 1 160 | 161 | return { 162 | 'y0': float(bbox[0] * img_height), 163 | 'x0': float(bbox[1] * img_width), 164 | 'y1': float(bbox[2] * img_height), 165 | 'x1': float(bbox[3] * img_width), 166 | 'score': float(bbox[model.obj_idx]) * float(bbox[model.cls_start_idx + cls_idx]), 167 | 'cls_scores': cls_scores, 168 | 'identity': label_to_cls_name.get(cls, cls), 169 | } 170 | 171 | 172 | def concat_bbox(net_out): 173 | bbox = None 174 | for det_layer in net_out: 175 | for prior in det_layer: 176 | batches, lw, lh, det_size = prior.shape.as_list() 177 | tmp = tf.reshape(prior, shape=[-1, lw * lh, det_size]) 178 | if bbox is None: 179 | bbox = tmp 180 | else: 181 | bbox = tf.concat([bbox, tmp], axis=1) 182 | 183 | return bbox 184 | 185 | 186 | # -----------------------------------------------------------------# 187 | # main # 188 | # -----------------------------------------------------------------# 189 | 190 | 191 | def inference(config): 192 | assert not config['crop'] 193 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 194 | 195 | logging.info('----- START -----') 196 | 197 | start = time.time() 198 | 199 | yolo = yolov3.yolov3(config) 200 | 201 | Inference(yolo, config).run() 202 | 203 | end = time.time() 204 | elapsed = int(end - start) 205 | logging.info('----- FINISHED in {:02d}:{:02d}:{:02d} -----'.format(elapsed // 3600, 206 | (elapsed // 60) % 60, 207 | elapsed % 60)) 208 | 209 | 210 | def main(): 211 | config = { 212 | 'checkpoint_path': './checkpoints', # edit 213 | 'run_id': 'yolo', # edit 214 | # 'step': 500000, # edit 215 | 'step': 'last', 216 | 'full_img_size': [1024, 1920, 3], 217 | 'cls_cnt': 2, # edit 218 | 'batch_size': 11, # edit 219 | 'cpu_thread_cnt': 24, # edit 220 | 'crop': False, 221 | 'training': False, 222 | 'aleatoric_loss': True, 223 | 'priors': yolov3.ECP_9_PRIORS, # edit 224 | 'implicit_background_class': True, 225 | 'data': { 226 | 'path': '$HOME/data/ecp/tfrecords', # edit 227 | 'file_pattern': 'ecp-day-val-*-of-*', # edit 228 | } 229 | } 230 | 231 | config['data']['file_pattern'] = os.path.join(os.path.expandvars(config['data']['path']), 232 | config['data']['file_pattern']) 233 | 234 | config['out_path'] = os.path.join('./inference', config['run_id']) # edit 235 | 236 | inference(config) 237 | 238 | 239 | if __name__ == '__main__': 240 | np.set_printoptions(suppress=True, formatter={'float_kind': '{:5.3}'.format}) 241 | logging.basicConfig(level=logging.DEBUG, 242 | format='%(asctime)s, pid: %(process)d, %(levelname)-8s %(message)s', 243 | datefmt='%a, %d %b %Y %H:%M:%S', 244 | ) 245 | main() 246 | -------------------------------------------------------------------------------- /lib_yolo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flkraus/bayesian-yolov3/f9faa718542c3dd657f5acb23b6642f399c63645/lib_yolo/__init__.py -------------------------------------------------------------------------------- /lib_yolo/darknet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def darknet53(model_builder, training, trainable): 8 | mb = model_builder 9 | 10 | mb.make_darknet_conv_layer(32, 3, training, trainable) # 0 11 | 12 | # Downsample (factor 2) 13 | mb.make_darknet_downsample_layer(64, 3, training, trainable) # 1 14 | 15 | mb.make_darknet_residual_block(32, training, trainable) # 2 - 4 16 | 17 | # Downsample (factor 4) 18 | mb.make_darknet_downsample_layer(128, 3, training, trainable) # 5 19 | 20 | for i in range(2): 21 | mb.make_darknet_residual_block(64, training, trainable) # 6 - 11 22 | 23 | # Downsample (factor 8) 24 | mb.make_darknet_downsample_layer(256, 3, training, trainable) # 12 25 | 26 | for i in range(8): 27 | mb.make_darknet_residual_block(128, training, trainable) # 13 - 36 28 | 29 | # Downsample (factor 16) 30 | mb.make_darknet_downsample_layer(512, 3, training, trainable) # 37 31 | 32 | for i in range(8): 33 | mb.make_darknet_residual_block(256, training, trainable) # 38 - 61 34 | 35 | # Downsample (factor 32) 36 | mb.make_darknet_downsample_layer(1024, 3, training, trainable) # 62 37 | 38 | for i in range(4): 39 | mb.make_darknet_residual_block(512, training, trainable) # 63 - 74 40 | 41 | 42 | def load_darknet_weights(net_layers, weightfile): 43 | with open(weightfile, "rb") as f: 44 | header = np.fromfile(f, dtype=np.int32, count=5) 45 | weights = np.fromfile(f, dtype=np.float32) 46 | 47 | ptr = 0 48 | tmp = tf.global_variables() 49 | vars = {} 50 | for var in tmp: 51 | vars[var.name] = var 52 | 53 | assign_ops = [] 54 | 55 | for i, l in enumerate(net_layers): 56 | if 'LeakyRelu' not in l.name: 57 | continue 58 | 59 | batch_norm = 'detection' not in l.name 60 | load_bias = not batch_norm 61 | if batch_norm: 62 | ptr = _load_batch_norm(l, vars, ptr, weights, assign_ops) 63 | 64 | ptr = _load_conv2d(l, vars, ptr, weights, assign_ops, load_bias) 65 | 66 | assert ptr == len(weights) 67 | return assign_ops 68 | 69 | 70 | def _load_conv2d(l, vars, ptr, weights, assign_ops, load_bias): 71 | namespace = l.name.split('/') 72 | namespace = os.path.join(*namespace[:2], 'conv2d') 73 | 74 | kernel_name = os.path.join(namespace, 'kernel:0') 75 | kernel = vars[kernel_name] 76 | 77 | if load_bias: 78 | bias_name = os.path.join(namespace, 'bias:0') 79 | bias = vars[bias_name] 80 | 81 | bias_shape = bias.shape.as_list() 82 | bias_params = np.prod(bias_shape) 83 | bias_weights = weights[ptr:ptr + bias_params].reshape(bias_shape) 84 | ptr += bias_params 85 | assign_ops.append(tf.assign(bias, bias_weights, validate_shape=True)) 86 | 87 | kernel_shape = kernel.shape.as_list() 88 | kernel_params = np.prod(kernel_shape) 89 | 90 | [h, w, c, n] = kernel_shape 91 | kernel_weights = weights[ptr:ptr + kernel_params].reshape([n, c, h, w]) 92 | # transpose to [h, w, c, n] 93 | kernel_weights = np.transpose(kernel_weights, (2, 3, 1, 0)) 94 | 95 | ptr += kernel_params 96 | assign_ops.append(tf.assign(kernel, kernel_weights, validate_shape=True)) 97 | 98 | return ptr 99 | 100 | 101 | def _load_batch_norm(l, vars, ptr, weights, assign_ops): 102 | namespace = l.name.split('/') 103 | namespace = os.path.join(*namespace[:2], 'batch_normalization') 104 | 105 | gamma = os.path.join(namespace, 'gamma:0') 106 | beta = os.path.join(namespace, 'beta:0') 107 | moving_mean = os.path.join(namespace, 'moving_mean:0') 108 | moving_variance = os.path.join(namespace, 'moving_variance:0') 109 | 110 | gamma = vars[gamma] 111 | beta = vars[beta] 112 | moving_mean = vars[moving_mean] 113 | moving_variance = vars[moving_variance] 114 | 115 | for var in [beta, gamma, moving_mean, moving_variance]: 116 | shape = var.shape.as_list() 117 | num_params = np.prod(shape) 118 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 119 | ptr += num_params 120 | assign_ops.append(tf.assign(var, var_weights, validate_shape=True)) 121 | 122 | return ptr 123 | -------------------------------------------------------------------------------- /lib_yolo/data.py: -------------------------------------------------------------------------------- 1 | # This file contains mostly numpy reference implementations for ground truth bbox encoding for the yolo loss. 2 | 3 | import numpy as np 4 | from scipy.special import logit, expit 5 | 6 | 7 | class Box: 8 | def __init__(self): 9 | self.xmin = None 10 | self.ymin = None 11 | self.xmax = None 12 | self.ymax = None 13 | 14 | self.x_center = None 15 | self.y_center = None 16 | self.w = None 17 | self.h = None 18 | 19 | self.area = None 20 | self.cls = None 21 | 22 | def __repr__(self): 23 | return ''.format(self.x_center, self.y_center, self.h, self.w, 24 | self.cls) 25 | 26 | @classmethod 27 | def from_corners(cls, xmin, ymin, xmax, ymax, label=-1): 28 | box = cls() 29 | 30 | box.xmin = xmin 31 | box.ymin = ymin 32 | box.xmax = xmax 33 | box.ymax = ymax 34 | 35 | box.x_center = (box.xmin + box.xmax) / 2. 36 | box.y_center = (box.ymin + box.ymax) / 2. 37 | box.w = box.xmax - box.xmin 38 | box.h = box.ymax - box.ymin 39 | 40 | box.area = box.w * box.h 41 | 42 | box.cls = label 43 | return box 44 | 45 | @classmethod 46 | def from_tf_image_format(cls, ymin, xmin, ymax, xmax, label=-1): 47 | return cls.from_corners(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax, label=label) 48 | 49 | @classmethod 50 | def from_width_and_height(cls, x_center, y_center, w, h, label=-1): 51 | box = cls() 52 | 53 | box.x_center = x_center 54 | box.y_center = y_center 55 | box.w = w 56 | box.h = h 57 | 58 | w2 = w / 2. 59 | h2 = h / 2. 60 | box.xmin = box.x_center - w2 61 | box.ymin = box.y_center - h2 62 | box.xmax = box.x_center + w2 63 | box.ymax = box.y_center + h2 64 | 65 | box.area = box.w * box.h 66 | 67 | box.cls = label 68 | return box 69 | 70 | 71 | class Cell: 72 | def __init__(self, row, col): 73 | self.row = row 74 | self.col = col 75 | 76 | def __repr__(self): 77 | return ''.format(self.row, self.col) 78 | 79 | 80 | class Prior: 81 | def __init__(self, h, w): 82 | self.h = h 83 | self.w = w 84 | 85 | def __repr__(self): 86 | return ''.format(self.h, self.w) 87 | 88 | 89 | class DetLayerInfo: 90 | def __init__(self, h, w, priors): 91 | self.h = h 92 | self.w = w 93 | self.priors = priors 94 | 95 | def __repr__(self): 96 | return ''.format(self.h, self.w, self.priors) 97 | 98 | 99 | def create_prior_box_grid(det_layer): 100 | boxes_per_cell = len(det_layer.priors) 101 | prior_box_grid = np.zeros((det_layer.h, det_layer.w, boxes_per_cell, 4)) 102 | for row in range(det_layer.h): 103 | for col in range(det_layer.w): 104 | for box, prior in enumerate(det_layer.priors): 105 | y_center = (row + 0.5) / det_layer.h 106 | x_center = (col + 0.5) / det_layer.w 107 | h2 = prior.h / 2. 108 | w2 = prior.w / 2. 109 | 110 | ymin = y_center - h2 111 | xmin = x_center - w2 112 | ymax = y_center + h2 113 | xmax = x_center + w2 114 | 115 | prior_box_grid[row, col, box, :] = [ymin, xmin, ymax, xmax] # tf.image bbox format 116 | return prior_box_grid 117 | 118 | 119 | def create_prior_data(det_layer): 120 | boxes_per_cell = len(det_layer.priors) 121 | 122 | bboxes = np.zeros((det_layer.h, det_layer.w, boxes_per_cell, 4), dtype=np.float32) 123 | bbox_areas = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 124 | cx = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 125 | cy = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 126 | pw = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 127 | ph = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 128 | lw = np.ones((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) * det_layer.w 129 | lh = np.ones((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) * det_layer.h 130 | center_x = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 131 | center_y = np.zeros((det_layer.h, det_layer.w, boxes_per_cell), dtype=np.float32) 132 | 133 | prior_areas = [p.h * p.w for p in det_layer.priors] 134 | 135 | for row in range(det_layer.h): 136 | for col in range(det_layer.w): 137 | for box, prior in enumerate(det_layer.priors): 138 | assert 0 <= prior.w <= 1, 'prior width must be specified as a number between 0 and 1' 139 | assert 0 <= prior.h <= 1, 'prior height must be specified as a number between 0 and 1' 140 | y_center = (row + 0.5) / det_layer.h 141 | x_center = (col + 0.5) / det_layer.w 142 | h2 = prior.h / 2. 143 | w2 = prior.w / 2. 144 | 145 | ymin = y_center - h2 146 | xmin = x_center - w2 147 | ymax = y_center + h2 148 | xmax = x_center + w2 149 | 150 | cx[row, col, box] = col / float(det_layer.w) 151 | cy[row, col, box] = row / float(det_layer.h) 152 | pw[row, col, box] = prior.w 153 | ph[row, col, box] = prior.h 154 | center_x[row, col, box] = x_center 155 | center_y[row, col, box] = y_center 156 | 157 | bboxes[row, col, box, :] = [ymin, xmin, ymax, xmax] # tf.image bbox format 158 | bbox_areas[row, col, box] = prior_areas[box] 159 | return bboxes, bbox_areas, cx, cy, pw, ph, lw, lh, center_x, center_y # TODO this is ugly 160 | 161 | 162 | def calc_gt(gt_boxes, det_layers): 163 | gt = [] 164 | for layer in det_layers: 165 | boxes_per_cell = len(layer.priors) 166 | gt.append({ 167 | 'loc': np.zeros((layer.h, layer.w, boxes_per_cell, 4)), 168 | 'obj': np.zeros((layer.h, layer.w, boxes_per_cell)), 169 | 'cls': np.zeros((layer.h, layer.w, boxes_per_cell)), 170 | 'fp': np.zeros((layer.h, layer.w, boxes_per_cell)), 171 | 'ignore': np.ones((layer.h, layer.w, boxes_per_cell)), 172 | }) 173 | 174 | prior_grids = [] 175 | for layer in det_layers: 176 | prior_grids.append(create_prior_box_grid(layer)) 177 | 178 | used_cells = {} 179 | 180 | for gt_box in gt_boxes: 181 | res = find_responsible_layer_and_prior(det_layers, gt_box) 182 | l_idx = res['layer'] 183 | p_idx = res['prior'] 184 | layer = det_layers[l_idx] 185 | prior = layer.priors[p_idx] 186 | cell = find_responsible_cell(layer, gt_box) 187 | 188 | used_cells[(l_idx, p_idx, cell.row, cell.col)] = used_cells.get((l_idx, p_idx, cell.row, cell.col), 0) + 1 189 | 190 | cx = cell.col / float(layer.w) 191 | cy = cell.row / float(layer.h) 192 | tx = logit(gt_box.x_center - cx) 193 | ty = logit(gt_box.y_center - cy) 194 | if tx < -100 or tx > 100: 195 | assert False 196 | if ty < -100 or ty > 100: 197 | assert False 198 | tw = np.log(gt_box.w / prior.w) 199 | th = np.log(gt_box.h / prior.h) 200 | 201 | gt[l_idx]['loc'][cell.row, cell.col, p_idx, :] = [tx, ty, tw, th] 202 | gt[l_idx]['obj'][cell.row, cell.col, p_idx] = 1 203 | gt[l_idx]['cls'][cell.row, cell.col, p_idx] = gt_box.cls 204 | gt[l_idx]['fp'][cell.row, cell.col, p_idx] = 1 205 | 206 | # calc iou for all prior boxes for all layers with the gt_box 207 | ious = iou_multiboxes(gt_box, prior_grids) 208 | for i in range(len(det_layers)): 209 | gt[i]['ignore'][ious[i] > 0.7] = 0 # TODO ignore threshold 210 | 211 | for i in range(len(det_layers)): 212 | gt[i]['ignore'] = np.maximum(gt[i]['ignore'], gt[i]['fp']) 213 | 214 | return gt, used_cells 215 | 216 | 217 | def iou_multiboxes(gt_box, prior_grids): 218 | ious = [] 219 | for pg in prior_grids: 220 | iou_grid = np.zeros(pg.shape[:3]) 221 | for row in range(pg.shape[0]): 222 | for col in range(pg.shape[1]): 223 | for box in range(pg.shape[2]): 224 | iou_grid[row, col, box] = iou(gt_box, Box.from_tf_image_format(*pg[row, col, box, :])) 225 | ious.append(iou_grid) 226 | 227 | return ious 228 | 229 | 230 | def find_responsible_cell(det_layer, gt_box): 231 | row = int(det_layer.h * gt_box.y_center) 232 | col = int(det_layer.w * gt_box.x_center) 233 | return Cell(row, col) 234 | 235 | 236 | def find_responsible_layer_and_prior(det_layers, gt_box): 237 | gt_box = Box.from_width_and_height(0, 0, w=gt_box.w, h=gt_box.h) 238 | best_iou = 0 239 | best_layer = None 240 | for l_idx, layer in enumerate(det_layers): 241 | ious = [iou(gt_box, Box.from_width_and_height(0, 0, w=prior.w, h=prior.h)) for prior in layer.priors] 242 | if np.max(ious) > best_iou: 243 | best_prior = np.argmax(ious) 244 | best_layer = l_idx 245 | best_iou = np.max(ious) 246 | 247 | assert best_layer is not None 248 | assert best_iou > 0 249 | return {'layer': best_layer, 'prior': best_prior} 250 | 251 | 252 | def iou(b1, b2): 253 | intersection = intersect(b1, b2) 254 | if intersect == 0: # TODO use np.is_close? 255 | return 0.0 256 | 257 | union = b1.area + b2.area - intersection 258 | 259 | return intersection / union 260 | 261 | 262 | def intersect(b1, b2): 263 | """ 264 | :param b1: Box 265 | :param b2: Box 266 | :return: 267 | """ 268 | 269 | xmin = np.maximum(b1.xmin, b2.xmin) 270 | ymin = np.maximum(b1.ymin, b2.ymin) 271 | xmax = np.minimum(b1.xmax, b2.xmax) 272 | ymax = np.minimum(b1.ymax, b2.ymax) 273 | 274 | if xmax <= xmin: 275 | return 0.0 276 | 277 | if ymax <= ymin: 278 | return 0.0 279 | 280 | intersection_box = Box.from_corners(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) 281 | return intersection_box.area 282 | 283 | 284 | def loc_to_boxes(loc, cls, fp, priors): 285 | lh, lw, boxes_per_cell = fp.shape 286 | boxes = [] 287 | for row in range(lh): 288 | for col in range(lw): 289 | for box in range(boxes_per_cell): 290 | if fp[row, col, box] == 1: 291 | cx = col / float(lw) 292 | cy = row / float(lh) 293 | x_center = expit(loc[row, col, box, 0]) + cx 294 | y_center = expit(loc[row, col, box, 1]) + cy 295 | w = np.exp(loc[row, col, box, 2]) * priors[box].w 296 | h = np.exp(loc[row, col, box, 3]) * priors[box].h 297 | 298 | label = cls[row, col, box] 299 | 300 | boxes.append(Box.from_width_and_height(x_center=x_center, y_center=y_center, w=w, h=h, label=label)) 301 | return boxes 302 | 303 | 304 | def loc_to_tf_records_format(loc, cls, fp, priors): 305 | ymin, xmin, ymax, xmax, labels = [], [], [], [], [] 306 | boxes = loc_to_boxes(loc, cls, fp, priors) 307 | for box in boxes: 308 | ymin.append(box.ymin) 309 | xmin.append(box.xmin) 310 | ymax.append(box.ymax) 311 | xmax.append(box.xmax) 312 | labels.append(box.cls) 313 | return [ymin, xmin, ymax, xmax], labels 314 | 315 | 316 | def create_boxes_from_tf_records_format(boxes): 317 | out = [] 318 | for i in range(len(boxes['ymin'])): 319 | out.append(Box.from_corners(xmin=boxes['xmin'][i], 320 | ymin=boxes['ymin'][i], 321 | xmax=boxes['xmax'][i], 322 | ymax=boxes['ymax'][i], 323 | label=boxes['cls'][i], )) 324 | 325 | return out 326 | -------------------------------------------------------------------------------- /lib_yolo/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class DataAugmenter: 5 | def __init__(self, img_size): 6 | self.__img_size = img_size 7 | self.__ones_h_w_c = tf.ones(shape=img_size, dtype=tf.float32) 8 | self.__ones_h_w = tf.ones(shape=img_size[:2], dtype=tf.float32) 9 | self.__zeros_h_w_c = tf.zeros(shape=img_size, dtype=tf.float32) 10 | self.__zeros_h_w = tf.zeros(shape=img_size[:2], dtype=tf.float32) 11 | 12 | self.__int = [tf.constant(num, dtype=tf.int32) for num in range(10)] 13 | self.__rand_0_1 = { 14 | 'shape': (), 15 | 'minval': 0, 16 | 'maxval': 1, 17 | 'dtype': tf.float32 18 | } 19 | 20 | def augment(self, img, bbox, label): 21 | # flip 50% 22 | img, bbox = tf.cond(tf.random_uniform(**self.__rand_0_1) < 0.5, true_fn=lambda: self.flip_lr(img, bbox), 23 | false_fn=lambda: (img, bbox)) 24 | 25 | # blur 5% of images 26 | img = tf.cond(tf.random_uniform(**self.__rand_0_1) < 0.05, true_fn=lambda: self.blur(img), false_fn=lambda: img) 27 | 28 | # add color augmentation (hue, brightness, ...) to 5% of images (additional to blur) 29 | img = tf.cond(tf.random_uniform(**self.__rand_0_1) < 0.05, true_fn=lambda: self.color_augmentations(img), 30 | false_fn=lambda: img) 31 | 32 | # add noise augmentation (salt&pepper, gaussian noise, ...) to 5% of images (additional to blur and color augm) 33 | img = tf.cond(tf.random_uniform(**self.__rand_0_1) < 0.05, true_fn=lambda: self.noise_augmentations(img), 34 | false_fn=lambda: img) 35 | 36 | return img, bbox, label 37 | 38 | def color_augmentations(self, img): 39 | rand_args = { 40 | 'shape': (), 41 | 'minval': 0, 42 | 'maxval': 3, 43 | 'dtype': tf.int32 44 | } 45 | choice = tf.random_uniform(**rand_args) 46 | 47 | img = tf.cond(tf.equal(choice, tf.constant(0, dtype=tf.int32)), 48 | true_fn=lambda: tf.image.random_saturation(img, 0.5, 1.5), false_fn=lambda: img) 49 | img = tf.cond(tf.equal(choice, tf.constant(1, dtype=tf.int32)), 50 | true_fn=lambda: tf.image.random_brightness(img, 0.2), false_fn=lambda: img) 51 | img = tf.cond(tf.equal(choice, tf.constant(2, dtype=tf.int32)), 52 | true_fn=lambda: tf.image.random_hue(img, 0.2), false_fn=lambda: img) 53 | 54 | return img 55 | 56 | def noise_augmentations(self, img): 57 | rand_args = { 58 | 'shape': (), 59 | 'minval': 0, 60 | 'maxval': 3, 61 | 'dtype': tf.int32 62 | } 63 | choice = tf.random_uniform(**rand_args) 64 | 65 | img = tf.cond(tf.equal(choice, tf.constant(0, dtype=tf.int32)), 66 | true_fn=lambda: self.colored_salt_n_pepper(img), false_fn=lambda: img) 67 | img = tf.cond(tf.equal(choice, tf.constant(1, dtype=tf.int32)), 68 | true_fn=lambda: self.salt_n_pepper(img), false_fn=lambda: img) 69 | img = tf.cond(tf.equal(choice, tf.constant(2, dtype=tf.int32)), 70 | true_fn=lambda: self.additive_gaussian_noise(img), false_fn=lambda: img) 71 | 72 | return img 73 | 74 | def flip_lr(self, img, bbox): 75 | img = tf.image.flip_left_right(img) 76 | 77 | ymin, xmin, ymax, xmax = tf.split(value=bbox, num_or_size_splits=4, axis=1) 78 | flipped_xmin = 1.0 - xmax 79 | flipped_xmax = 1.0 - xmin 80 | bbox = tf.concat([ymin, flipped_xmin, ymax, flipped_xmax], axis=1) 81 | 82 | return img, bbox 83 | 84 | def colored_salt_n_pepper(self, img): 85 | # season each channel individually 86 | salt_mask = tf.random_uniform(shape=self.__img_size, minval=0, maxval=1) 87 | pepper_mask = tf.random_uniform(shape=self.__img_size, minval=0, maxval=1) 88 | amount = tf.random_uniform(shape=(), minval=0.0005, maxval=0.008) 89 | img = tf.where(tf.less(salt_mask, amount), self.__ones_h_w_c, img) 90 | img = tf.where(tf.less(pepper_mask, amount), self.__zeros_h_w_c, img) 91 | 92 | return img 93 | 94 | def salt_n_pepper(self, img): 95 | # season all channels at the same time 96 | size = self.__img_size[:2] 97 | salt_mask = tf.random_uniform(shape=size, minval=0, maxval=1) 98 | pepper_mask = tf.random_uniform(shape=size, minval=0, maxval=1) 99 | amount = tf.random_uniform(shape=(), minval=0.0005, maxval=0.008) 100 | 101 | salt = tf.where(tf.less(salt_mask, amount), self.__ones_h_w, self.__zeros_h_w) 102 | pepper = tf.where(tf.less(pepper_mask, amount), -self.__ones_h_w, self.__zeros_h_w) # note the minus sign! 103 | 104 | # if a pepper corn and a salt cristal fall on the same spot they vanish both. Magic! 105 | salt_n_pepper = salt + pepper 106 | img = img + tf.expand_dims(salt_n_pepper, 2) # season your image 107 | img = tf.clip_by_value(img, clip_value_min=0, clip_value_max=1) # remove excess seasoning 108 | return img 109 | 110 | def blur(self, img): 111 | kernel_size = tf.random_uniform(shape=(), minval=2, maxval=4, dtype=tf.int32) # 2 or 3 112 | kernel = tf.ones(shape=tf.stack((kernel_size, kernel_size, 1, 1))) 113 | kernel_size = tf.cast(kernel_size, tf.float32) 114 | kernel = kernel / (kernel_size * kernel_size) 115 | 116 | r, g, b = tf.unstack(tf.expand_dims(img, 0), num=3, axis=3) 117 | r = tf.expand_dims(r, 3) 118 | g = tf.expand_dims(g, 3) 119 | b = tf.expand_dims(b, 3) 120 | 121 | r = tf.nn.conv2d(r, filter=kernel, strides=[1, 1, 1, 1], padding='SAME') 122 | g = tf.nn.conv2d(g, filter=kernel, strides=[1, 1, 1, 1], padding='SAME') 123 | b = tf.nn.conv2d(b, filter=kernel, strides=[1, 1, 1, 1], padding='SAME') 124 | 125 | img = tf.concat([r, g, b], axis=3) 126 | return tf.squeeze(img, axis=0) 127 | 128 | def additive_gaussian_noise(self, img): 129 | stddev = tf.random_uniform(shape=(), minval=0.001, maxval=0.05, dtype=tf.float32) 130 | 131 | noise = tf.random_normal(shape=self.__img_size, mean=0.0, stddev=stddev) 132 | img = img + noise 133 | return img 134 | 135 | 136 | class ImageCropper: 137 | def __init__(self, config): 138 | self.config = config 139 | self.crop_height = config['crop_img_size'][0] 140 | self.crop_width = config['crop_img_size'][1] 141 | self.full_height = config['full_img_size'][0] 142 | self.full_width = config['full_img_size'][1] 143 | 144 | aspect_ratio_full_img = self.full_width / float(self.full_height) 145 | aspect_ratio_crop_img = self.crop_width / float(self.crop_height) 146 | 147 | # comparing two floats with "==" is always a smart idea (10/10, would do it again). 148 | assert aspect_ratio_full_img == aspect_ratio_crop_img, 'invalid crop aspect ratio, must be same as full image' 149 | 150 | def random_crop_and_sometimes_rescale(self, img, bbox, label): 151 | # randomly rescales the crop area 33% of the time, else normal random crop 152 | img, bbox, label = tf.cond(tf.random_uniform(shape=(), minval=0.0, maxval=1.0, dtype=tf.float32) < 0.33, 153 | true_fn=lambda: self.random_crop_with_rescale(img, bbox, label), 154 | false_fn=lambda: self.random_crop(img, bbox, label)) 155 | return img, bbox, label 156 | 157 | def random_crop_with_rescale(self, img, bbox, label): 158 | """ 159 | First select a crop of random size, then resize to desired crop_size. 160 | """ 161 | with tf.name_scope('random_crop_and_rescale_data'): 162 | scale = tf.clip_by_value(tf.random_normal(shape=(), mean=0, stddev=0.5, dtype=tf.float32), 163 | clip_value_min=-0.7, clip_value_max=0.7) 164 | crop_height = tf.cast(tf.minimum((1 + scale) * self.crop_height, self.full_height), tf.int32) 165 | crop_width = tf.cast(tf.minimum((1 + scale) * self.crop_width, self.full_width), tf.int32) 166 | 167 | # prefer crops in the middle of the image in y direction (use normal dist to select y) 168 | y_maxval = tf.cast(self.full_height - crop_height, tf.float32) 169 | y_mean = y_maxval / 2. 170 | stddev_y = y_maxval / 4. 171 | y_min_ind = tf.random_normal(shape=(), mean=y_mean, stddev=stddev_y, dtype=tf.float32) 172 | y_min_ind = tf.clip_by_value(y_min_ind, 0, y_maxval) 173 | y_min_ind = tf.cast(y_min_ind, tf.int32) 174 | 175 | # crops are uniformly distributed in x direction 176 | x_min_ind = tf.random_uniform(shape=(), minval=0, maxval=self.full_width - crop_width + 1, 177 | dtype=tf.int32) 178 | 179 | y_min = tf.cast(y_min_ind, tf.float32) / self.full_height 180 | x_min = tf.cast(x_min_ind, tf.float32) / self.full_width 181 | y_max = y_min + (tf.cast(crop_height, tf.float32) / float(self.full_height)) 182 | x_max = x_min + (tf.cast(crop_width, tf.float32) / float(self.full_width)) 183 | 184 | img = tf.image.crop_to_bounding_box(img, y_min_ind, x_min_ind, crop_height, crop_width) 185 | bbox, label = crop_boxes(bbox, label, y_min, x_min, y_max, x_max) 186 | 187 | img = tf.image.resize_images(img, [self.crop_height, self.crop_width]) 188 | 189 | return img, bbox, label 190 | 191 | def random_crop(self, img, bbox, label): 192 | with tf.name_scope('random_crop_data'): 193 | # prefer crops in the middle of the image in y direction (use normal dist to select y) 194 | y_maxval = self.full_height - self.crop_height 195 | y_mean = y_maxval / 2 196 | stddev_y = y_maxval / 4 197 | y_min_ind = tf.random_normal(shape=(), mean=y_mean, stddev=stddev_y, dtype=tf.float32) 198 | y_min_ind = tf.clip_by_value(y_min_ind, 0, y_maxval) 199 | y_min_ind = tf.cast(y_min_ind, tf.int32) 200 | 201 | # crops are uniformly distributed in x direction 202 | x_min_ind = tf.random_uniform(shape=(), minval=0, maxval=self.full_width - self.crop_width + 1, 203 | dtype=tf.int32) 204 | 205 | y_min = tf.cast(y_min_ind, tf.float32) / self.full_height 206 | x_min = tf.cast(x_min_ind, tf.float32) / self.full_width 207 | y_max = y_min + (self.crop_height / float(self.full_height)) 208 | x_max = x_min + (self.crop_width / float(self.full_width)) 209 | 210 | img = tf.image.crop_to_bounding_box(img, y_min_ind, x_min_ind, self.crop_height, self.crop_width) 211 | bbox, label = crop_boxes(bbox, label, y_min, x_min, y_max, x_max) 212 | 213 | return img, bbox, label 214 | 215 | def center_crop(self, img, bbox, label): 216 | with tf.name_scope('center_crop_data'): 217 | y_min_ind = (self.full_height - self.crop_height) // 2 218 | x_min_ind = (self.full_width - self.crop_width) // 2 219 | 220 | y_min = tf.cast(y_min_ind, tf.float32) / self.full_height 221 | x_min = tf.cast(x_min_ind, tf.float32) / self.full_width 222 | y_max = y_min + (self.crop_height / float(self.full_height)) 223 | x_max = x_min + (self.crop_width / float(self.full_width)) 224 | 225 | img = tf.image.crop_to_bounding_box(img, y_min_ind, x_min_ind, self.crop_height, self.crop_width) 226 | bbox, label = crop_boxes(bbox, label, y_min, x_min, y_max, x_max) 227 | 228 | return img, bbox, label 229 | 230 | 231 | def crop_boxes(boxes, labels, crop_y_min, crop_x_min, crop_y_max, crop_x_max, thresh=0.25): 232 | with tf.name_scope('crop_boxes'): 233 | y_min, x_min, y_max, x_max = tf.split(boxes, num_or_size_splits=4, axis=1) 234 | areas = tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) 235 | y_min_clipped = tf.maximum(tf.minimum(y_min, crop_y_max), crop_y_min) 236 | y_max_clipped = tf.maximum(tf.minimum(y_max, crop_y_max), crop_y_min) 237 | x_min_clipped = tf.maximum(tf.minimum(x_min, crop_x_max), crop_x_min) 238 | x_max_clipped = tf.maximum(tf.minimum(x_max, crop_x_max), crop_x_min) 239 | clipped = tf.concat([(y_min_clipped - crop_y_min) / (crop_y_max - crop_y_min), 240 | (x_min_clipped - crop_x_min) / (crop_x_max - crop_x_min), 241 | (y_max_clipped - crop_y_min) / (crop_y_max - crop_y_min), 242 | (x_max_clipped - crop_x_min) / (crop_x_max - crop_x_min)], axis=1) 243 | 244 | areas_clipped = tf.squeeze((y_max_clipped - y_min_clipped) * (x_max_clipped - x_min_clipped), [1]) 245 | 246 | # remove all boxes which are less than xx% of their original area (thresh = 0.25 ~ 25%) 247 | nonzero_area_indices = tf.cast(tf.reshape(tf.where(tf.greater(areas_clipped / areas, thresh)), [-1]), tf.int32) 248 | clipped = tf.gather(clipped, nonzero_area_indices) 249 | clipped_labels = tf.gather(labels, nonzero_area_indices) 250 | 251 | return clipped, clipped_labels 252 | 253 | 254 | def box_area(boxes): 255 | y_min, x_min, y_max, x_max = tf.split(value=boxes, num_or_size_splits=4, axis=1) 256 | return tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) 257 | -------------------------------------------------------------------------------- /lib_yolo/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from lib_yolo import data_augmentation, tfdata 4 | 5 | 6 | def decode_img(encoded, shape): 7 | # decode image and scale to [0, 1) 8 | img = tf.image.decode_png(encoded, dtype=tf.uint8) 9 | img = tf.image.convert_image_dtype(img, dtype=tf.float32) # convert to [0, 1) 10 | img.set_shape(shape) 11 | return img 12 | 13 | 14 | def make_parse_fn(config): 15 | def parse_example(example): 16 | feature_map = { 17 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 18 | 'image/height': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 19 | 'image/width': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 20 | 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), 21 | # 'image/object/class/text': tf.VarLenFeature(dtype=tf.string), 22 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 23 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 24 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 25 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 26 | # 'image/object/cnt': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 27 | } 28 | 29 | features = tf.parse_single_example(example, features=feature_map) 30 | 31 | img = decode_img(features['image/encoded'], config['full_img_size']) 32 | 33 | # assemble bbox 34 | xmin = tf.sparse_tensor_to_dense(features['image/object/bbox/xmin'], default_value=0) 35 | ymin = tf.sparse_tensor_to_dense(features['image/object/bbox/ymin'], default_value=0) 36 | xmax = tf.sparse_tensor_to_dense(features['image/object/bbox/xmax'], default_value=0) 37 | ymax = tf.sparse_tensor_to_dense(features['image/object/bbox/ymax'], default_value=0) 38 | bbox = tf.stack([ymin, xmin, ymax, xmax], axis=1) # we use the standard tf bbox format 39 | 40 | label = tf.cast(tf.sparse_tensor_to_dense(features['image/object/class/label'], default_value=-1), 41 | dtype=tf.int32) 42 | 43 | # Note regarding implicit background class: 44 | # The tensorflow object detection API enforces that the class labels start with 1. 45 | # The class 0 is reserved for an (implicit) background class. 46 | 47 | # yolo does not need a implicit background class. 48 | # To ensure compatibility with tf object detection API we support both: class ids starting at 1 or 0. 49 | implicit_background_class = config['implicit_background_class'] 50 | if implicit_background_class: 51 | label = label - 1 # shift class 1 -> 0, 2 -> 1, etc... 52 | 53 | return img, bbox, label # this is a mess 54 | 55 | return parse_example 56 | 57 | 58 | def make_encode_bbox_fn(model_blueprint, config): 59 | def encode_bbox(img, bbox, label): 60 | gt = tfdata.encode_boxes(bbox, label, model_blueprint.det_layers, ign_thresh=config['ign_thresh']) 61 | return [img, *gt] 62 | 63 | return encode_bbox 64 | 65 | 66 | def zero_center(img, *gt): 67 | img = 2 * (img - 0.5) # [0, 1) -> [-1, 1) 68 | return [img, *gt] 69 | 70 | 71 | def make_stack_same_img_fn_encoded(batch_size): 72 | def stack_same_image_encoded(img, *gt): 73 | img = tf.stack([img] * batch_size, axis=0) 74 | new_gt = [] 75 | for gt_ in gt: 76 | new_gt.append({ 77 | 'loc': tf.stack([gt_['loc']] * batch_size, axis=0), 78 | 'obj': tf.stack([gt_['obj']] * batch_size, axis=0), 79 | 'cls': tf.stack([gt_['cls']] * batch_size, axis=0), 80 | 'fpm': tf.stack([gt_['fpm']] * batch_size, axis=0), 81 | 'ign': tf.stack([gt_['ign']] * batch_size, axis=0), 82 | }) 83 | return [img, *new_gt] 84 | 85 | return stack_same_image_encoded 86 | 87 | 88 | def make_stack_same_input_fn(batch_size): 89 | def stack_same_input(img, bbox, label): 90 | img = tf.stack([img] * batch_size, axis=0) 91 | bbox = tf.stack([bbox] * batch_size, axis=0) 92 | label = tf.stack([label] * batch_size, axis=0) 93 | return img, bbox, label 94 | 95 | return stack_same_input 96 | 97 | 98 | def create_dataset(config, dataset_key): 99 | # in 1.9 list_files can shuffle directly... 100 | info = config[dataset_key] 101 | files = tf.data.Dataset.list_files(info['file_pattern']).shuffle(info['num_shards']) 102 | 103 | # cycle_length is important if the whole dataset does not fit into memory 104 | dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=2, block_length=1) 105 | 106 | dataset = dataset.map(make_parse_fn(config), num_parallel_calls=config['cpu_thread_cnt']) 107 | 108 | if info['cache']: 109 | dataset = dataset.cache() # this fails if the dataset does not fit into memory 110 | return dataset 111 | 112 | 113 | class TrainValDataset: 114 | def __init__(self, model_blueprint, config): 115 | encode_bbox_fn = make_encode_bbox_fn(model_blueprint, config) 116 | 117 | train_dataset = create_dataset(config, 'train') 118 | val_dataset = create_dataset(config, 'val') 119 | 120 | # process val dataset 121 | if config['crop']: 122 | val_dataset = val_dataset.map(config['val']['crop_fn'], num_parallel_calls=config['cpu_thread_cnt']) 123 | 124 | val_dataset = val_dataset.map(encode_bbox_fn, num_parallel_calls=config['cpu_thread_cnt']) 125 | 126 | val_dataset = val_dataset.shuffle(buffer_size=config['val']['shuffle_buffer_size']) 127 | val_dataset = val_dataset.repeat() 128 | val_dataset = val_dataset.prefetch(buffer_size=1) # needed for val dataset? 129 | val_dataset = val_dataset.batch(batch_size=config['batch_size']) 130 | 131 | # process train dataset 132 | if config['crop']: 133 | train_dataset = train_dataset.map(config['train']['crop_fn'], num_parallel_calls=config['cpu_thread_cnt']) 134 | 135 | img_size = config['crop_img_size'] if config['crop'] else config['full_img_size'] # TODO move to fn? 136 | augmenter = data_augmentation.DataAugmenter(img_size) 137 | train_dataset = train_dataset.map(augmenter.augment, num_parallel_calls=config['cpu_thread_cnt']) 138 | 139 | train_dataset = train_dataset.map(encode_bbox_fn, num_parallel_calls=config['cpu_thread_cnt']) 140 | train_dataset = train_dataset.shuffle(buffer_size=config['train']['shuffle_buffer_size']) 141 | train_dataset = train_dataset.repeat() 142 | train_dataset = train_dataset.batch(batch_size=config['batch_size']) 143 | 144 | train_dataset = train_dataset.prefetch(buffer_size=1) 145 | 146 | self.__train_iterator = train_dataset.make_one_shot_iterator() 147 | self.__val_iterator = val_dataset.make_one_shot_iterator() 148 | 149 | # ---------------- # 150 | # public interface # 151 | # ---------------- # 152 | self.handle = tf.placeholder(tf.string, shape=[]) 153 | self.iterator = tf.data.Iterator.from_string_handle(self.handle, train_dataset.output_types, 154 | train_dataset.output_shapes) 155 | self.train_handle = None 156 | self.val_handle = None 157 | 158 | def init_dataset(self, sess): 159 | self.train_handle = sess.run(self.__train_iterator.string_handle()) 160 | self.val_handle = sess.run(self.__val_iterator.string_handle()) 161 | 162 | 163 | class ValDataset: 164 | def __init__(self, config, map_fns=tuple(), dataset_key='data'): 165 | val_dataset = create_dataset(config, dataset_key) 166 | 167 | # process val dataset 168 | if config['crop']: 169 | val_dataset = val_dataset.map(config['val']['crop_fn'], num_parallel_calls=config['cpu_thread_cnt']) 170 | 171 | val_dataset = val_dataset.shuffle(buffer_size=config['val']['shuffle_buffer_size']) 172 | val_dataset = val_dataset.repeat() 173 | 174 | for map_fn in map_fns: 175 | val_dataset = val_dataset.map(map_fn, num_parallel_calls=24) 176 | 177 | val_dataset = val_dataset.map(make_stack_same_input_fn(batch_size=config['batch_size']), 178 | num_parallel_calls=config['cpu_thread_cnt']) 179 | val_dataset = val_dataset.prefetch(buffer_size=1) # needed for val dataset? 180 | 181 | # ---------------- # 182 | # public interface # 183 | # ---------------- # 184 | self.handle = tf.placeholder(tf.string, shape=[]) 185 | self.iterator = val_dataset.make_one_shot_iterator() 186 | 187 | 188 | class TestingDataset: 189 | def __init__(self, config, config_key='data'): 190 | self.__config = config 191 | 192 | info = config[config_key] 193 | files = tf.data.Dataset.list_files(info['file_pattern']) 194 | # cycle_length is important if the whole dataset does not fit into memory 195 | dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=2, block_length=1) 196 | dataset = dataset.map(self.parse_example, num_parallel_calls=config['cpu_thread_cnt']) 197 | 198 | dataset = dataset.batch(batch_size=config['batch_size']) 199 | dataset = dataset.prefetch(buffer_size=1) # needed for val dataset? 200 | 201 | # ---------------- # 202 | # public interface # 203 | # ---------------- # 204 | self.iterator = dataset.make_one_shot_iterator() 205 | 206 | def parse_example(self, example): 207 | feature_map = { 208 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 209 | 'image/height': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 210 | 'image/width': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 211 | 'image/filename': tf.VarLenFeature(dtype=tf.string), 212 | } 213 | 214 | features = tf.parse_single_example(example, features=feature_map) 215 | 216 | img = decode_img(features['image/encoded'], self.__config['full_img_size']) 217 | 218 | filename = tf.sparse_tensor_to_dense(features['image/filename'], default_value='') 219 | return img, filename # this is a mess 220 | -------------------------------------------------------------------------------- /lib_yolo/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def flatten(*tensors): 5 | flattened = [] 6 | for t in tensors: 7 | flattened.append(tf.layers.flatten(t)) 8 | return flattened 9 | 10 | 11 | def split_detection(inputs, boxes_per_cell, cls_cnt): # vanilla YOLOv3 12 | # split the prediction feature map into localization, objectness prediction and class prediction 13 | 14 | loc, obj, cls = [], [], [] 15 | 16 | # split feature map to get the predictions for different priors 17 | split_boxes = tf.split(inputs, boxes_per_cell, axis=-1) 18 | 19 | for i, dets in enumerate(split_boxes): 20 | # for this to work we need 'channels_last' (axis=-1) 21 | loc_, obj_, cls_ = tf.split(dets, [4, 1, cls_cnt], axis=-1) 22 | loc.append(loc_) 23 | obj.append(obj_) 24 | cls.append(cls_) 25 | 26 | with tf.name_scope('loc'): 27 | loc = tf.stack(loc, axis=-2) # shape=(b, h, w, boxes_per_cell, loc_size), where loc_size=4 (x, y, w, h) 28 | with tf.name_scope('obj'): 29 | obj = tf.stack(obj, axis=-2) 30 | obj = tf.squeeze(obj, axis=-1) # shape=(b, h, w, boxes_per_cell) 31 | with tf.name_scope('cls'): 32 | cls = tf.stack(cls, axis=-2) # shape=(b, h, w, boxes_per_cell, cls_cnt) 33 | 34 | return { 35 | 'loc': loc, 36 | 'obj': obj, 37 | 'cls': cls, 38 | } 39 | 40 | 41 | def split_detection_aleatoric(inputs, boxes_per_cell, cls_cnt): 42 | # split the prediction feature map into localization, objectness prediction and class prediction 43 | 44 | loc, log_loc_var, obj, log_obj_stddev, cls, log_cls_stddev = [], [], [], [], [], [] 45 | 46 | # split feature map to get the predictions for different priors 47 | split_boxes = tf.split(inputs, boxes_per_cell, axis=-1) 48 | 49 | for i, dets in enumerate(split_boxes): 50 | # for this to work we need 'channels_last' (axis=-1) 51 | loc_, log_loc_var_, obj_, log_obj_stddev_, cls_, log_cls_stddev_ = tf.split(dets, 52 | [4, 4, 1, 1, cls_cnt, cls_cnt], 53 | axis=-1) 54 | loc.append(loc_) 55 | log_loc_var.append(log_loc_var_) 56 | obj.append(obj_) 57 | log_obj_stddev.append(log_obj_stddev_) 58 | cls.append(cls_) 59 | log_cls_stddev.append(log_cls_stddev_) 60 | 61 | with tf.name_scope('loc'): 62 | loc = tf.stack(loc, axis=-2) # shape=(b, h, w, boxes_per_cell, loc_size), where loc_size=4 (x, y, w, h) 63 | with tf.name_scope('log_loc_var'): 64 | log_loc_var = tf.stack(log_loc_var, 65 | axis=-2) # shape=(b, h, w, boxes_per_cell, loc_size), where loc_size=4 (x, y, w, h) 66 | with tf.name_scope('obj'): 67 | obj = tf.stack(obj, axis=-2) 68 | obj = tf.squeeze(obj, axis=-1) # shape=(b, h, w, boxes_per_cell) 69 | with tf.name_scope('log_obj_stddev'): 70 | log_obj_stddev = tf.stack(log_obj_stddev, axis=-2) 71 | log_obj_stddev = tf.squeeze(log_obj_stddev, axis=-1) # shape=(b, h, w, boxes_per_cell) 72 | with tf.name_scope('cls'): 73 | cls = tf.stack(cls, axis=-2) # shape=(b, h, w, boxes_per_cell, cls_cnt) 74 | with tf.name_scope('log_cls_stddev'): 75 | log_cls_stddev = tf.stack(log_cls_stddev, axis=-2) # shape=(b, h, w, boxes_per_cell, cls_cnt) 76 | 77 | return { 78 | 'loc': loc, 79 | 'log_loc_var': log_loc_var, 80 | 'obj': obj, 81 | 'log_obj_stddev': log_obj_stddev, 82 | 'cls': cls, 83 | 'log_cls_stddev': log_cls_stddev, 84 | } 85 | 86 | 87 | def aleatoric_obj_loss(det, gt): # aleatoric classification loss attenuation of Alex Kendall, not active 88 | T = 42 # this is completely random 89 | 90 | expected_value = tf.zeros_like(det['obj']) 91 | obj_stddev = tf.exp(tf.clip_by_value(det['log_obj_stddev'], -40, 40)) # this guarantees positive values 92 | for i in range(T): 93 | eps = tf.random_normal(tf.shape(det['obj']), mean=0.0, stddev=1.0) 94 | x = det['obj'] + (obj_stddev * eps) # sample logits 95 | s = tf.sigmoid(x) 96 | p = tf.where(gt['obj'] > 0.5, s, 1 - s) # sigmoid probability of true class 97 | expected_value = expected_value + p 98 | expected_value = expected_value / float(T) 99 | 100 | log_loss = - tf.log(expected_value) # don't forget the minus sign 101 | return log_loss 102 | 103 | 104 | def aleatoric_cls_loss(det, gt): # aleatoric classification loss attenuation of Alex Kendall, not active 105 | T = 42 # this is completely random 106 | 107 | cls_cnt = tf.shape(det['cls'])[-1] 108 | gt_one_hot = tf.one_hot(gt['cls'], cls_cnt) 109 | 110 | # [sic] we want to use shape of gt['cls'] not det['cls']! 111 | expected_value = tf.zeros_like(gt['cls'], dtype=tf.float32) 112 | cls_stddev = tf.exp(tf.clip_by_value(det['log_cls_stddev'], -40, 40)) # this guarantees positive values 113 | for i in range(T): 114 | eps = tf.random_normal(tf.shape(det['cls']), mean=0.0, stddev=1.0) 115 | x = det['cls'] + (cls_stddev * eps) # sample logits 116 | s = tf.nn.softmax(x) 117 | # calculate softmax probability of true class (this is a little bit hacky...) 118 | p = tf.reduce_sum(s * gt_one_hot, axis=-1) 119 | expected_value = expected_value + p 120 | expected_value = expected_value / float(T) 121 | 122 | log_loss = - tf.log(expected_value) # don't forget the minus sign 123 | return log_loss 124 | 125 | 126 | def loss_tf(det, gt, aleatoric_loss=False): 127 | """ 128 | :param det: dict: 129 | ['loc'] tensor with shape=(b, h, w, boxes_per_cell, loc_size), where loc_size=4 (x, y, w, h) 130 | ['log_loc_var'] same shape as 'loc' but only guaranteed to be present if aleatoric_loss=True 131 | ['obj'] tensor with shape=(b, h, w, boxes_per_cell) 132 | ['cls'] tensor with shape=(b, h, w, boxes_per_cell, cls_cnt) 133 | :param gt: dict: 134 | ['loc'] tensor with shape=(b, h, w, boxes_per_cell, loc_size), where loc_size=4 (x, y, w, h) 135 | ['obj'] tensor with shape=(b, h, w, boxes_per_cell) 136 | ['cls'] tensor with shape=(b, h, w, boxes_per_cell, cls_cnt) 137 | 138 | :param aleatoric_loss: If the aleatoric localization loss should be added 139 | :return: 140 | """ 141 | 142 | with tf.name_scope('batch_size'): 143 | batch_size = tf.cast(tf.shape(det['loc'])[0], dtype=tf.float32) 144 | 145 | with tf.name_scope('loss'): 146 | with tf.name_scope('localization'): 147 | loc_loss = (gt['loc'] - det['loc']) 148 | loc_loss = loc_loss ** 2 149 | 150 | if aleatoric_loss: 151 | log_loc_var = tf.clip_by_value(det['log_loc_var'], clip_value_min=-40, clip_value_max=40) 152 | loc_loss = loc_loss * tf.exp(-log_loc_var) 153 | loc_loss = loc_loss + log_loc_var 154 | 155 | loc_loss = loc_loss * tf.expand_dims(gt['obj'], axis=-1) # only apply loss to cells if there is an object 156 | loc_loss_all = tf.reduce_sum(loc_loss) / (2 * batch_size) 157 | tf.summary.scalar('value', loc_loss_all) 158 | tf.losses.add_loss(loc_loss_all) 159 | 160 | with tf.name_scope('objectness'): 161 | # if aleatoric_loss: 162 | # obj_loss = aleatoric_obj_loss(det, gt) 163 | # else: 164 | obj_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=gt['obj'], logits=det['obj']) 165 | 166 | obj_loss = obj_loss * gt['ign'] 167 | obj_loss_all = tf.reduce_sum(obj_loss) / batch_size 168 | tf.summary.scalar('value', obj_loss_all) 169 | tf.losses.add_loss(obj_loss_all) 170 | 171 | with tf.name_scope('cls_x_entropy'): 172 | # if aleatoric_loss: 173 | # cls_loss = aleatoric_cls_loss(det, gt) 174 | # else: 175 | cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=gt['cls'], logits=det['cls']) 176 | 177 | cls_loss = cls_loss * gt['obj'] # only apply loss to cells if there is an object 178 | cls_loss_all = tf.reduce_sum(cls_loss) / batch_size 179 | tf.summary.scalar('value', cls_loss_all) 180 | tf.losses.add_loss(cls_loss_all) 181 | 182 | tf.summary.scalar('total', loc_loss_all + obj_loss_all + cls_loss_all) 183 | loss = { 184 | 'loc': loc_loss_all, 185 | 'obj': obj_loss_all, 186 | 'cls': cls_loss_all, 187 | } 188 | return loss 189 | 190 | 191 | def decode_bbox_standard(det, priors): 192 | """ 193 | Converts center x, center y, width and height values to coordinates of top left and bottom right points. 194 | 195 | :param predictions: outputs of YOLO v3 detector of shape (?, h, w, boxes_per_cell * (cls_cnt + 5)) 196 | :return: converted detections of same shape as input 197 | """ 198 | 199 | batches, lh, lw, box_cnt, cls_cnt = det['cls'].shape.as_list() 200 | assert box_cnt == len(priors) 201 | 202 | obj_ = tf.sigmoid(det['obj']) 203 | cls_ = tf.nn.softmax(det['cls']) 204 | 205 | loc_split = tf.split(det['loc'], [1] * box_cnt, axis=-2) 206 | obj_split = tf.split(obj_, [1] * box_cnt, axis=-1) 207 | cls_split = tf.split(cls_, [1] * box_cnt, axis=-2) 208 | 209 | # calculate x, y offsets 210 | grid_x = tf.range(lw, dtype=tf.float32) 211 | grid_y = tf.range(lh, dtype=tf.float32) 212 | x_offset, y_offset = tf.meshgrid(grid_x, grid_y) 213 | 214 | shape = x_offset.shape.as_list() 215 | assert shape[0] == lh and shape[1] == lw 216 | shape = y_offset.shape.as_list() 217 | assert shape[0] == lh and shape[1] == lw 218 | 219 | x_offset = tf.expand_dims(x_offset, axis=0) 220 | y_offset = tf.expand_dims(y_offset, axis=0) 221 | 222 | result = [] 223 | 224 | for idx, p in enumerate(priors): 225 | loc = loc_split[idx] 226 | obj = obj_split[idx] 227 | cls = cls_split[idx] 228 | 229 | x, y, w, h = tf.split(loc, [1, 1, 1, 1], axis=-1) 230 | 231 | # squeeze dim one axis from splitting by prior 232 | x = tf.squeeze(x, axis=[-2, -1]) # don't squeeze the batch axis 233 | y = tf.squeeze(y, axis=[-2, -1]) # don't squeeze the batch axis 234 | w = tf.squeeze(w, axis=[-2, -1]) # don't squeeze the batch axis 235 | h = tf.squeeze(h, axis=[-2, -1]) # don't squeeze the batch axis 236 | 237 | cls = tf.squeeze(cls, axis=[-2]) 238 | 239 | # calc bbox coordinates 240 | x = (x_offset + tf.sigmoid(x)) / lw 241 | y = (y_offset + tf.sigmoid(y)) / lh 242 | w = (tf.exp(w) * p.w) 243 | h = (tf.exp(h) * p.h) 244 | 245 | # center + width and height -> upper left and lower right corner 246 | w2 = w / 2 247 | h2 = h / 2 248 | x0 = x - w2 249 | y0 = y - h2 250 | x1 = x + w2 251 | y1 = y + h2 252 | 253 | # store everything in one tensor with dim N x (4 + 1 + cls_cnt) 254 | bbox = tf.stack([y0, x0, y1, x1], axis=-1) 255 | bbox = tf.concat([bbox, obj, cls], axis=-1) 256 | result.append(bbox) 257 | 258 | return result 259 | 260 | 261 | def decode_bbox_aleatoric(det, priors, layer_id): 262 | """ 263 | Converts center x, center y, width and height values to coordinates of top left and bottom right points. 264 | 265 | :param predictions: outputs of YOLO v3 detector of shape (?, h, w, boxes_per_cell * (detection_size)) 266 | :return: converted detections of same shape as input 267 | """ 268 | 269 | batches, lh, lw, box_cnt, cls_cnt = det['cls'].shape.as_list() 270 | assert box_cnt == len(priors) 271 | 272 | loc_var_ = tf.exp(det['log_loc_var']) 273 | obj_ = tf.sigmoid(det['obj']) 274 | cls_ = tf.nn.softmax(det['cls']) 275 | 276 | obj_entropy_ = logistic_entropy(obj_) 277 | cls_entropy_ = softmax_entropy(cls_) 278 | 279 | loc_split = tf.split(det['loc'], [1] * box_cnt, axis=-2) 280 | loc_var_split = tf.split(loc_var_, [1] * box_cnt, axis=-2) 281 | 282 | obj_split = tf.split(obj_, [1] * box_cnt, axis=-1) 283 | obj_entropy_split = tf.split(obj_entropy_, [1] * box_cnt, axis=-1) 284 | 285 | cls_split = tf.split(cls_, [1] * box_cnt, axis=-2) 286 | cls_entropy_split = tf.split(cls_entropy_, [1] * box_cnt, axis=-1) 287 | 288 | # calculate x, y offsets 289 | grid_x = tf.range(lw, dtype=tf.float32) 290 | grid_y = tf.range(lh, dtype=tf.float32) 291 | x_offset, y_offset = tf.meshgrid(grid_x, grid_y) 292 | 293 | shape = x_offset.shape.as_list() 294 | assert shape[0] == lh and shape[1] == lw 295 | shape = y_offset.shape.as_list() 296 | assert shape[0] == lh and shape[1] == lw 297 | 298 | x_offset = tf.expand_dims(x_offset, axis=0) 299 | y_offset = tf.expand_dims(y_offset, axis=0) 300 | 301 | result = [] 302 | 303 | for idx, p in enumerate(priors): 304 | loc = loc_split[idx] 305 | loc_var = loc_var_split[idx] 306 | obj = obj_split[idx] 307 | obj_entropy = obj_entropy_split[idx] 308 | cls = cls_split[idx] 309 | cls_entropy = cls_entropy_split[idx] 310 | 311 | x, y, w, h = tf.split(loc, [1, 1, 1, 1], axis=-1) 312 | 313 | # squeeze dim one axis from splitting by prior 314 | x = tf.squeeze(x, axis=[-2, -1]) # don't squeeze the batch axis 315 | y = tf.squeeze(y, axis=[-2, -1]) # don't squeeze the batch axis 316 | w = tf.squeeze(w, axis=[-2, -1]) # don't squeeze the batch axis 317 | h = tf.squeeze(h, axis=[-2, -1]) # don't squeeze the batch axis 318 | 319 | loc_var = tf.squeeze(loc_var, axis=[-2]) 320 | cls = tf.squeeze(cls, axis=[-2]) 321 | 322 | # calc bbox coordinates 323 | x = (x_offset + tf.sigmoid(x)) / lw 324 | y = (y_offset + tf.sigmoid(y)) / lh 325 | w = (tf.exp(w) * p.w) 326 | h = (tf.exp(h) * p.h) 327 | 328 | # center + width and height -> upper left and lower right corner 329 | w2 = w / 2 330 | h2 = h / 2 331 | x0 = x - w2 332 | y0 = y - h2 333 | x1 = x + w2 334 | y1 = y + h2 335 | 336 | loc_ale_total_var = tf.reduce_prod(loc_var, axis=-1) 337 | 338 | ones = tf.ones_like(cls_entropy) 339 | 340 | # store everything in one tensor with dim N x ((4 + 4 + 1) + (1 + 1) + (cls_cnt + 1)) 341 | bbox = tf.stack([y0, x0, y1, x1], axis=-1) 342 | bbox = tf.concat([bbox, loc_var, tf.expand_dims(loc_ale_total_var, axis=-1), obj, obj_entropy, cls, cls_entropy, 343 | layer_id * ones, idx * ones], axis=-1) # layer_id, prior_id], axis=-1) 344 | result.append(bbox) 345 | 346 | return result 347 | 348 | 349 | def logistic_entropy(scores): 350 | no_obj = (1 - scores) * tf.log((1 - scores)) 351 | obj = scores * tf.log(scores) 352 | entropy = -(no_obj + obj) # Note: there is a minus sign! 353 | return entropy 354 | 355 | 356 | def softmax_entropy(scores): 357 | entropy = - tf.reduce_sum(scores * tf.log(scores), axis=-1) # Note: there is a minus sign! 358 | return entropy 359 | 360 | 361 | def decode_epistemic(det): 362 | """ 363 | Calc mean, var and classification uncertainty for T forward passes of the same image 364 | """ 365 | 366 | loc = det['loc'] 367 | loc_var = tf.exp(det['log_loc_var']) 368 | obj = tf.sigmoid(det['obj']) 369 | obj_stddev = tf.exp(det['log_obj_stddev']) # ignore 370 | cls = tf.nn.softmax(det['cls']) 371 | cls_stddev = tf.exp(det['log_cls_stddev']) # ignore 372 | 373 | # localization (co)variance 374 | loc_col = tf.expand_dims(loc, axis=-1) # last two dimensions represent a column vector 375 | loc_row = tf.expand_dims(loc, axis=-2) # last two dimenstion represent a row vector 376 | 377 | ev_loc = tf.reduce_mean(loc, axis=0) 378 | ev_loc_col = tf.expand_dims(ev_loc, axis=-1) # last two dimensions represent a column vector 379 | ev_loc_row = tf.expand_dims(ev_loc, axis=-2) # last two dimenstion represent a row vector 380 | 381 | ev_loc_locT = tf.reduce_mean(loc_col * loc_row, axis=0) # E[loc_col * loc_row] (4 x 4) 382 | 383 | epi_covar_loc = ev_loc_locT - (ev_loc_col * ev_loc_row) 384 | ale_var_loc = tf.reduce_mean(loc_var, axis=0) 385 | 386 | # class and objectness uncertainty 387 | obj_mean = tf.reduce_mean(obj, axis=0) 388 | obj_predictive_entropy = logistic_entropy(obj_mean) 389 | obj_posterior_entropy = tf.reduce_mean(logistic_entropy(obj), axis=0) 390 | obj_mutual_info = obj_predictive_entropy - obj_posterior_entropy 391 | 392 | cls_mean = tf.reduce_mean(cls, axis=0) 393 | cls_predictive_entropy = softmax_entropy(cls_mean) 394 | cls_posterior_entropy = tf.reduce_mean(softmax_entropy(cls), axis=0) 395 | cls_mutual_info = cls_predictive_entropy - cls_posterior_entropy 396 | 397 | return { 398 | 'ev_loc': ev_loc, # shape=(lh, lw, box_cnt, 4) 399 | 'epi_covar_loc': epi_covar_loc, # shape=(lh, lw, box_cnt, 4, 4) 400 | 'ale_var_loc': ale_var_loc, # shape=(lh, lw, box_cnt, 4) 401 | 402 | 'obj_samples': obj, # shape=(T, lh, lw, box_cnt) # TODO currently irrelevant 403 | 'obj_mean': obj_mean, # shape=(lh, lw, box_cnt) 404 | 'obj_mutual_info': obj_mutual_info, # shape=(lh, lw, box_cnt) 405 | 'obj_entropy': obj_predictive_entropy, # shape=(lh, lw, box_cnt) 406 | 407 | 'cls_samples': cls, # shape=(T, lh, lw, box_cnt, cls_cnt) # TODO currently irrelevant 408 | 'cls_mean': cls_mean, # shape=(lh, lw, box_cnt, cls_cnt) 409 | 'cls_mutual_info': cls_mutual_info, # shape=(lh, lw, box_cnt) 410 | 'cls_entropy': cls_predictive_entropy, # shape=(lh, lw, box_cnt) 411 | } 412 | 413 | 414 | def decode_bbox_epistemic(det_epistemic, priors, layer_id): 415 | T, lh, lw, box_cnt, cls_cnt = det_epistemic[ 416 | 'cls_samples'].shape.as_list() # T == number of forward passes for same image 417 | assert box_cnt == len(priors) 418 | 419 | ev_loc_split = tf.split(det_epistemic['ev_loc'], [1] * box_cnt, axis=-2) 420 | epi_covar_loc_split = tf.split(det_epistemic['epi_covar_loc'], [1] * box_cnt, axis=-3) 421 | ale_var_loc_split = tf.split(det_epistemic['ale_var_loc'], [1] * box_cnt, axis=-2) 422 | 423 | obj_mean_split = tf.split(det_epistemic['obj_mean'], [1] * box_cnt, axis=-1) 424 | obj_mutual_info_split = tf.split(det_epistemic['obj_mutual_info'], [1] * box_cnt, axis=-1) 425 | obj_entropy_split = tf.split(det_epistemic['obj_entropy'], [1] * box_cnt, axis=-1) 426 | 427 | cls_mean_split = tf.split(det_epistemic['cls_mean'], [1] * box_cnt, axis=-2) 428 | cls_mutual_info_split = tf.split(det_epistemic['cls_mutual_info'], [1] * box_cnt, axis=-1) 429 | cls_entropy_split = tf.split(det_epistemic['cls_entropy'], [1] * box_cnt, axis=-1) 430 | 431 | result = [] 432 | 433 | # calculate x, y offsets 434 | grid_x = tf.range(lw, dtype=tf.float32) 435 | grid_y = tf.range(lh, dtype=tf.float32) 436 | x_offset, y_offset = tf.meshgrid(grid_x, grid_y) 437 | 438 | ones = tf.ones(shape=(lh, lw, 1), dtype=tf.float32) 439 | 440 | shape = x_offset.shape.as_list() 441 | assert shape[0] == lh and shape[1] == lw 442 | shape = y_offset.shape.as_list() 443 | assert shape[0] == lh and shape[1] == lw 444 | 445 | for idx, p in enumerate(priors): 446 | ev_loc = ev_loc_split[idx] 447 | 448 | epi_covar_loc = epi_covar_loc_split[idx] 449 | ale_var_loc = ale_var_loc_split[idx] 450 | 451 | obj_mean = obj_mean_split[idx] 452 | obj_mutual_info = obj_mutual_info_split[idx] 453 | obj_entropy = obj_entropy_split[idx] 454 | 455 | cls_mean = cls_mean_split[idx] 456 | cls_mutual_info = cls_mutual_info_split[idx] 457 | cls_entropy = cls_entropy_split[idx] 458 | 459 | x, y, w, h = tf.split(ev_loc, [1, 1, 1, 1], axis=-1) 460 | 461 | # squeeze dim one axis from splitting 462 | x = tf.squeeze(x, axis=[-2, -1]) 463 | y = tf.squeeze(y, axis=[-2, -1]) 464 | w = tf.squeeze(w, axis=[-2, -1]) 465 | h = tf.squeeze(h, axis=[-2, -1]) 466 | 467 | epi_covar_loc = tf.squeeze(epi_covar_loc, axis=[-3]) 468 | ale_var_loc = tf.squeeze(ale_var_loc, axis=[-2]) 469 | cls_mean = tf.squeeze(cls_mean, axis=[-2]) 470 | 471 | # calc bbox coordinates 472 | x = (x_offset + tf.sigmoid(x)) / lw 473 | y = (y_offset + tf.sigmoid(y)) / lh 474 | w = (tf.exp(w) * p.w) 475 | h = (tf.exp(h) * p.h) 476 | 477 | # center + width and height -> upper left and lower right corner 478 | w2 = w / 2 479 | h2 = h / 2 480 | x0 = x - w2 481 | y0 = y - h2 482 | x1 = x + w2 483 | y1 = y + h2 484 | # store everything in one tensor with dim: 485 | # N x ((4 + 4 + 4 + 1 + 1) + (1 + 1) + (cls_cnt + 1 + 1)) 486 | # (localization) + (obj) + (cls) 487 | 488 | loc_epi_total_var = tf.linalg.det(epi_covar_loc) 489 | loc_ale_var = tf.reduce_sum(ale_var_loc, axis=-1) 490 | bbox = tf.stack([y0, x0, y1, x1], axis=-1) 491 | 492 | epi_loc_var = tf.linalg.diag_part(epi_covar_loc) 493 | bbox = tf.concat([bbox, 494 | epi_loc_var, ale_var_loc, # epistemic and aleatoric var of x, y, w, h 495 | tf.expand_dims(loc_epi_total_var, axis=-1), # total var epi 496 | tf.expand_dims(loc_ale_var, axis=-1), # total var ale 497 | obj_mean, obj_mutual_info, obj_entropy, 498 | cls_mean, cls_mutual_info, cls_entropy, 499 | layer_id * ones, idx * ones], axis=-1) # layer_id, prior_id 500 | result.append(bbox) 501 | 502 | return result 503 | 504 | 505 | def residual(inputs, shortcut): 506 | inputs = inputs + shortcut 507 | return inputs 508 | 509 | 510 | def darknet_batch_norm(inputs, training, trainable): 511 | inputs = tf.layers.batch_normalization(inputs, training=training, trainable=trainable, epsilon=1e-05) 512 | return inputs 513 | 514 | 515 | def batch_norm(inputs, training): 516 | inputs = tf.layers.batch_normalization(inputs, training=training, 517 | epsilon=1e-05) 518 | return inputs 519 | 520 | 521 | def dropout(inputs, drop_prob, standard_test_dropout=False): 522 | training = not standard_test_dropout 523 | inputs = tf.layers.dropout(inputs, rate=drop_prob, training=training) # we always want dropout 524 | return inputs 525 | 526 | 527 | def darknet_conv(inputs, filters, kernel_size, strides, training, trainable, weight_regularizer): 528 | assert kernel_size in [1, 3], 'invalid kernel size' 529 | assert strides in [1, 2], 'invalid strides' 530 | if not trainable: 531 | assert not training 532 | 533 | if strides > 1: 534 | # the padding in tensorflow and darknet framework differ (darknet is the same as cafe) 535 | # https://stackoverflow.com/questions/42924324/tensorflows-asymmetric-padding-assumptions 536 | inputs = darknet_downsample_padding(inputs, kernel_size) 537 | padding = 'VALID' 538 | else: 539 | padding = 'SAME' 540 | 541 | normalizer = {'type': 'darknet_bn', 'training': training} 542 | return conv(inputs, filters, kernel_size, strides, normalizer, trainable, weight_regularizer, padding=padding) 543 | 544 | 545 | def conv(inputs, filters, kernel_size, strides, normalizer, trainable, weight_regularizer, padding='SAME'): 546 | assert kernel_size in [1, 3], 'invalid kernel size' 547 | assert strides in [1, 2], 'invalid strides' 548 | 549 | use_bias = False 550 | inputs = tf.layers.conv2d(inputs, filters, kernel_size, strides=strides, activation=None, padding=padding, 551 | use_bias=use_bias, 552 | trainable=trainable, 553 | kernel_regularizer=weight_regularizer, 554 | bias_regularizer=weight_regularizer if use_bias else None) 555 | 556 | # check for multiple normalizers: 557 | if isinstance(normalizer, dict): 558 | normalizer = [normalizer] 559 | 560 | # possible to add multiple normalizers (first dropout then batch norm) 561 | for n in normalizer: 562 | if n['type'] == 'bn': 563 | inputs = batch_norm(inputs, training=n['training']) 564 | elif n['type'] == 'darknet_bn': 565 | inputs = darknet_batch_norm(inputs, training=n['training'], trainable=trainable) 566 | elif n['type'] == 'dropout': 567 | if n.get('standard_test_dropout', False): 568 | dropout(inputs, n['drop_prob'], standard_test_dropout=True) 569 | else: 570 | inputs = dropout(inputs, n['drop_prob']) 571 | elif n['type'] is not None: 572 | raise ValueError('Invalid regularizer type: {}'.format(n['type'])) 573 | 574 | inputs = tf.nn.leaky_relu(inputs, alpha=0.1) 575 | return inputs 576 | 577 | 578 | def upsample(inputs): 579 | shape = tf.shape(inputs) # NHWC 580 | return tf.image.resize_nearest_neighbor(inputs, (2 * shape[1], 2 * shape[2])) # upsample by factor 2 581 | 582 | 583 | def route(routes): 584 | assert len(routes) < 3, 'too many routes' 585 | assert len(routes), 'too few routes' 586 | 587 | if len(routes) > 1: 588 | inputs = tf.concat(routes, axis=3) # concatenate channels (if channels_first, then axis=1) 589 | else: 590 | inputs = tf.identity(routes[0]) # use identity layer to avoid name clashing when loading darknet weights 591 | 592 | return inputs 593 | 594 | 595 | def stack_feature_map(inputs, T): 596 | inputs = tf.concat([inputs] * T, axis=0) 597 | return inputs 598 | 599 | 600 | def detection(inputs, cls_cnt, box_cnt, weight_regularizer): 601 | filters = box_cnt * (4 + 1 + cls_cnt) 602 | # use linear activation! 603 | inputs = tf.layers.conv2d(inputs, filters, 1, strides=1, activation=None, padding='SAME', 604 | kernel_regularizer=weight_regularizer, bias_regularizer=weight_regularizer) 605 | return inputs 606 | 607 | 608 | def detection_aleatoric(inputs, cls_cnt, box_cnt, weight_regularizer): 609 | filters = box_cnt * (2 * (4 + 1 + cls_cnt)) 610 | # use linear activation! 611 | inputs = tf.layers.conv2d(inputs, filters, 1, strides=1, activation=None, padding='SAME', 612 | kernel_regularizer=weight_regularizer, bias_regularizer=weight_regularizer) 613 | return inputs 614 | 615 | 616 | def darknet_downsample_padding(inputs, kernel_size): 617 | """ 618 | the padding in tensorflow and darknet framework differ, darknet is the same as cafe, see: 619 | https://stackoverflow.com/questions/42924324/tensorflows-asymmetric-padding-assumptions 620 | 621 | For us this only makes a difference when downsampling (conv2d 3x3 filter and stride 2). 622 | 623 | Note: Maxpooling with a 2x2 kernels also differs between tensorflow and the darknet framework. 624 | However since we do not use maxpool it is ignored here. 625 | :param inputs: 626 | :param kernel_size: 627 | :return: 628 | """ 629 | assert kernel_size == 3, 'invalid kernel size' 630 | 631 | pad_front = 1 # this differs from the standard tf padding when: stride = 2, kernel_size = 3 and input_size is even 632 | pad_end = 1 # could be 0 if input_size is odd, but the overhead of padding is negligible 633 | 634 | inputs = tf.pad(inputs, [[0, 0], [pad_front, pad_end], [pad_front, pad_end], [0, 0]], mode='CONSTANT') 635 | return inputs 636 | -------------------------------------------------------------------------------- /lib_yolo/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from lib_yolo import layers, data 4 | 5 | 6 | def img_size_and_priors_if_crop(config): # TODO fn name is not very descriptive 7 | img_size = config['crop_img_size'] if config['crop'] else config['full_img_size'] 8 | priors = config['priors'] 9 | 10 | if config['crop']: 11 | # priors are always defined for the full img => need to rescale if we crop the image 12 | scale_h = config['full_img_size'][0] / float(config['crop_img_size'][0]) 13 | scale_w = config['full_img_size'][1] / float(config['crop_img_size'][1]) 14 | for stride, prs in priors.items(): 15 | priors[stride] = [data.Prior(h=p.h * scale_h, w=p.w * scale_w) for p in prs] 16 | 17 | return img_size, priors 18 | 19 | 20 | class ModelBuilder: 21 | def __init__(self, inputs, cls_cnt, l2_scale=0.0005): 22 | self.__layers = [] 23 | self.inputs = inputs 24 | self.__cls_cnt = cls_cnt 25 | self.__current_downsample = 1 26 | self.__det_layers = [] 27 | self.__weight_regularizer = tf.contrib.layers.l2_regularizer(l2_scale) 28 | 29 | shape = inputs.shape.as_list() 30 | assert len(shape) == 4 31 | self.__input_size = [shape[1], shape[2]] # h x w 32 | 33 | def layer_cnt(self): 34 | return len(self.__layers) 35 | 36 | def get_model(self, obj_idx, cls_start_idx): 37 | assert obj_idx < cls_start_idx 38 | return Model(self.__layers, self.__det_layers, self.__cls_cnt, obj_idx, cls_start_idx) 39 | 40 | def __update_layers(self): 41 | self.__layers.append(self.inputs) 42 | 43 | def __conv_layer(self, filters, kernel_size, strides, normalizer, variable_scope): 44 | assert strides in [1, 2] 45 | 46 | with tf.variable_scope(None, default_name=variable_scope): 47 | # use_bias = True 48 | self.inputs = layers.conv(self.inputs, filters, kernel_size, strides, normalizer, trainable=True, 49 | weight_regularizer=self.__weight_regularizer) 50 | self.__update_layers() 51 | 52 | def make_conv_layer(self, filters, kernel_size, normalizer): 53 | self.__conv_layer(filters, kernel_size, 1, normalizer, 'conv') 54 | 55 | def make_downsample_layer(self, filters, kernel_size, normalizer): 56 | self.__current_downsample *= 2 57 | self.__conv_layer(filters, kernel_size, 2, normalizer, 'downsample') 58 | 59 | def __darknet_conv_layer(self, filters, kernel_size, strides, training, trainable, variable_scope): 60 | assert strides in [1, 2] 61 | if not trainable: 62 | assert not training 63 | 64 | with tf.variable_scope(None, default_name=variable_scope): 65 | self.inputs = layers.darknet_conv(self.inputs, filters, kernel_size, strides, training, trainable, 66 | self.__weight_regularizer) 67 | self.__update_layers() 68 | 69 | def make_darknet_conv_layer(self, filters, kernel_size, training, trainable): 70 | self.__darknet_conv_layer(filters, kernel_size, 1, training, trainable, 'conv') 71 | 72 | def make_darknet_downsample_layer(self, filters, kernel_size, training, trainable): 73 | self.__current_downsample *= 2 74 | self.__darknet_conv_layer(filters, kernel_size, 2, training, trainable, 'downsample') 75 | 76 | def make_route_layer(self, routes): 77 | with tf.variable_scope(None, default_name='route'): 78 | self.inputs = layers.route([self.__layers[route] for route in routes]) 79 | self.__update_layers() 80 | 81 | def make_stack_feature_map_layer(self, layer, T): 82 | with tf.variable_scope(None, default_name='stack_feature_map'): 83 | self.inputs = layers.stack_feature_map(self.__layers[layer], T) 84 | self.__update_layers() 85 | 86 | def make_residual_layer(self, shortcut): 87 | with tf.variable_scope(None, default_name='residual'): 88 | self.inputs = layers.residual(self.inputs, self.__layers[shortcut]) 89 | self.__update_layers() 90 | 91 | def make_residual_block(self, filters, normalizer): 92 | self.make_conv_layer(filters, 1, normalizer) 93 | self.make_conv_layer(2 * filters, 3, normalizer) 94 | self.make_residual_layer(-3) 95 | 96 | def make_darknet_residual_block(self, filters, training, trainable): 97 | self.make_darknet_conv_layer(filters, 1, training, trainable) 98 | self.make_darknet_conv_layer(2 * filters, 3, training, trainable) 99 | self.make_residual_layer(-3) 100 | 101 | def make_upsample_layer(self): 102 | self.__current_downsample //= 2 103 | with tf.variable_scope(None, default_name='upsample'): 104 | self.inputs = layers.upsample(self.inputs) 105 | self.__update_layers() 106 | 107 | def make_detection_layer(self, all_priors, gt=None): 108 | priors = all_priors[self.__current_downsample] 109 | box_cnt = len(priors) 110 | with tf.variable_scope('detection'): 111 | self.inputs = layers.detection(self.inputs, cls_cnt=self.__cls_cnt, box_cnt=box_cnt, 112 | weight_regularizer=self.__weight_regularizer) 113 | self.__update_layers() 114 | det = layers.split_detection(self.inputs, boxes_per_cell=box_cnt, 115 | cls_cnt=self.__cls_cnt) 116 | bbox = layers.decode_bbox_standard(det, priors) 117 | if gt: 118 | loss = layers.loss_tf(det, gt, aleatoric_loss=False) 119 | 120 | self.__det_layers.append(DetLayer( 121 | input_img_size=self.__input_size, 122 | downsample_factor=self.__current_downsample, 123 | priors=priors, 124 | loss=loss if gt else None, 125 | det=det, 126 | bbox=bbox, 127 | raw_output=self.inputs, 128 | )) 129 | 130 | return self.inputs 131 | 132 | def make_detection_layer_aleatoric(self, all_priors, aleatoric_loss, gt=None): 133 | priors = all_priors[self.__current_downsample] 134 | box_cnt = len(priors) 135 | with tf.variable_scope('detection'): 136 | self.inputs = layers.detection_aleatoric(self.inputs, cls_cnt=self.__cls_cnt, box_cnt=box_cnt, 137 | weight_regularizer=self.__weight_regularizer) 138 | self.__update_layers() 139 | det = layers.split_detection_aleatoric(self.inputs, boxes_per_cell=box_cnt, 140 | cls_cnt=self.__cls_cnt) 141 | bbox = layers.decode_bbox_aleatoric(det, priors, layer_id=len(self.__det_layers)) 142 | if gt: 143 | loss = layers.loss_tf(det, gt, aleatoric_loss) 144 | 145 | self.__det_layers.append(DetLayer( 146 | input_img_size=self.__input_size, 147 | downsample_factor=self.__current_downsample, 148 | priors=priors, 149 | loss=loss if gt else None, 150 | det=det, 151 | bbox=bbox, 152 | raw_output=self.inputs, 153 | )) 154 | 155 | return self.inputs 156 | 157 | def make_detection_layer_aleatoric_epistemic(self, all_priors, aleatoric_loss, gt=None, inference_mode=False): 158 | priors = all_priors[self.__current_downsample] 159 | box_cnt = len(priors) 160 | with tf.variable_scope('detection'): 161 | self.inputs = layers.detection_aleatoric(self.inputs, cls_cnt=self.__cls_cnt, box_cnt=box_cnt, 162 | weight_regularizer=self.__weight_regularizer) 163 | self.__update_layers() 164 | det = layers.split_detection_aleatoric(self.inputs, boxes_per_cell=box_cnt, 165 | cls_cnt=self.__cls_cnt) 166 | if inference_mode: 167 | det = layers.decode_epistemic(det) 168 | bbox = layers.decode_bbox_epistemic(det, priors, layer_id=len(self.__det_layers)) 169 | else: 170 | bbox = layers.decode_bbox_aleatoric(det, priors, layer_id=len(self.__det_layers)) 171 | 172 | if gt: 173 | loss = layers.loss_tf(det, gt, aleatoric_loss) 174 | 175 | self.__det_layers.append(DetLayer( 176 | input_img_size=self.__input_size, 177 | downsample_factor=self.__current_downsample, 178 | priors=priors, 179 | loss=loss if gt else None, 180 | det=det, 181 | bbox=bbox, 182 | raw_output=self.inputs, 183 | )) 184 | 185 | return self.inputs 186 | 187 | 188 | class Model: 189 | def __init__(self, layers, det_layers, cls_cnt, obj_idx, cls_start_idx): 190 | assert len(det_layers) > 0 191 | self.layers = layers 192 | self.det_layers = det_layers 193 | self.cls_cnt = cls_cnt 194 | self.obj_idx = obj_idx 195 | self.cls_start_idx = cls_start_idx 196 | 197 | if self.det_layers[0].loc_loss is not None: 198 | with tf.name_scope('global_loss'): 199 | self.detection_loss = tf.losses.get_total_loss(add_regularization_losses=False, name='detection_loss') 200 | self.regularization_loss = tf.losses.get_regularization_loss(name='regularization_loss') 201 | self.total_loss = tf.add(self.detection_loss, self.regularization_loss, name='total_loss') 202 | 203 | self.loc_loss = det_layers[0].loc_loss 204 | self.obj_loss = det_layers[0].obj_loss 205 | self.cls_loss = det_layers[0].cls_loss 206 | for l in self.det_layers[1:]: 207 | self.loc_loss += l.loc_loss 208 | self.obj_loss += l.obj_loss 209 | self.cls_loss += l.cls_loss 210 | 211 | tf.summary.scalar('loc', self.loc_loss) 212 | tf.summary.scalar('obj', self.obj_loss) 213 | tf.summary.scalar('cls', self.cls_loss) 214 | tf.summary.scalar('detection', self.detection_loss) 215 | tf.summary.scalar('l2_weight_reg', self.regularization_loss) 216 | tf.summary.scalar('total', self.total_loss) 217 | 218 | def matches_blueprint(self, blueprint): 219 | try: 220 | for dl, bpdl in zip(self.det_layers, blueprint.det_layers): 221 | assert dl.matches_blueprint(bpdl) 222 | assert self.cls_cnt == blueprint.cls_cnt 223 | except AssertionError: 224 | return False 225 | return True 226 | 227 | 228 | class DetLayer: 229 | def __init__(self, input_img_size, downsample_factor, priors, loss, det, bbox, raw_output): 230 | self.h = input_img_size[0] // downsample_factor 231 | self.w = input_img_size[1] // downsample_factor 232 | self.downsample = downsample_factor 233 | self.priors = priors 234 | self.loc_loss = loss['loc'] if loss else None 235 | self.obj_loss = loss['obj'] if loss else None 236 | self.cls_loss = loss['cls'] if loss else None 237 | 238 | self.det = det 239 | self.bbox = bbox 240 | 241 | self.raw_output = raw_output 242 | 243 | def matches_blueprint(self, blueprint): 244 | try: 245 | assert self.h == blueprint.h 246 | assert self.w == blueprint.w 247 | assert self.downsample == blueprint.downsample 248 | assert len(self.priors) == len(blueprint.priors) 249 | for p, bpp in zip(self.priors, blueprint.priors): 250 | assert p.h == bpp.h 251 | assert p.w == bpp.w 252 | except AssertionError: 253 | return False 254 | return True 255 | 256 | 257 | class ModelBlueprint: 258 | def __init__(self, det_layers, cls_cnt): 259 | self.det_layers = det_layers 260 | self.cls_cnt = cls_cnt 261 | 262 | 263 | class DetLayerBlueprint: 264 | def __init__(self, input_img_size, downsample_factor, priors): 265 | self.h = input_img_size[0] // downsample_factor 266 | self.w = input_img_size[1] // downsample_factor 267 | self.downsample = downsample_factor 268 | self.priors = priors 269 | -------------------------------------------------------------------------------- /lib_yolo/tfdata.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from scipy.special import logit 3 | 4 | from lib_yolo import data 5 | 6 | 7 | def logit(x): 8 | """ 9 | inverse of sigmoid function 10 | """ 11 | return - tf.log((1. / x) - 1.) 12 | 13 | 14 | def create_prior_data(det_layers): 15 | # create priors 16 | bboxes = bbox_areas = cx = cy = pw = ph = lw = lh = center_x = center_y = None 17 | for layer in det_layers: 18 | bboxes_, bbox_areas_, cx_, cy_, pw_, ph_, lw_, lh_, center_x_, center_y_ = data.create_prior_data(layer) 19 | 20 | bboxes_ = tf.constant(bboxes_, dtype=tf.float32) 21 | bbox_areas_ = tf.constant(bbox_areas_, dtype=tf.float32) 22 | cx_ = tf.constant(cx_, dtype=tf.float32) 23 | cy_ = tf.constant(cy_, dtype=tf.float32) 24 | pw_ = tf.constant(pw_, dtype=tf.float32) 25 | ph_ = tf.constant(ph_, dtype=tf.float32) 26 | center_x_ = tf.constant(center_x_, dtype=tf.float32) 27 | center_y_ = tf.constant(center_y_, dtype=tf.float32) 28 | 29 | bboxes_ = tf.reshape(bboxes_, shape=[-1, 4]) 30 | bbox_areas_ = tf.reshape(bbox_areas_, shape=[-1]) 31 | cx_ = tf.reshape(cx_, shape=[-1]) 32 | cy_ = tf.reshape(cy_, shape=[-1]) 33 | pw_ = tf.reshape(pw_, shape=[-1]) 34 | ph_ = tf.reshape(ph_, shape=[-1]) 35 | lw_ = tf.reshape(lw_, shape=[-1]) 36 | lh_ = tf.reshape(lh_, shape=[-1]) 37 | center_x_ = tf.reshape(center_x_, shape=[-1]) 38 | center_y_ = tf.reshape(center_y_, shape=[-1]) 39 | 40 | if bboxes is None: 41 | bboxes = bboxes_ 42 | bbox_areas = bbox_areas_ 43 | cx = cx_ 44 | cy = cy_ 45 | pw = pw_ 46 | ph = ph_ 47 | lw = lw_ 48 | lh = lh_ 49 | center_x = center_x_ 50 | center_y = center_y_ 51 | else: 52 | bboxes = tf.concat([bboxes, bboxes_], axis=0) 53 | bbox_areas = tf.concat([bbox_areas, bbox_areas_], axis=0) 54 | cx = tf.concat([cx, cx_], axis=0) 55 | cy = tf.concat([cy, cy_], axis=0) 56 | pw = tf.concat([pw, pw_], axis=0) 57 | ph = tf.concat([ph, ph_], axis=0) 58 | lw = tf.concat([lw, lw_], axis=0) 59 | lh = tf.concat([lh, lh_], axis=0) 60 | center_x = tf.concat([center_x, center_x_], axis=0) 61 | center_y = tf.concat([center_y, center_y_], axis=0) 62 | 63 | return { 64 | 'bboxes': bboxes, 65 | 'bbox_areas': bbox_areas, 66 | 'cx': cx, 67 | 'cy': cy, 68 | 'pw': pw, 69 | 'ph': ph, 70 | 'lw': lw, 71 | 'lh': lh, 72 | 'center_x': center_x, # TODO unused 73 | 'center_y': center_y, # TODO unused 74 | } 75 | 76 | 77 | def encode_boxes(bboxes, labels, det_layers, ign_thresh): 78 | prior_data = create_prior_data(det_layers) 79 | 80 | # initialize output 81 | total_box_cnt = 0 82 | layer_sizes = [] 83 | for layer in det_layers: 84 | boxes_per_cell = len(layer.priors) 85 | layer_sizes.append(layer.h * layer.w * boxes_per_cell) 86 | total_box_cnt += layer_sizes[-1] 87 | 88 | # loc = tf.zeros(shape=(total_box_cnt, 4), dtype=tf.float32) 89 | loc_x = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 90 | loc_y = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 91 | loc_w = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 92 | loc_h = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 93 | obj = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 94 | cls = tf.zeros(shape=(total_box_cnt,), dtype=tf.int32) 95 | ign = tf.ones(shape=(total_box_cnt,), dtype=tf.float32) 96 | 97 | ones_float = tf.ones(shape=(total_box_cnt,), dtype=tf.float32) 98 | ones_int = tf.ones(shape=(total_box_cnt,), dtype=tf.int32) 99 | zeros = tf.zeros(shape=(total_box_cnt,), dtype=tf.float32) 100 | 101 | w = bboxes[..., 3] - bboxes[..., 1] 102 | h = bboxes[..., 2] - bboxes[..., 0] 103 | x = (bboxes[..., 3] + bboxes[..., 1]) / 2. 104 | y = (bboxes[..., 2] + bboxes[..., 0]) / 2. 105 | 106 | def loop_condition(idx_, *vargs): 107 | return tf.less(idx_, tf.shape(labels)[0]) 108 | 109 | # calc overlaps between gt bboxes and priors 110 | def loop_body(idx_, loc_x_, loc_y_, loc_w_, loc_h_, obj_, cls_, ign_): 111 | bbox = bboxes[idx_] 112 | 113 | # TODO can this be optimized? x[idx] * ones 114 | dist_to_cell_center_x = prior_data['lw'] * (x[idx_] - prior_data['cx']) 115 | dist_to_cell_center_y = prior_data['lh'] * (y[idx_] - prior_data['cy']) 116 | x_obj_mask = tf.logical_and(tf.greater_equal(dist_to_cell_center_x, 0), tf.less_equal(dist_to_cell_center_x, 1)) 117 | y_obj_mask = tf.logical_and(tf.greater_equal(dist_to_cell_center_y, 0), tf.less_equal(dist_to_cell_center_y, 1)) 118 | obj_mask = tf.logical_and(x_obj_mask, y_obj_mask) 119 | 120 | # calc best iou score 121 | iou = calc_iou(bbox, prior_data) 122 | best_ious = tf.greater_equal(iou, tf.reduce_max(iou)) # TODO performance 123 | 124 | # use location and iou to determine the correct prior and cell 125 | obj_mask = tf.logical_and(best_ious, obj_mask) 126 | 127 | # check if we got at least one obj 128 | # assertion = tf.assert_greater(tf.reduce_sum(tf.where(obj_mask, ones_float, zeros)), 129 | # 0.5, message='ERROR, skipped obj') 130 | 131 | ign_mask = tf.greater_equal(iou, ign_thresh) 132 | 133 | # with tf.control_dependencies([assertion]): 134 | eps = 1e-7 135 | loc_x_ = tf.where(obj_mask, logit(tf.clip_by_value(dist_to_cell_center_x, eps, 1 - eps)), loc_x_) 136 | loc_y_ = tf.where(obj_mask, logit(tf.clip_by_value(dist_to_cell_center_y, eps, 1 - eps)), loc_y_) 137 | 138 | loc_w_ = tf.where(obj_mask, tf.log(tf.maximum(w[idx_] / prior_data['pw'], eps)), loc_w_) 139 | loc_h_ = tf.where(obj_mask, tf.log(tf.maximum(h[idx_] / prior_data['ph'], eps)), loc_h_) 140 | cls_ = tf.where(obj_mask, labels[idx_] * ones_int, cls_) 141 | obj_ = tf.where(obj_mask, ones_float, obj_) # TODO make conditional on cls == ignore 142 | ign_ = tf.where(ign_mask, zeros, ign_) 143 | 144 | idx_ += 1 145 | 146 | return idx_, loc_x_, loc_y_, loc_w_, loc_h_, obj_, cls_, ign_ 147 | 148 | idx = 0 149 | [idx, loc_x, loc_y, loc_w, loc_h, obj, cls, ign] = tf.while_loop(loop_condition, loop_body, 150 | [idx, loc_x, loc_y, loc_w, loc_h, obj, cls, 151 | ign]) 152 | 153 | loc = tf.stack([loc_x, loc_y, loc_w, loc_h], axis=1) # maybe this is unnecessary 154 | ign = tf.maximum(ign, obj) 155 | 156 | loc = tf.split(loc, layer_sizes, axis=0) 157 | cls = tf.split(cls, layer_sizes, axis=0) 158 | obj = tf.split(obj, layer_sizes, axis=0) 159 | ign = tf.split(ign, layer_sizes, axis=0) 160 | 161 | encoded = [] 162 | for i, layer in enumerate(det_layers): 163 | shape = [layer.h, layer.w, len(layer.priors)] 164 | encoded.append({ 165 | 'loc': tf.reshape(loc[i], shape=[*shape, 4]), 166 | 'cls': tf.reshape(cls[i], shape=shape), 167 | 'obj': tf.reshape(obj[i], shape=shape), 168 | 'ign': tf.reshape(ign[i], shape=shape), 169 | }) 170 | 171 | return encoded 172 | 173 | 174 | def calc_iou(ref_bbox, prior_data): 175 | bboxes = prior_data['bboxes'] 176 | bboxes_area = prior_data['bbox_areas'] 177 | 178 | int_ymin = tf.maximum(bboxes[..., 0], ref_bbox[0]) 179 | int_xmin = tf.maximum(bboxes[..., 1], ref_bbox[1]) 180 | int_ymax = tf.minimum(bboxes[..., 2], ref_bbox[2]) 181 | int_xmax = tf.minimum(bboxes[..., 3], ref_bbox[3]) 182 | h = tf.maximum(int_ymax - int_ymin, 0.) 183 | w = tf.maximum(int_xmax - int_xmin, 0.) 184 | 185 | inter = h * w 186 | union = bboxes_area - inter + ((ref_bbox[2] - ref_bbox[0]) * (ref_bbox[3] - ref_bbox[1])) 187 | iou = tf.div(inter, union) 188 | return iou 189 | -------------------------------------------------------------------------------- /lib_yolo/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | import os 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from lib_yolo import dataset_utils, data_augmentation 10 | 11 | 12 | def save_config(config, folder): 13 | time_string = datetime.datetime.now().isoformat().split('.')[0] 14 | file = os.path.join(folder, 'config_{}_{}.json'.format(time_string, config['run_id'])) 15 | 16 | try: 17 | os.makedirs(folder) 18 | except IOError: 19 | pass 20 | 21 | with open(file, 'w') as f: 22 | json.dump(config, f, indent=4, default=lambda x: str(x)) 23 | 24 | 25 | def start(model_cls, config): 26 | if config['crop']: 27 | cropper = data_augmentation.ImageCropper(config) 28 | config['train']['crop_fn'] = cropper.random_crop_and_sometimes_rescale 29 | config['val']['crop_fn'] = cropper.random_crop_and_sometimes_rescale 30 | 31 | model_factory = model_cls(config) 32 | 33 | dataset = dataset_utils.TrainValDataset(model_blueprint=model_factory.blueprint, config=config) 34 | 35 | # currently all models have 3 detection layers, so this works... 36 | img, gt1, gt2, gt3 = dataset.iterator.get_next() 37 | model = model_factory.init_model(inputs=img, training=True, gt1=gt1, gt2=gt2, gt3=gt3).get_model() 38 | 39 | # also all models are powerd by darknet53... 40 | assign_ops = model_factory.load_darknet53_weights(config['darknet53_weights']) 41 | 42 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 43 | dataset.init_dataset(sess) 44 | try: 45 | train(sess, model, dataset, config, init_ops=assign_ops) 46 | except: 47 | logging.exception('ERROR') 48 | raise 49 | 50 | 51 | def train(sess, model, dataset, config, init_ops=None): 52 | def train_loop_body(): 53 | summary, tloss, dloss, rloss, lloss, oloss, closs = sess.run( 54 | [train_step, summary_op, model.total_loss, model.detection_loss, model.regularization_loss, 55 | model.loc_loss, model.obj_loss, model.cls_loss], feed_dict={dataset.handle: dataset.train_handle})[1:] 56 | if np.isnan(tloss) or np.isinf(tloss): 57 | logging.error('{:5d} >>> total_loss: {:8.2f}, det_loss: {:8.2f}, loc_loss: {:8.2f},' 58 | ' obj_loss: {:8.2f}, cls_loss: {:8.2f}, reg_loss: {:8.5f}'.format( 59 | step, tloss, dloss, lloss, oloss, closs, rloss)) 60 | return False 61 | 62 | if step % 25 == 0: 63 | writer_train.add_summary(summary, step) 64 | logging.info('{:5d} train >>> total_loss: {:8.2f}, det_loss: {:8.2f}, loc_loss: {:8.2f},' 65 | ' obj_loss: {:8.2f}, cls_loss: {:8.2f}, reg_loss: {:8.5f}'.format( 66 | step, tloss, dloss, lloss, oloss, closs, rloss)) 67 | 68 | if step % 100 == 0: 69 | summary, tloss, dloss, rloss, lloss, oloss, closs = sess.run( 70 | [summary_op, model.total_loss, model.detection_loss, model.regularization_loss, 71 | model.loc_loss, model.obj_loss, model.cls_loss], 72 | feed_dict={dataset.handle: dataset.val_handle}) 73 | writer_val.add_summary(summary, step) 74 | logging.info( 75 | '{:5d} val >>> total_loss: {:8.2f}, det_loss: {:8.2f}, loc_loss: {:8.2f},' 76 | ' obj_loss: {:8.2f}, cls_loss: {:8.2f}, reg_loss: {:8.5f}'.format( 77 | step, tloss, dloss, lloss, oloss, closs, rloss)) 78 | 79 | if step % config['checkpoint_interval'] == 0: 80 | saver.save(sess, os.path.join(save_path, config['run_id']), global_step=step) 81 | 82 | return True 83 | 84 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 85 | with tf.variable_scope('optimizer'): 86 | with tf.control_dependencies(update_ops): 87 | optimizer = tf.train.AdamOptimizer(learning_rate=config['lr']) 88 | train_step = optimizer.minimize(model.total_loss, var_list=tf.trainable_variables()) 89 | 90 | # TODO create hist summaries 91 | summary_op = tf.summary.merge_all() 92 | 93 | saver = tf.train.Saver(max_to_keep=config['ckp_max_to_keep']) 94 | save_path = os.path.join(config['checkpoint_path'], config['run_id']) 95 | save_config(config, save_path) 96 | 97 | if config['resume_training']: 98 | checkpoint = config['resume_checkpoint'] 99 | if checkpoint == 'last': 100 | checkpoint = tf.train.latest_checkpoint(save_path) 101 | step = int(checkpoint.split('-')[-1]) 102 | saver.restore(sess, checkpoint) 103 | else: 104 | # It is important to first run the global variable initializer and then the init_ops! 105 | # Otherwise the initializations of init_ops would be overridden. 106 | sess.run(tf.global_variables_initializer()) 107 | if init_ops: 108 | sess.run(init_ops) 109 | step = 0 110 | 111 | writer_train = tf.summary.FileWriter(os.path.join(config['tensorboard_path'], config['run_id'], 'train'), 112 | sess.graph) 113 | writer_val = tf.summary.FileWriter(os.path.join(config['tensorboard_path'], config['run_id'], 'val')) 114 | try: 115 | while step < config['train_steps']: 116 | step += 1 117 | success = train_loop_body() 118 | if not success: 119 | logging.error('An error occurred, abort training.') 120 | break 121 | except KeyboardInterrupt: 122 | # gracefully exit 123 | logging.info('KeyboardInterrupt: Abort training.') 124 | ans = '' 125 | while ans.lower() not in ['yes', 'y', 'no', 'n']: 126 | ans = input('Save checkpoint (yes/no): ') 127 | if ans.lower() in ['no', 'n']: 128 | return # without saving checkpoint 129 | except: 130 | # try to save if an unexpected error occurs 131 | logging.error('Unexpected error occured, try to save checkpoint.') 132 | saver.save(sess, os.path.join(save_path, config['run_id']), global_step=step) 133 | 134 | # save if training ended 135 | saver.save(sess, os.path.join(save_path, config['run_id']), global_step=step) 136 | -------------------------------------------------------------------------------- /lib_yolo/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from lib_yolo import dataset_utils, yolov3, data_augmentation 9 | 10 | 11 | def logistic(x): 12 | return 1 / (1 + np.exp(-x)) 13 | 14 | 15 | def softmax(x): 16 | exp_x = np.exp(x - np.max(x)) 17 | return exp_x / np.sum(exp_x) 18 | 19 | 20 | def _vis(sess, model, inputs, thresh, epistemic=False, feed_dict=None): 21 | *bboxes, img = sess.run([*[det_layer.bbox for det_layer in model.det_layers], inputs], feed_dict=feed_dict) 22 | 23 | boxes = [] 24 | for bbox_all in bboxes: 25 | for bbox_by_prior in bbox_all: 26 | if epistemic: 27 | b = bbox_by_prior 28 | else: 29 | b = bbox_by_prior[0, ...] 30 | b = np.reshape(b, newshape=[b.shape[0] * b.shape[1], b.shape[2]]) 31 | boxes.append(b) 32 | boxes = np.concatenate(boxes, axis=0) 33 | 34 | img = img[0, ...] 35 | img = np.expand_dims(img, 0) 36 | 37 | def draw_boxes(img_, boxes_): 38 | nms_ind = tf.image.non_max_suppression(boxes_[:, :4], boxes_[:, model.obj_idx], 1000) 39 | boxes_ = tf.gather(boxes_, nms_ind, axis=0) 40 | thresh_ind = tf.where(boxes_[:, model.obj_idx] > thresh) 41 | boxes_ = tf.gather(boxes_, thresh_ind, axis=0) 42 | boxes_ = tf.squeeze(boxes_, axis=[-2]) 43 | boxes_ = boxes_[:, :4] 44 | boxes_ = tf.expand_dims(boxes_, axis=0) 45 | return tf.image.draw_bounding_boxes(img_, boxes_) 46 | 47 | draw_op = tf.cond(tf.reduce_any(boxes[:, model.obj_idx] > thresh), 48 | true_fn=lambda: draw_boxes(img, boxes), 49 | false_fn=lambda: img) 50 | 51 | result = sess.run(draw_op) 52 | 53 | # mng = plt.get_current_fig_manager() 54 | # mng.resize(*mng.window.maxsize()) 55 | 56 | plt.imshow((255 * np.squeeze(result)).astype(np.uint8)) 57 | plt.show() 58 | 59 | 60 | def vis_standard(sess, model, inputs, thresh, feed_dict=None): 61 | _vis(sess, model, inputs, thresh, feed_dict=feed_dict) 62 | 63 | 64 | def vis_aleatoric(sess, model, inputs, thresh, feed_dict=None): 65 | _vis(sess, model, inputs, thresh, feed_dict=feed_dict) 66 | 67 | 68 | def vis_bayes(sess, model, inputs, thresh, feed_dict=None): 69 | _vis(sess, model, inputs, thresh, epistemic=True, feed_dict=feed_dict) 70 | 71 | 72 | def predictions_to_boxes_numpy_reference_implementation(predictions, cls_cnt, priors, box_format='xywh'): 73 | # predictions = np.array(predictions, dtype=np.float32) 74 | 75 | batches, lh, lw, c = predictions.shape 76 | det_size = 2 * (4 + 1 + cls_cnt) 77 | 78 | result = np.zeros([batches, len(priors) * lh * lw, det_size], dtype=np.float32) 79 | 80 | for b in range(batches): 81 | result_idx = 0 82 | for row in range(lh): 83 | for col in range(lw): 84 | anchor_offset = 0 85 | for p in priors: 86 | [x, y, w, h, x_var, y_var, w_var, h_var, 87 | obj, log_obj_stddev, *cls_] = predictions[b, row, col, anchor_offset:anchor_offset + det_size] 88 | cls = cls_[:cls_cnt] 89 | log_cls_stddev = cls_[cls_cnt:] 90 | cls = np.array(cls, dtype=np.float32) 91 | 92 | x = (col + logistic(x)) / lw 93 | y = (row + logistic(y)) / lh 94 | w = (np.exp(w) * p.w) 95 | h = (np.exp(h) * p.h) 96 | x_var = np.exp(x_var) 97 | y_var = np.exp(y_var) 98 | w_var = np.exp(w_var) 99 | h_var = np.exp(h_var) 100 | 101 | obj = logistic(obj) 102 | obj_stddev = np.exp(log_obj_stddev) 103 | cls = softmax(cls) 104 | cls_stddev = np.exp(log_cls_stddev) 105 | 106 | if box_format == 'xywh': 107 | result[b, result_idx, :] = [x, y, w, h, x_var, y_var, w_var, h_var, obj, obj_stddev, 108 | *cls, *cls_stddev] 109 | else: 110 | w2 = w / 2 111 | h2 = h / 2 112 | x0 = x - w2 113 | y0 = y - h2 114 | x1 = x + w2 115 | y1 = y + h2 116 | 117 | result[b, result_idx, :] = [y0, x0, y1, x1, x_var, y_var, w_var, h_var, obj, obj_stddev, 118 | *cls, *cls_stddev] 119 | 120 | result_idx += 1 121 | anchor_offset += det_size 122 | 123 | return result 124 | 125 | 126 | def qualitative_eval(model_cls, config): 127 | if config['crop']: 128 | cropper = data_augmentation.ImageCropper(config) 129 | config['val']['crop_fn'] = cropper.center_crop 130 | 131 | if model_cls == yolov3.bayesian_yolov3_aleatoric: 132 | config['inference_mode'] = True 133 | config.setdefault('T', 20) 134 | 135 | model_factory = model_cls(config) 136 | 137 | dataset = dataset_utils.ValDataset(config, dataset_key='val') 138 | img, bbox, label = dataset.iterator.get_next() 139 | 140 | model = model_factory.init_model(inputs=img, training=False).get_model() 141 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) as sess: 142 | checkpoint = config.get('resume_checkpoint', 'last') 143 | if checkpoint == 'last': 144 | checkpoint = tf.train.latest_checkpoint(os.path.join(config['checkpoint_path'], config['run_id'])) 145 | tf.train.Saver().restore(sess, checkpoint) 146 | 147 | vis_fn = { 148 | yolov3.yolov3: _vis, 149 | yolov3.yolov3_aleatoric: vis_aleatoric, 150 | yolov3.bayesian_yolov3_aleatoric: vis_bayes, 151 | }[model_cls] 152 | for i in range(1000): 153 | vis_fn(sess, model, img, config['thresh']) 154 | 155 | 156 | def add_file_logging(config, override_existing=False): 157 | path = os.path.join(config['log_path'], '{}.log'.format(config['run_id'])) 158 | 159 | try: 160 | os.makedirs(config['log_path']) 161 | except IOError: 162 | pass 163 | 164 | if os.path.exists(path) and not override_existing: 165 | raise RuntimeError('Logging file {} already exists'.format(path)) 166 | 167 | file_handler = logging.FileHandler(path, 'w') 168 | file_handler.setLevel(logging.INFO) 169 | formatter = logging.Formatter(fmt='%(asctime)s, %(levelname)-8s %(message)s', 170 | datefmt='%a, %d %b %Y %H:%M:%S') 171 | file_handler.setFormatter(formatter) 172 | logging.getLogger('').addHandler(file_handler) 173 | -------------------------------------------------------------------------------- /lib_yolo/yolov3.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from lib_yolo import darknet, model, data 4 | 5 | 6 | def _get_city_persons_9_priors(): 7 | priors = [[495.27, 203.83], 8 | [297.84, 122.19], 9 | [197.44, 81.48], 10 | [141.07, 58.5], 11 | [102.72, 43.1], 12 | [75.78, 31.66], 13 | [54.24, 23.19], 14 | [37.55, 16.15], 15 | [22.55, 10.09]] 16 | 17 | priors = [[p[0] / 1024., p[1] / 2048.] for p in priors] # priors are calculated for original citypersons img size 18 | priors_32 = [data.Prior(h=p[0], w=p[1]) for p in priors[:3]] 19 | priors_16 = [data.Prior(h=p[0], w=p[1]) for p in priors[3:6]] 20 | priors_8 = [data.Prior(h=p[0], w=p[1]) for p in priors[6:]] 21 | 22 | return { 23 | 32: priors_32, 24 | 16: priors_16, 25 | 8: priors_8, 26 | } 27 | 28 | 29 | def _get_ecp_9_priors(): 30 | priors = [ 31 | [0.56643243, 0.13731691], 32 | [0.41022839, 0.09028599], 33 | [0.30508716, 0.06047965], 34 | [0.20774711, 0.04376083], 35 | [0.15475611, 0.02996197], 36 | [0.10878717, 0.02149197], 37 | [0.07694039, 0.01488527], 38 | [0.05248527, 0.01007212], 39 | [0.03272104, 0.00631827], 40 | ] 41 | 42 | # # in pixels: 43 | # priors = [[580.02680832, 263.64846719999997], 44 | # [420.07387136, 173.3491008], 45 | # [312.40925184, 116.120928], 46 | # [212.73304064, 84.0207936], 47 | # [158.47025664, 57.5269824], 48 | # [111.39806208, 41.264582399999995], 49 | # [78.78695936, 28.5797184], 50 | # [53.74491648, 19.338470400000002], 51 | # [33.50634496, 12.1310784]] 52 | 53 | priors_32 = [data.Prior(h=p[0], w=p[1]) for p in priors[:3]] 54 | priors_16 = [data.Prior(h=p[0], w=p[1]) for p in priors[3:6]] 55 | priors_8 = [data.Prior(h=p[0], w=p[1]) for p in priors[6:]] 56 | 57 | return { 58 | 32: priors_32, 59 | 16: priors_16, 60 | 8: priors_8, 61 | } 62 | 63 | 64 | def _get_ecp_night_9_priors(): 65 | priors = [ 66 | [0.6197282176953125, 0.14694562146874998], 67 | [0.4243941425683594, 0.09687759120833334], 68 | [0.3103862368359375, 0.06362734035416667], 69 | [0.23494613041992188, 0.043568554453125], 70 | [0.1634832566796875, 0.03293052755208333], 71 | [0.12444031231445313, 0.023274527578125], 72 | [0.08800429220703125, 0.016930080526041665], 73 | [0.06101826478515625, 0.011638404229166668], 74 | [0.03925641140625, 0.007475639645833334], 75 | ] 76 | 77 | # # in pixels: 78 | # priors = [[634.60169492, 282.13559322], 79 | # [434.57960199, 186.00497512], 80 | # [317.83550652, 122.16449348], 81 | # [240.58483755, 83.65162455], 82 | # [167.40685484, 63.2266129], 83 | # [127.42687981, 44.68709295], 84 | # [90.11639522, 32.50575461], 85 | # [62.48270314, 22.34573612], 86 | # [40.19856528, 14.35322812]] 87 | 88 | priors_32 = [data.Prior(h=p[0], w=p[1]) for p in priors[:3]] 89 | priors_16 = [data.Prior(h=p[0], w=p[1]) for p in priors[3:6]] 90 | priors_8 = [data.Prior(h=p[0], w=p[1]) for p in priors[6:]] 91 | 92 | return { 93 | 32: priors_32, 94 | 16: priors_16, 95 | 8: priors_8, 96 | } 97 | 98 | 99 | def _get_ecp_day_night_9_priors(): 100 | priors = [ 101 | [0.5728529907421875, 0.13943622409895834], 102 | [0.41761617583007815, 0.09156660707291667], 103 | [0.3015263176855469, 0.06248444700520834], 104 | [0.22101856140625, 0.042888710765625], 105 | [0.1533158565527344, 0.031196821406250002], 106 | [0.11255495265625, 0.021566710822916668], 107 | [0.07823327209960937, 0.015212825187500001], 108 | [0.0533416983203125, 0.010216603067708333], 109 | [0.0332035418359375, 0.006413999807291667] 110 | ] 111 | 112 | # # in pixels: 113 | # priors = [[586.60146252, 267.71755027], 114 | # [427.63896405, 175.80788558], 115 | # [308.76294931, 119.97013825], 116 | # [226.32300688, 82.34632467], 117 | # [156.99543711, 59.8978971], 118 | # [115.25627152, 41.40808478], 119 | # [80.11087063, 29.20862436], 120 | # [54.62189908, 19.61587789], 121 | # [34.00042684, 12.31487963]] 122 | 123 | priors_32 = [data.Prior(h=p[0], w=p[1]) for p in priors[:3]] 124 | priors_16 = [data.Prior(h=p[0], w=p[1]) for p in priors[3:6]] 125 | priors_8 = [data.Prior(h=p[0], w=p[1]) for p in priors[6:]] 126 | 127 | return { 128 | 32: priors_32, 129 | 16: priors_16, 130 | 8: priors_8, 131 | } 132 | 133 | 134 | def _get_ecp_with_bic_9_priors(): 135 | priors = [ 136 | [0.5541169062011718, 0.15767184942708334], 137 | [0.3872792363671875, 0.08849276056770834], 138 | [0.27297898112304686, 0.05552458755208333], 139 | [0.18570756796875, 0.034849724458333335], 140 | [0.13080457012695312, 0.052510955223958336], 141 | [0.12203939466796875, 0.02422101765625], 142 | [0.083340965234375, 0.01635016602083333], 143 | [0.055563667021484374, 0.010672233619791667], 144 | [0.03409191838867188, 0.006481136984375], 145 | ] 146 | 147 | # # in pixels: 148 | # priors = [[567.41571195, 302.7299509], 149 | # [396.57393804, 169.90610029], 150 | # [279.53047667, 106.6072081], 151 | # [190.1645496, 66.91147096], 152 | # [133.94387981, 100.82103403], 153 | # [124.96834014, 46.5043539], 154 | # [85.3411484, 31.39231876], 155 | # [56.89719503, 20.49068855], 156 | # [34.91012443, 12.44378301]] 157 | 158 | priors_32 = [data.Prior(h=p[0], w=p[1]) for p in priors[:3]] 159 | priors_16 = [data.Prior(h=p[0], w=p[1]) for p in priors[3:6]] 160 | priors_8 = [data.Prior(h=p[0], w=p[1]) for p in priors[6:]] 161 | 162 | return { 163 | 32: priors_32, 164 | 16: priors_16, 165 | 8: priors_8, 166 | } 167 | 168 | 169 | CITY_PERSONS_9_PRIORS = _get_city_persons_9_priors() 170 | ECP_9_PRIORS = _get_ecp_9_priors() 171 | ECP_NIGHT_9_PRIORS = _get_ecp_night_9_priors() 172 | ECP_DAY_NIGHT_9_PRIORS = _get_ecp_day_night_9_priors() 173 | ECP_BIC_9_PRIORS = _get_ecp_with_bic_9_priors() 174 | 175 | 176 | class yolov3: 177 | def __init__(self, config): 178 | self.__model = None 179 | self.img_size, self.__priors = model.img_size_and_priors_if_crop(config) 180 | self.__darknet53_layer_cnt = 0 181 | self.__freeze_darknet53 = config.get('freeze_darknet53', True) 182 | self.cls_cnt = config['cls_cnt'] 183 | self.obj_idx = 4 184 | self.cls_start_idx = 5 185 | 186 | self.blueprint = model.ModelBlueprint(det_layers=[ 187 | model.DetLayerBlueprint( 188 | input_img_size=self.img_size, 189 | downsample_factor=32, 190 | priors=self.__priors[32], 191 | ), model.DetLayerBlueprint( 192 | input_img_size=self.img_size, 193 | downsample_factor=16, 194 | priors=self.__priors[16], 195 | ), 196 | model.DetLayerBlueprint( 197 | input_img_size=self.img_size, 198 | downsample_factor=8, 199 | priors=self.__priors[8], 200 | ) 201 | ], cls_cnt=self.cls_cnt) 202 | 203 | # check that input_img_size is multiple of biggest stride 204 | # otherwise the following happens: 205 | # ValueError: Dimensions must be equal, but are 16 and 15 for \ 206 | # 'det_net_1/detection/loss/localization/sub' (op: 'Sub') with input shapes: [?,16,32,3,4], [?,15,31,3,4]. 207 | assert config['full_img_size'][0] % 32 == 0 208 | assert config['full_img_size'][1] % 32 == 0 209 | if config['crop']: 210 | assert config['crop_img_size'][0] % 32 == 0 211 | assert config['crop_img_size'][1] % 32 == 0 212 | 213 | def get_model(self): 214 | """ 215 | call init_model first! 216 | """ 217 | assert self.__model is not None, 'Call init_model first.' 218 | return self.__model 219 | 220 | def load_darknet53_weights(self, weightfile): 221 | assert self.__model is not None, 'Call init_model first.' 222 | return darknet.load_darknet_weights(self.__model.layers[:self.__darknet53_layer_cnt], weightfile) 223 | 224 | def init_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 225 | if self.__model is not None: 226 | raise Exception('model can only be initialized once!') 227 | 228 | self.__build_model(inputs, training, gt1, gt2, gt3) 229 | assert self.__model.matches_blueprint(self.blueprint), 'Model does not match blueprint' 230 | return self 231 | 232 | def __build_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 233 | in_shape = inputs.get_shape().as_list() 234 | assert len(in_shape) == 4, 'invalid data format' 235 | # assert in_shape[3] == 3, 'invalid data format' 236 | 237 | mb = model.ModelBuilder(inputs=inputs, cls_cnt=self.cls_cnt) 238 | normalizer = {'type': 'bn', 'training': training} 239 | 240 | with tf.variable_scope('darknet53'): 241 | darknet53_training = False if self.__freeze_darknet53 else training 242 | darknet53_trainable = not self.__freeze_darknet53 243 | darknet.darknet53(mb, training=darknet53_training, trainable=darknet53_trainable) # 0 - 74 244 | 245 | dn_out = mb.inputs 246 | self.__darknet53_layer_cnt = mb.layer_cnt() 247 | 248 | with tf.variable_scope('det_net_1'): 249 | mb.make_conv_layer(512, 1, normalizer) # .................... # 75 250 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 76 251 | 252 | mb.make_conv_layer(512, 1, normalizer) # .................... # 77 253 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 78 254 | 255 | mb.make_conv_layer(512, 1, normalizer) # .................... # 79 256 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 80 257 | 258 | mb.make_detection_layer(all_priors=self.__priors, gt=gt1) # .. # 81 259 | # YOLO LAYER # ............................................... # 82 260 | det_net_1_out = mb.inputs 261 | 262 | with tf.variable_scope('det_net_2'): 263 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 264 | mb.make_route_layer([-3]) # .................................. # 83 265 | mb.make_conv_layer(256, 1, normalizer) # .................... # 84 266 | 267 | # Downsample (factor 16) 268 | mb.make_upsample_layer() # ................................... # 85 269 | mb.make_route_layer([-1, 61]) # .............................. # 86 270 | 271 | mb.make_conv_layer(256, 1, normalizer) # .................... # 87 272 | mb.make_conv_layer(512, 3, normalizer) # .................... # 88 273 | 274 | mb.make_conv_layer(256, 1, normalizer) # .................... # 89 275 | mb.make_conv_layer(512, 3, normalizer) # .................... # 90 276 | 277 | mb.make_conv_layer(256, 1, normalizer) # .................... # 91 278 | mb.make_conv_layer(512, 3, normalizer) # .................... # 92 279 | 280 | mb.make_detection_layer(all_priors=self.__priors, gt=gt2) # .. # 93 281 | # YOLO LAYER # ................................................ # 94 282 | det_net_2_out = mb.inputs 283 | 284 | with tf.variable_scope('det_net_3'): 285 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 286 | mb.make_route_layer([-3]) # .................................. # 95 287 | mb.make_conv_layer(128, 1, normalizer) # .................... # 96 288 | 289 | # Downsample (factor 8) 290 | mb.make_upsample_layer() # ................................... # 97 291 | mb.make_route_layer([-1, 36]) # .............................. # 98 292 | 293 | mb.make_conv_layer(128, 1, normalizer) # .................... # 99 294 | mb.make_conv_layer(256, 3, normalizer) # .................... # 100 295 | 296 | mb.make_conv_layer(128, 1, normalizer) # .................... # 101 297 | mb.make_conv_layer(256, 3, normalizer) # .................... # 102 298 | 299 | mb.make_conv_layer(128, 1, normalizer) # .................... # 103 300 | mb.make_conv_layer(256, 3, normalizer) # .................... # 104 301 | 302 | mb.make_detection_layer(all_priors=self.__priors, gt=gt3) # .. # 105 303 | # YOLO LAYER # ............................................... # 106 304 | det_net_3_out = mb.inputs 305 | 306 | self.__model = mb.get_model(self.obj_idx, self.cls_start_idx) 307 | self.__model.dn_out = dn_out 308 | self.__model.det_net_1_out = det_net_1_out 309 | self.__model.det_net_2_out = det_net_2_out 310 | self.__model.det_net_3_out = det_net_3_out 311 | 312 | 313 | class yolov3_aleatoric: 314 | def __init__(self, config): 315 | self.__aleatoric_loss = config['aleatoric_loss'] 316 | self.__model = None 317 | self.img_size, self.__priors = model.img_size_and_priors_if_crop(config) 318 | self.__darknet53_layer_cnt = 0 319 | self.__freeze_darknet53 = config.get('freeze_darknet53', True) 320 | self.cls_cnt = config['cls_cnt'] 321 | self.obj_idx = 9 322 | self.cls_start_idx = 11 323 | 324 | self.blueprint = model.ModelBlueprint(det_layers=[ 325 | model.DetLayerBlueprint( 326 | input_img_size=self.img_size, 327 | downsample_factor=32, 328 | priors=self.__priors[32], 329 | ), model.DetLayerBlueprint( 330 | input_img_size=self.img_size, 331 | downsample_factor=16, 332 | priors=self.__priors[16], 333 | ), 334 | model.DetLayerBlueprint( 335 | input_img_size=self.img_size, 336 | downsample_factor=8, 337 | priors=self.__priors[8], 338 | ) 339 | ], cls_cnt=self.cls_cnt) 340 | 341 | # check that input_img_size is multiple of biggest stride 342 | # otherwise the following happens: 343 | # ValueError: Dimensions must be equal, but are 16 and 15 for \ 344 | # 'det_net_1/detection/loss/localization/sub' (op: 'Sub') with input shapes: [?,16,32,3,4], [?,15,31,3,4]. 345 | assert config['full_img_size'][0] % 32 == 0 346 | assert config['full_img_size'][1] % 32 == 0 347 | if config['crop']: 348 | assert config['crop_img_size'][0] % 32 == 0 349 | assert config['crop_img_size'][1] % 32 == 0 350 | 351 | def get_model(self): 352 | """ 353 | call init_model first! 354 | """ 355 | assert self.__model is not None, 'Call init_model first.' 356 | return self.__model 357 | 358 | def load_darknet53_weights(self, weightfile): 359 | assert self.__model is not None, 'Call init_model first.' 360 | return darknet.load_darknet_weights(self.__model.layers[:self.__darknet53_layer_cnt], weightfile) 361 | 362 | def init_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 363 | if self.__model is not None: 364 | raise Exception('model can only be initialized once!') 365 | 366 | self.__build_model(inputs, training, gt1, gt2, gt3) 367 | assert self.__model.matches_blueprint(self.blueprint), 'Model does not match blueprint' 368 | return self 369 | 370 | def __build_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 371 | in_shape = inputs.get_shape().as_list() 372 | assert len(in_shape) == 4, 'invalid data format' 373 | # assert in_shape[3] == 3, 'invalid data format' 374 | 375 | mb = model.ModelBuilder(inputs=inputs, cls_cnt=self.cls_cnt) 376 | normalizer = {'type': 'bn', 'training': training} 377 | 378 | with tf.variable_scope('darknet53'): 379 | darknet53_training = False if self.__freeze_darknet53 else training 380 | darknet53_trainable = not self.__freeze_darknet53 381 | darknet.darknet53(mb, training=darknet53_training, trainable=darknet53_trainable) # 0 - 74 382 | 383 | dn_out = mb.inputs 384 | self.__darknet53_layer_cnt = mb.layer_cnt() 385 | 386 | with tf.variable_scope('det_net_1'): 387 | mb.make_conv_layer(512, 1, normalizer) # .................... # 75 388 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 76 389 | 390 | mb.make_conv_layer(512, 1, normalizer) # .................... # 77 391 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 78 392 | 393 | mb.make_conv_layer(512, 1, normalizer) # .................... # 79 394 | mb.make_conv_layer(1024, 3, normalizer) # ................... # 80 395 | 396 | mb.make_detection_layer_aleatoric(all_priors=self.__priors, aleatoric_loss=self.__aleatoric_loss, 397 | gt=gt1) # .. # 81 398 | # YOLO LAYER # ............................................... # 82 399 | det_net_1_out = mb.inputs 400 | 401 | with tf.variable_scope('det_net_2'): 402 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 403 | mb.make_route_layer([-3]) # .................................. # 83 404 | mb.make_conv_layer(256, 1, normalizer) # .................... # 84 405 | 406 | # Downsample (factor 16) 407 | mb.make_upsample_layer() # ................................... # 85 408 | mb.make_route_layer([-1, 61]) # .............................. # 86 409 | 410 | mb.make_conv_layer(256, 1, normalizer) # .................... # 87 411 | mb.make_conv_layer(512, 3, normalizer) # .................... # 88 412 | 413 | mb.make_conv_layer(256, 1, normalizer) # .................... # 89 414 | mb.make_conv_layer(512, 3, normalizer) # .................... # 90 415 | 416 | mb.make_conv_layer(256, 1, normalizer) # .................... # 91 417 | mb.make_conv_layer(512, 3, normalizer) # .................... # 92 418 | 419 | mb.make_detection_layer_aleatoric(all_priors=self.__priors, aleatoric_loss=self.__aleatoric_loss, 420 | gt=gt2) # .. # 93 421 | # YOLO LAYER # ................................................ # 94 422 | det_net_2_out = mb.inputs 423 | 424 | with tf.variable_scope('det_net_3'): 425 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 426 | mb.make_route_layer([-3]) # .................................. # 95 427 | mb.make_conv_layer(128, 1, normalizer) # .................... # 96 428 | 429 | # Downsample (factor 8) 430 | mb.make_upsample_layer() # ................................... # 97 431 | mb.make_route_layer([-1, 36]) # .............................. # 98 432 | 433 | mb.make_conv_layer(128, 1, normalizer) # .................... # 99 434 | mb.make_conv_layer(256, 3, normalizer) # .................... # 100 435 | 436 | mb.make_conv_layer(128, 1, normalizer) # .................... # 101 437 | mb.make_conv_layer(256, 3, normalizer) # .................... # 102 438 | 439 | mb.make_conv_layer(128, 1, normalizer) # .................... # 103 440 | mb.make_conv_layer(256, 3, normalizer) # .................... # 104 441 | 442 | mb.make_detection_layer_aleatoric(all_priors=self.__priors, aleatoric_loss=self.__aleatoric_loss, 443 | gt=gt3) # .. # 105 444 | # YOLO LAYER # ............................................... # 106 445 | det_net_3_out = mb.inputs 446 | 447 | self.__model = mb.get_model(self.obj_idx, self.cls_start_idx) 448 | self.__model.dn_out = dn_out 449 | self.__model.det_net_1_out = det_net_1_out 450 | self.__model.det_net_2_out = det_net_2_out 451 | self.__model.det_net_3_out = det_net_3_out 452 | 453 | 454 | class bayesian_yolov3_aleatoric: 455 | def __init__(self, config): 456 | self.__aleatoric_loss = config['aleatoric_loss'] 457 | self.__model = None 458 | self.img_size, self.__priors = model.img_size_and_priors_if_crop(config) 459 | self.__darknet53_layer_cnt = 0 460 | self.__freeze_darknet53 = config.get('freeze_darknet53', True) 461 | self.__inference_mode = config['inference_mode'] 462 | self.__drop_prob = 0.1 463 | self.cls_cnt = config['cls_cnt'] 464 | self.obj_idx = 14 465 | self.cls_start_idx = 17 466 | 467 | if self.__inference_mode: 468 | self.__T = config['T'] 469 | 470 | self.__standard_test_dropout = config.get('standard_test_dropout', False) 471 | 472 | self.blueprint = model.ModelBlueprint(det_layers=[ 473 | model.DetLayerBlueprint( 474 | input_img_size=self.img_size, 475 | downsample_factor=32, 476 | priors=self.__priors[32], 477 | ), model.DetLayerBlueprint( 478 | input_img_size=self.img_size, 479 | downsample_factor=16, 480 | priors=self.__priors[16], 481 | ), 482 | model.DetLayerBlueprint( 483 | input_img_size=self.img_size, 484 | downsample_factor=8, 485 | priors=self.__priors[8], 486 | ) 487 | ], cls_cnt=self.cls_cnt) 488 | 489 | # check that input_img_size is multiple of biggest stride 490 | # otherwise the following happens: 491 | # ValueError: Dimensions must be equal, but are 16 and 15 for \ 492 | # 'det_net_1/detection/loss/localization/sub' (op: 'Sub') with input shapes: [?,16,32,3,4], [?,15,31,3,4]. 493 | assert config['full_img_size'][0] % 32 == 0 494 | assert config['full_img_size'][1] % 32 == 0 495 | if config['crop']: 496 | assert config['crop_img_size'][0] % 32 == 0 497 | assert config['crop_img_size'][1] % 32 == 0 498 | 499 | def get_model(self): 500 | """ 501 | call init_model first! 502 | """ 503 | assert self.__model is not None, 'Call init_model first.' 504 | return self.__model 505 | 506 | def load_darknet53_weights(self, weightfile): 507 | assert self.__model is not None, 'Call init_model first.' 508 | return darknet.load_darknet_weights(self.__model.layers[:self.__darknet53_layer_cnt], weightfile) 509 | 510 | def init_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 511 | if self.__model is not None: 512 | raise Exception('model can only be initialized once!') 513 | 514 | self.__build_model(inputs, training, gt1, gt2, gt3) 515 | assert self.__model.matches_blueprint(self.blueprint), 'Model does not match blueprint' 516 | return self 517 | 518 | def __build_model(self, inputs, training, gt1=None, gt2=None, gt3=None): 519 | in_shape = inputs.get_shape().as_list() 520 | assert len(in_shape) == 4, 'invalid data format' 521 | # assert in_shape[3] == 3, 'invalid data format' 522 | 523 | mb = model.ModelBuilder(inputs=inputs, cls_cnt=self.cls_cnt) 524 | bn = {'type': 'bn', 'training': training} 525 | dropout_bn = [ 526 | {'type': 'dropout', 'drop_prob': self.__drop_prob, 'standard_test_dropout': self.__standard_test_dropout}, 527 | bn 528 | ] # add batch norm after dropout 529 | 530 | with tf.variable_scope('darknet53'): 531 | darknet53_training = False if self.__freeze_darknet53 else training 532 | darknet53_trainable = not self.__freeze_darknet53 533 | darknet.darknet53(mb, training=darknet53_training, trainable=darknet53_trainable) # 0 - 74 534 | 535 | dn_out = mb.inputs 536 | self.__darknet53_layer_cnt = mb.layer_cnt() 537 | 538 | if self.__inference_mode: 539 | # stack dn_out N times 540 | # make route layer to stacked dn => this messes with the layer numbering => should be fine in inference mode 541 | mb.make_stack_feature_map_layer(-1, self.__T) # additional layer, now the counting is of (shouldn't matter) 542 | 543 | with tf.variable_scope('det_net_1'): 544 | mb.make_conv_layer(512, 1, dropout_bn) # ..................... # 75 545 | mb.make_conv_layer(1024, 3, dropout_bn) # .................... # 76 546 | 547 | mb.make_conv_layer(512, 1, dropout_bn) # ..................... # 77 548 | mb.make_conv_layer(1024, 3, dropout_bn) # .................... # 78 549 | 550 | mb.make_conv_layer(512, 1, dropout_bn) # ..................... # 79 551 | mb.make_conv_layer(1024, 3, bn) # ............................ # 80 552 | 553 | mb.make_detection_layer_aleatoric_epistemic( 554 | all_priors=self.__priors, 555 | aleatoric_loss=self.__aleatoric_loss, 556 | gt=gt1, 557 | inference_mode=self.__inference_mode, 558 | ) # .......................................................... # 81 559 | # YOLO LAYER # ............................................... # 82 560 | det_net_1_out = mb.inputs 561 | 562 | with tf.variable_scope('det_net_2'): 563 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 564 | mb.make_route_layer([-3]) # .................................. # 83 565 | mb.make_conv_layer(256, 1, bn) # ............................. # 84 566 | 567 | # Downsample (factor 16) 568 | mb.make_upsample_layer() # ................................... # 85 569 | if self.__inference_mode: 570 | mb.make_stack_feature_map_layer(61, self.__T) 571 | mb.make_route_layer([-2, -1]) # .......................... # 86 572 | else: 573 | mb.make_route_layer([-1, 61]) # .......................... # 86 574 | 575 | mb.make_conv_layer(256, 1, dropout_bn) # ..................... # 87 576 | mb.make_conv_layer(512, 3, dropout_bn) # ..................... # 88 577 | 578 | mb.make_conv_layer(256, 1, dropout_bn) # ..................... # 89 579 | mb.make_conv_layer(512, 3, dropout_bn) # ..................... # 90 580 | 581 | mb.make_conv_layer(256, 1, dropout_bn) # ..................... # 91 582 | mb.make_conv_layer(512, 3, bn) # ............................. # 92 583 | 584 | mb.make_detection_layer_aleatoric_epistemic( 585 | all_priors=self.__priors, 586 | aleatoric_loss=self.__aleatoric_loss, 587 | gt=gt2, 588 | inference_mode=self.__inference_mode, 589 | ) # .......................................................... # 93 590 | # YOLO LAYER # ................................................ # 94 591 | det_net_2_out = mb.inputs 592 | 593 | with tf.variable_scope('det_net_3'): 594 | # -3 instead of -4, since we don't add a YOLO layer to layers list. 595 | mb.make_route_layer([-3]) # .................................. # 95 596 | mb.make_conv_layer(128, 1, bn) # ............................. # 96 597 | 598 | # Downsample (factor 8) 599 | mb.make_upsample_layer() # ................................... # 97 600 | if self.__inference_mode: 601 | mb.make_stack_feature_map_layer(36, self.__T) 602 | mb.make_route_layer([-2, -1]) # .......................... # 98 603 | else: 604 | mb.make_route_layer([-1, 36]) # .......................... # 98 605 | 606 | mb.make_conv_layer(128, 1, dropout_bn) # ..................... # 99 607 | mb.make_conv_layer(256, 3, dropout_bn) # ..................... # 100 608 | 609 | mb.make_conv_layer(128, 1, dropout_bn) # ..................... # 101 610 | mb.make_conv_layer(256, 3, dropout_bn) # ..................... # 102 611 | 612 | mb.make_conv_layer(128, 1, dropout_bn) # ..................... # 103 613 | mb.make_conv_layer(256, 3, bn) # ............................. # 104 614 | 615 | mb.make_detection_layer_aleatoric_epistemic( 616 | all_priors=self.__priors, 617 | aleatoric_loss=self.__aleatoric_loss, 618 | gt=gt3, 619 | inference_mode=self.__inference_mode, 620 | ) # .......................................................... # 105 621 | # YOLO LAYER # ............................................... # 106 622 | det_net_3_out = mb.inputs 623 | 624 | self.__model = mb.get_model(self.obj_idx, self.cls_start_idx) 625 | self.__model.dn_out = dn_out 626 | self.__model.det_net_1_out = det_net_1_out 627 | self.__model.det_net_2_out = det_net_2_out 628 | self.__model.det_net_3_out = det_net_3_out 629 | -------------------------------------------------------------------------------- /pretraining.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | from lib_yolo import yolov3, train, utils 6 | 7 | 8 | def main(): 9 | config = { 10 | 'training': True, # edit: set to false for qualitative evaluation 11 | 'resume_training': False, # edit 12 | 'resume_checkpoint': 'last', # edit: either filename or 'last' the resume a training 13 | 'run_id': 'pretraining', # is used to identify the training process 14 | 'priors': yolov3.ECP_9_PRIORS, # edit if not ECP dataset 15 | 'checkpoint_path': './checkpoints', 16 | 'tensorboard_path': './tensorboard', 17 | 'log_path': './log', 18 | 'ckp_max_to_keep': 102, # edit: number of checkpoints to keep 19 | 'checkpoint_interval': 5000, # edit 20 | 'ign_thresh': 0.7, 21 | 'crop_img_size': [768, 1440, 3], 22 | 'full_img_size': [1024, 1920, 3], # edit if not ECP dataset 23 | 'train_steps': 500000, 24 | 'darknet53_weights': './darknet53.conv.74', 25 | 'batch_size': 8, # edit 26 | 'lr': 1e-5, 27 | 'cpu_thread_cnt': 24, # edit 28 | 'crop': True, # edit: random crops and rescaling reduces memory consumption and improves training 29 | 'freeze_darknet53': True, # if True the basenet weights are frozen during training 30 | 'aleatoric_loss': False, 31 | 'cls_cnt': 2, # edit if not ECP dataset 32 | 'implicit_background_class': True, # whether the label ids start at 1 or 0. True = 1, False = 0 33 | 'train': { 34 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-train-*-of-*'), # edit 35 | 'num_shards': 20, 36 | 'shuffle_buffer_size': 2000, 37 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 38 | }, 39 | 'val': { 40 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-val-*-of-*'), # edit 41 | 'num_shards': 4, 42 | 'shuffle_buffer_size': 10, 43 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 44 | } 45 | } 46 | 47 | # Note regarding implicit background class: 48 | # The tensorflow object detection API enforces that the class labels start with 1. 49 | # The class 0 is reserved for an (implicit) background class. We support both file formats. 50 | 51 | utils.add_file_logging(config, override_existing=True) 52 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 53 | 54 | model_cls = yolov3.yolov3_aleatoric 55 | 56 | if config['training']: 57 | train.start(model_cls, config) 58 | else: 59 | config['thresh'] = 0.1 # filter out boxes with objectness score less than thresh 60 | utils.qualitative_eval(model_cls, config) 61 | 62 | 63 | if __name__ == '__main__': 64 | logging.basicConfig(level=logging.INFO, 65 | format='%(asctime)s, %(levelname)-8s %(message)s', 66 | datefmt='%a, %d %b %Y %H:%M:%S', 67 | ) 68 | main() 69 | -------------------------------------------------------------------------------- /uncertainty_training.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | from lib_yolo import yolov3, train, utils 6 | 7 | 8 | def main(): 9 | config = { 10 | 'training': True, # edit: set to False for qualitative evaluation 11 | 'resume_training': True, 12 | 'resume_checkpoint': './checkpoints/pretraining/pretraining-125000', # edit 13 | # 'resume_checkpoint': 'last', # edit: either filename or 'last' to resume a training 14 | 'priors': yolov3.ECP_9_PRIORS, # edit if not ECP dataset 15 | 'checkpoint_path': './checkpoints', 16 | 'tensorboard_path': './tensorboard', 17 | 'log_path': './log', 18 | 'ckp_max_to_keep': 75, 19 | 'checkpoint_interval': 5000, 20 | 'ign_thresh': 0.7, 21 | 'crop_img_size': [768, 1440, 3], 22 | 'full_img_size': [1024, 1920, 3], # edit if not ECP dataset 23 | 'train_steps': 500000, 24 | 'darknet53_weights': './darknet53.conv.74', 25 | 'batch_size': 2, # edit 26 | 'lr': 1e-5, 27 | 'run_id': 'epi_ale', 28 | 'cpu_thread_cnt': 24, # edit 29 | 'crop': True, # edit, random crops and rescaling reduces memory consumption and improves training 30 | 'freeze_darknet53': True, # if True the basenet weights are frozen during training 31 | 'inference_mode': False, 32 | 'aleatoric_loss': True, 33 | 'cls_cnt': 2, # edit if not ECP dataset 34 | 'implicit_background_class': True, # whether the label ids start at 1 or 0. True = 1, False = 0 35 | 'train': { 36 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-train-*-of-*'), # edit 37 | 'num_shards': 20, 38 | 'shuffle_buffer_size': 2000, 39 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 40 | }, 41 | 'val': { 42 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-val-*-of-*'), # edit 43 | 'num_shards': 4, 44 | 'shuffle_buffer_size': 10, 45 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 46 | } 47 | } 48 | 49 | # Note regarding implicit background class: 50 | # The tensorflow object detection API enforces that the class labels start with 1. 51 | # The class 0 is reserved for an (implicit) background class. We support both file formats. 52 | 53 | utils.add_file_logging(config, override_existing=True) 54 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 55 | 56 | model_cls = yolov3.bayesian_yolov3_aleatoric 57 | 58 | if config['training']: 59 | train.start(model_cls, config) 60 | else: 61 | config['inference_mode'] = True 62 | config['resume_checkpoint'] = 'last' 63 | config['thresh'] = 0.1 # filter out boxes with objectness score less than thresh 64 | config['T'] = 20 # increase if you have enough memory 65 | utils.qualitative_eval(model_cls, config) 66 | 67 | 68 | if __name__ == '__main__': 69 | logging.basicConfig(level=logging.INFO, 70 | format='%(asctime)s, %(levelname)-8s %(message)s', 71 | datefmt='%a, %d %b %Y %H:%M:%S', 72 | ) 73 | main() 74 | -------------------------------------------------------------------------------- /vis_uncertainty.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import multiprocessing 4 | import os 5 | import time 6 | 7 | import matplotlib.cm 8 | import numpy as np 9 | import tensorflow as tf 10 | from PIL import Image 11 | 12 | from lib_yolo import yolov3 13 | 14 | 15 | def colorize(img, vmin=None, vmax=None, cmap='plasma'): 16 | # normalize 17 | vmin = tf.reduce_min(img) if vmin is None else vmin 18 | vmax = tf.contrib.distributions.percentile(img, 99.) if vmax is None else vmax 19 | img = (img - vmin) / (vmax - vmin) 20 | 21 | img = tf.squeeze(img, axis=[-1]) 22 | 23 | # quantize 24 | indices = tf.clip_by_value(tf.to_int32(tf.round(img * 255)), 0, 255) 25 | 26 | # gather 27 | cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') 28 | colors = tf.constant(cm.colors, dtype=tf.float32) 29 | img = tf.gather(colors, indices) 30 | 31 | return img 32 | 33 | 34 | def color_map(img, uncertainty, stride, vmin, vmax, alpha=0.7): 35 | uncertainty = colorize(uncertainty, vmin, vmax) 36 | uncertainty = tf.expand_dims(uncertainty, axis=0) 37 | shape = uncertainty.shape 38 | uncertainty = tf.image.resize_nearest_neighbor(uncertainty, size=(shape[1] * stride, shape[2] * stride)) 39 | 40 | blended = alpha * img + (1 - alpha) * uncertainty 41 | tf.squeeze(blended, axis=0) 42 | blended = blended[0, ...] 43 | 44 | blended = tf.image.convert_image_dtype(blended, dtype=tf.uint8) # convert to [0, 255] 45 | 46 | return blended 47 | 48 | 49 | class Inference: 50 | def __init__(self, yolo, config): 51 | self.batch_size = config['batch_size'] 52 | self.img_size = yolo.img_size 53 | self.img_tensor = tf.placeholder(tf.float32, shape=(1, *self.img_size)) 54 | checkpoints = os.path.join(config['checkpoint_path'], config['run_id']) 55 | if config['step'] == 'last': 56 | self.checkpoint = tf.train.latest_checkpoint(checkpoints) 57 | else: 58 | self.checkpoint = None 59 | for cp in os.listdir(checkpoints): 60 | if cp.endswith('-{}.meta'.format(config['step'])): 61 | self.checkpoint = os.path.join(checkpoints, os.path.splitext(cp)[0]) 62 | break 63 | assert self.checkpoint is not None 64 | 65 | step = self.checkpoint.split('-')[-1] 66 | 67 | self.config = config 68 | self.worker_thread = None 69 | 70 | assert config['inference_mode'] 71 | self.model = yolo.init_model(inputs=self.img_tensor, training=False).get_model() 72 | 73 | self.grids = [] 74 | stats = [None] * 9 75 | 76 | ucty_idx = config.get('ucty_idx', -1) 77 | uncertainty_key = config['uncertainty_key'] 78 | 79 | # stride 32 80 | l = self.model.det_layers[0] 81 | if 'obj' in uncertainty_key or 'cls' in uncertainty_key: 82 | uncertainty = l.det[uncertainty_key] 83 | elif 'epi' in uncertainty_key: 84 | uncertainty = l.det[uncertainty_key][..., ucty_idx, ucty_idx] 85 | else: 86 | uncertainty = l.det[uncertainty_key][..., ucty_idx] 87 | lh, lw, box_cnt = uncertainty.shape.as_list() 88 | uncertainty = tf.split(uncertainty, [1] * box_cnt, axis=-1) 89 | 90 | self.grids.append( 91 | color_map(self.img_tensor, uncertainty[0], l.downsample, 0, stats[0])) 92 | self.grids.append( 93 | color_map(self.img_tensor, uncertainty[1], l.downsample, 0, stats[1])) 94 | self.grids.append( 95 | color_map(self.img_tensor, uncertainty[2], l.downsample, 0, stats[2])) 96 | 97 | # stride 16 98 | l = self.model.det_layers[1] 99 | if 'obj' in uncertainty_key or 'cls' in uncertainty_key: 100 | uncertainty = l.det[uncertainty_key] 101 | elif 'epi' in uncertainty_key: 102 | uncertainty = l.det[uncertainty_key][..., ucty_idx, ucty_idx] 103 | else: 104 | uncertainty = l.det[uncertainty_key][..., ucty_idx] 105 | lh, lw, box_cnt = uncertainty.shape.as_list() 106 | uncertainty = tf.split(uncertainty, [1] * box_cnt, axis=-1) 107 | 108 | self.grids.append( 109 | color_map(self.img_tensor, uncertainty[0], l.downsample, 0, stats[3])) 110 | self.grids.append( 111 | color_map(self.img_tensor, uncertainty[1], l.downsample, 0, stats[4])) 112 | self.grids.append( 113 | color_map(self.img_tensor, uncertainty[2], l.downsample, 0, stats[5])) 114 | 115 | # stride 8 116 | l = self.model.det_layers[2] 117 | if 'obj' in uncertainty_key or 'cls' in uncertainty_key: 118 | uncertainty = l.det[uncertainty_key] 119 | elif 'epi' in uncertainty_key: 120 | uncertainty = l.det[uncertainty_key][..., ucty_idx, ucty_idx] 121 | else: 122 | uncertainty = l.det[uncertainty_key][..., ucty_idx] 123 | lh, lw, box_cnt = uncertainty.shape.as_list() 124 | uncertainty = tf.split(uncertainty, [1] * box_cnt, axis=-1) 125 | 126 | self.grids.append( 127 | color_map(self.img_tensor, uncertainty[0], l.downsample, 0, stats[6])) 128 | self.grids.append( 129 | color_map(self.img_tensor, uncertainty[1], l.downsample, 0, stats[7])) 130 | self.grids.append( 131 | color_map(self.img_tensor, uncertainty[2], l.downsample, 0, stats[8])) 132 | 133 | self.sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 1})) 134 | tf.train.Saver().restore(self.sess, self.checkpoint) 135 | 136 | def load_img(self, filename): 137 | img = Image.open(filename) 138 | img = np.array(img) 139 | img = img.astype(np.float32) 140 | 141 | if self.config['crop']: 142 | y = (img.shape[0] - self.img_size[0]) // 2 143 | x = (img.shape[1] - self.img_size[1]) // 2 144 | img = img[y:y + self.img_size[0], x:x + self.img_size[1], :] 145 | 146 | img = np.expand_dims(img, axis=0) 147 | img /= 255. 148 | return img 149 | 150 | def make_color_map(self, filename, config): 151 | img_data = self.load_img(filename) 152 | grids, = self.sess.run([self.grids], feed_dict={self.img_tensor: img_data}) 153 | img_name = os.path.basename(filename) 154 | save_uncertainty_maps(grids, img_name, config) 155 | 156 | 157 | def save_uncertainty_maps(grids, file_name, config): 158 | file_name = os.path.basename(file_name) 159 | for idx, img in enumerate(grids): 160 | result = Image.fromarray(img) 161 | path = os.path.join(config['out_path'], 162 | '{}_prior{}_{}.png'.format(os.path.splitext(file_name)[0], idx, config['ucty'])) 163 | result.save(path) 164 | 165 | 166 | def worker(files, config): 167 | os.makedirs(config['out_path'], exist_ok=True) 168 | yolo = yolov3.bayesian_yolov3_aleatoric(config) 169 | inference = Inference(yolo, config) 170 | 171 | logging.info('Processing: {}'.format(config['ucty'])) 172 | for file in files: 173 | logging.info('Processing file: {}'.format(file)) 174 | inference.make_color_map(file, config) 175 | logging.info('Finished: {}'.format(config['ucty'])) 176 | 177 | 178 | def do_it(files, config): 179 | for uncertainty_key in ['epi_covar_loc', 'ale_var_loc']: 180 | for ucty_idx in range(4): 181 | if 'epi' in uncertainty_key: 182 | ucty_type = 'epi' 183 | else: 184 | ucty_type = 'ale' 185 | 186 | mapping = ['x', 'y', 'w', 'h'] 187 | 188 | config['ucty'] = ucty_type + '_' + mapping[ucty_idx] 189 | config['ucty_idx'] = ucty_idx 190 | config['uncertainty_key'] = uncertainty_key 191 | 192 | p = multiprocessing.Process(target=worker, args=(files, config)) 193 | p.start() 194 | p.join() 195 | 196 | for uncertainty_key in ['cls_mutual_info', 'obj_mean', 'obj_mutual_info']: 197 | config['uncertainty_key'] = uncertainty_key 198 | config['ucty'] = uncertainty_key 199 | 200 | p = multiprocessing.Process(target=worker, args=(files, config)) 201 | p.start() 202 | p.join() 203 | 204 | 205 | def main(): 206 | config = { 207 | 'checkpoint_path': './checkpoints/', 208 | 'run_id': 'epi_ale', # edit 209 | 'step': 'last', # edit, int or 'last' 210 | 'crop_img_size': [768, 1440, 3], 211 | 'full_img_size': [1024, 1920, 3], # edit if not ecp 212 | 'cls_cnt': 2, 213 | 'batch_size': 1, 214 | 'T': 30, 215 | 'inference_mode': True, 216 | 'cpu_thread_cnt': 10, 217 | 'freeze_darknet53': False, # actual value irrelevant 218 | 'crop': False, # edit 219 | 'training': False, 220 | 'aleatoric_loss': True, # actual value irrelevant 221 | 'priors': yolov3.ECP_9_PRIORS, # actual value irrelevant 222 | 'out_path': './uncertainty_visualization', # edit 223 | } 224 | 225 | # NOTE: only works for bayesian_yolov3_aleatoric clss 226 | assert config['batch_size'] == 1 227 | assert config['inference_mode'] 228 | 229 | files = glob.glob('./test_images/*') # edit 230 | 231 | logging.info('----- START -----') 232 | start = time.time() 233 | 234 | do_it(files, config) 235 | 236 | end = time.time() 237 | elapsed = int(end - start) 238 | logging.info('----- FINISHED in {:02d}:{:02d}:{:02d} -----'.format(elapsed // 3600, 239 | (elapsed // 60) % 60, 240 | elapsed % 60)) 241 | 242 | 243 | if __name__ == '__main__': 244 | logging.basicConfig(level=logging.INFO, 245 | format='%(asctime)s, pid: %(process)d, %(levelname)-8s %(message)s', 246 | datefmt='%a, %d %b %Y %H:%M:%S', 247 | ) 248 | main() 249 | -------------------------------------------------------------------------------- /yolov3_training.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | from lib_yolo import yolov3, train, utils 6 | 7 | 8 | def main(): 9 | config = { 10 | 'training': True, # edit 11 | 'resume_training': False, # edit 12 | 'resume_checkpoint': 'last', # edit: either filename or 'last' the resume a training 13 | 'priors': yolov3.ECP_9_PRIORS, # edit if not ECP dataset 14 | 'run_id': 'yolo', 15 | 'checkpoint_path': './checkpoints', 16 | 'tensorboard_path': './tensorboard', 17 | 'log_path': './log', 18 | 'ckp_max_to_keep': 102, 19 | 'checkpoint_interval': 5000, 20 | 'ign_thresh': 0.7, 21 | 'crop_img_size': [768, 1440, 3], 22 | 'full_img_size': [1024, 1920, 3], # edit if not ECP dataset 23 | 'train_steps': 500000, 24 | 'darknet53_weights': './darknet53.conv.74', 25 | 'batch_size': 8, # edit 26 | 'lr': 1e-5, 27 | 'cpu_thread_cnt': 24, # edit 28 | 'crop': True, # edit, random crops and rescaling reduces memory consumption and improves training 29 | 'freeze_darknet53': True, # if True the basenet weights are frozen during training 30 | 'aleatoric_loss': False, 31 | 'cls_cnt': 2, # edit if not ECP dataset 32 | 'implicit_background_class': True, # whether the label ids start at 1 or 0. True = 1, False = 0 33 | 'train': { 34 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-train-*-of-*'), # edit 35 | 'num_shards': 20, 36 | 'shuffle_buffer_size': 2000, 37 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 38 | }, 39 | 'val': { 40 | 'file_pattern': os.path.expandvars('$HOME/data/ecp/tfrecords/ecp-day-val-*-of-*'), # edit 41 | 'num_shards': 4, 42 | 'shuffle_buffer_size': 10, 43 | 'cache': False, # edit if you have enough memory, caches whole dataset in memory 44 | } 45 | } 46 | 47 | # Note regarding implicit background class: 48 | # The tensorflow object detection API enforces that the class labels start with 1. 49 | # The class 0 is reserved for an (implicit) background class. We support both file formats. 50 | 51 | utils.add_file_logging(config, override_existing=True) 52 | logging.info(json.dumps(config, indent=4, default=lambda x: str(x))) 53 | 54 | model_cls = yolov3.yolov3 55 | 56 | if config['training']: 57 | train.start(model_cls, config) 58 | else: 59 | config['thresh'] = 0.01 # filter out boxes with objectness score less than thresh 60 | utils.qualitative_eval(model_cls, config) 61 | 62 | 63 | if __name__ == '__main__': 64 | logging.basicConfig(level=logging.INFO, 65 | format='%(asctime)s, %(levelname)-8s %(message)s', 66 | datefmt='%a, %d %b %Y %H:%M:%S', 67 | ) 68 | main() 69 | --------------------------------------------------------------------------------