├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── LICENSE.fuck ├── README.md ├── checkpoint ├── .DS_Store └── checkpoint ├── convert_weight.py ├── core ├── __init__.py ├── backbone.py ├── common.py ├── config.py ├── dataset.py ├── utils.py └── yolov3.py ├── data ├── anchors │ ├── basline_anchors.txt │ └── coco_anchors.txt ├── classes │ ├── coco.names │ └── voc.names └── dataset │ ├── voc_test.txt │ └── voc_train.txt ├── docs ├── Box-Clustering.ipynb ├── images │ ├── .jpg │ ├── 611_result.jpg │ ├── road.jpeg │ └── road.mp4 └── requirements.txt ├── evaluate.py ├── freeze_graph.py ├── from_darknet_weights_to_ckpt.py ├── from_darknet_weights_to_pb.py ├── image_demo.py ├── mAP ├── __init__.py ├── extra │ ├── README.md │ ├── class_list.txt │ ├── convert_gt_xml.py │ ├── convert_gt_yolo.py │ ├── convert_keras-yolo3.py │ ├── convert_pred_darkflow_json.py │ ├── convert_pred_yolo.py │ ├── find_class.py │ ├── intersect-gt-and-pred.py │ ├── remove_class.py │ ├── remove_delimiter_char.py │ ├── remove_space.py │ ├── rename_class.py │ └── result.txt └── main.py ├── scripts ├── show_bboxes.py └── voc_annotation.py ├── train.py └── video_demo.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: https://github.com/YunYang1994/TensorFlow2.0-Examples/blob/master/.github/Sponsor.md 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build and Release Folders 2 | bin-debug/ 3 | bin-release/ 4 | [Oo]bj/ 5 | [Bb]in/ 6 | 7 | # Other files and folders 8 | .settings/ 9 | 10 | # Executables 11 | *.swf 12 | *.air 13 | *.ipa 14 | *.apk 15 | 16 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` 17 | # should NOT be excluded as they contain compiler settings and other important 18 | # information for Eclipse / Flash Builder. 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 YangYun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE.fuck: -------------------------------------------------------------------------------- 1 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 2 | Version 1, JUNE 2019 3 | 4 | Copyright (C) 2019 YunYang1994 5 | 6 | Everyone is permitted to copy and distribute verbatim or modified 7 | copies of this license document, and changing it is allowed as long 8 | as the name is changed. 9 | 10 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 11 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 12 | 13 | 0. You just DO WHAT THE FUCK YOU WANT TO. DON'T ASK ME, JUST DO IT. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 🆕 Are you looking for a new YOLOv3 implemented by TF2.0 ? 3 | 4 | >If you hate the fucking tensorflow1.x very much, no worries! I have implemented **a new YOLOv3 repo with TF2.0**, and also made a chinese blog on how to implement YOLOv3 object detector from scratch.
5 | [code](https://github.com/YunYang1994/TensorFlow2.0-Examples/tree/master/4-Object_Detection/YOLOV3) | [blog](https://yunyang1994.gitee.io/2018/12/28/YOLOv3-算法的一点理解/) | [issue](https://github.com/YunYang1994/tensorflow-yolov3/issues/39) 6 | 7 | ## part 1. Quick start 8 | 1. Clone this file 9 | ```bashrc 10 | $ git clone https://github.com/YunYang1994/tensorflow-yolov3.git 11 | ``` 12 | 2. You are supposed to install some dependencies before getting out hands with these codes. 13 | ```bashrc 14 | $ cd tensorflow-yolov3 15 | $ pip install -r ./docs/requirements.txt 16 | ``` 17 | 3. Exporting loaded COCO weights as TF checkpoint(`yolov3_coco.ckpt`)【[BaiduCloud](https://pan.baidu.com/s/11mwiUy8KotjUVQXqkGGPFQ&shfl=sharepset)】 18 | ```bashrc 19 | $ cd checkpoint 20 | $ wget https://github.com/YunYang1994/tensorflow-yolov3/releases/download/v1.0/yolov3_coco.tar.gz 21 | $ tar -xvf yolov3_coco.tar.gz 22 | $ cd .. 23 | $ python convert_weight.py 24 | $ python freeze_graph.py 25 | ``` 26 | 4. Then you will get some `.pb` files in the root path., and run the demo script 27 | ```bashrc 28 | $ python image_demo.py 29 | $ python video_demo.py # if use camera, set video_path = 0 30 | ``` 31 |

32 | 33 | 34 |

35 | 36 | ## part 2. Train your own dataset 37 | Two files are required as follows: 38 | 39 | - [`dataset.txt`](https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/master/data/dataset/voc_train.txt): 40 | 41 | ``` 42 | xxx/xxx.jpg 18.19,6.32,424.13,421.83,20 323.86,2.65,640.0,421.94,20 43 | xxx/xxx.jpg 48,240,195,371,11 8,12,352,498,14 44 | # image_path x_min, y_min, x_max, y_max, class_id x_min, y_min ,..., class_id 45 | # make sure that x_max < width and y_max < height 46 | ``` 47 | 48 | - [`class.names`](https://github.com/YunYang1994/tensorflow-yolov3/blob/master/data/classes/coco.names): 49 | 50 | ``` 51 | person 52 | bicycle 53 | car 54 | ... 55 | toothbrush 56 | ``` 57 | 58 | ### 2.1 Train on VOC dataset 59 | Download VOC PASCAL trainval and test data 60 | ```bashrc 61 | $ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 62 | $ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 63 | $ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 64 | ``` 65 | Extract all of these tars into one directory and rename them, which should have the following basic structure. 66 | 67 | ```bashrc 68 | 69 | VOC # path: /home/yang/dataset/VOC 70 | ├── test 71 | | └──VOCdevkit 72 | | └──VOC2007 (from VOCtest_06-Nov-2007.tar) 73 | └── train 74 | └──VOCdevkit 75 | └──VOC2007 (from VOCtrainval_06-Nov-2007.tar) 76 | └──VOC2012 (from VOCtrainval_11-May-2012.tar) 77 | 78 | $ python scripts/voc_annotation.py --data_path /home/yang/test/VOC 79 | ``` 80 | Then edit your `./core/config.py` to make some necessary configurations 81 | 82 | ```bashrc 83 | __C.YOLO.CLASSES = "./data/classes/voc.names" 84 | __C.TRAIN.ANNOT_PATH = "./data/dataset/voc_train.txt" 85 | __C.TEST.ANNOT_PATH = "./data/dataset/voc_test.txt" 86 | ``` 87 | Here are two kinds of training method: 88 | 89 | ##### (1) train from scratch: 90 | 91 | ```bashrc 92 | $ python train.py 93 | $ tensorboard --logdir ./data 94 | ``` 95 | ##### (2) train from COCO weights(recommend): 96 | 97 | ```bashrc 98 | $ cd checkpoint 99 | $ wget https://github.com/YunYang1994/tensorflow-yolov3/releases/download/v1.0/yolov3_coco.tar.gz 100 | $ tar -xvf yolov3_coco.tar.gz 101 | $ cd .. 102 | $ python convert_weight.py --train_from_coco 103 | $ python train.py 104 | ``` 105 | ### 2.2 Evaluate on VOC dataset 106 | 107 | ``` 108 | $ python evaluate.py 109 | $ cd mAP 110 | $ python main.py -na 111 | ``` 112 | 113 | the mAP on the VOC2012 dataset: 114 | 115 |

116 | 117 | 118 |

119 | 120 | 121 | ## part 3. Other Implementations 122 | 123 | [-**`YOLOv3目标检测有了TensorFlow实现,可用自己的数据来训练`**](https://mp.weixin.qq.com/s/cq7g1-4oFTftLbmKcpi_aQ)
124 | 125 | [-**`Stronger-yolo`**](https://github.com/Stinky-Tofu/Stronger-yolo)
126 | 127 | [- **`Implementing YOLO v3 in Tensorflow (TF-Slim)`**](https://itnext.io/implementing-yolo-v3-in-tensorflow-tf-slim-c3c55ff59dbe) 128 | 129 | [- **`YOLOv3_TensorFlow`**](https://github.com/wizyoung/YOLOv3_TensorFlow) 130 | 131 | [- **`Object Detection using YOLOv2 on Pascal VOC2012`**](https://fairyonice.github.io/Part_1_Object_Detection_with_Yolo_for_VOC_2014_data_anchor_box_clustering.html) 132 | 133 | [-**`Understanding YOLO`**](https://hackernoon.com/understanding-yolo-f5a74bbc7967) 134 | 135 | -------------------------------------------------------------------------------- /checkpoint/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/checkpoint/.DS_Store -------------------------------------------------------------------------------- /checkpoint/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "yolov3_test_loss=2530.1914.ckpt-1" 2 | all_model_checkpoint_paths: "yolov3_test_loss=2530.1914.ckpt-1" 3 | -------------------------------------------------------------------------------- /convert_weight.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : convert_weight.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 13:51:31 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import argparse 15 | import tensorflow as tf 16 | from core.yolov3 import YOLOV3 17 | from core.config import cfg 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--train_from_coco", action='store_true') 20 | flag = parser.parse_args() 21 | 22 | org_weights_path = cfg.YOLO.ORIGINAL_WEIGHT 23 | cur_weights_path = cfg.YOLO.DEMO_WEIGHT 24 | preserve_cur_names = ['conv_sbbox', 'conv_mbbox', 'conv_lbbox'] 25 | preserve_org_names = ['Conv_6', 'Conv_14', 'Conv_22'] 26 | 27 | 28 | org_weights_mess = [] 29 | tf.Graph().as_default() 30 | load = tf.train.import_meta_graph(org_weights_path + '.meta') 31 | with tf.Session() as sess: 32 | load.restore(sess, org_weights_path) 33 | for var in tf.global_variables(): 34 | var_name = var.op.name 35 | var_name_mess = str(var_name).split('/') 36 | var_shape = var.shape 37 | if flag.train_from_coco: 38 | if (var_name_mess[-1] not in ['weights', 'gamma', 'beta', 'moving_mean', 'moving_variance']) or \ 39 | (var_name_mess[1] == 'yolo-v3' and (var_name_mess[-2] in preserve_org_names)): continue 40 | org_weights_mess.append([var_name, var_shape]) 41 | print("=> " + str(var_name).ljust(50), var_shape) 42 | print() 43 | tf.reset_default_graph() 44 | 45 | cur_weights_mess = [] 46 | tf.Graph().as_default() 47 | with tf.name_scope('input'): 48 | input_data = tf.placeholder(dtype=tf.float32, shape=(1, 416, 416, 3), name='input_data') 49 | training = tf.placeholder(dtype=tf.bool, name='trainable') 50 | model = YOLOV3(input_data, training) 51 | for var in tf.global_variables(): 52 | var_name = var.op.name 53 | var_name_mess = str(var_name).split('/') 54 | var_shape = var.shape 55 | print(var_name_mess[0]) 56 | if flag.train_from_coco: 57 | if var_name_mess[0] in preserve_cur_names: continue 58 | cur_weights_mess.append([var_name, var_shape]) 59 | print("=> " + str(var_name).ljust(50), var_shape) 60 | 61 | org_weights_num = len(org_weights_mess) 62 | cur_weights_num = len(cur_weights_mess) 63 | if cur_weights_num != org_weights_num: 64 | raise RuntimeError 65 | 66 | print('=> Number of weights that will rename:\t%d' % cur_weights_num) 67 | cur_to_org_dict = {} 68 | for index in range(org_weights_num): 69 | org_name, org_shape = org_weights_mess[index] 70 | cur_name, cur_shape = cur_weights_mess[index] 71 | if cur_shape != org_shape: 72 | print(org_weights_mess[index]) 73 | print(cur_weights_mess[index]) 74 | raise RuntimeError 75 | cur_to_org_dict[cur_name] = org_name 76 | print("=> " + str(cur_name).ljust(50) + ' : ' + org_name) 77 | 78 | with tf.name_scope('load_save'): 79 | name_to_var_dict = {var.op.name: var for var in tf.global_variables()} 80 | restore_dict = {cur_to_org_dict[cur_name]: name_to_var_dict[cur_name] for cur_name in cur_to_org_dict} 81 | load = tf.train.Saver(restore_dict) 82 | save = tf.train.Saver(tf.global_variables()) 83 | for var in tf.global_variables(): 84 | print("=> " + var.op.name) 85 | 86 | with tf.Session() as sess: 87 | sess.run(tf.global_variables_initializer()) 88 | print('=> Restoring weights from:\t %s' % org_weights_path) 89 | load.restore(sess, org_weights_path) 90 | save.save(sess, cur_weights_path) 91 | tf.reset_default_graph() 92 | 93 | 94 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/core/__init__.py -------------------------------------------------------------------------------- /core/backbone.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : backbone.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-17 11:03:35 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import core.common as common 15 | import tensorflow as tf 16 | 17 | 18 | def darknet53(input_data, trainable): 19 | 20 | with tf.variable_scope('darknet'): 21 | 22 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 3, 32), trainable=trainable, name='conv0') 23 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 32, 64), 24 | trainable=trainable, name='conv1', downsample=True) 25 | 26 | for i in range(1): 27 | input_data = common.residual_block(input_data, 64, 32, 64, trainable=trainable, name='residual%d' %(i+0)) 28 | 29 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 64, 128), 30 | trainable=trainable, name='conv4', downsample=True) 31 | 32 | for i in range(2): 33 | input_data = common.residual_block(input_data, 128, 64, 128, trainable=trainable, name='residual%d' %(i+1)) 34 | 35 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 128, 256), 36 | trainable=trainable, name='conv9', downsample=True) 37 | 38 | for i in range(8): 39 | input_data = common.residual_block(input_data, 256, 128, 256, trainable=trainable, name='residual%d' %(i+3)) 40 | 41 | route_1 = input_data 42 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 256, 512), 43 | trainable=trainable, name='conv26', downsample=True) 44 | 45 | for i in range(8): 46 | input_data = common.residual_block(input_data, 512, 256, 512, trainable=trainable, name='residual%d' %(i+11)) 47 | 48 | route_2 = input_data 49 | input_data = common.convolutional(input_data, filters_shape=(3, 3, 512, 1024), 50 | trainable=trainable, name='conv43', downsample=True) 51 | 52 | for i in range(4): 53 | input_data = common.residual_block(input_data, 1024, 512, 1024, trainable=trainable, name='residual%d' %(i+19)) 54 | 55 | return route_1, route_2, input_data 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /core/common.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : common.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 09:56:29 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import tensorflow as tf 15 | 16 | 17 | def convolutional(input_data, filters_shape, trainable, name, downsample=False, activate=True, bn=True): 18 | 19 | with tf.variable_scope(name): 20 | if downsample: 21 | pad_h, pad_w = (filters_shape[0] - 2) // 2 + 1, (filters_shape[1] - 2) // 2 + 1 22 | paddings = tf.constant([[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]]) 23 | input_data = tf.pad(input_data, paddings, 'CONSTANT') 24 | strides = (1, 2, 2, 1) 25 | padding = 'VALID' 26 | else: 27 | strides = (1, 1, 1, 1) 28 | padding = "SAME" 29 | 30 | weight = tf.get_variable(name='weight', dtype=tf.float32, trainable=True, 31 | shape=filters_shape, initializer=tf.random_normal_initializer(stddev=0.01)) 32 | conv = tf.nn.conv2d(input=input_data, filter=weight, strides=strides, padding=padding) 33 | 34 | if bn: 35 | conv = tf.layers.batch_normalization(conv, beta_initializer=tf.zeros_initializer(), 36 | gamma_initializer=tf.ones_initializer(), 37 | moving_mean_initializer=tf.zeros_initializer(), 38 | moving_variance_initializer=tf.ones_initializer(), training=trainable) 39 | else: 40 | bias = tf.get_variable(name='bias', shape=filters_shape[-1], trainable=True, 41 | dtype=tf.float32, initializer=tf.constant_initializer(0.0)) 42 | conv = tf.nn.bias_add(conv, bias) 43 | 44 | if activate == True: conv = tf.nn.leaky_relu(conv, alpha=0.1) 45 | 46 | return conv 47 | 48 | 49 | def residual_block(input_data, input_channel, filter_num1, filter_num2, trainable, name): 50 | 51 | short_cut = input_data 52 | 53 | with tf.variable_scope(name): 54 | input_data = convolutional(input_data, filters_shape=(1, 1, input_channel, filter_num1), 55 | trainable=trainable, name='conv1') 56 | input_data = convolutional(input_data, filters_shape=(3, 3, filter_num1, filter_num2), 57 | trainable=trainable, name='conv2') 58 | 59 | residual_output = input_data + short_cut 60 | 61 | return residual_output 62 | 63 | 64 | 65 | def route(name, previous_output, current_output): 66 | 67 | with tf.variable_scope(name): 68 | output = tf.concat([current_output, previous_output], axis=-1) 69 | 70 | return output 71 | 72 | 73 | def upsample(input_data, name, method="deconv"): 74 | assert method in ["resize", "deconv"] 75 | 76 | if method == "resize": 77 | with tf.variable_scope(name): 78 | input_shape = tf.shape(input_data) 79 | output = tf.image.resize_nearest_neighbor(input_data, (input_shape[1] * 2, input_shape[2] * 2)) 80 | 81 | if method == "deconv": 82 | # replace resize_nearest_neighbor with conv2d_transpose To support TensorRT optimization 83 | numm_filter = input_data.shape.as_list()[-1] 84 | output = tf.layers.conv2d_transpose(input_data, numm_filter, kernel_size=2, padding='same', 85 | strides=(2,2), kernel_initializer=tf.random_normal_initializer()) 86 | 87 | return output 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : config.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 13:06:54 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | from easydict import EasyDict as edict 15 | 16 | 17 | __C = edict() 18 | # Consumers can get config by: from config import cfg 19 | 20 | cfg = __C 21 | 22 | # YOLO options 23 | __C.YOLO = edict() 24 | 25 | # Set the class name 26 | __C.YOLO.CLASSES = "./data/classes/coco.names" 27 | __C.YOLO.ANCHORS = "./data/anchors/basline_anchors.txt" 28 | __C.YOLO.MOVING_AVE_DECAY = 0.9995 29 | __C.YOLO.STRIDES = [8, 16, 32] 30 | __C.YOLO.ANCHOR_PER_SCALE = 3 31 | __C.YOLO.IOU_LOSS_THRESH = 0.5 32 | __C.YOLO.UPSAMPLE_METHOD = "resize" 33 | __C.YOLO.ORIGINAL_WEIGHT = "./checkpoint/yolov3_coco.ckpt" 34 | __C.YOLO.DEMO_WEIGHT = "./checkpoint/yolov3_coco_demo.ckpt" 35 | 36 | # Train options 37 | __C.TRAIN = edict() 38 | 39 | __C.TRAIN.ANNOT_PATH = "./data/dataset/voc_train.txt" 40 | __C.TRAIN.BATCH_SIZE = 6 41 | __C.TRAIN.INPUT_SIZE = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] 42 | __C.TRAIN.DATA_AUG = True 43 | __C.TRAIN.LEARN_RATE_INIT = 1e-4 44 | __C.TRAIN.LEARN_RATE_END = 1e-6 45 | __C.TRAIN.WARMUP_EPOCHS = 2 46 | __C.TRAIN.FISRT_STAGE_EPOCHS = 20 47 | __C.TRAIN.SECOND_STAGE_EPOCHS = 30 48 | __C.TRAIN.INITIAL_WEIGHT = "./checkpoint/yolov3_coco_demo.ckpt" 49 | 50 | 51 | 52 | # TEST options 53 | __C.TEST = edict() 54 | 55 | __C.TEST.ANNOT_PATH = "./data/dataset/voc_test.txt" 56 | __C.TEST.BATCH_SIZE = 2 57 | __C.TEST.INPUT_SIZE = 544 58 | __C.TEST.DATA_AUG = False 59 | __C.TEST.WRITE_IMAGE = True 60 | __C.TEST.WRITE_IMAGE_PATH = "./data/detection/" 61 | __C.TEST.WRITE_IMAGE_SHOW_LABEL = True 62 | __C.TEST.WEIGHT_FILE = "./checkpoint/yolov3_test_loss=9.2099.ckpt-5" 63 | __C.TEST.SHOW_LABEL = True 64 | __C.TEST.SCORE_THRESHOLD = 0.3 65 | __C.TEST.IOU_THRESHOLD = 0.45 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : dataset.py 8 | # Author : YunYang1994 9 | # Created date: 2019-03-15 18:05:03 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import os 15 | import cv2 16 | import random 17 | import numpy as np 18 | import tensorflow as tf 19 | import core.utils as utils 20 | from core.config import cfg 21 | 22 | 23 | 24 | class Dataset(object): 25 | """implement Dataset here""" 26 | def __init__(self, dataset_type): 27 | self.annot_path = cfg.TRAIN.ANNOT_PATH if dataset_type == 'train' else cfg.TEST.ANNOT_PATH 28 | self.input_sizes = cfg.TRAIN.INPUT_SIZE if dataset_type == 'train' else cfg.TEST.INPUT_SIZE 29 | self.batch_size = cfg.TRAIN.BATCH_SIZE if dataset_type == 'train' else cfg.TEST.BATCH_SIZE 30 | self.data_aug = cfg.TRAIN.DATA_AUG if dataset_type == 'train' else cfg.TEST.DATA_AUG 31 | 32 | self.train_input_sizes = cfg.TRAIN.INPUT_SIZE 33 | self.strides = np.array(cfg.YOLO.STRIDES) 34 | self.classes = utils.read_class_names(cfg.YOLO.CLASSES) 35 | self.num_classes = len(self.classes) 36 | self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS)) 37 | self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE 38 | self.max_bbox_per_scale = 150 39 | 40 | self.annotations = self.load_annotations(dataset_type) 41 | self.num_samples = len(self.annotations) 42 | self.num_batchs = int(np.ceil(self.num_samples / self.batch_size)) 43 | self.batch_count = 0 44 | 45 | 46 | def load_annotations(self, dataset_type): 47 | with open(self.annot_path, 'r') as f: 48 | txt = f.readlines() 49 | annotations = [line.strip() for line in txt if len(line.strip().split()[1:]) != 0] 50 | np.random.shuffle(annotations) 51 | return annotations 52 | 53 | def __iter__(self): 54 | return self 55 | 56 | def __next__(self): 57 | 58 | with tf.device('/cpu:0'): 59 | self.train_input_size = random.choice(self.train_input_sizes) 60 | self.train_output_sizes = self.train_input_size // self.strides 61 | 62 | batch_image = np.zeros((self.batch_size, self.train_input_size, self.train_input_size, 3)) 63 | 64 | batch_label_sbbox = np.zeros((self.batch_size, self.train_output_sizes[0], self.train_output_sizes[0], 65 | self.anchor_per_scale, 5 + self.num_classes)) 66 | batch_label_mbbox = np.zeros((self.batch_size, self.train_output_sizes[1], self.train_output_sizes[1], 67 | self.anchor_per_scale, 5 + self.num_classes)) 68 | batch_label_lbbox = np.zeros((self.batch_size, self.train_output_sizes[2], self.train_output_sizes[2], 69 | self.anchor_per_scale, 5 + self.num_classes)) 70 | 71 | batch_sbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4)) 72 | batch_mbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4)) 73 | batch_lbboxes = np.zeros((self.batch_size, self.max_bbox_per_scale, 4)) 74 | 75 | num = 0 76 | if self.batch_count < self.num_batchs: 77 | while num < self.batch_size: 78 | index = self.batch_count * self.batch_size + num 79 | if index >= self.num_samples: index -= self.num_samples 80 | annotation = self.annotations[index] 81 | image, bboxes = self.parse_annotation(annotation) 82 | label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.preprocess_true_boxes(bboxes) 83 | 84 | batch_image[num, :, :, :] = image 85 | batch_label_sbbox[num, :, :, :, :] = label_sbbox 86 | batch_label_mbbox[num, :, :, :, :] = label_mbbox 87 | batch_label_lbbox[num, :, :, :, :] = label_lbbox 88 | batch_sbboxes[num, :, :] = sbboxes 89 | batch_mbboxes[num, :, :] = mbboxes 90 | batch_lbboxes[num, :, :] = lbboxes 91 | num += 1 92 | self.batch_count += 1 93 | return batch_image, batch_label_sbbox, batch_label_mbbox, batch_label_lbbox, \ 94 | batch_sbboxes, batch_mbboxes, batch_lbboxes 95 | else: 96 | self.batch_count = 0 97 | np.random.shuffle(self.annotations) 98 | raise StopIteration 99 | 100 | def random_horizontal_flip(self, image, bboxes): 101 | 102 | if random.random() < 0.5: 103 | _, w, _ = image.shape 104 | image = image[:, ::-1, :] 105 | bboxes[:, [0,2]] = w - bboxes[:, [2,0]] 106 | 107 | return image, bboxes 108 | 109 | def random_crop(self, image, bboxes): 110 | 111 | if random.random() < 0.5: 112 | h, w, _ = image.shape 113 | max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) 114 | 115 | max_l_trans = max_bbox[0] 116 | max_u_trans = max_bbox[1] 117 | max_r_trans = w - max_bbox[2] 118 | max_d_trans = h - max_bbox[3] 119 | 120 | crop_xmin = max(0, int(max_bbox[0] - random.uniform(0, max_l_trans))) 121 | crop_ymin = max(0, int(max_bbox[1] - random.uniform(0, max_u_trans))) 122 | crop_xmax = max(w, int(max_bbox[2] + random.uniform(0, max_r_trans))) 123 | crop_ymax = max(h, int(max_bbox[3] + random.uniform(0, max_d_trans))) 124 | 125 | image = image[crop_ymin : crop_ymax, crop_xmin : crop_xmax] 126 | 127 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin 128 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin 129 | 130 | return image, bboxes 131 | 132 | def random_translate(self, image, bboxes): 133 | 134 | if random.random() < 0.5: 135 | h, w, _ = image.shape 136 | max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) 137 | 138 | max_l_trans = max_bbox[0] 139 | max_u_trans = max_bbox[1] 140 | max_r_trans = w - max_bbox[2] 141 | max_d_trans = h - max_bbox[3] 142 | 143 | tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1)) 144 | ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1)) 145 | 146 | M = np.array([[1, 0, tx], [0, 1, ty]]) 147 | image = cv2.warpAffine(image, M, (w, h)) 148 | 149 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] + tx 150 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] + ty 151 | 152 | return image, bboxes 153 | 154 | def parse_annotation(self, annotation): 155 | 156 | line = annotation.split() 157 | image_path = line[0] 158 | if not os.path.exists(image_path): 159 | raise KeyError("%s does not exist ... " %image_path) 160 | image = np.array(cv2.imread(image_path)) 161 | bboxes = np.array([list(map(lambda x: int(float(x)), box.split(','))) for box in line[1:]]) 162 | 163 | if self.data_aug: 164 | image, bboxes = self.random_horizontal_flip(np.copy(image), np.copy(bboxes)) 165 | image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes)) 166 | image, bboxes = self.random_translate(np.copy(image), np.copy(bboxes)) 167 | 168 | image, bboxes = utils.image_preporcess(np.copy(image), [self.train_input_size, self.train_input_size], np.copy(bboxes)) 169 | 170 | updated_bb = [] 171 | for bb in bboxes: 172 | x1, y1, x2, y2, cls_label = bb 173 | 174 | if x2 <= x1 or y2 <= y1: 175 | # dont use such boxes as this may cause nan loss. 176 | continue 177 | 178 | x1 = int(np.clip(x1, 0, image.shape[1])) 179 | y1 = int(np.clip(y1, 0, image.shape[0])) 180 | x2 = int(np.clip(x2, 0, image.shape[1])) 181 | y2 = int(np.clip(y2, 0, image.shape[0])) 182 | # clipping coordinates between 0 to image dimensions as negative values 183 | # or values greater than image dimensions may cause nan loss. 184 | updated_bb.append([x1, y1, x2, y2, cls_label]) 185 | 186 | return image, np.array(updated_bb) 187 | 188 | def bbox_iou(self, boxes1, boxes2): 189 | 190 | boxes1 = np.array(boxes1) 191 | boxes2 = np.array(boxes2) 192 | 193 | boxes1_area = boxes1[..., 2] * boxes1[..., 3] 194 | boxes2_area = boxes2[..., 2] * boxes2[..., 3] 195 | 196 | boxes1 = np.concatenate([boxes1[..., :2] - boxes1[..., 2:] * 0.5, 197 | boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1) 198 | boxes2 = np.concatenate([boxes2[..., :2] - boxes2[..., 2:] * 0.5, 199 | boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1) 200 | 201 | left_up = np.maximum(boxes1[..., :2], boxes2[..., :2]) 202 | right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:]) 203 | 204 | inter_section = np.maximum(right_down - left_up, 0.0) 205 | inter_area = inter_section[..., 0] * inter_section[..., 1] 206 | union_area = boxes1_area + boxes2_area - inter_area 207 | 208 | return inter_area / (union_area + 1e-6) 209 | # added 1e-6 in denominator to avoid generation of inf, which may cause nan loss 210 | 211 | 212 | def preprocess_true_boxes(self, bboxes): 213 | 214 | label = [np.zeros((self.train_output_sizes[i], self.train_output_sizes[i], self.anchor_per_scale, 215 | 5 + self.num_classes)) for i in range(3)] 216 | bboxes_xywh = [np.zeros((self.max_bbox_per_scale, 4)) for _ in range(3)] 217 | bbox_count = np.zeros((3,)) 218 | 219 | for bbox in bboxes: 220 | bbox_coor = bbox[:4] 221 | bbox_class_ind = bbox[4] 222 | 223 | onehot = np.zeros(self.num_classes, dtype=np.float) 224 | onehot[bbox_class_ind] = 1.0 225 | uniform_distribution = np.full(self.num_classes, 1.0 / self.num_classes) 226 | deta = 0.01 227 | smooth_onehot = onehot * (1 - deta) + deta * uniform_distribution 228 | 229 | bbox_xywh = np.concatenate([(bbox_coor[2:] + bbox_coor[:2]) * 0.5, bbox_coor[2:] - bbox_coor[:2]], axis=-1) 230 | bbox_xywh_scaled = 1.0 * bbox_xywh[np.newaxis, :] / self.strides[:, np.newaxis] 231 | 232 | iou = [] 233 | exist_positive = False 234 | for i in range(3): 235 | anchors_xywh = np.zeros((self.anchor_per_scale, 4)) 236 | anchors_xywh[:, 0:2] = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) + 0.5 237 | anchors_xywh[:, 2:4] = self.anchors[i] 238 | 239 | iou_scale = self.bbox_iou(bbox_xywh_scaled[i][np.newaxis, :], anchors_xywh) 240 | iou.append(iou_scale) 241 | iou_mask = iou_scale > 0.3 242 | 243 | if np.any(iou_mask): 244 | xind, yind = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) 245 | xind = np.clip(xind, 0, self.train_output_sizes[i] - 1) 246 | yind = np.clip(yind, 0, self.train_output_sizes[i] - 1) 247 | # This will mitigate errors generated when the location computed by this is more the grid cell location. 248 | # e.g. For 52x52 grid cells possible values of xind and yind are in range [0-51] including both. 249 | # But sometimes the coomputation makes it 52 and then it will try to find that location in label array 250 | # which is not present and throws error during training. 251 | 252 | label[i][yind, xind, iou_mask, :] = 0 253 | label[i][yind, xind, iou_mask, 0:4] = bbox_xywh 254 | label[i][yind, xind, iou_mask, 4:5] = 1.0 255 | label[i][yind, xind, iou_mask, 5:] = smooth_onehot 256 | 257 | bbox_ind = int(bbox_count[i] % self.max_bbox_per_scale) 258 | bboxes_xywh[i][bbox_ind, :4] = bbox_xywh 259 | bbox_count[i] += 1 260 | 261 | exist_positive = True 262 | 263 | if not exist_positive: 264 | best_anchor_ind = np.argmax(np.array(iou).reshape(-1), axis=-1) 265 | best_detect = int(best_anchor_ind / self.anchor_per_scale) 266 | best_anchor = int(best_anchor_ind % self.anchor_per_scale) 267 | xind, yind = np.floor(bbox_xywh_scaled[best_detect, 0:2]).astype(np.int32) 268 | xind = np.clip(xind, 0, self.train_output_sizes[i] - 1) 269 | yind = np.clip(yind, 0, self.train_output_sizes[i] - 1) 270 | # This will mitigate errors generated when the location computed by this is more the grid cell location. 271 | # e.g. For 52x52 grid cells possible values of xind and yind are in range [0-51] including both. 272 | # But sometimes the coomputation makes it 52 and then it will try to find that location in label array 273 | # which is not present and throws error during training. 274 | 275 | label[best_detect][yind, xind, best_anchor, :] = 0 276 | label[best_detect][yind, xind, best_anchor, 0:4] = bbox_xywh 277 | label[best_detect][yind, xind, best_anchor, 4:5] = 1.0 278 | label[best_detect][yind, xind, best_anchor, 5:] = smooth_onehot 279 | 280 | bbox_ind = int(bbox_count[best_detect] % self.max_bbox_per_scale) 281 | bboxes_xywh[best_detect][bbox_ind, :4] = bbox_xywh 282 | bbox_count[best_detect] += 1 283 | label_sbbox, label_mbbox, label_lbbox = label 284 | sbboxes, mbboxes, lbboxes = bboxes_xywh 285 | return label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes 286 | 287 | def __len__(self): 288 | return self.num_batchs 289 | 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : utils.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 13:14:19 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import cv2 15 | import random 16 | import colorsys 17 | import numpy as np 18 | import tensorflow as tf 19 | from core.config import cfg 20 | 21 | def read_class_names(class_file_name): 22 | '''loads class name from a file''' 23 | names = {} 24 | with open(class_file_name, 'r') as data: 25 | for ID, name in enumerate(data): 26 | names[ID] = name.strip('\n') 27 | return names 28 | 29 | 30 | def get_anchors(anchors_path): 31 | '''loads the anchors from a file''' 32 | with open(anchors_path) as f: 33 | anchors = f.readline() 34 | anchors = np.array(anchors.split(','), dtype=np.float32) 35 | return anchors.reshape(3, 3, 2) 36 | 37 | 38 | def image_preporcess(image, target_size, gt_boxes=None): 39 | 40 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) 41 | 42 | ih, iw = target_size 43 | h, w, _ = image.shape 44 | 45 | scale = min(iw/w, ih/h) 46 | nw, nh = int(scale * w), int(scale * h) 47 | image_resized = cv2.resize(image, (nw, nh)) 48 | 49 | image_paded = np.full(shape=[ih, iw, 3], fill_value=128.0) 50 | dw, dh = (iw - nw) // 2, (ih-nh) // 2 51 | image_paded[dh:nh+dh, dw:nw+dw, :] = image_resized 52 | image_paded = image_paded / 255. 53 | 54 | if gt_boxes is None: 55 | return image_paded 56 | 57 | else: 58 | gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] * scale + dw 59 | gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh 60 | return image_paded, gt_boxes 61 | 62 | 63 | def draw_bbox(image, bboxes, classes=read_class_names(cfg.YOLO.CLASSES), show_label=True): 64 | """ 65 | bboxes: [x_min, y_min, x_max, y_max, probability, cls_id] format coordinates. 66 | """ 67 | 68 | num_classes = len(classes) 69 | image_h, image_w, _ = image.shape 70 | hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)] 71 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 72 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) 73 | 74 | random.seed(0) 75 | random.shuffle(colors) 76 | random.seed(None) 77 | 78 | for i, bbox in enumerate(bboxes): 79 | coor = np.array(bbox[:4], dtype=np.int32) 80 | fontScale = 0.5 81 | score = bbox[4] 82 | class_ind = int(bbox[5]) 83 | bbox_color = colors[class_ind] 84 | bbox_thick = int(0.6 * (image_h + image_w) / 600) 85 | c1, c2 = (coor[0], coor[1]), (coor[2], coor[3]) 86 | cv2.rectangle(image, c1, c2, bbox_color, bbox_thick) 87 | 88 | if show_label: 89 | bbox_mess = '%s: %.2f' % (classes[class_ind], score) 90 | t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick//2)[0] 91 | cv2.rectangle(image, c1, (c1[0] + t_size[0], c1[1] - t_size[1] - 3), bbox_color, -1) # filled 92 | 93 | cv2.putText(image, bbox_mess, (c1[0], c1[1]-2), cv2.FONT_HERSHEY_SIMPLEX, 94 | fontScale, (0, 0, 0), bbox_thick//2, lineType=cv2.LINE_AA) 95 | 96 | return image 97 | 98 | 99 | 100 | def bboxes_iou(boxes1, boxes2): 101 | 102 | boxes1 = np.array(boxes1) 103 | boxes2 = np.array(boxes2) 104 | 105 | boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1]) 106 | boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1]) 107 | 108 | left_up = np.maximum(boxes1[..., :2], boxes2[..., :2]) 109 | right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:]) 110 | 111 | inter_section = np.maximum(right_down - left_up, 0.0) 112 | inter_area = inter_section[..., 0] * inter_section[..., 1] 113 | union_area = boxes1_area + boxes2_area - inter_area 114 | ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps) 115 | 116 | return ious 117 | 118 | 119 | 120 | def read_pb_return_tensors(graph, pb_file, return_elements): 121 | 122 | with tf.gfile.FastGFile(pb_file, 'rb') as f: 123 | frozen_graph_def = tf.GraphDef() 124 | frozen_graph_def.ParseFromString(f.read()) 125 | 126 | with graph.as_default(): 127 | return_elements = tf.import_graph_def(frozen_graph_def, 128 | return_elements=return_elements) 129 | return return_elements 130 | 131 | 132 | def nms(bboxes, iou_threshold, sigma=0.3, method='nms'): 133 | """ 134 | :param bboxes: (xmin, ymin, xmax, ymax, score, class) 135 | 136 | Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf 137 | https://github.com/bharatsingh430/soft-nms 138 | """ 139 | classes_in_img = list(set(bboxes[:, 5])) 140 | best_bboxes = [] 141 | 142 | for cls in classes_in_img: 143 | cls_mask = (bboxes[:, 5] == cls) 144 | cls_bboxes = bboxes[cls_mask] 145 | 146 | while len(cls_bboxes) > 0: 147 | max_ind = np.argmax(cls_bboxes[:, 4]) 148 | best_bbox = cls_bboxes[max_ind] 149 | best_bboxes.append(best_bbox) 150 | cls_bboxes = np.concatenate([cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]]) 151 | iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4]) 152 | weight = np.ones((len(iou),), dtype=np.float32) 153 | 154 | assert method in ['nms', 'soft-nms'] 155 | 156 | if method == 'nms': 157 | iou_mask = iou > iou_threshold 158 | weight[iou_mask] = 0.0 159 | 160 | if method == 'soft-nms': 161 | weight = np.exp(-(1.0 * iou ** 2 / sigma)) 162 | 163 | cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight 164 | score_mask = cls_bboxes[:, 4] > 0. 165 | cls_bboxes = cls_bboxes[score_mask] 166 | 167 | return best_bboxes 168 | 169 | 170 | def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold): 171 | 172 | valid_scale=[0, np.inf] 173 | pred_bbox = np.array(pred_bbox) 174 | 175 | pred_xywh = pred_bbox[:, 0:4] 176 | pred_conf = pred_bbox[:, 4] 177 | pred_prob = pred_bbox[:, 5:] 178 | 179 | # # (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax) 180 | pred_coor = np.concatenate([pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5, 181 | pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5], axis=-1) 182 | # # (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org) 183 | org_h, org_w = org_img_shape 184 | resize_ratio = min(input_size / org_w, input_size / org_h) 185 | 186 | dw = (input_size - resize_ratio * org_w) / 2 187 | dh = (input_size - resize_ratio * org_h) / 2 188 | 189 | pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio 190 | pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio 191 | 192 | # # (3) clip some boxes those are out of range 193 | pred_coor = np.concatenate([np.maximum(pred_coor[:, :2], [0, 0]), 194 | np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1) 195 | invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), (pred_coor[:, 1] > pred_coor[:, 3])) 196 | pred_coor[invalid_mask] = 0 197 | 198 | # # (4) discard some invalid boxes 199 | bboxes_scale = np.sqrt(np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1)) 200 | scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), (bboxes_scale < valid_scale[1])) 201 | 202 | # # (5) discard some boxes with low scores 203 | classes = np.argmax(pred_prob, axis=-1) 204 | scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes] 205 | score_mask = scores > score_threshold 206 | mask = np.logical_and(scale_mask, score_mask) 207 | coors, scores, classes = pred_coor[mask], scores[mask], classes[mask] 208 | 209 | return np.concatenate([coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1) 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /core/yolov3.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : yolov3.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 10:47:03 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | import core.utils as utils 17 | import core.common as common 18 | import core.backbone as backbone 19 | from core.config import cfg 20 | 21 | 22 | class YOLOV3(object): 23 | """Implement tensoflow yolov3 here""" 24 | def __init__(self, input_data, trainable): 25 | 26 | self.trainable = trainable 27 | self.classes = utils.read_class_names(cfg.YOLO.CLASSES) 28 | self.num_class = len(self.classes) 29 | self.strides = np.array(cfg.YOLO.STRIDES) 30 | self.anchors = utils.get_anchors(cfg.YOLO.ANCHORS) 31 | self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE 32 | self.iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH 33 | self.upsample_method = cfg.YOLO.UPSAMPLE_METHOD 34 | 35 | try: 36 | self.conv_lbbox, self.conv_mbbox, self.conv_sbbox = self.__build_nework(input_data) 37 | except: 38 | raise NotImplementedError("Can not build up yolov3 network!") 39 | 40 | with tf.variable_scope('pred_sbbox'): 41 | self.pred_sbbox = self.decode(self.conv_sbbox, self.anchors[0], self.strides[0]) 42 | 43 | with tf.variable_scope('pred_mbbox'): 44 | self.pred_mbbox = self.decode(self.conv_mbbox, self.anchors[1], self.strides[1]) 45 | 46 | with tf.variable_scope('pred_lbbox'): 47 | self.pred_lbbox = self.decode(self.conv_lbbox, self.anchors[2], self.strides[2]) 48 | 49 | def __build_nework(self, input_data): 50 | 51 | route_1, route_2, input_data = backbone.darknet53(input_data, self.trainable) 52 | 53 | input_data = common.convolutional(input_data, (1, 1, 1024, 512), self.trainable, 'conv52') 54 | input_data = common.convolutional(input_data, (3, 3, 512, 1024), self.trainable, 'conv53') 55 | input_data = common.convolutional(input_data, (1, 1, 1024, 512), self.trainable, 'conv54') 56 | input_data = common.convolutional(input_data, (3, 3, 512, 1024), self.trainable, 'conv55') 57 | input_data = common.convolutional(input_data, (1, 1, 1024, 512), self.trainable, 'conv56') 58 | 59 | conv_lobj_branch = common.convolutional(input_data, (3, 3, 512, 1024), self.trainable, name='conv_lobj_branch') 60 | conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 1024, 3*(self.num_class + 5)), 61 | trainable=self.trainable, name='conv_lbbox', activate=False, bn=False) 62 | 63 | input_data = common.convolutional(input_data, (1, 1, 512, 256), self.trainable, 'conv57') 64 | input_data = common.upsample(input_data, name='upsample0', method=self.upsample_method) 65 | 66 | with tf.variable_scope('route_1'): 67 | input_data = tf.concat([input_data, route_2], axis=-1) 68 | 69 | input_data = common.convolutional(input_data, (1, 1, 768, 256), self.trainable, 'conv58') 70 | input_data = common.convolutional(input_data, (3, 3, 256, 512), self.trainable, 'conv59') 71 | input_data = common.convolutional(input_data, (1, 1, 512, 256), self.trainable, 'conv60') 72 | input_data = common.convolutional(input_data, (3, 3, 256, 512), self.trainable, 'conv61') 73 | input_data = common.convolutional(input_data, (1, 1, 512, 256), self.trainable, 'conv62') 74 | 75 | conv_mobj_branch = common.convolutional(input_data, (3, 3, 256, 512), self.trainable, name='conv_mobj_branch' ) 76 | conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 512, 3*(self.num_class + 5)), 77 | trainable=self.trainable, name='conv_mbbox', activate=False, bn=False) 78 | 79 | input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv63') 80 | input_data = common.upsample(input_data, name='upsample1', method=self.upsample_method) 81 | 82 | with tf.variable_scope('route_2'): 83 | input_data = tf.concat([input_data, route_1], axis=-1) 84 | 85 | input_data = common.convolutional(input_data, (1, 1, 384, 128), self.trainable, 'conv64') 86 | input_data = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, 'conv65') 87 | input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv66') 88 | input_data = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, 'conv67') 89 | input_data = common.convolutional(input_data, (1, 1, 256, 128), self.trainable, 'conv68') 90 | 91 | conv_sobj_branch = common.convolutional(input_data, (3, 3, 128, 256), self.trainable, name='conv_sobj_branch') 92 | conv_sbbox = common.convolutional(conv_sobj_branch, (1, 1, 256, 3*(self.num_class + 5)), 93 | trainable=self.trainable, name='conv_sbbox', activate=False, bn=False) 94 | 95 | return conv_lbbox, conv_mbbox, conv_sbbox 96 | 97 | def decode(self, conv_output, anchors, stride): 98 | """ 99 | return tensor of shape [batch_size, output_size, output_size, anchor_per_scale, 5 + num_classes] 100 | contains (x, y, w, h, score, probability) 101 | """ 102 | 103 | conv_shape = tf.shape(conv_output) 104 | batch_size = conv_shape[0] 105 | output_size = conv_shape[1] 106 | anchor_per_scale = len(anchors) 107 | 108 | conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, anchor_per_scale, 5 + self.num_class)) 109 | 110 | conv_raw_dxdy = conv_output[:, :, :, :, 0:2] 111 | conv_raw_dwdh = conv_output[:, :, :, :, 2:4] 112 | conv_raw_conf = conv_output[:, :, :, :, 4:5] 113 | conv_raw_prob = conv_output[:, :, :, :, 5: ] 114 | 115 | y = tf.tile(tf.range(output_size, dtype=tf.int32)[:, tf.newaxis], [1, output_size]) 116 | x = tf.tile(tf.range(output_size, dtype=tf.int32)[tf.newaxis, :], [output_size, 1]) 117 | 118 | xy_grid = tf.concat([x[:, :, tf.newaxis], y[:, :, tf.newaxis]], axis=-1) 119 | xy_grid = tf.tile(xy_grid[tf.newaxis, :, :, tf.newaxis, :], [batch_size, 1, 1, anchor_per_scale, 1]) 120 | xy_grid = tf.cast(xy_grid, tf.float32) 121 | 122 | pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride 123 | pred_wh = (tf.exp(conv_raw_dwdh) * anchors) * stride 124 | pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1) 125 | 126 | pred_conf = tf.sigmoid(conv_raw_conf) 127 | pred_prob = tf.sigmoid(conv_raw_prob) 128 | 129 | return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1) 130 | 131 | def focal(self, target, actual, alpha=1, gamma=2): 132 | focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma) 133 | return focal_loss 134 | 135 | def bbox_giou(self, boxes1, boxes2): 136 | 137 | boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5, 138 | boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1) 139 | boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5, 140 | boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1) 141 | 142 | boxes1 = tf.concat([tf.minimum(boxes1[..., :2], boxes1[..., 2:]), 143 | tf.maximum(boxes1[..., :2], boxes1[..., 2:])], axis=-1) 144 | boxes2 = tf.concat([tf.minimum(boxes2[..., :2], boxes2[..., 2:]), 145 | tf.maximum(boxes2[..., :2], boxes2[..., 2:])], axis=-1) 146 | 147 | boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1]) 148 | boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1]) 149 | 150 | left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2]) 151 | right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:]) 152 | 153 | inter_section = tf.maximum(right_down - left_up, 0.0) 154 | inter_area = inter_section[..., 0] * inter_section[..., 1] 155 | union_area = boxes1_area + boxes2_area - inter_area 156 | iou = inter_area / (union_area + 1e-6) 157 | # added 1e-6 in denominator to avoid generation of inf, which may cause nan loss 158 | 159 | enclose_left_up = tf.minimum(boxes1[..., :2], boxes2[..., :2]) 160 | enclose_right_down = tf.maximum(boxes1[..., 2:], boxes2[..., 2:]) 161 | enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0) 162 | enclose_area = enclose[..., 0] * enclose[..., 1] 163 | giou = iou - 1.0 * (enclose_area - union_area) / (enclose_area + 1e-6) 164 | # added 1e-6 in denominator to avoid generation of inf, which may cause nan loss 165 | 166 | return giou 167 | 168 | def bbox_iou(self, boxes1, boxes2): 169 | 170 | boxes1_area = boxes1[..., 2] * boxes1[..., 3] 171 | boxes2_area = boxes2[..., 2] * boxes2[..., 3] 172 | 173 | boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5, 174 | boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1) 175 | boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5, 176 | boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1) 177 | 178 | left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2]) 179 | right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:]) 180 | 181 | inter_section = tf.maximum(right_down - left_up, 0.0) 182 | inter_area = inter_section[..., 0] * inter_section[..., 1] 183 | union_area = boxes1_area + boxes2_area - inter_area 184 | iou = 1.0 * inter_area / union_area 185 | 186 | return iou 187 | 188 | def loss_layer(self, conv, pred, label, bboxes, anchors, stride): 189 | 190 | conv_shape = tf.shape(conv) 191 | batch_size = conv_shape[0] 192 | output_size = conv_shape[1] 193 | input_size = stride * output_size 194 | conv = tf.reshape(conv, (batch_size, output_size, output_size, 195 | self.anchor_per_scale, 5 + self.num_class)) 196 | conv_raw_conf = conv[:, :, :, :, 4:5] 197 | conv_raw_prob = conv[:, :, :, :, 5:] 198 | 199 | pred_xywh = pred[:, :, :, :, 0:4] 200 | pred_conf = pred[:, :, :, :, 4:5] 201 | 202 | label_xywh = label[:, :, :, :, 0:4] 203 | respond_bbox = label[:, :, :, :, 4:5] 204 | label_prob = label[:, :, :, :, 5:] 205 | 206 | giou = tf.expand_dims(self.bbox_giou(pred_xywh, label_xywh), axis=-1) 207 | input_size = tf.cast(input_size, tf.float32) 208 | 209 | bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2) 210 | giou_loss = respond_bbox * bbox_loss_scale * (1- giou) 211 | 212 | iou = self.bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :]) 213 | max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1) 214 | 215 | respond_bgd = (1.0 - respond_bbox) * tf.cast( max_iou < self.iou_loss_thresh, tf.float32 ) 216 | 217 | conf_focal = self.focal(respond_bbox, pred_conf) 218 | 219 | conf_loss = conf_focal * ( 220 | respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf) 221 | + 222 | respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf) 223 | ) 224 | 225 | prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob) 226 | 227 | giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis=[1,2,3,4])) 228 | conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1,2,3,4])) 229 | prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1,2,3,4])) 230 | 231 | return giou_loss, conf_loss, prob_loss 232 | 233 | 234 | 235 | def compute_loss(self, label_sbbox, label_mbbox, label_lbbox, true_sbbox, true_mbbox, true_lbbox): 236 | 237 | with tf.name_scope('smaller_box_loss'): 238 | loss_sbbox = self.loss_layer(self.conv_sbbox, self.pred_sbbox, label_sbbox, true_sbbox, 239 | anchors = self.anchors[0], stride = self.strides[0]) 240 | 241 | with tf.name_scope('medium_box_loss'): 242 | loss_mbbox = self.loss_layer(self.conv_mbbox, self.pred_mbbox, label_mbbox, true_mbbox, 243 | anchors = self.anchors[1], stride = self.strides[1]) 244 | 245 | with tf.name_scope('bigger_box_loss'): 246 | loss_lbbox = self.loss_layer(self.conv_lbbox, self.pred_lbbox, label_lbbox, true_lbbox, 247 | anchors = self.anchors[2], stride = self.strides[2]) 248 | 249 | with tf.name_scope('giou_loss'): 250 | giou_loss = loss_sbbox[0] + loss_mbbox[0] + loss_lbbox[0] 251 | 252 | with tf.name_scope('conf_loss'): 253 | conf_loss = loss_sbbox[1] + loss_mbbox[1] + loss_lbbox[1] 254 | 255 | with tf.name_scope('prob_loss'): 256 | prob_loss = loss_sbbox[2] + loss_mbbox[2] + loss_lbbox[2] 257 | 258 | return giou_loss, conf_loss, prob_loss 259 | 260 | 261 | -------------------------------------------------------------------------------- /data/anchors/basline_anchors.txt: -------------------------------------------------------------------------------- 1 | 1.25,1.625, 2.0,3.75, 4.125,2.875, 1.875,3.8125, 3.875,2.8125, 3.6875,7.4375, 3.625,2.8125, 4.875,6.1875, 11.65625,10.1875 2 | -------------------------------------------------------------------------------- /data/anchors/coco_anchors.txt: -------------------------------------------------------------------------------- 1 | 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 2 | -------------------------------------------------------------------------------- /data/classes/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /data/classes/voc.names: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /data/dataset/voc_test.txt: -------------------------------------------------------------------------------- 1 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000001.jpg 48,240,195,371,11 8,12,352,498,14 2 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000002.jpg 139,200,207,301,18 3 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000003.jpg 123,155,215,195,17 239,156,307,205,8 4 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000004.jpg 13,311,84,362,6 362,330,500,389,6 235,328,334,375,6 175,327,252,364,6 139,320,189,359,6 108,325,150,353,6 84,323,121,350,6 5 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000006.jpg 187,135,282,242,15 154,209,369,375,10 255,207,366,375,8 138,211,249,375,8 6 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000008.jpg 192,16,364,249,8 7 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000010.jpg 87,97,258,427,12 133,72,245,284,14 8 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000011.jpg 126,51,330,308,7 9 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000013.jpg 299,160,446,252,9 10 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000014.jpg 72,163,302,228,5 185,194,500,316,6 416,180,500,222,6 314,8,344,65,14 331,4,361,61,14 357,8,401,61,14 11 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000015.jpg 77,136,360,358,1 12 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000018.jpg 31,30,358,279,11 13 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000022.jpg 68,103,368,283,12 186,44,255,230,14 14 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000025.jpg 2,84,59,248,9 68,115,233,279,9 64,173,377,373,9 320,2,496,375,14 221,4,341,374,14 135,14,220,148,14 69,43,156,177,9 58,54,104,139,14 279,1,331,86,14 320,22,344,96,14 15 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000027.jpg 174,101,349,351,14 16 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000028.jpg 63,18,374,500,7 17 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000029.jpg 56,63,284,290,11 18 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000031.jpg 41,77,430,255,18 19 | /home/yang/test/VOC/test/VOCdevkit/VOC2007/JPEGImages/000037.jpg 61,96,464,339,11 20 | -------------------------------------------------------------------------------- /data/dataset/voc_train.txt: -------------------------------------------------------------------------------- 1 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000005.jpg 263,211,324,339,8 165,264,253,372,8 241,194,295,299,8 2 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000007.jpg 141,50,500,330,6 3 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000009.jpg 69,172,270,330,12 150,141,229,284,14 285,201,327,331,14 258,198,297,329,14 4 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000012.jpg 156,97,351,270,6 5 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000016.jpg 92,72,305,473,1 6 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000017.jpg 185,62,279,199,14 90,78,403,336,12 7 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000019.jpg 231,88,483,256,7 11,113,266,259,7 8 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000020.jpg 33,148,371,416,6 9 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000021.jpg 1,235,182,388,11 210,36,336,482,14 46,82,170,365,14 11,181,142,419,14 10 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000023.jpg 9,230,245,500,1 230,220,334,500,1 2,1,117,369,14 3,2,243,462,14 225,1,334,486,14 11 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000024.jpg 196,165,489,247,18 12 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000026.jpg 90,125,337,212,6 13 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000030.jpg 36,205,180,289,1 51,160,150,292,14 295,138,450,290,14 14 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000032.jpg 104,78,375,183,0 133,88,197,123,0 195,180,213,229,14 26,189,44,238,14 15 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000033.jpg 9,107,499,263,0 421,200,482,226,0 325,188,411,223,0 16 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000034.jpg 116,167,360,400,18 141,153,333,229,18 17 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000035.jpg 1,96,191,361,14 218,98,465,318,14 18 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000036.jpg 27,79,319,344,11 19 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000039.jpg 156,89,344,279,19 20 | /home/yang/test/VOC/train/VOCdevkit/VOC2007/JPEGImages/000041.jpg 363,47,432,107,19 216,92,307,302,14 164,148,227,244,14 21 | -------------------------------------------------------------------------------- /docs/images/.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/docs/images/.jpg -------------------------------------------------------------------------------- /docs/images/611_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/docs/images/611_result.jpg -------------------------------------------------------------------------------- /docs/images/road.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/docs/images/road.jpeg -------------------------------------------------------------------------------- /docs/images/road.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/docs/images/road.mp4 -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.15.1 2 | Pillow==5.3.0 3 | scipy==1.1.0 4 | tensorflow-gpu==1.11.0 5 | wget==3.2 6 | seaborn==0.9.0 7 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : evaluate.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-21 15:30:26 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import cv2 15 | import os 16 | import shutil 17 | import numpy as np 18 | import tensorflow as tf 19 | import core.utils as utils 20 | from core.config import cfg 21 | from core.yolov3 import YOLOV3 22 | 23 | class YoloTest(object): 24 | def __init__(self): 25 | self.input_size = cfg.TEST.INPUT_SIZE 26 | self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE 27 | self.classes = utils.read_class_names(cfg.YOLO.CLASSES) 28 | self.num_classes = len(self.classes) 29 | self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS)) 30 | self.score_threshold = cfg.TEST.SCORE_THRESHOLD 31 | self.iou_threshold = cfg.TEST.IOU_THRESHOLD 32 | self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY 33 | self.annotation_path = cfg.TEST.ANNOT_PATH 34 | self.weight_file = cfg.TEST.WEIGHT_FILE 35 | self.write_image = cfg.TEST.WRITE_IMAGE 36 | self.write_image_path = cfg.TEST.WRITE_IMAGE_PATH 37 | self.show_label = cfg.TEST.SHOW_LABEL 38 | 39 | with tf.name_scope('input'): 40 | self.input_data = tf.placeholder(dtype=tf.float32, name='input_data') 41 | self.trainable = tf.placeholder(dtype=tf.bool, name='trainable') 42 | 43 | model = YOLOV3(self.input_data, self.trainable) 44 | self.pred_sbbox, self.pred_mbbox, self.pred_lbbox = model.pred_sbbox, model.pred_mbbox, model.pred_lbbox 45 | 46 | with tf.name_scope('ema'): 47 | ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay) 48 | 49 | self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 50 | self.saver = tf.train.Saver(ema_obj.variables_to_restore()) 51 | self.saver.restore(self.sess, self.weight_file) 52 | 53 | def predict(self, image): 54 | 55 | org_image = np.copy(image) 56 | org_h, org_w, _ = org_image.shape 57 | 58 | image_data = utils.image_preporcess(image, [self.input_size, self.input_size]) 59 | image_data = image_data[np.newaxis, ...] 60 | 61 | pred_sbbox, pred_mbbox, pred_lbbox = self.sess.run( 62 | [self.pred_sbbox, self.pred_mbbox, self.pred_lbbox], 63 | feed_dict={ 64 | self.input_data: image_data, 65 | self.trainable: False 66 | } 67 | ) 68 | 69 | pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + self.num_classes)), 70 | np.reshape(pred_mbbox, (-1, 5 + self.num_classes)), 71 | np.reshape(pred_lbbox, (-1, 5 + self.num_classes))], axis=0) 72 | bboxes = utils.postprocess_boxes(pred_bbox, (org_h, org_w), self.input_size, self.score_threshold) 73 | bboxes = utils.nms(bboxes, self.iou_threshold) 74 | 75 | return bboxes 76 | 77 | def evaluate(self): 78 | predicted_dir_path = './mAP/predicted' 79 | ground_truth_dir_path = './mAP/ground-truth' 80 | if os.path.exists(predicted_dir_path): shutil.rmtree(predicted_dir_path) 81 | if os.path.exists(ground_truth_dir_path): shutil.rmtree(ground_truth_dir_path) 82 | if os.path.exists(self.write_image_path): shutil.rmtree(self.write_image_path) 83 | os.mkdir(predicted_dir_path) 84 | os.mkdir(ground_truth_dir_path) 85 | os.mkdir(self.write_image_path) 86 | 87 | with open(self.annotation_path, 'r') as annotation_file: 88 | for num, line in enumerate(annotation_file): 89 | annotation = line.strip().split() 90 | image_path = annotation[0] 91 | image_name = image_path.split('/')[-1] 92 | image = cv2.imread(image_path) 93 | bbox_data_gt = np.array([list(map(int, box.split(','))) for box in annotation[1:]]) 94 | 95 | if len(bbox_data_gt) == 0: 96 | bboxes_gt=[] 97 | classes_gt=[] 98 | else: 99 | bboxes_gt, classes_gt = bbox_data_gt[:, :4], bbox_data_gt[:, 4] 100 | ground_truth_path = os.path.join(ground_truth_dir_path, str(num) + '.txt') 101 | 102 | print('=> ground truth of %s:' % image_name) 103 | num_bbox_gt = len(bboxes_gt) 104 | with open(ground_truth_path, 'w') as f: 105 | for i in range(num_bbox_gt): 106 | class_name = self.classes[classes_gt[i]] 107 | xmin, ymin, xmax, ymax = list(map(str, bboxes_gt[i])) 108 | bbox_mess = ' '.join([class_name, xmin, ymin, xmax, ymax]) + '\n' 109 | f.write(bbox_mess) 110 | print('\t' + str(bbox_mess).strip()) 111 | print('=> predict result of %s:' % image_name) 112 | predict_result_path = os.path.join(predicted_dir_path, str(num) + '.txt') 113 | bboxes_pr = self.predict(image) 114 | 115 | if self.write_image: 116 | image = utils.draw_bbox(image, bboxes_pr, show_label=self.show_label) 117 | cv2.imwrite(self.write_image_path+image_name, image) 118 | 119 | with open(predict_result_path, 'w') as f: 120 | for bbox in bboxes_pr: 121 | coor = np.array(bbox[:4], dtype=np.int32) 122 | score = bbox[4] 123 | class_ind = int(bbox[5]) 124 | class_name = self.classes[class_ind] 125 | score = '%.4f' % score 126 | xmin, ymin, xmax, ymax = list(map(str, coor)) 127 | bbox_mess = ' '.join([class_name, score, xmin, ymin, xmax, ymax]) + '\n' 128 | f.write(bbox_mess) 129 | print('\t' + str(bbox_mess).strip()) 130 | 131 | def voc_2012_test(self, voc2012_test_path): 132 | 133 | img_inds_file = os.path.join(voc2012_test_path, 'ImageSets', 'Main', 'test.txt') 134 | with open(img_inds_file, 'r') as f: 135 | txt = f.readlines() 136 | image_inds = [line.strip() for line in txt] 137 | 138 | results_path = 'results/VOC2012/Main' 139 | if os.path.exists(results_path): 140 | shutil.rmtree(results_path) 141 | os.makedirs(results_path) 142 | 143 | for image_ind in image_inds: 144 | image_path = os.path.join(voc2012_test_path, 'JPEGImages', image_ind + '.jpg') 145 | image = cv2.imread(image_path) 146 | 147 | print('predict result of %s:' % image_ind) 148 | bboxes_pr = self.predict(image) 149 | for bbox in bboxes_pr: 150 | coor = np.array(bbox[:4], dtype=np.int32) 151 | score = bbox[4] 152 | class_ind = int(bbox[5]) 153 | class_name = self.classes[class_ind] 154 | score = '%.4f' % score 155 | xmin, ymin, xmax, ymax = list(map(str, coor)) 156 | bbox_mess = ' '.join([image_ind, score, xmin, ymin, xmax, ymax]) + '\n' 157 | with open(os.path.join(results_path, 'comp4_det_test_' + class_name + '.txt'), 'a') as f: 158 | f.write(bbox_mess) 159 | print('\t' + str(bbox_mess).strip()) 160 | 161 | 162 | if __name__ == '__main__': YoloTest().evaluate() 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /freeze_graph.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : freeze_graph.py 8 | # Author : YunYang1994 9 | # Created date: 2019-03-20 15:57:33 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | 15 | import tensorflow as tf 16 | from core.yolov3 import YOLOV3 17 | 18 | pb_file = "./yolov3_coco.pb" 19 | ckpt_file = "./checkpoint/yolov3_coco_demo.ckpt" 20 | output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"] 21 | 22 | with tf.name_scope('input'): 23 | input_data = tf.placeholder(dtype=tf.float32, name='input_data') 24 | 25 | model = YOLOV3(input_data, trainable=False) 26 | print(model.conv_sbbox, model.conv_mbbox, model.conv_lbbox) 27 | 28 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 29 | saver = tf.train.Saver() 30 | saver.restore(sess, ckpt_file) 31 | 32 | converted_graph_def = tf.graph_util.convert_variables_to_constants(sess, 33 | input_graph_def = sess.graph.as_graph_def(), 34 | output_node_names = output_node_names) 35 | 36 | with tf.gfile.GFile(pb_file, "wb") as f: 37 | f.write(converted_graph_def.SerializeToString()) 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /from_darknet_weights_to_ckpt.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from core.yolov3 import YOLOV3 3 | 4 | iput_size = 416 5 | darknet_weights = '' 6 | ckpt_file = './checkpoint/yolov3_coco.ckpt' 7 | 8 | def load_weights(var_list, weights_file): 9 | """ 10 | Loads and converts pre-trained weights. 11 | :param var_list: list of network variables. 12 | :param weights_file: name of the binary file. 13 | :return: list of assign ops 14 | """ 15 | with open(weights_file, "rb") as fp: 16 | _ = np.fromfile(fp, dtype=np.int32, count=5) 17 | weights = np.fromfile(fp, dtype=np.float32) # np.ndarray 18 | print('weights_num:', weights.shape[0]) 19 | ptr = 0 20 | i = 0 21 | assign_ops = [] 22 | while i < len(var_list) - 1: 23 | var1 = var_list[i] 24 | var2 = var_list[i + 1] 25 | # do something only if we process conv layer 26 | if 'conv' in var1.name.split('/')[-2]: 27 | # check type of next layer 28 | if 'batch_normalization' in var2.name.split('/')[-2]: 29 | # load batch norm params 30 | gamma, beta, mean, var = var_list[i + 1:i + 5] 31 | batch_norm_vars = [beta, gamma, mean, var] 32 | for vari in batch_norm_vars: 33 | shape = vari.shape.as_list() 34 | num_params = np.prod(shape) 35 | vari_weights = weights[ptr:ptr + num_params].reshape(shape) 36 | ptr += num_params 37 | assign_ops.append( 38 | tf.assign(vari, vari_weights, validate_shape=True)) 39 | i += 4 40 | elif 'conv' in var2.name.split('/')[-2]: 41 | # load biases 42 | bias = var2 43 | bias_shape = bias.shape.as_list() 44 | bias_params = np.prod(bias_shape) 45 | bias_weights = weights[ptr:ptr + 46 | bias_params].reshape(bias_shape) 47 | ptr += bias_params 48 | assign_ops.append( 49 | tf.assign(bias, bias_weights, validate_shape=True)) 50 | i += 1 51 | shape = var1.shape.as_list() 52 | num_params = np.prod(shape) 53 | 54 | var_weights = weights[ptr:ptr + num_params].reshape( 55 | (shape[3], shape[2], shape[0], shape[1])) 56 | # remember to transpose to column-major 57 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 58 | ptr += num_params 59 | assign_ops.append( 60 | tf.assign(var1, var_weights, validate_shape=True)) 61 | i += 1 62 | print('ptr:', ptr) 63 | return assign_ops 64 | 65 | with tf.name_scope('input'): 66 | input_data = tf.placeholder(dtype=tf.float32,shape=(None, iput_size, iput_size, 3), name='input_data') 67 | model = YOLOV3(input_data, trainable=False) 68 | load_ops = load_weights(tf.global_variables(), darknet_weights) 69 | 70 | saver = tf.train.Saver(tf.global_variables()) 71 | 72 | with tf.Session() as sess: 73 | sess.run(load_ops) 74 | save_path = saver.save(sess, save_path=ckpt_file) 75 | print('Model saved in path: {}'.format(save_path)) 76 | -------------------------------------------------------------------------------- /from_darknet_weights_to_pb.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from core.yolov3 import YOLOV3 3 | from from_darknet_weights_to_ckpt import load_weights 4 | 5 | input_size = 416 6 | darknet_weights = '' 7 | pb_file = './yolov3.pb' 8 | output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"] 9 | 10 | with tf.name_scope('input'): 11 | input_data = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data') 12 | model = YOLOV3(input_data, trainable=False) 13 | load_ops = load_weights(tf.global_variables(), darknet_weights) 14 | 15 | with tf.Session() as sess: 16 | sess.run(load_ops) 17 | output_graph_def = tf.graph_util.convert_variables_to_constants( 18 | sess, 19 | tf.get_default_graph().as_graph_def(), 20 | output_node_names=output_node_names 21 | ) 22 | 23 | with tf.gfile.GFile(output_graph, "wb") as f: 24 | f.write(output_graph_def.SerializeToString()) 25 | 26 | print("{} ops written to {}.".format(len(output_graph_def.node), output_graph)) 27 | -------------------------------------------------------------------------------- /image_demo.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : image_demo.py 8 | # Author : YunYang1994 9 | # Created date: 2019-01-20 16:06:06 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import cv2 15 | import numpy as np 16 | import core.utils as utils 17 | import tensorflow as tf 18 | from PIL import Image 19 | 20 | return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] 21 | pb_file = "./yolov3_coco.pb" 22 | image_path = "./docs/images/road.jpeg" 23 | num_classes = 80 24 | input_size = 416 25 | graph = tf.Graph() 26 | 27 | original_image = cv2.imread(image_path) 28 | original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) 29 | original_image_size = original_image.shape[:2] 30 | image_data = utils.image_preporcess(np.copy(original_image), [input_size, input_size]) 31 | image_data = image_data[np.newaxis, ...] 32 | 33 | return_tensors = utils.read_pb_return_tensors(graph, pb_file, return_elements) 34 | 35 | 36 | with tf.Session(graph=graph) as sess: 37 | pred_sbbox, pred_mbbox, pred_lbbox = sess.run( 38 | [return_tensors[1], return_tensors[2], return_tensors[3]], 39 | feed_dict={ return_tensors[0]: image_data}) 40 | 41 | pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + num_classes)), 42 | np.reshape(pred_mbbox, (-1, 5 + num_classes)), 43 | np.reshape(pred_lbbox, (-1, 5 + num_classes))], axis=0) 44 | 45 | bboxes = utils.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.3) 46 | bboxes = utils.nms(bboxes, 0.45, method='nms') 47 | image = utils.draw_bbox(original_image, bboxes) 48 | image = Image.fromarray(image) 49 | image.show() 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /mAP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunYang1994/tensorflow-yolov3/03cb272af2e26d598c553f3a2d38024fc6f67a0b/mAP/__init__.py -------------------------------------------------------------------------------- /mAP/extra/README.md: -------------------------------------------------------------------------------- 1 | # Extra 2 | 3 | ## Ground-Truth: 4 | - ### convert `xml` to our format: 5 | 6 | 1) Insert ground-truth xml files into **ground-truth/** 7 | 2) Run the python script: `python convert_gt_xml.py` 8 | 9 | - ### convert YOLO to our format: 10 | 11 | 1) Add class list to the file `class_list.txt` 12 | 2) Insert ground-truth files into **ground-truth/** 13 | 3) Insert images into **images/** 14 | 4) Run the python script: `python convert_gt_yolo.py` 15 | 16 | - ### convert keras-yolo3 to our format: 17 | 18 | 1) Add or update the class list to the file `class_list.txt` 19 | 2) Use the parameter `--gt` to set the **ground-truth** source. 20 | 3) Run the python script: `python3 convert_keras-yolo3.py --gt ` 21 | 1) Supports only python 3. 22 | 2) This code can handle recursive annotation structure. Just use the `-r` parameter. 23 | 3) The converted annotation is placed by default in a new from_kerasyolo3 folder. You can change that with the parameter `-o`. 24 | 4) The format is defined according with github.com/qqwweee/keras-yolo3 25 | 26 | ## Predicted: 27 | - ### convert darkflow `json` to our format: 28 | 29 | 1) Insert result json files into **predicted/** 30 | 2) Run the python script: `python convert_pred_darkflow_json.py` 31 | 32 | - ### convert YOLO to our format: 33 | 34 | After runnuning darknet on a list of images, e.g.: `darknet.exe detector test data/voc.data yolo-voc.cfg yolo-voc.weights -dont_show -ext_output < data/test.txt > result.txt` 35 | 36 | 1) Copy the file `result.txt` to the folder `extra/` 37 | 2) Run the python script: `python convert_pred_yolo.py` 38 | 39 | - ### convert keras-yolo3 to our format: 40 | 41 | 1) Add or update the class list to the file `class_list.txt` 42 | 2) Use the parameter `--predicted` to set the **prediction** source. 43 | 3) Run the python script: `python3 convert_keras-yolo3.py --pred ` 44 | 1) Supports only python 3. 45 | 2) This code can handle recursive annotation structure. Just use the `-r` parameter. 46 | 3) The converted annotation is placed by default in a new from_kerasyolo3 folder. You can change that with the parameter `-o`. 47 | 4) The format is defined according with github.com/gustavovaliati/keras-yolo3 48 | 49 | ## Remove specific char delimiter from files 50 | 51 | E.g. remove `;` from: 52 | 53 | `;;;;` 54 | 55 | to: 56 | 57 | ` ` 58 | 59 | In the case you have the `--ground-truth` or `--predicted` files in the right format but with a specific char being used as a delimiter (e.g. `";"`), you can remove it by running: 60 | 61 | `python remove_delimiter_char.py --char ";" --ground-truth` 62 | 63 | ## Find the files that contain a specific class of objects 64 | 65 | 1) Run the `find_class.py` script and specify the **class** as argument, e.g. 66 | `python find_class.py chair` 67 | 68 | ## Remove all the instances of a specific class of objects 69 | 70 | 1) Run the `remove_class.py` script and specify the **class** as argument, e.g. 71 | `python remove_class.py chair` 72 | 73 | ## Rename a specific class of objects 74 | 75 | 1) Run the `rename_class.py` script and specify the `--current-class-name` and `--new-class-name` as arguments, e.g. 76 | 77 | `python rename_class.py --current-class-name Picture Frame --new-class-name PictureFrame` 78 | 79 | ## Rename all classes by replacing spaces with delimiters 80 | Use this option instead of the above option when you have a lot of classes with spaces. 81 | It's useful when renaming classes with spaces become tedious (because you have a lot of them). 82 | 83 | 1) Add class list to the file `class_list.txt` (the script will search this file for class names with spaces) 84 | 2) Run the `remove_space.py` script and specify the `--delimiter` (default: "-") and `--yes` if you want to force confirmation on all yes/no queries, e.g. 85 | 86 | `python remove_space.py --delimiter "-" --yes` 87 | 88 | ## Intersect ground-truth and predicted files 89 | This script ensures same number of files in ground-truth and predicted folder. 90 | When you encounter file not found error, it's usually because you have 91 | mismatched numbers of ground-truth and predicted files. 92 | You can use this script to move ground-truth and predicted files that are 93 | not in the intersection into a backup folder (backup_no_matches_found). 94 | This will retain only files that have the same name in both folders. 95 | 96 | 1) Prepare `.txt` files in your `ground-truth` and `predicted` folders. 97 | 2) Run the `intersect-gt-and-pred.py` script to move non-intersected files into a backup folder (default: `backup_no_matches_found`). 98 | 99 | `python intersect-gt-and-pred.py` 100 | -------------------------------------------------------------------------------- /mAP/extra/class_list.txt: -------------------------------------------------------------------------------- 1 | bed 2 | person 3 | pictureframe 4 | shirt 5 | lamp 6 | nightstand 7 | clock 8 | heater 9 | windowblind 10 | pillow 11 | robot 12 | cabinetry 13 | door 14 | doorhandle 15 | shelf 16 | pottedplant 17 | chair 18 | diningtable 19 | backpack 20 | whiteboard 21 | cup 22 | tvmonitor 23 | pen 24 | pencil 25 | wardrobe 26 | apple 27 | orange 28 | countertop 29 | tap 30 | banana 31 | bicyclehelmet 32 | book 33 | bookcase 34 | refrigerator 35 | wastecontainer 36 | tincan 37 | handbag 38 | sofa 39 | glasses 40 | vase 41 | coffeetable 42 | bowl 43 | remote 44 | candle 45 | bottle 46 | sink 47 | envelope 48 | doll 49 | -------------------------------------------------------------------------------- /mAP/extra/convert_gt_xml.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import xml.etree.ElementTree as ET 5 | 6 | 7 | # change directory to the one with the files to be changed 8 | path_to_folder = '../ground-truth' 9 | #print(path_to_folder) 10 | os.chdir(path_to_folder) 11 | 12 | # old files (xml format) will be moved to a "backup" folder 13 | ## create the backup dir if it doesn't exist already 14 | if not os.path.exists("backup"): 15 | os.makedirs("backup") 16 | 17 | # create VOC format files 18 | xml_list = glob.glob('*.xml') 19 | if len(xml_list) == 0: 20 | print("Error: no .xml files found in ground-truth") 21 | sys.exit() 22 | for tmp_file in xml_list: 23 | #print(tmp_file) 24 | # 1. create new file (VOC format) 25 | with open(tmp_file.replace(".xml", ".txt"), "a") as new_f: 26 | root = ET.parse(tmp_file).getroot() 27 | for obj in root.findall('object'): 28 | obj_name = obj.find('name').text 29 | bndbox = obj.find('bndbox') 30 | left = bndbox.find('xmin').text 31 | top = bndbox.find('ymin').text 32 | right = bndbox.find('xmax').text 33 | bottom = bndbox.find('ymax').text 34 | new_f.write(obj_name + " " + left + " " + top + " " + right + " " + bottom + '\n') 35 | # 2. move old file (xml format) to backup 36 | os.rename(tmp_file, "backup/" + tmp_file) 37 | print("Conversion completed!") 38 | -------------------------------------------------------------------------------- /mAP/extra/convert_gt_yolo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import cv2 5 | 6 | 7 | def convert_yolo_coordinates_to_voc(x_c_n, y_c_n, width_n, height_n, img_width, img_height): 8 | ## remove normalization given the size of the image 9 | x_c = float(x_c_n) * img_width 10 | y_c = float(y_c_n) * img_height 11 | width = float(width_n) * img_width 12 | height = float(height_n) * img_height 13 | ## compute half width and half height 14 | half_width = width / 2 15 | half_height = height / 2 16 | ## compute left, top, right, bottom 17 | ## in the official VOC challenge the top-left pixel in the image has coordinates (1;1) 18 | left = int(x_c - half_width) + 1 19 | top = int(y_c - half_height) + 1 20 | right = int(x_c + half_width) + 1 21 | bottom = int(y_c + half_height) + 1 22 | return left, top, right, bottom 23 | 24 | # read the class_list.txt to a list 25 | with open("class_list.txt") as f: 26 | obj_list = f.readlines() 27 | ## remove whitespace characters like `\n` at the end of each line 28 | obj_list = [x.strip() for x in obj_list] 29 | ## e.g. first object in the list 30 | #print(obj_list[0]) 31 | 32 | # change directory to the one with the files to be changed 33 | path_to_folder = '../ground-truth' 34 | #print(path_to_folder) 35 | os.chdir(path_to_folder) 36 | 37 | # old files (YOLO format) will be moved to a new folder (backup/) 38 | ## create the backup dir if it doesn't exist already 39 | if not os.path.exists("backup"): 40 | os.makedirs("backup") 41 | 42 | # create VOC format files 43 | txt_list = glob.glob('*.txt') 44 | if len(txt_list) == 0: 45 | print("Error: no .txt files found in ground-truth") 46 | sys.exit() 47 | for tmp_file in txt_list: 48 | #print(tmp_file) 49 | # 1. check that there is an image with that name 50 | ## get name before ".txt" 51 | image_name = tmp_file.split(".txt",1)[0] 52 | #print(image_name) 53 | ## check if image exists 54 | for fname in os.listdir('../images'): 55 | if fname.startswith(image_name): 56 | ## image found 57 | #print(fname) 58 | img = cv2.imread('../images/' + fname) 59 | ## get image width and height 60 | img_height, img_width = img.shape[:2] 61 | break 62 | else: 63 | ## image not found 64 | print("Error: image not found, corresponding to " + tmp_file) 65 | sys.exit() 66 | # 2. open txt file lines to a list 67 | with open(tmp_file) as f: 68 | content = f.readlines() 69 | ## remove whitespace characters like `\n` at the end of each line 70 | content = [x.strip() for x in content] 71 | # 3. move old file (YOLO format) to backup 72 | os.rename(tmp_file, "backup/" + tmp_file) 73 | # 4. create new file (VOC format) 74 | with open(tmp_file, "a") as new_f: 75 | for line in content: 76 | ## split a line by spaces. 77 | ## "c" stands for center and "n" stands for normalized 78 | obj_id, x_c_n, y_c_n, width_n, height_n = line.split() 79 | obj_name = obj_list[int(obj_id)] 80 | left, top, right, bottom = convert_yolo_coordinates_to_voc(x_c_n, y_c_n, width_n, height_n, img_width, img_height) 81 | ## add new line to file 82 | #print(obj_name + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom)) 83 | new_f.write(obj_name + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom) + '\n') 84 | print("Conversion completed!") 85 | -------------------------------------------------------------------------------- /mAP/extra/convert_keras-yolo3.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ABOUT THIS SCRIPT: 3 | Converts ground-truth from the annotation files 4 | according to the https://github.com/qqwweee/keras-yolo3 5 | or https://github.com/gustavovaliati/keras-yolo3 format. 6 | 7 | And converts the predicitons from the annotation files 8 | according to the https://github.com/gustavovaliati/keras-yolo3 format. 9 | ''' 10 | 11 | import argparse 12 | import datetime 13 | import os 14 | 15 | ''' 16 | Each time this script runs, it saves the output in a different path 17 | controlled by the following folder suffix: annotation_version. 18 | ''' 19 | annotation_version = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 20 | 21 | ap = argparse.ArgumentParser() 22 | 23 | ap.add_argument("-o", "--output_path", 24 | required=False, 25 | default='from_kerasyolo3/version_{}'.format(annotation_version), 26 | type=str, 27 | help="The dataset root path location.") 28 | ap.add_argument("-r", "--gen_recursive", 29 | required=False, 30 | default=False, 31 | action="store_true", 32 | help="Define if the output txt files will be placed in a \ 33 | recursive folder tree or to direct txt files.") 34 | group = ap.add_mutually_exclusive_group(required=True) 35 | group.add_argument('--gt', 36 | type=str, 37 | default=None, 38 | help="The annotation file that refers to ground-truth in (keras-yolo3 format)") 39 | group.add_argument('--pred', 40 | type=str, 41 | default=None, 42 | help="The annotation file that refers to predictions in (keras-yolo3 format)") 43 | 44 | ARGS = ap.parse_args() 45 | 46 | with open('class_list.txt', 'r') as class_file: 47 | class_map = class_file.readlines() 48 | print(class_map) 49 | annotation_file = ARGS.gt if ARGS.gt else ARGS.pred 50 | 51 | os.makedirs(ARGS.output_path, exist_ok=True) 52 | 53 | with open(annotation_file, 'r') as annot_f: 54 | for annot in annot_f: 55 | annot = annot.split(' ') 56 | img_path = annot[0].strip() 57 | if ARGS.gen_recursive: 58 | annotation_dir_name = os.path.dirname(img_path) 59 | # remove the root path to enable to path.join. 60 | if annotation_dir_name.startswith('/'): 61 | annotation_dir_name = annotation_dir_name.replace('/', '', 1) 62 | destination_dir = os.path.join(ARGS.output_path, annotation_dir_name) 63 | os.makedirs(destination_dir, exist_ok=True) 64 | # replace .jpg with your image format. 65 | file_name = os.path.basename(img_path).replace('.jpg', '.txt') 66 | output_file_path = os.path.join(destination_dir, file_name) 67 | else: 68 | file_name = img_path.replace('.jpg', '.txt').replace('/', '__') 69 | output_file_path = os.path.join(ARGS.output_path, file_name) 70 | os.path.dirname(output_file_path) 71 | 72 | with open(output_file_path, 'w') as out_f: 73 | for bbox in annot[1:]: 74 | if ARGS.gt: 75 | # Here we are dealing with ground-truth annotations 76 | # [] 77 | # todo: handle difficulty 78 | x_min, y_min, x_max, y_max, class_id = list(map(float, bbox.split(','))) 79 | out_box = '{} {} {} {} {}'.format( 80 | class_map[int(class_id)].strip(), x_min, y_min, x_max, y_max) 81 | else: 82 | # Here we are dealing with predictions annotations 83 | # 84 | x_min, y_min, x_max, y_max, class_id, score = list(map(float, bbox.split(','))) 85 | out_box = '{} {} {} {} {} {}'.format( 86 | class_map[int(class_id)].strip(), score, x_min, y_min, x_max, y_max) 87 | 88 | out_f.write(out_box + "\n") 89 | -------------------------------------------------------------------------------- /mAP/extra/convert_pred_darkflow_json.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import json 5 | 6 | 7 | # change directory to the one with the files to be changed 8 | path_to_folder = '../predicted' 9 | #print(path_to_folder) 10 | os.chdir(path_to_folder) 11 | 12 | # old files (darkflow json format) will be moved to a "backup" folder 13 | ## create the backup dir if it doesn't exist already 14 | if not os.path.exists("backup"): 15 | os.makedirs("backup") 16 | 17 | # create VOC format files 18 | json_list = glob.glob('*.json') 19 | if len(json_list) == 0: 20 | print("Error: no .json files found in predicted") 21 | sys.exit() 22 | for tmp_file in json_list: 23 | #print(tmp_file) 24 | # 1. create new file (VOC format) 25 | with open(tmp_file.replace(".json", ".txt"), "a") as new_f: 26 | data = json.load(open(tmp_file)) 27 | for obj in data: 28 | obj_name = obj['label'] 29 | conf = obj['confidence'] 30 | left = obj['topleft']['x'] 31 | top = obj['topleft']['y'] 32 | right = obj['bottomright']['x'] 33 | bottom = obj['bottomright']['y'] 34 | new_f.write(obj_name + " " + str(conf) + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom) + '\n') 35 | # 2. move old file (darkflow format) to backup 36 | os.rename(tmp_file, "backup/" + tmp_file) 37 | print("Conversion completed!") 38 | -------------------------------------------------------------------------------- /mAP/extra/convert_pred_yolo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | IN_FILE = 'result.txt' 5 | OUTPUT_DIR = os.path.join('..', 'predicted') 6 | 7 | SEPARATOR_KEY = 'Enter Image Path:' 8 | IMG_FORMAT = '.jpg' 9 | 10 | outfile = None 11 | with open(IN_FILE) as infile: 12 | for line in infile: 13 | if SEPARATOR_KEY in line: 14 | if IMG_FORMAT not in line: 15 | break 16 | # get text between two substrings (SEPARATOR_KEY and IMG_FORMAT) 17 | image_path = re.search(SEPARATOR_KEY + '(.*)' + IMG_FORMAT, line) 18 | # get the image name (the final component of a image_path) 19 | # e.g., from 'data/horses_1' to 'horses_1' 20 | image_name = os.path.basename(image_path.group(1)) 21 | # close the previous file 22 | if outfile is not None: 23 | outfile.close() 24 | # open a new file 25 | outfile = open(os.path.join(OUTPUT_DIR, image_name + '.txt'), 'w') 26 | elif outfile is not None: 27 | # split line on first occurrence of the character ':' and '%' 28 | class_name, info = line.split(':', 1) 29 | confidence, bbox = info.split('%', 1) 30 | # get all the coordinates of the bounding box 31 | bbox = bbox.replace(')','') # remove the character ')' 32 | # go through each of the parts of the string and check if it is a digit 33 | left, top, width, height = [int(s) for s in bbox.split() if s.lstrip('-').isdigit()] 34 | right = left + width 35 | bottom = top + height 36 | outfile.write("{} {} {} {} {} {}\n".format(class_name, float(confidence)/100, left, top, right, bottom)) 37 | -------------------------------------------------------------------------------- /mAP/extra/find_class.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | if len(sys.argv) != 2: 6 | print("Error: wrong format.\nUsage: python find_class.py [class_name]") 7 | sys.exit(0) 8 | 9 | searching_class_name = sys.argv[1] 10 | 11 | def find_class(class_name): 12 | file_list = glob.glob('*.txt') 13 | file_list.sort() 14 | # iterate through the text files 15 | file_found = False 16 | for txt_file in file_list: 17 | # open txt file lines to a list 18 | with open(txt_file) as f: 19 | content = f.readlines() 20 | # remove whitespace characters like `\n` at the end of each line 21 | content = [x.strip() for x in content] 22 | # go through each line of eache file 23 | for line in content: 24 | class_name = line.split()[0] 25 | if class_name == searching_class_name: 26 | print(" " + txt_file) 27 | file_found = True 28 | break 29 | if not file_found: 30 | print(" No file found with that class") 31 | 32 | print("Ground-Truth folder:") 33 | os.chdir("../ground-truth") 34 | find_class(searching_class_name) 35 | print("\nPredicted folder:") 36 | os.chdir("../predicted") 37 | find_class(searching_class_name) 38 | -------------------------------------------------------------------------------- /mAP/extra/intersect-gt-and-pred.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | ## This script ensures same number of files in ground-truth and predicted folder. 6 | ## When you encounter file not found error, it's usually because you have 7 | ## mismatched numbers of ground-truth and predicted files. 8 | ## You can use this script to move ground-truth and predicted files that are 9 | ## not in the intersection into a backup folder (backup_no_matches_found). 10 | ## This will retain only files that have the same name in both folders. 11 | 12 | # change directory to the one with the files to be changed 13 | path_to_gt = '../ground-truth' 14 | path_to_pred = '../predicted' 15 | backup_folder = 'backup_no_matches_found' # must end without slash 16 | 17 | os.chdir(path_to_gt) 18 | gt_files = glob.glob('*.txt') 19 | if len(gt_files) == 0: 20 | print("Error: no .txt files found in", path_to_gt) 21 | sys.exit() 22 | os.chdir(path_to_pred) 23 | pred_files = glob.glob('*.txt') 24 | if len(pred_files) == 0: 25 | print("Error: no .txt files found in", path_to_pred) 26 | sys.exit() 27 | 28 | gt_files = set(gt_files) 29 | pred_files = set(pred_files) 30 | print('total ground-truth files:', len(gt_files)) 31 | print('total predicted files:', len(pred_files)) 32 | print() 33 | 34 | gt_backup = gt_files - pred_files 35 | pred_backup = pred_files - gt_files 36 | 37 | def backup(src_folder, backup_files, backup_folder): 38 | # non-intersection files (txt format) will be moved to a backup folder 39 | if not backup_files: 40 | print('No backup required for', src_folder) 41 | return 42 | os.chdir(src_folder) 43 | ## create the backup dir if it doesn't exist already 44 | if not os.path.exists(backup_folder): 45 | os.makedirs(backup_folder) 46 | for file in backup_files: 47 | os.rename(file, backup_folder + '/' + file) 48 | 49 | backup(path_to_gt, gt_backup, backup_folder) 50 | backup(path_to_pred, pred_backup, backup_folder) 51 | if gt_backup: 52 | print('total ground-truth backup files:', len(gt_backup)) 53 | if pred_backup: 54 | print('total predicted backup files:', len(pred_backup)) 55 | 56 | intersection = gt_files & pred_files 57 | print('total intersected files:', len(intersection)) 58 | print("Intersection completed!") 59 | -------------------------------------------------------------------------------- /mAP/extra/remove_class.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | if len(sys.argv) != 2: 6 | print("Error: wrong format.\nUsage: python remove_class.py [class_name]") 7 | sys.exit(0) 8 | 9 | searching_class_name = sys.argv[1] 10 | 11 | 12 | def query_yes_no(question, default="yes"): 13 | """Ask a yes/no question via raw_input() and return their answer. 14 | 15 | "question" is a string that is presented to the user. 16 | "default" is the presumed answer if the user just hits . 17 | It must be "yes" (the default), "no" or None (meaning 18 | an answer is required of the user). 19 | 20 | The "answer" return value is True for "yes" or False for "no". 21 | """ 22 | valid = {"yes": True, "y": True, "ye": True, 23 | "no": False, "n": False} 24 | if default is None: 25 | prompt = " [y/n] " 26 | elif default == "yes": 27 | prompt = " [Y/n] " 28 | elif default == "no": 29 | prompt = " [y/N] " 30 | else: 31 | raise ValueError("invalid default answer: '%s'" % default) 32 | 33 | while True: 34 | sys.stdout.write(question + prompt) 35 | if sys.version_info[0] == 3: 36 | choice = input().lower() # if version 3 of Python 37 | else: 38 | choice = raw_input().lower() 39 | if default is not None and choice == '': 40 | return valid[default] 41 | elif choice in valid: 42 | return valid[choice] 43 | else: 44 | sys.stdout.write("Please respond with 'yes' or 'no' " 45 | "(or 'y' or 'n').\n") 46 | 47 | 48 | def remove_class(class_name): 49 | # get list of txt files 50 | file_list = glob.glob('*.txt') 51 | file_list.sort() 52 | # iterate through the txt files 53 | for txt_file in file_list: 54 | class_found = False 55 | # open txt file lines to a list 56 | with open(txt_file) as f: 57 | content = f.readlines() 58 | # remove whitespace characters like `\n` at the end of each line 59 | content = [x.strip() for x in content] 60 | new_content = [] 61 | # go through each line of eache file 62 | for line in content: 63 | class_name = line.split()[0] 64 | if class_name == searching_class_name: 65 | class_found = True 66 | else: 67 | new_content.append(line) 68 | if class_found: 69 | # rewrite file 70 | with open(txt_file, 'w') as new_f: 71 | for line in new_content: 72 | new_f.write("%s\n" % line) 73 | 74 | if query_yes_no("Are you sure you want to remove the class \"" + searching_class_name + "\"?"): 75 | print(" Ground-Truth folder:") 76 | os.chdir("../ground-truth") 77 | remove_class(searching_class_name) 78 | print(" Done!") 79 | print(" Predicted folder:") 80 | os.chdir("../predicted") 81 | remove_class(searching_class_name) 82 | print(" Done!") 83 | -------------------------------------------------------------------------------- /mAP/extra/remove_delimiter_char.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | import argparse 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-c', '--char', required=True, type=str, help='specific character to be removed (e.g. ";").') 9 | # mutually exclusive arguments (can't select both) 10 | group = parser.add_mutually_exclusive_group(required=True) 11 | group.add_argument('-g', '--ground-truth', help="if to remove that char from the ground-truth files.", action="store_true") 12 | group.add_argument('-p', '--predicted', help="if to remove that char from the predicted objects files.", action="store_true") 13 | args = parser.parse_args() 14 | 15 | def file_lines_to_list(path): 16 | # open txt file lines to a list 17 | with open(path) as f: 18 | content = f.readlines() 19 | # remove whitespace characters like `\n` at the end of each line 20 | content = [x.strip() for x in content] 21 | return content 22 | 23 | if len(args.char) != 1: 24 | print("Error: Please select a single char to be removed.") 25 | sys.exit(0) 26 | 27 | if args.predicted: 28 | os.chdir("../predicted/") 29 | else: 30 | os.chdir("../ground-truth/") 31 | 32 | ## create the backup dir if it doesn't exist already 33 | backup_path = "backup" 34 | if not os.path.exists(backup_path): 35 | os.makedirs(backup_path) 36 | 37 | # get a list with the predicted files 38 | files_list = glob.glob('*.txt') 39 | files_list.sort() 40 | 41 | for txt_file in files_list: 42 | lines = file_lines_to_list(txt_file) 43 | is_char_present = any(args.char in line for line in lines) 44 | if is_char_present: 45 | # move old file to backup 46 | os.rename(txt_file, backup_path + "/" + txt_file) 47 | # create new file 48 | with open(txt_file, "a") as new_f: 49 | for line in lines: 50 | #print(line) 51 | if args.predicted: 52 | class_name, confidence, left, top, right, bottom = line.split(args.char) 53 | # remove any white space if existent in the class name 54 | class_name = class_name.replace(" ", "") 55 | new_f.write(class_name + " " + confidence + " " + left + " " + top + " " + right + " " + bottom + '\n') 56 | else: 57 | # ground-truth has no "confidence" 58 | class_name, left, top, right, bottom = line.split(args.char) 59 | # remove any white space if existent in the class name 60 | class_name = class_name.replace(" ", "") 61 | new_f.write(class_name + " " + left + " " + top + " " + right + " " + bottom + '\n') 62 | print("Conversion completed!") 63 | -------------------------------------------------------------------------------- /mAP/extra/remove_space.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import argparse 5 | 6 | # this script will load class_list.txt and find class names with spaces 7 | # then replace spaces with delimiters inside ground-truth/ and predicted/ 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('-d', '--delimiter', type=str, help="delimiter to replace space (default: '-')", default='-') 11 | parser.add_argument('-y', '--yes', action='store_true', help="force yes confirmation on yes/no query (default: False)", default=False) 12 | args = parser.parse_args() 13 | 14 | def query_yes_no(question, default="yes", bypass=False): 15 | """Ask a yes/no question via raw_input() and return their answer. 16 | 17 | "question" is a string that is presented to the user. 18 | "default" is the presumed answer if the user just hits . 19 | It must be "yes" (the default), "no" or None (meaning 20 | an answer is required of the user). 21 | 22 | The "answer" return value is True for "yes" or False for "no". 23 | """ 24 | valid = {"yes": True, "y": True, "ye": True, 25 | "no": False, "n": False} 26 | if default is None: 27 | prompt = " [y/n] " 28 | elif default == "yes": 29 | prompt = " [Y/n] " 30 | elif default == "no": 31 | prompt = " [y/N] " 32 | else: 33 | raise ValueError("invalid default answer: '%s'" % default) 34 | 35 | while True: 36 | sys.stdout.write(question + prompt) 37 | if bypass: 38 | break 39 | if sys.version_info[0] == 3: 40 | choice = input().lower() # if version 3 of Python 41 | else: 42 | choice = raw_input().lower() 43 | if default is not None and choice == '': 44 | return valid[default] 45 | elif choice in valid: 46 | return valid[choice] 47 | else: 48 | sys.stdout.write("Please respond with 'yes' or 'no' " 49 | "(or 'y' or 'n').\n") 50 | 51 | 52 | def rename_class(current_class_name, new_class_name): 53 | # get list of txt files 54 | file_list = glob.glob('*.txt') 55 | file_list.sort() 56 | # iterate through the txt files 57 | for txt_file in file_list: 58 | class_found = False 59 | # open txt file lines to a list 60 | with open(txt_file) as f: 61 | content = f.readlines() 62 | # remove whitespace characters like `\n` at the end of each line 63 | content = [x.strip() for x in content] 64 | new_content = [] 65 | # go through each line of eache file 66 | for line in content: 67 | #class_name = line.split()[0] 68 | if current_class_name in line: 69 | class_found = True 70 | line = line.replace(current_class_name, new_class_name) 71 | new_content.append(line) 72 | if class_found: 73 | # rewrite file 74 | with open(txt_file, 'w') as new_f: 75 | for line in new_content: 76 | new_f.write("%s\n" % line) 77 | 78 | with open('class_list.txt') as f: 79 | for line in f: 80 | current_class_name = line.rstrip("\n") 81 | new_class_name = line.replace(' ', args.delimiter).rstrip("\n") 82 | if current_class_name == new_class_name: 83 | continue 84 | y_n_message = ("Are you sure you want " 85 | "to rename the class " 86 | "\"" + current_class_name + "\" " 87 | "into \"" + new_class_name + "\"?" 88 | ) 89 | 90 | if query_yes_no(y_n_message, bypass=args.yes): 91 | os.chdir("../ground-truth") 92 | rename_class(current_class_name, new_class_name) 93 | os.chdir("../predicted") 94 | rename_class(current_class_name, new_class_name) 95 | 96 | print('Done!') 97 | -------------------------------------------------------------------------------- /mAP/extra/rename_class.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | # argparse current class name to a list (since it can contain more than one word, e.g."dining table") 8 | parser.add_argument('-c', '--current-class-name', nargs='+', type=str, help="current class name e.g.:\"dining table\".", required=True) 9 | # new class name (should be a single string without any spaces, e.g. "diningtable") 10 | parser.add_argument('-n', '--new-class-name', type=str, help="new class name.", required=True) 11 | args = parser.parse_args() 12 | 13 | current_class_name = " ".join(args.current_class_name) # join current name to single string 14 | new_class_name = args.new_class_name 15 | 16 | 17 | def query_yes_no(question, default="yes"): 18 | """Ask a yes/no question via raw_input() and return their answer. 19 | 20 | "question" is a string that is presented to the user. 21 | "default" is the presumed answer if the user just hits . 22 | It must be "yes" (the default), "no" or None (meaning 23 | an answer is required of the user). 24 | 25 | The "answer" return value is True for "yes" or False for "no". 26 | """ 27 | valid = {"yes": True, "y": True, "ye": True, 28 | "no": False, "n": False} 29 | if default is None: 30 | prompt = " [y/n] " 31 | elif default == "yes": 32 | prompt = " [Y/n] " 33 | elif default == "no": 34 | prompt = " [y/N] " 35 | else: 36 | raise ValueError("invalid default answer: '%s'" % default) 37 | 38 | while True: 39 | sys.stdout.write(question + prompt) 40 | if sys.version_info[0] == 3: 41 | choice = input().lower() # if version 3 of Python 42 | else: 43 | choice = raw_input().lower() 44 | if default is not None and choice == '': 45 | return valid[default] 46 | elif choice in valid: 47 | return valid[choice] 48 | else: 49 | sys.stdout.write("Please respond with 'yes' or 'no' " 50 | "(or 'y' or 'n').\n") 51 | 52 | 53 | def rename_class(current_class_name, new_class_name): 54 | # get list of txt files 55 | file_list = glob.glob('*.txt') 56 | file_list.sort() 57 | # iterate through the txt files 58 | for txt_file in file_list: 59 | class_found = False 60 | # open txt file lines to a list 61 | with open(txt_file) as f: 62 | content = f.readlines() 63 | # remove whitespace characters like `\n` at the end of each line 64 | content = [x.strip() for x in content] 65 | new_content = [] 66 | # go through each line of eache file 67 | for line in content: 68 | #class_name = line.split()[0] 69 | if current_class_name in line: 70 | class_found = True 71 | line = line.replace(current_class_name, new_class_name) 72 | new_content.append(line) 73 | if class_found: 74 | # rewrite file 75 | with open(txt_file, 'w') as new_f: 76 | for line in new_content: 77 | new_f.write("%s\n" % line) 78 | 79 | y_n_message = ("Are you sure you want " 80 | "to rename the class " 81 | "\"" + current_class_name + "\" " 82 | "into \"" + new_class_name + "\"?" 83 | ) 84 | 85 | if query_yes_no(y_n_message): 86 | print(" Ground-Truth folder:") 87 | os.chdir("../ground-truth") 88 | rename_class(current_class_name, new_class_name) 89 | print(" Done!") 90 | print(" Predicted folder:") 91 | os.chdir("../predicted") 92 | rename_class(current_class_name, new_class_name) 93 | print(" Done!") 94 | -------------------------------------------------------------------------------- /mAP/extra/result.txt: -------------------------------------------------------------------------------- 1 | Total BFLOPS 65.864 2 | 3 | seen 64 4 | Enter Image Path: data/horses.jpg: Predicted in 42.076185 seconds. 5 | horse: 88% (left_x: 3 top_y: 185 width: 150 height: 167) 6 | horse: 99% (left_x: 5 top_y: 198 width: 307 height: 214) 7 | horse: 96% (left_x: 236 top_y: 180 width: 215 height: 169) 8 | horse: 99% (left_x: 440 top_y: 209 width: 156 height: 142) 9 | Enter Image Path: data/person.jpg: Predicted in 41.767213 seconds. 10 | dog: 99% (left_x: 58 top_y: 262 width: 147 height: 89) 11 | person: 100% (left_x: 190 top_y: 95 width: 86 height: 284) 12 | horse: 100% (left_x: 394 top_y: 137 width: 215 height: 206) 13 | Enter Image Path: -------------------------------------------------------------------------------- /mAP/main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | import operator 6 | import sys 7 | import argparse 8 | 9 | MINOVERLAP = 0.5 # default value (defined in the PASCAL VOC2012 challenge) 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true") 13 | parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true") 14 | parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true") 15 | # argparse receiving list of classes to be ignored 16 | parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.") 17 | # argparse receiving list of classes with specific IoU 18 | parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.") 19 | args = parser.parse_args() 20 | 21 | # if there are no classes to ignore then replace None by empty list 22 | if args.ignore is None: 23 | args.ignore = [] 24 | 25 | specific_iou_flagged = False 26 | if args.set_class_iou is not None: 27 | specific_iou_flagged = True 28 | 29 | # if there are no images then no animation can be shown 30 | img_path = 'images' 31 | if os.path.exists(img_path): 32 | for dirpath, dirnames, files in os.walk(img_path): 33 | if not files: 34 | # no image files found 35 | args.no_animation = True 36 | else: 37 | args.no_animation = True 38 | 39 | # try to import OpenCV if the user didn't choose the option --no-animation 40 | show_animation = False 41 | if not args.no_animation: 42 | try: 43 | import cv2 44 | show_animation = True 45 | except ImportError: 46 | print("\"opencv-python\" not found, please install to visualize the results.") 47 | args.no_animation = True 48 | 49 | # try to import Matplotlib if the user didn't choose the option --no-plot 50 | draw_plot = False 51 | if not args.no_plot: 52 | try: 53 | import matplotlib.pyplot as plt 54 | draw_plot = True 55 | except ImportError: 56 | print("\"matplotlib\" not found, please install it to get the resulting plots.") 57 | args.no_plot = True 58 | 59 | """ 60 | throw error and exit 61 | """ 62 | def error(msg): 63 | print(msg) 64 | sys.exit(0) 65 | 66 | """ 67 | check if the number is a float between 0.0 and 1.0 68 | """ 69 | def is_float_between_0_and_1(value): 70 | try: 71 | val = float(value) 72 | if val > 0.0 and val < 1.0: 73 | return True 74 | else: 75 | return False 76 | except ValueError: 77 | return False 78 | 79 | """ 80 | Calculate the AP given the recall and precision array 81 | 1st) We compute a version of the measured precision/recall curve with 82 | precision monotonically decreasing 83 | 2nd) We compute the AP as the area under this curve by numerical integration. 84 | """ 85 | def voc_ap(rec, prec): 86 | """ 87 | --- Official matlab code VOC2012--- 88 | mrec=[0 ; rec ; 1]; 89 | mpre=[0 ; prec ; 0]; 90 | for i=numel(mpre)-1:-1:1 91 | mpre(i)=max(mpre(i),mpre(i+1)); 92 | end 93 | i=find(mrec(2:end)~=mrec(1:end-1))+1; 94 | ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 95 | """ 96 | rec.insert(0, 0.0) # insert 0.0 at begining of list 97 | rec.append(1.0) # insert 1.0 at end of list 98 | mrec = rec[:] 99 | prec.insert(0, 0.0) # insert 0.0 at begining of list 100 | prec.append(0.0) # insert 0.0 at end of list 101 | mpre = prec[:] 102 | """ 103 | This part makes the precision monotonically decreasing 104 | (goes from the end to the beginning) 105 | matlab: for i=numel(mpre)-1:-1:1 106 | mpre(i)=max(mpre(i),mpre(i+1)); 107 | """ 108 | # matlab indexes start in 1 but python in 0, so I have to do: 109 | # range(start=(len(mpre) - 2), end=0, step=-1) 110 | # also the python function range excludes the end, resulting in: 111 | # range(start=(len(mpre) - 2), end=-1, step=-1) 112 | for i in range(len(mpre)-2, -1, -1): 113 | mpre[i] = max(mpre[i], mpre[i+1]) 114 | """ 115 | This part creates a list of indexes where the recall changes 116 | matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1; 117 | """ 118 | i_list = [] 119 | for i in range(1, len(mrec)): 120 | if mrec[i] != mrec[i-1]: 121 | i_list.append(i) # if it was matlab would be i + 1 122 | """ 123 | The Average Precision (AP) is the area under the curve 124 | (numerical integration) 125 | matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 126 | """ 127 | ap = 0.0 128 | for i in i_list: 129 | ap += ((mrec[i]-mrec[i-1])*mpre[i]) 130 | return ap, mrec, mpre 131 | 132 | 133 | """ 134 | Convert the lines of a file to a list 135 | """ 136 | def file_lines_to_list(path): 137 | # open txt file lines to a list 138 | with open(path) as f: 139 | content = f.readlines() 140 | # remove whitespace characters like `\n` at the end of each line 141 | content = [x.strip() for x in content] 142 | return content 143 | 144 | """ 145 | Draws text in image 146 | """ 147 | def draw_text_in_image(img, text, pos, color, line_width): 148 | font = cv2.FONT_HERSHEY_PLAIN 149 | fontScale = 1 150 | lineType = 1 151 | bottomLeftCornerOfText = pos 152 | cv2.putText(img, text, 153 | bottomLeftCornerOfText, 154 | font, 155 | fontScale, 156 | color, 157 | lineType) 158 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] 159 | return img, (line_width + text_width) 160 | 161 | """ 162 | Plot - adjust axes 163 | """ 164 | def adjust_axes(r, t, fig, axes): 165 | # get text width for re-scaling 166 | bb = t.get_window_extent(renderer=r) 167 | text_width_inches = bb.width / fig.dpi 168 | # get axis width in inches 169 | current_fig_width = fig.get_figwidth() 170 | new_fig_width = current_fig_width + text_width_inches 171 | propotion = new_fig_width / current_fig_width 172 | # get axis limit 173 | x_lim = axes.get_xlim() 174 | axes.set_xlim([x_lim[0], x_lim[1]*propotion]) 175 | 176 | """ 177 | Draw plot using Matplotlib 178 | """ 179 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): 180 | # sort the dictionary by decreasing value, into a list of tuples 181 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) 182 | # unpacking the list of tuples into two lists 183 | sorted_keys, sorted_values = zip(*sorted_dic_by_value) 184 | # 185 | if true_p_bar != "": 186 | """ 187 | Special case to draw in (green=true predictions) & (red=false predictions) 188 | """ 189 | fp_sorted = [] 190 | tp_sorted = [] 191 | for key in sorted_keys: 192 | fp_sorted.append(dictionary[key] - true_p_bar[key]) 193 | tp_sorted.append(true_p_bar[key]) 194 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Predictions') 195 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Predictions', left=fp_sorted) 196 | # add legend 197 | plt.legend(loc='lower right') 198 | """ 199 | Write number on side of bar 200 | """ 201 | fig = plt.gcf() # gcf - get current figure 202 | axes = plt.gca() 203 | r = fig.canvas.get_renderer() 204 | for i, val in enumerate(sorted_values): 205 | fp_val = fp_sorted[i] 206 | tp_val = tp_sorted[i] 207 | fp_str_val = " " + str(fp_val) 208 | tp_str_val = fp_str_val + " " + str(tp_val) 209 | # trick to paint multicolor with offset: 210 | # first paint everything and then repaint the first number 211 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') 212 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') 213 | if i == (len(sorted_values)-1): # largest bar 214 | adjust_axes(r, t, fig, axes) 215 | else: 216 | plt.barh(range(n_classes), sorted_values, color=plot_color) 217 | """ 218 | Write number on side of bar 219 | """ 220 | fig = plt.gcf() # gcf - get current figure 221 | axes = plt.gca() 222 | r = fig.canvas.get_renderer() 223 | for i, val in enumerate(sorted_values): 224 | str_val = " " + str(val) # add a space before 225 | if val < 1.0: 226 | str_val = " {0:.2f}".format(val) 227 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') 228 | # re-set axes to show number inside the figure 229 | if i == (len(sorted_values)-1): # largest bar 230 | adjust_axes(r, t, fig, axes) 231 | # set window title 232 | fig.canvas.set_window_title(window_title) 233 | # write classes in y axis 234 | tick_font_size = 12 235 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) 236 | """ 237 | Re-scale height accordingly 238 | """ 239 | init_height = fig.get_figheight() 240 | # comput the matrix height in points and inches 241 | dpi = fig.dpi 242 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) 243 | height_in = height_pt / dpi 244 | # compute the required figure height 245 | top_margin = 0.15 # in percentage of the figure height 246 | bottom_margin = 0.05 # in percentage of the figure height 247 | figure_height = height_in / (1 - top_margin - bottom_margin) 248 | # set new height 249 | if figure_height > init_height: 250 | fig.set_figheight(figure_height) 251 | 252 | # set plot title 253 | plt.title(plot_title, fontsize=14) 254 | # set axis titles 255 | # plt.xlabel('classes') 256 | plt.xlabel(x_label, fontsize='large') 257 | # adjust size of window 258 | fig.tight_layout() 259 | # save the plot 260 | fig.savefig(output_path) 261 | # show image 262 | if to_show: 263 | plt.show() 264 | # close the plot 265 | plt.close() 266 | 267 | """ 268 | Create a "tmp_files/" and "results/" directory 269 | """ 270 | tmp_files_path = "tmp_files" 271 | if not os.path.exists(tmp_files_path): # if it doesn't exist already 272 | os.makedirs(tmp_files_path) 273 | results_files_path = "results" 274 | if os.path.exists(results_files_path): # if it exist already 275 | # reset the results directory 276 | shutil.rmtree(results_files_path) 277 | 278 | os.makedirs(results_files_path) 279 | if draw_plot: 280 | os.makedirs(results_files_path + "/classes") 281 | if show_animation: 282 | os.makedirs(results_files_path + "/images") 283 | os.makedirs(results_files_path + "/images/single_predictions") 284 | 285 | """ 286 | Ground-Truth 287 | Load each of the ground-truth files into a temporary ".json" file. 288 | Create a list of all the class names present in the ground-truth (gt_classes). 289 | """ 290 | # get a list with the ground-truth files 291 | ground_truth_files_list = glob.glob('ground-truth/*.txt') 292 | if len(ground_truth_files_list) == 0: 293 | error("Error: No ground-truth files found!") 294 | ground_truth_files_list.sort() 295 | # dictionary with counter per class 296 | gt_counter_per_class = {} 297 | 298 | for txt_file in ground_truth_files_list: 299 | #print(txt_file) 300 | file_id = txt_file.split(".txt",1)[0] 301 | file_id = os.path.basename(os.path.normpath(file_id)) 302 | # check if there is a correspondent predicted objects file 303 | if not os.path.exists('predicted/' + file_id + ".txt"): 304 | error_msg = "Error. File not found: predicted/" + file_id + ".txt\n" 305 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-pred.py)" 306 | error(error_msg) 307 | lines_list = file_lines_to_list(txt_file) 308 | # create ground-truth dictionary 309 | bounding_boxes = [] 310 | is_difficult = False 311 | for line in lines_list: 312 | try: 313 | if "difficult" in line: 314 | class_name, left, top, right, bottom, _difficult = line.split() 315 | is_difficult = True 316 | else: 317 | class_name, left, top, right, bottom = line.split() 318 | except ValueError: 319 | error_msg = "Error: File " + txt_file + " in the wrong format.\n" 320 | error_msg += " Expected: ['difficult']\n" 321 | error_msg += " Received: " + line 322 | error_msg += "\n\nIf you have a with spaces between words you should remove them\n" 323 | error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder." 324 | error(error_msg) 325 | # check if class is in the ignore list, if yes skip 326 | if class_name in args.ignore: 327 | continue 328 | bbox = left + " " + top + " " + right + " " +bottom 329 | if is_difficult: 330 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) 331 | is_difficult = False 332 | else: 333 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) 334 | # count that object 335 | if class_name in gt_counter_per_class: 336 | gt_counter_per_class[class_name] += 1 337 | else: 338 | # if class didn't exist yet 339 | gt_counter_per_class[class_name] = 1 340 | # dump bounding_boxes into a ".json" file 341 | with open(tmp_files_path + "/" + file_id + "_ground_truth.json", 'w') as outfile: 342 | json.dump(bounding_boxes, outfile) 343 | 344 | gt_classes = list(gt_counter_per_class.keys()) 345 | # let's sort the classes alphabetically 346 | gt_classes = sorted(gt_classes) 347 | n_classes = len(gt_classes) 348 | #print(gt_classes) 349 | #print(gt_counter_per_class) 350 | 351 | """ 352 | Check format of the flag --set-class-iou (if used) 353 | e.g. check if class exists 354 | """ 355 | if specific_iou_flagged: 356 | n_args = len(args.set_class_iou) 357 | error_msg = \ 358 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]' 359 | if n_args % 2 != 0: 360 | error('Error, missing arguments. Flag usage:' + error_msg) 361 | # [class_1] [IoU_1] [class_2] [IoU_2] 362 | # specific_iou_classes = ['class_1', 'class_2'] 363 | specific_iou_classes = args.set_class_iou[::2] # even 364 | # iou_list = ['IoU_1', 'IoU_2'] 365 | iou_list = args.set_class_iou[1::2] # odd 366 | if len(specific_iou_classes) != len(iou_list): 367 | error('Error, missing arguments. Flag usage:' + error_msg) 368 | for tmp_class in specific_iou_classes: 369 | if tmp_class not in gt_classes: 370 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg) 371 | for num in iou_list: 372 | if not is_float_between_0_and_1(num): 373 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg) 374 | 375 | """ 376 | Predicted 377 | Load each of the predicted files into a temporary ".json" file. 378 | """ 379 | # get a list with the predicted files 380 | predicted_files_list = glob.glob('predicted/*.txt') 381 | predicted_files_list.sort() 382 | 383 | for class_index, class_name in enumerate(gt_classes): 384 | bounding_boxes = [] 385 | for txt_file in predicted_files_list: 386 | #print(txt_file) 387 | # the first time it checks if all the corresponding ground-truth files exist 388 | file_id = txt_file.split(".txt",1)[0] 389 | file_id = os.path.basename(os.path.normpath(file_id)) 390 | if class_index == 0: 391 | if not os.path.exists('ground-truth/' + file_id + ".txt"): 392 | error_msg = "Error. File not found: ground-truth/" + file_id + ".txt\n" 393 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-pred.py)" 394 | error(error_msg) 395 | lines = file_lines_to_list(txt_file) 396 | for line in lines: 397 | try: 398 | tmp_class_name, confidence, left, top, right, bottom = line.split() 399 | except ValueError: 400 | error_msg = "Error: File " + txt_file + " in the wrong format.\n" 401 | error_msg += " Expected: \n" 402 | error_msg += " Received: " + line 403 | error(error_msg) 404 | if tmp_class_name == class_name: 405 | #print("match") 406 | bbox = left + " " + top + " " + right + " " +bottom 407 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) 408 | #print(bounding_boxes) 409 | # sort predictions by decreasing confidence 410 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) 411 | with open(tmp_files_path + "/" + class_name + "_predictions.json", 'w') as outfile: 412 | json.dump(bounding_boxes, outfile) 413 | 414 | """ 415 | Calculate the AP for each class 416 | """ 417 | sum_AP = 0.0 418 | ap_dictionary = {} 419 | # open file to store the results 420 | with open(results_files_path + "/results.txt", 'w') as results_file: 421 | results_file.write("# AP and precision/recall per class\n") 422 | count_true_positives = {} 423 | for class_index, class_name in enumerate(gt_classes): 424 | count_true_positives[class_name] = 0 425 | """ 426 | Load predictions of that class 427 | """ 428 | predictions_file = tmp_files_path + "/" + class_name + "_predictions.json" 429 | predictions_data = json.load(open(predictions_file)) 430 | 431 | """ 432 | Assign predictions to ground truth objects 433 | """ 434 | nd = len(predictions_data) 435 | tp = [0] * nd # creates an array of zeros of size nd 436 | fp = [0] * nd 437 | for idx, prediction in enumerate(predictions_data): 438 | file_id = prediction["file_id"] 439 | if show_animation: 440 | # find ground truth image 441 | ground_truth_img = glob.glob1(img_path, file_id + ".*") 442 | #tifCounter = len(glob.glob1(myPath,"*.tif")) 443 | if len(ground_truth_img) == 0: 444 | error("Error. Image not found with id: " + file_id) 445 | elif len(ground_truth_img) > 1: 446 | error("Error. Multiple image with id: " + file_id) 447 | else: # found image 448 | #print(img_path + "/" + ground_truth_img[0]) 449 | # Load image 450 | img = cv2.imread(img_path + "/" + ground_truth_img[0]) 451 | # load image with draws of multiple detections 452 | img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0] 453 | if os.path.isfile(img_cumulative_path): 454 | img_cumulative = cv2.imread(img_cumulative_path) 455 | else: 456 | img_cumulative = img.copy() 457 | # Add bottom border to image 458 | bottom_border = 60 459 | BLACK = [0, 0, 0] 460 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) 461 | # assign prediction to ground truth object if any 462 | # open ground-truth with that file_id 463 | gt_file = tmp_files_path + "/" + file_id + "_ground_truth.json" 464 | ground_truth_data = json.load(open(gt_file)) 465 | ovmax = -1 466 | gt_match = -1 467 | # load prediction bounding-box 468 | bb = [ float(x) for x in prediction["bbox"].split() ] 469 | for obj in ground_truth_data: 470 | # look for a class_name match 471 | if obj["class_name"] == class_name: 472 | bbgt = [ float(x) for x in obj["bbox"].split() ] 473 | bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] 474 | iw = bi[2] - bi[0] + 1 475 | ih = bi[3] - bi[1] + 1 476 | if iw > 0 and ih > 0: 477 | # compute overlap (IoU) = area of intersection / area of union 478 | ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] 479 | + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih 480 | ov = iw * ih / ua 481 | if ov > ovmax: 482 | ovmax = ov 483 | gt_match = obj 484 | 485 | # assign prediction as true positive/don't care/false positive 486 | if show_animation: 487 | status = "NO MATCH FOUND!" # status is only used in the animation 488 | # set minimum overlap 489 | min_overlap = MINOVERLAP 490 | if specific_iou_flagged: 491 | if class_name in specific_iou_classes: 492 | index = specific_iou_classes.index(class_name) 493 | min_overlap = float(iou_list[index]) 494 | if ovmax >= min_overlap: 495 | if "difficult" not in gt_match: 496 | if not bool(gt_match["used"]): 497 | # true positive 498 | tp[idx] = 1 499 | gt_match["used"] = True 500 | count_true_positives[class_name] += 1 501 | # update the ".json" file 502 | with open(gt_file, 'w') as f: 503 | f.write(json.dumps(ground_truth_data)) 504 | if show_animation: 505 | status = "MATCH!" 506 | else: 507 | # false positive (multiple detection) 508 | fp[idx] = 1 509 | if show_animation: 510 | status = "REPEATED MATCH!" 511 | else: 512 | # false positive 513 | fp[idx] = 1 514 | if ovmax > 0: 515 | status = "INSUFFICIENT OVERLAP" 516 | 517 | """ 518 | Draw image to show animation 519 | """ 520 | if show_animation: 521 | height, widht = img.shape[:2] 522 | # colors (OpenCV works with BGR) 523 | white = (255,255,255) 524 | light_blue = (255,200,100) 525 | green = (0,255,0) 526 | light_red = (30,30,255) 527 | # 1st line 528 | margin = 10 529 | v_pos = int(height - margin - (bottom_border / 2)) 530 | text = "Image: " + ground_truth_img[0] + " " 531 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 532 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " 533 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) 534 | if ovmax != -1: 535 | color = light_red 536 | if status == "INSUFFICIENT OVERLAP": 537 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) 538 | else: 539 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) 540 | color = green 541 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 542 | # 2nd line 543 | v_pos += int(bottom_border / 2) 544 | rank_pos = str(idx+1) # rank position (idx starts at 0) 545 | text = "Prediction #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(prediction["confidence"])*100) 546 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 547 | color = light_red 548 | if status == "MATCH!": 549 | color = green 550 | text = "Result: " + status + " " 551 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 552 | 553 | font = cv2.FONT_HERSHEY_SIMPLEX 554 | if ovmax > 0: # if there is intersections between the bounding-boxes 555 | bbgt = [ int(x) for x in gt_match["bbox"].split() ] 556 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 557 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 558 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) 559 | bb = [int(i) for i in bb] 560 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 561 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 562 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) 563 | # show image 564 | cv2.imshow("Animation", img) 565 | cv2.waitKey(20) # show for 20 ms 566 | # save image to results 567 | output_img_path = results_files_path + "/images/single_predictions/" + class_name + "_prediction" + str(idx) + ".jpg" 568 | cv2.imwrite(output_img_path, img) 569 | # save the image with all the objects drawn to it 570 | cv2.imwrite(img_cumulative_path, img_cumulative) 571 | 572 | #print(tp) 573 | # compute precision/recall 574 | cumsum = 0 575 | for idx, val in enumerate(fp): 576 | fp[idx] += cumsum 577 | cumsum += val 578 | cumsum = 0 579 | for idx, val in enumerate(tp): 580 | tp[idx] += cumsum 581 | cumsum += val 582 | #print(tp) 583 | rec = tp[:] 584 | for idx, val in enumerate(tp): 585 | rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name] 586 | #print(rec) 587 | prec = tp[:] 588 | for idx, val in enumerate(tp): 589 | prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx]) 590 | #print(prec) 591 | 592 | ap, mrec, mprec = voc_ap(rec, prec) 593 | sum_AP += ap 594 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) 595 | """ 596 | Write to results.txt 597 | """ 598 | rounded_prec = [ '%.2f' % elem for elem in prec ] 599 | rounded_rec = [ '%.2f' % elem for elem in rec ] 600 | results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") 601 | if not args.quiet: 602 | print(text) 603 | ap_dictionary[class_name] = ap 604 | 605 | """ 606 | Draw plot 607 | """ 608 | if draw_plot: 609 | plt.plot(rec, prec, '-o') 610 | # add a new penultimate point to the list (mrec[-2], 0.0) 611 | # since the last line segment (and respective area) do not affect the AP value 612 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] 613 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] 614 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') 615 | # set window title 616 | fig = plt.gcf() # gcf - get current figure 617 | fig.canvas.set_window_title('AP ' + class_name) 618 | # set plot title 619 | plt.title('class: ' + text) 620 | #plt.suptitle('This is a somewhat long figure title', fontsize=16) 621 | # set axis titles 622 | plt.xlabel('Recall') 623 | plt.ylabel('Precision') 624 | # optional - set axes 625 | axes = plt.gca() # gca - get current axes 626 | axes.set_xlim([0.0,1.0]) 627 | axes.set_ylim([0.0,1.05]) # .05 to give some extra space 628 | # Alternative option -> wait for button to be pressed 629 | #while not plt.waitforbuttonpress(): pass # wait for key display 630 | # Alternative option -> normal display 631 | #plt.show() 632 | # save the plot 633 | fig.savefig(results_files_path + "/classes/" + class_name + ".png") 634 | plt.cla() # clear axes for next plot 635 | 636 | if show_animation: 637 | cv2.destroyAllWindows() 638 | 639 | results_file.write("\n# mAP of all classes\n") 640 | mAP = sum_AP / n_classes 641 | text = "mAP = {0:.2f}%".format(mAP*100) 642 | results_file.write(text + "\n") 643 | print(text) 644 | 645 | # remove the tmp_files directory 646 | shutil.rmtree(tmp_files_path) 647 | 648 | """ 649 | Count total of Predictions 650 | """ 651 | # iterate through all the files 652 | pred_counter_per_class = {} 653 | #all_classes_predicted_files = set([]) 654 | for txt_file in predicted_files_list: 655 | # get lines to list 656 | lines_list = file_lines_to_list(txt_file) 657 | for line in lines_list: 658 | class_name = line.split()[0] 659 | # check if class is in the ignore list, if yes skip 660 | if class_name in args.ignore: 661 | continue 662 | # count that object 663 | if class_name in pred_counter_per_class: 664 | pred_counter_per_class[class_name] += 1 665 | else: 666 | # if class didn't exist yet 667 | pred_counter_per_class[class_name] = 1 668 | #print(pred_counter_per_class) 669 | pred_classes = list(pred_counter_per_class.keys()) 670 | 671 | 672 | """ 673 | Plot the total number of occurences of each class in the ground-truth 674 | """ 675 | if draw_plot: 676 | window_title = "Ground-Truth Info" 677 | plot_title = "Ground-Truth\n" 678 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" 679 | x_label = "Number of objects per class" 680 | output_path = results_files_path + "/Ground-Truth Info.png" 681 | to_show = False 682 | plot_color = 'forestgreen' 683 | draw_plot_func( 684 | gt_counter_per_class, 685 | n_classes, 686 | window_title, 687 | plot_title, 688 | x_label, 689 | output_path, 690 | to_show, 691 | plot_color, 692 | '', 693 | ) 694 | 695 | """ 696 | Write number of ground-truth objects per class to results.txt 697 | """ 698 | with open(results_files_path + "/results.txt", 'a') as results_file: 699 | results_file.write("\n# Number of ground-truth objects per class\n") 700 | for class_name in sorted(gt_counter_per_class): 701 | results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") 702 | 703 | """ 704 | Finish counting true positives 705 | """ 706 | for class_name in pred_classes: 707 | # if class exists in predictions but not in ground-truth then there are no true positives in that class 708 | if class_name not in gt_classes: 709 | count_true_positives[class_name] = 0 710 | #print(count_true_positives) 711 | 712 | """ 713 | Plot the total number of occurences of each class in the "predicted" folder 714 | """ 715 | if draw_plot: 716 | window_title = "Predicted Objects Info" 717 | # Plot title 718 | plot_title = "Predicted Objects\n" 719 | plot_title += "(" + str(len(predicted_files_list)) + " files and " 720 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(pred_counter_per_class.values())) 721 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" 722 | # end Plot title 723 | x_label = "Number of objects per class" 724 | output_path = results_files_path + "/Predicted Objects Info.png" 725 | to_show = False 726 | plot_color = 'forestgreen' 727 | true_p_bar = count_true_positives 728 | draw_plot_func( 729 | pred_counter_per_class, 730 | len(pred_counter_per_class), 731 | window_title, 732 | plot_title, 733 | x_label, 734 | output_path, 735 | to_show, 736 | plot_color, 737 | true_p_bar 738 | ) 739 | 740 | """ 741 | Write number of predicted objects per class to results.txt 742 | """ 743 | with open(results_files_path + "/results.txt", 'a') as results_file: 744 | results_file.write("\n# Number of predicted objects per class\n") 745 | for class_name in sorted(pred_classes): 746 | n_pred = pred_counter_per_class[class_name] 747 | text = class_name + ": " + str(n_pred) 748 | text += " (tp:" + str(count_true_positives[class_name]) + "" 749 | text += ", fp:" + str(n_pred - count_true_positives[class_name]) + ")\n" 750 | results_file.write(text) 751 | 752 | """ 753 | Draw mAP plot (Show AP's of all classes in decreasing order) 754 | """ 755 | if draw_plot: 756 | window_title = "mAP" 757 | plot_title = "mAP = {0:.2f}%".format(mAP*100) 758 | x_label = "Average Precision" 759 | output_path = results_files_path + "/mAP.png" 760 | to_show = True 761 | plot_color = 'royalblue' 762 | draw_plot_func( 763 | ap_dictionary, 764 | n_classes, 765 | window_title, 766 | plot_title, 767 | x_label, 768 | output_path, 769 | to_show, 770 | plot_color, 771 | "" 772 | ) 773 | -------------------------------------------------------------------------------- /scripts/show_bboxes.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : show_bboxes.py 8 | # Author : YunYang1994 9 | # Created date: 2019-05-29 01:18:24 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import cv2 15 | import numpy as np 16 | from PIL import Image 17 | 18 | ID = 0 19 | label_txt = "../data/dataset/traffic_test.txt" 20 | image_info = open(label_txt).readlines()[ID].split() 21 | 22 | image_path = image_info[0] 23 | image = cv2.imread(image_path) 24 | for bbox in image_info[1:]: 25 | bbox = bbox.split(",") 26 | image = cv2.rectangle(image,(int(float(bbox[0])), 27 | int(float(bbox[1]))), 28 | (int(float(bbox[2])), 29 | int(float(bbox[3]))), (255,0,0), 2) 30 | 31 | image = Image.fromarray(np.uint8(image)) 32 | image.show() 33 | -------------------------------------------------------------------------------- /scripts/voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import xml.etree.ElementTree as ET 4 | 5 | def convert_voc_annotation(data_path, data_type, anno_path, use_difficult_bbox=True): 6 | 7 | classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 8 | 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 9 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 10 | 'train', 'tvmonitor'] 11 | img_inds_file = os.path.join(data_path, 'ImageSets', 'Main', data_type + '.txt') 12 | with open(img_inds_file, 'r') as f: 13 | txt = f.readlines() 14 | image_inds = [line.strip() for line in txt] 15 | 16 | with open(anno_path, 'a') as f: 17 | for image_ind in image_inds: 18 | image_path = os.path.join(data_path, 'JPEGImages', image_ind + '.jpg') 19 | annotation = image_path 20 | label_path = os.path.join(data_path, 'Annotations', image_ind + '.xml') 21 | root = ET.parse(label_path).getroot() 22 | objects = root.findall('object') 23 | for obj in objects: 24 | difficult = obj.find('difficult').text.strip() 25 | if (not use_difficult_bbox) and(int(difficult) == 1): 26 | continue 27 | bbox = obj.find('bndbox') 28 | class_ind = classes.index(obj.find('name').text.lower().strip()) 29 | xmin = bbox.find('xmin').text.strip() 30 | xmax = bbox.find('xmax').text.strip() 31 | ymin = bbox.find('ymin').text.strip() 32 | ymax = bbox.find('ymax').text.strip() 33 | annotation += ' ' + ','.join([xmin, ymin, xmax, ymax, str(class_ind)]) 34 | print(annotation) 35 | f.write(annotation + "\n") 36 | return len(image_inds) 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--data_path", default="/home/yang/test/VOC/") 42 | parser.add_argument("--train_annotation", default="./data/dataset/voc_train.txt") 43 | parser.add_argument("--test_annotation", default="./data/dataset/voc_test.txt") 44 | flags = parser.parse_args() 45 | 46 | if os.path.exists(flags.train_annotation):os.remove(flags.train_annotation) 47 | if os.path.exists(flags.test_annotation):os.remove(flags.test_annotation) 48 | 49 | num1 = convert_voc_annotation(os.path.join(flags.data_path, 'train/VOCdevkit/VOC2007'), 'trainval', flags.train_annotation, False) 50 | num2 = convert_voc_annotation(os.path.join(flags.data_path, 'train/VOCdevkit/VOC2012'), 'trainval', flags.train_annotation, False) 51 | num3 = convert_voc_annotation(os.path.join(flags.data_path, 'test/VOCdevkit/VOC2007'), 'test', flags.test_annotation, False) 52 | print('=> The number of image for train is: %d\tThe number of image for test is:%d' %(num1 + num2, num3)) 53 | 54 | 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2019 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : train.py 8 | # Author : YunYang1994 9 | # Created date: 2019-02-28 17:50:26 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import os 15 | import time 16 | import shutil 17 | import numpy as np 18 | import tensorflow as tf 19 | import core.utils as utils 20 | from tqdm import tqdm 21 | from core.dataset import Dataset 22 | from core.yolov3 import YOLOV3 23 | from core.config import cfg 24 | 25 | 26 | class YoloTrain(object): 27 | def __init__(self): 28 | self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE 29 | self.classes = utils.read_class_names(cfg.YOLO.CLASSES) 30 | self.num_classes = len(self.classes) 31 | self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT 32 | self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END 33 | self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS 34 | self.second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS 35 | self.warmup_periods = cfg.TRAIN.WARMUP_EPOCHS 36 | self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT 37 | self.time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 38 | self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY 39 | self.max_bbox_per_scale = 150 40 | self.train_logdir = "./data/log/train" 41 | self.trainset = Dataset('train') 42 | self.testset = Dataset('test') 43 | self.steps_per_period = len(self.trainset) 44 | self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 45 | 46 | with tf.name_scope('define_input'): 47 | self.input_data = tf.placeholder(dtype=tf.float32, name='input_data') 48 | self.label_sbbox = tf.placeholder(dtype=tf.float32, name='label_sbbox') 49 | self.label_mbbox = tf.placeholder(dtype=tf.float32, name='label_mbbox') 50 | self.label_lbbox = tf.placeholder(dtype=tf.float32, name='label_lbbox') 51 | self.true_sbboxes = tf.placeholder(dtype=tf.float32, name='sbboxes') 52 | self.true_mbboxes = tf.placeholder(dtype=tf.float32, name='mbboxes') 53 | self.true_lbboxes = tf.placeholder(dtype=tf.float32, name='lbboxes') 54 | self.trainable = tf.placeholder(dtype=tf.bool, name='training') 55 | 56 | with tf.name_scope("define_loss"): 57 | self.model = YOLOV3(self.input_data, self.trainable) 58 | self.net_var = tf.global_variables() 59 | self.giou_loss, self.conf_loss, self.prob_loss = self.model.compute_loss( 60 | self.label_sbbox, self.label_mbbox, self.label_lbbox, 61 | self.true_sbboxes, self.true_mbboxes, self.true_lbboxes) 62 | self.loss = self.giou_loss + self.conf_loss + self.prob_loss 63 | 64 | with tf.name_scope('learn_rate'): 65 | self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step') 66 | warmup_steps = tf.constant(self.warmup_periods * self.steps_per_period, 67 | dtype=tf.float64, name='warmup_steps') 68 | train_steps = tf.constant( (self.first_stage_epochs + self.second_stage_epochs)* self.steps_per_period, 69 | dtype=tf.float64, name='train_steps') 70 | self.learn_rate = tf.cond( 71 | pred=self.global_step < warmup_steps, 72 | true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init, 73 | false_fn=lambda: self.learn_rate_end + 0.5 * (self.learn_rate_init - self.learn_rate_end) * 74 | (1 + tf.cos( 75 | (self.global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi)) 76 | ) 77 | global_step_update = tf.assign_add(self.global_step, 1.0) 78 | 79 | with tf.name_scope("define_weight_decay"): 80 | moving_ave = tf.train.ExponentialMovingAverage(self.moving_ave_decay).apply(tf.trainable_variables()) 81 | 82 | with tf.name_scope("define_first_stage_train"): 83 | self.first_stage_trainable_var_list = [] 84 | for var in tf.trainable_variables(): 85 | var_name = var.op.name 86 | var_name_mess = str(var_name).split('/') 87 | if var_name_mess[0] in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']: 88 | self.first_stage_trainable_var_list.append(var) 89 | 90 | first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, 91 | var_list=self.first_stage_trainable_var_list) 92 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 93 | with tf.control_dependencies([first_stage_optimizer, global_step_update]): 94 | with tf.control_dependencies([moving_ave]): 95 | self.train_op_with_frozen_variables = tf.no_op() 96 | 97 | with tf.name_scope("define_second_stage_train"): 98 | second_stage_trainable_var_list = tf.trainable_variables() 99 | second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, 100 | var_list=second_stage_trainable_var_list) 101 | 102 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 103 | with tf.control_dependencies([second_stage_optimizer, global_step_update]): 104 | with tf.control_dependencies([moving_ave]): 105 | self.train_op_with_all_variables = tf.no_op() 106 | 107 | with tf.name_scope('loader_and_saver'): 108 | self.loader = tf.train.Saver(self.net_var) 109 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10) 110 | 111 | with tf.name_scope('summary'): 112 | tf.summary.scalar("learn_rate", self.learn_rate) 113 | tf.summary.scalar("giou_loss", self.giou_loss) 114 | tf.summary.scalar("conf_loss", self.conf_loss) 115 | tf.summary.scalar("prob_loss", self.prob_loss) 116 | tf.summary.scalar("total_loss", self.loss) 117 | 118 | logdir = "./data/log/" 119 | if os.path.exists(logdir): shutil.rmtree(logdir) 120 | os.mkdir(logdir) 121 | self.write_op = tf.summary.merge_all() 122 | self.summary_writer = tf.summary.FileWriter(logdir, graph=self.sess.graph) 123 | 124 | 125 | def train(self): 126 | self.sess.run(tf.global_variables_initializer()) 127 | try: 128 | print('=> Restoring weights from: %s ... ' % self.initial_weight) 129 | self.loader.restore(self.sess, self.initial_weight) 130 | except: 131 | print('=> %s does not exist !!!' % self.initial_weight) 132 | print('=> Now it starts to train YOLOV3 from scratch ...') 133 | self.first_stage_epochs = 0 134 | 135 | for epoch in range(1, 1+self.first_stage_epochs+self.second_stage_epochs): 136 | if epoch <= self.first_stage_epochs: 137 | train_op = self.train_op_with_frozen_variables 138 | else: 139 | train_op = self.train_op_with_all_variables 140 | 141 | pbar = tqdm(self.trainset) 142 | train_epoch_loss, test_epoch_loss = [], [] 143 | 144 | for train_data in pbar: 145 | _, summary, train_step_loss, global_step_val = self.sess.run( 146 | [train_op, self.write_op, self.loss, self.global_step],feed_dict={ 147 | self.input_data: train_data[0], 148 | self.label_sbbox: train_data[1], 149 | self.label_mbbox: train_data[2], 150 | self.label_lbbox: train_data[3], 151 | self.true_sbboxes: train_data[4], 152 | self.true_mbboxes: train_data[5], 153 | self.true_lbboxes: train_data[6], 154 | self.trainable: True, 155 | }) 156 | 157 | train_epoch_loss.append(train_step_loss) 158 | self.summary_writer.add_summary(summary, global_step_val) 159 | pbar.set_description("train loss: %.2f" %train_step_loss) 160 | 161 | for test_data in self.testset: 162 | test_step_loss = self.sess.run( self.loss, feed_dict={ 163 | self.input_data: test_data[0], 164 | self.label_sbbox: test_data[1], 165 | self.label_mbbox: test_data[2], 166 | self.label_lbbox: test_data[3], 167 | self.true_sbboxes: test_data[4], 168 | self.true_mbboxes: test_data[5], 169 | self.true_lbboxes: test_data[6], 170 | self.trainable: False, 171 | }) 172 | 173 | test_epoch_loss.append(test_step_loss) 174 | 175 | train_epoch_loss, test_epoch_loss = np.mean(train_epoch_loss), np.mean(test_epoch_loss) 176 | ckpt_file = "./checkpoint/yolov3_test_loss=%.4f.ckpt" % test_epoch_loss 177 | log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) 178 | print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..." 179 | %(epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file)) 180 | self.saver.save(self.sess, ckpt_file, global_step=epoch) 181 | 182 | 183 | 184 | if __name__ == '__main__': YoloTrain().train() 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /video_demo.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding=utf-8 3 | #================================================================ 4 | # Copyright (C) 2018 * Ltd. All rights reserved. 5 | # 6 | # Editor : VIM 7 | # File name : video_demo.py 8 | # Author : YunYang1994 9 | # Created date: 2018-11-30 15:56:37 10 | # Description : 11 | # 12 | #================================================================ 13 | 14 | import cv2 15 | import time 16 | import numpy as np 17 | import core.utils as utils 18 | import tensorflow as tf 19 | from PIL import Image 20 | 21 | 22 | return_elements = ["input/input_data:0", "pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"] 23 | pb_file = "./yolov3_coco.pb" 24 | video_path = "./docs/images/road.mp4" 25 | # video_path = 0 26 | num_classes = 80 27 | input_size = 416 28 | graph = tf.Graph() 29 | return_tensors = utils.read_pb_return_tensors(graph, pb_file, return_elements) 30 | 31 | with tf.Session(graph=graph) as sess: 32 | vid = cv2.VideoCapture(video_path) 33 | while True: 34 | return_value, frame = vid.read() 35 | if return_value: 36 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 37 | image = Image.fromarray(frame) 38 | else: 39 | raise ValueError("No image!") 40 | frame_size = frame.shape[:2] 41 | image_data = utils.image_preporcess(np.copy(frame), [input_size, input_size]) 42 | image_data = image_data[np.newaxis, ...] 43 | prev_time = time.time() 44 | 45 | pred_sbbox, pred_mbbox, pred_lbbox = sess.run( 46 | [return_tensors[1], return_tensors[2], return_tensors[3]], 47 | feed_dict={ return_tensors[0]: image_data}) 48 | 49 | pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + num_classes)), 50 | np.reshape(pred_mbbox, (-1, 5 + num_classes)), 51 | np.reshape(pred_lbbox, (-1, 5 + num_classes))], axis=0) 52 | 53 | bboxes = utils.postprocess_boxes(pred_bbox, frame_size, input_size, 0.3) 54 | bboxes = utils.nms(bboxes, 0.45, method='nms') 55 | image = utils.draw_bbox(frame, bboxes) 56 | 57 | curr_time = time.time() 58 | exec_time = curr_time - prev_time 59 | result = np.asarray(image) 60 | info = "time: %.2f ms" %(1000*exec_time) 61 | cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE) 62 | result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 63 | cv2.imshow("result", result) 64 | if cv2.waitKey(1) & 0xFF == ord('q'): break 65 | 66 | 67 | 68 | 69 | --------------------------------------------------------------------------------