├── .gitignore ├── LICENSE ├── README.md ├── args.py ├── convert_weight.py ├── data ├── coco.names ├── darknet_weights │ └── readme ├── demo_data │ ├── dog.jpg │ ├── kite.jpg │ ├── messi.jpg │ └── results │ │ ├── dog.jpg │ │ ├── kite.jpg │ │ └── messi.jpg ├── logs │ └── readme ├── my_data │ └── readme └── yolo_anchors.txt ├── docs ├── backbone.png └── yolo_v3_architecture.png ├── eval.py ├── get_kmeans.py ├── misc ├── experiments_on_voc │ ├── args_voc.py │ ├── eval_voc.py │ ├── train.txt │ ├── val.txt │ └── voc.names ├── parse_voc_xml.py └── remove_optimizers_params_in_ckpt.py ├── model.py ├── test_single_image.py ├── train.py ├── utils ├── __init__.py ├── data_aug.py ├── data_utils.py ├── eval_utils.py ├── layer_utils.py ├── misc_utils.py ├── nms_utils.py └── plot_utils.py └── video_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # latex 2 | ## Core latex/pdflatex auxiliary files: 3 | *.aux 4 | *.lof 5 | *.log 6 | *.lot 7 | *.fls 8 | *.out 9 | *.toc 10 | *.fmt 11 | *.fot 12 | *.cb 13 | *.cb2 14 | .*.lb 15 | *.bbl 16 | *.blg 17 | *.synctex.gz 18 | ## Intermediate documents: 19 | *.dvi 20 | *.xdv 21 | *-converted-to.* 22 | # these rules might exclude image files for figures etc. 23 | # *.ps 24 | # *.eps 25 | # *.pdf 26 | 27 | # python 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | # C extensions 33 | *.so 34 | 35 | # c++ 36 | # Prerequisites 37 | *.d 38 | # Compiled Object files 39 | *.slo 40 | *.lo 41 | *.o 42 | *.obj 43 | # Precompiled Headers 44 | *.gch 45 | *.pch 46 | # Compiled Dynamic libraries 47 | *.so 48 | *.dylib 49 | *.dll 50 | 51 | 52 | # tfrecord 53 | *.tfrecord 54 | 55 | # macOS 56 | .DS_Store 57 | 58 | # folders 59 | *.vscode/* 60 | *.texpadtmp/* 61 | *.idea/* 62 | 63 | # darknet 64 | *.weights -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Wizyoung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv3_TensorFlow 2 | 3 | **NOTE:** This repo is no longer maintained (actually I dropped the support for a long time) as I have switched to PyTorch for one year. Life is short, I use PyTorch. 4 | 5 | 6 | -------- 7 | 8 | ### 1. Introduction 9 | 10 | This is my implementation of [YOLOv3](https://pjreddie.com/media/files/papers/YOLOv3.pdf) in pure TensorFlow. It contains the full pipeline of training and evaluation on your own dataset. The key features of this repo are: 11 | 12 | - Efficient tf.data pipeline 13 | - Weights converter (converting pretrained darknet weights on COCO dataset to TensorFlow checkpoint.) 14 | - Extremely fast GPU non maximum supression. 15 | - Full training and evaluation pipeline. 16 | - Kmeans algorithm to select prior anchor boxes. 17 | 18 | ### 2. Requirements 19 | 20 | Python version: 2 or 3 21 | 22 | Packages: 23 | 24 | - tensorflow >= 1.8.0 (theoretically any version that supports tf.data is ok) 25 | - opencv-python 26 | - tqdm 27 | 28 | ### 3. Weights convertion 29 | 30 | The pretrained darknet weights file can be downloaded [here](https://pjreddie.com/media/files/yolov3.weights). Place this weights file under directory `./data/darknet_weights/` and then run: 31 | 32 | ```shell 33 | python convert_weight.py 34 | ``` 35 | 36 | Then the converted TensorFlow checkpoint file will be saved to `./data/darknet_weights/` directory. 37 | 38 | You can also download the converted TensorFlow checkpoint file by me via [[Google Drive link](https://drive.google.com/drive/folders/1mXbNgNxyXPi7JNsnBaxEv1-nWr7SVoQt?usp=sharing)] or [[Github Release](https://github.com/wizyoung/YOLOv3_TensorFlow/releases/)] and then place it to the same directory. 39 | 40 | ### 4. Running demos 41 | 42 | There are some demo images and videos under the `./data/demo_data/`. You can run the demo by: 43 | 44 | Single image test demo: 45 | 46 | ```shell 47 | python test_single_image.py ./data/demo_data/messi.jpg 48 | ``` 49 | 50 | Video test demo: 51 | 52 | ```shell 53 | python video_test.py ./data/demo_data/video.mp4 54 | ``` 55 | 56 | Some results: 57 | 58 | ![](https://github.com/wizyoung/YOLOv3_TensorFlow/blob/master/data/demo_data/results/dog.jpg?raw=true) 59 | 60 | ![](https://github.com/wizyoung/YOLOv3_TensorFlow/blob/master/data/demo_data/results/messi.jpg?raw=true) 61 | 62 | ![](https://github.com/wizyoung/YOLOv3_TensorFlow/blob/master/data/demo_data/results/kite.jpg?raw=true) 63 | 64 | Compare the kite detection results with TensorFlow's offical API result [here](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/img/kites_detections_output.jpg). 65 | 66 | (The kite detection result is under input image resolution 1344x896) 67 | 68 | ### 5. Inference speed 69 | 70 | How fast is the inference speed? With images scaled to 416*416: 71 | 72 | 73 | | Backbone | GPU | Time(ms) | 74 | | :-------------------- | :------: | :------: | 75 | | Darknet-53 (paper) | Titan X | 29 | 76 | | Darknet-53 (my impl.) | Titan XP | ~23 | 77 | 78 | why is it so fast? Check the ImageNet classification result comparision from the paper: 79 | 80 | ![](https://github.com/wizyoung/YOLOv3_TensorFlow/blob/master/docs/backbone.png?raw=true) 81 | 82 | ### 6. Model architecture 83 | 84 | For better understanding of the model architecture, you can refer to the following picture. With great thanks to [Levio](https://blog.csdn.net/leviopku/article/details/82660381) for your excellent work! 85 | 86 | ![](https://github.com/wizyoung/YOLOv3_TensorFlow/blob/master/docs/yolo_v3_architecture.png?raw=true) 87 | 88 | ### 7. Training 89 | 90 | #### 7.1 Data preparation 91 | 92 | (1) annotation file 93 | 94 | Generate `train.txt/val.txt/test.txt` files under `./data/my_data/` directory. One line for one image, in the format like `image_index image_absolute_path img_width img_height box_1 box_2 ... box_n`. Box_x format: `label_index x_min y_min x_max y_max`. (The origin of coordinates is at the left top corner, left top => (xmin, ymin), right bottom => (xmax, ymax).) `image_index` is the line index which starts from zero. `label_index` is in range [0, class_num - 1]. 95 | 96 | For example: 97 | 98 | ``` 99 | 0 xxx/xxx/a.jpg 1920 1080 0 453 369 473 391 1 588 245 608 268 100 | 1 xxx/xxx/b.jpg 1920 1080 1 466 403 485 422 2 793 300 809 320 101 | ... 102 | ``` 103 | 104 | Since so many users report to use tools like LabelImg to generate xml format annotations, I add one demo script on VOC dataset to do the convertion. Check the `misc/parse_voc_xml.py` file for more details. 105 | 106 | (2) class_names file: 107 | 108 | Generate the `data.names` file under `./data/my_data/` directory. Each line represents a class name. 109 | 110 | For example: 111 | 112 | ``` 113 | bird 114 | person 115 | bike 116 | ... 117 | ``` 118 | 119 | The COCO dataset class names file is placed at `./data/coco.names`. 120 | 121 | (3) prior anchor file: 122 | 123 | Using the kmeans algorithm to get the prior anchors: 124 | 125 | ``` 126 | python get_kmeans.py 127 | ``` 128 | 129 | Then you will get 9 anchors and the average IoU. Save the anchors to a txt file. 130 | 131 | The COCO dataset anchors offered by YOLO's author is placed at `./data/yolo_anchors.txt`, you can use that one too. 132 | 133 | The yolo anchors computed by the kmeans script is on the resized image scale. The default resize method is the letterbox resize, i.e., keep the original aspect ratio in the resized image. 134 | 135 | #### 7.2 Training 136 | 137 | Using `train.py`. The hyper-parameters and the corresponding annotations can be found in `args.py`: 138 | 139 | ```shell 140 | CUDA_VISIBLE_DEVICES=GPU_ID python train.py 141 | ``` 142 | 143 | Check the `args.py` for more details. You should set the parameters yourself in your own specific task. 144 | 145 | ### 8. Evaluation 146 | 147 | Using `eval.py` to evaluate the validation or test dataset. The parameters are as following: 148 | 149 | ```shell 150 | $ python eval.py -h 151 | usage: eval.py [-h] [--eval_file EVAL_FILE] 152 | [--restore_path RESTORE_PATH] 153 | [--anchor_path ANCHOR_PATH] 154 | [--class_name_path CLASS_NAME_PATH] 155 | [--batch_size BATCH_SIZE] 156 | [--img_size [IMG_SIZE [IMG_SIZE ...]]] 157 | [--num_threads NUM_THREADS] 158 | [--prefetech_buffer PREFETECH_BUFFER] 159 | [--nms_threshold NMS_THRESHOLD] 160 | [--score_threshold SCORE_THRESHOLD] 161 | [--nms_topk NMS_TOPK] 162 | ``` 163 | 164 | Check the `eval.py` for more details. You should set the parameters yourself. 165 | 166 | You will get the loss, recall, precision, average precision and mAP metrics results. 167 | 168 | For higher mAP, you should set score_threshold to a small number. 169 | 170 | ### 9. Some tricks 171 | 172 | Here are some training tricks in my experiment: 173 | 174 | (1) Apply the two-stage training strategy or the one-stage training strategy: 175 | 176 | Two-stage training: 177 | 178 | First stage: Restore `darknet53_body` part weights from COCO checkpoints, train the `yolov3_head` with big learning rate like 1e-3 until the loss reaches to a low level. 179 | 180 | Second stage: Restore the weights from the first stage, then train the whole model with small learning rate like 1e-4 or smaller. At this stage remember to restore the optimizer parameters if you use optimizers like adam. 181 | 182 | One-stage training: 183 | 184 | Just restore the whole weight file except the last three convolution layers (Conv_6, Conv_14, Conv_22). In this condition, be careful about the possible nan loss value. 185 | 186 | (2) I've included many useful training strategies in `args.py`: 187 | 188 | - Cosine decay of lr (SGDR) 189 | - Multi-scale training 190 | - Label smoothing 191 | - Mix up data augmentation 192 | - Focal loss 193 | 194 | These are all good strategies but it does **not** mean they will definitely improve the performance. You should choose the appropriate strategies for your own task. 195 | 196 | This [paper](https://arxiv.org/abs/1902.04103) from gluon-cv has proved that data augmentation is critical to YOLO v3, which is completely in consistent with my own experiments. Some data augmentation strategies that seems reasonable may lead to poor performance. For example, after introducing random color jittering, the mAP on my own dataset drops heavily. Thus I hope you pay extra attention to the data augmentation. 197 | 198 | (4) Loss nan? Setting a bigger warm_up_epoch number or smaller learning rate and try several more times. If you fine-tune the whole model, using adam may cause nan value sometimes. You can try choosing momentum optimizer. 199 | 200 | ### 10. Fine-tune on VOC dataset 201 | 202 | I did a quick train on the VOC dataset. The params I used in my experiments are included under `misc/experiments_on_voc/` folder for your reference. The train dataset is the VOC 2007 + 2012 trainval set, and the test dataset is the VOC 2007 test set. 203 | 204 | Finally with the 416\*416 input image, I got a 87.54% test mAP (not using the 07 metric). No hard-try fine-tuning. You should get the similar or better results. 205 | 206 | My pretrained weights on VOC dataset can be downloaded [here](https://drive.google.com/drive/folders/1ICKcJPozQOVRQnE1_vMn90nr7dejg0yW?usp=sharing). 207 | 208 | ### 11. TODO 209 | 210 | [ ] Multi-GPUs with sync batch norm. 211 | 212 | [ ] Maybe tf 2.0 ? 213 | 214 | ------- 215 | 216 | ### Credits: 217 | 218 | I referred to many fantastic repos during the implementation: 219 | 220 | [YunYang1994/tensorflow-yolov3](https://github.com/YunYang1994/tensorflow-yolov3) 221 | 222 | [qqwweee/keras-yolo3](https://github.com/qqwweee/keras-yolo3) 223 | 224 | [eriklindernoren/PyTorch-YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) 225 | 226 | [pjreddie/darknet](https://github.com/pjreddie/darknet) 227 | 228 | [dmlc/gluon-cv](https://github.com/dmlc/gluon-cv/tree/master/scripts/detection/yolo) 229 | 230 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # This file contains the parameter used in train.py 3 | 4 | from __future__ import division, print_function 5 | 6 | from utils.misc_utils import parse_anchors, read_class_names 7 | import math 8 | 9 | ### Some paths 10 | train_file = './data/my_data/train.txt' # The path of the training txt file. 11 | val_file = './data/my_data/val.txt' # The path of the validation txt file. 12 | restore_path = './data/darknet_weights/yolov3.ckpt' # The path of the weights to restore. 13 | save_dir = './checkpoint/' # The directory of the weights to save. 14 | log_dir = './data/logs/' # The directory to store the tensorboard log files. 15 | progress_log_path = './data/progress.log' # The path to record the training progress. 16 | anchor_path = './data/yolo_anchors.txt' # The path of the anchor txt file. 17 | class_name_path = './data/coco.names' # The path of the class names. 18 | 19 | ### Training releated numbers 20 | batch_size = 6 21 | img_size = [416, 416] # Images will be resized to `img_size` and fed to the network, size format: [width, height] 22 | letterbox_resize = True # Whether to use the letterbox resize, i.e., keep the original aspect ratio in the resized image. 23 | total_epoches = 100 24 | train_evaluation_step = 100 # Evaluate on the training batch after some steps. 25 | val_evaluation_epoch = 2 # Evaluate on the whole validation dataset after some epochs. Set to None to evaluate every epoch. 26 | save_epoch = 10 # Save the model after some epochs. 27 | batch_norm_decay = 0.99 # decay in bn ops 28 | weight_decay = 5e-4 # l2 weight decay 29 | global_step = 0 # used when resuming training 30 | 31 | ### tf.data parameters 32 | num_threads = 10 # Number of threads for image processing used in tf.data pipeline. 33 | prefetech_buffer = 5 # Prefetech_buffer used in tf.data pipeline. 34 | 35 | ### Learning rate and optimizer 36 | optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop] 37 | save_optimizer = True # Whether to save the optimizer parameters into the checkpoint file. 38 | learning_rate_init = 1e-4 39 | lr_type = 'piecewise' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise] 40 | lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type. 41 | lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type. 42 | lr_lower_bound = 1e-6 # The minimum learning rate. 43 | # only used in piecewise lr type 44 | pw_boundaries = [30, 50] # epoch based boundaries 45 | pw_values = [learning_rate_init, 3e-5, 1e-5] 46 | 47 | ### Load and finetune 48 | # Choose the parts you want to restore the weights. List form. 49 | # restore_include: None, restore_exclude: None => restore the whole model 50 | # restore_include: None, restore_exclude: scope => restore the whole model except `scope` 51 | # restore_include: scope1, restore_exclude: scope2 => if scope1 contains scope2, restore scope1 and not restore scope2 (scope1 - scope2) 52 | # choise 1: only restore the darknet body 53 | # restore_include = ['yolov3/darknet53_body'] 54 | # restore_exclude = None 55 | # choise 2: restore all layers except the last 3 conv2d layers in 3 scale 56 | restore_include = None 57 | restore_exclude = ['yolov3/yolov3_head/Conv_14', 'yolov3/yolov3_head/Conv_6', 'yolov3/yolov3_head/Conv_22'] 58 | # Choose the parts you want to finetune. List form. 59 | # Set to None to train the whole model. 60 | update_part = ['yolov3/yolov3_head'] 61 | 62 | ### other training strategies 63 | multi_scale_train = True # Whether to apply multi-scale training strategy. Image size varies from [320, 320] to [640, 640] by default. 64 | use_label_smooth = True # Whether to use class label smoothing strategy. 65 | use_focal_loss = True # Whether to apply focal loss on the conf loss. 66 | use_mix_up = True # Whether to use mix up data augmentation strategy. 67 | use_warm_up = True # whether to use warm up strategy to prevent from gradient exploding. 68 | warm_up_epoch = 3 # Warm up training epoches. Set to a larger value if gradient explodes. 69 | 70 | ### some constants in validation 71 | # nms 72 | nms_threshold = 0.45 # iou threshold in nms operation 73 | score_threshold = 0.01 # threshold of the probability of the classes in nms operation, i.e. score = pred_confs * pred_probs. set lower for higher recall. 74 | nms_topk = 150 # keep at most nms_topk outputs after nms 75 | # mAP eval 76 | eval_threshold = 0.5 # the iou threshold applied in mAP evaluation 77 | use_voc_07_metric = False # whether to use voc 2007 evaluation metric, i.e. the 11-point metric 78 | 79 | ### parse some params 80 | anchors = parse_anchors(anchor_path) 81 | classes = read_class_names(class_name_path) 82 | class_num = len(classes) 83 | train_img_cnt = len(open(train_file, 'r').readlines()) 84 | val_img_cnt = len(open(val_file, 'r').readlines()) 85 | train_batch_num = int(math.ceil(float(train_img_cnt) / batch_size)) 86 | 87 | lr_decay_freq = int(train_batch_num * lr_decay_epoch) 88 | pw_boundaries = [float(i) * train_batch_num + global_step for i in pw_boundaries] -------------------------------------------------------------------------------- /convert_weight.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # for more details about the yolo darknet weights file, refer to 3 | # https://itnext.io/implementing-yolo-v3-in-tensorflow-tf-slim-c3c55ff59dbe 4 | 5 | from __future__ import division, print_function 6 | 7 | import os 8 | import sys 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from model import yolov3 13 | from utils.misc_utils import parse_anchors, load_weights 14 | 15 | num_class = 80 16 | img_size = 416 17 | weight_path = './data/darknet_weights/yolov3.weights' 18 | save_path = './data/darknet_weights/yolov3.ckpt' 19 | anchors = parse_anchors('./data/yolo_anchors.txt') 20 | 21 | model = yolov3(80, anchors) 22 | with tf.Session() as sess: 23 | inputs = tf.placeholder(tf.float32, [1, img_size, img_size, 3]) 24 | 25 | with tf.variable_scope('yolov3'): 26 | feature_map = model.forward(inputs) 27 | 28 | saver = tf.train.Saver(var_list=tf.global_variables(scope='yolov3')) 29 | 30 | load_ops = load_weights(tf.global_variables(scope='yolov3'), weight_path) 31 | sess.run(load_ops) 32 | saver.save(sess, save_path=save_path) 33 | print('TensorFlow model checkpoint has been saved to {}'.format(save_path)) 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /data/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /data/darknet_weights/readme: -------------------------------------------------------------------------------- 1 | place pretrained weights on COCO dataset here. -------------------------------------------------------------------------------- /data/demo_data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/dog.jpg -------------------------------------------------------------------------------- /data/demo_data/kite.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/kite.jpg -------------------------------------------------------------------------------- /data/demo_data/messi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/messi.jpg -------------------------------------------------------------------------------- /data/demo_data/results/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/results/dog.jpg -------------------------------------------------------------------------------- /data/demo_data/results/kite.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/results/kite.jpg -------------------------------------------------------------------------------- /data/demo_data/results/messi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/data/demo_data/results/messi.jpg -------------------------------------------------------------------------------- /data/logs/readme: -------------------------------------------------------------------------------- 1 | tensorboard event files will be here. -------------------------------------------------------------------------------- /data/my_data/readme: -------------------------------------------------------------------------------- 1 | place your data files here. -------------------------------------------------------------------------------- /data/yolo_anchors.txt: -------------------------------------------------------------------------------- 1 | 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 -------------------------------------------------------------------------------- /docs/backbone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/docs/backbone.png -------------------------------------------------------------------------------- /docs/yolo_v3_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/docs/yolo_v3_architecture.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import argparse 8 | from tqdm import trange 9 | 10 | from utils.data_utils import get_batch_data 11 | from utils.misc_utils import parse_anchors, read_class_names, AverageMeter 12 | from utils.eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec 13 | from utils.nms_utils import gpu_nms 14 | 15 | from model import yolov3 16 | 17 | ################# 18 | # ArgumentParser 19 | ################# 20 | parser = argparse.ArgumentParser(description="YOLO-V3 eval procedure.") 21 | # some paths 22 | parser.add_argument("--eval_file", type=str, default="./data/my_data/val.txt", 23 | help="The path of the validation or test txt file.") 24 | 25 | parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt", 26 | help="The path of the weights to restore.") 27 | 28 | parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", 29 | help="The path of the anchor txt file.") 30 | 31 | parser.add_argument("--class_name_path", type=str, default="./data/coco.names", 32 | help="The path of the class names.") 33 | 34 | # some numbers 35 | parser.add_argument("--img_size", nargs='*', type=int, default=[416, 416], 36 | help="Resize the input image to `img_size`, size format: [width, height]") 37 | 38 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=False, 39 | help="Whether to use the letterbox resize, i.e., keep the original image aspect ratio.") 40 | 41 | parser.add_argument("--num_threads", type=int, default=10, 42 | help="Number of threads for image processing used in tf.data pipeline.") 43 | 44 | parser.add_argument("--prefetech_buffer", type=int, default=5, 45 | help="Prefetech_buffer used in tf.data pipeline.") 46 | 47 | parser.add_argument("--nms_threshold", type=float, default=0.45, 48 | help="IOU threshold in nms operation.") 49 | 50 | parser.add_argument("--score_threshold", type=float, default=0.01, 51 | help="Threshold of the probability of the classes in nms operation.") 52 | 53 | parser.add_argument("--nms_topk", type=int, default=400, 54 | help="Keep at most nms_topk outputs after nms.") 55 | 56 | parser.add_argument("--use_voc_07_metric", type=lambda x: (str(x).lower() == 'true'), default=False, 57 | help="Whether to use the voc 2007 mAP metrics.") 58 | 59 | args = parser.parse_args() 60 | 61 | # args params 62 | args.anchors = parse_anchors(args.anchor_path) 63 | args.classes = read_class_names(args.class_name_path) 64 | args.class_num = len(args.classes) 65 | args.img_cnt = len(open(args.eval_file, 'r').readlines()) 66 | 67 | # setting placeholders 68 | is_training = tf.placeholder(dtype=tf.bool, name="phase_train") 69 | handle_flag = tf.placeholder(tf.string, [], name='iterator_handle_flag') 70 | pred_boxes_flag = tf.placeholder(tf.float32, [1, None, None]) 71 | pred_scores_flag = tf.placeholder(tf.float32, [1, None, None]) 72 | gpu_nms_op = gpu_nms(pred_boxes_flag, pred_scores_flag, args.class_num, args.nms_topk, args.score_threshold, args.nms_threshold) 73 | 74 | ################## 75 | # tf.data pipeline 76 | ################## 77 | val_dataset = tf.data.TextLineDataset(args.eval_file) 78 | val_dataset = val_dataset.batch(1) 79 | val_dataset = val_dataset.map( 80 | lambda x: tf.py_func(get_batch_data, [x, args.class_num, args.img_size, args.anchors, 'val', False, False, args.letterbox_resize], [tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), 81 | num_parallel_calls=args.num_threads 82 | ) 83 | val_dataset.prefetch(args.prefetech_buffer) 84 | iterator = val_dataset.make_one_shot_iterator() 85 | 86 | image_ids, image, y_true_13, y_true_26, y_true_52 = iterator.get_next() 87 | image_ids.set_shape([None]) 88 | y_true = [y_true_13, y_true_26, y_true_52] 89 | image.set_shape([None, args.img_size[1], args.img_size[0], 3]) 90 | for y in y_true: 91 | y.set_shape([None, None, None, None, None]) 92 | 93 | ################## 94 | # Model definition 95 | ################## 96 | yolo_model = yolov3(args.class_num, args.anchors) 97 | with tf.variable_scope('yolov3'): 98 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) 99 | loss = yolo_model.compute_loss(pred_feature_maps, y_true) 100 | y_pred = yolo_model.predict(pred_feature_maps) 101 | 102 | saver_to_restore = tf.train.Saver() 103 | 104 | with tf.Session() as sess: 105 | sess.run([tf.global_variables_initializer()]) 106 | saver_to_restore.restore(sess, args.restore_path) 107 | 108 | print('\n----------- start to eval -----------\n') 109 | 110 | val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = \ 111 | AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 112 | val_preds = [] 113 | 114 | for j in trange(args.img_cnt): 115 | __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], feed_dict={is_training: False}) 116 | pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) 117 | 118 | val_preds.extend(pred_content) 119 | val_loss_total.update(__loss[0]) 120 | val_loss_xy.update(__loss[1]) 121 | val_loss_wh.update(__loss[2]) 122 | val_loss_conf.update(__loss[3]) 123 | val_loss_class.update(__loss[4]) 124 | 125 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() 126 | gt_dict = parse_gt_rec(args.eval_file, args.img_size, args.letterbox_resize) 127 | print('mAP eval:') 128 | for ii in range(args.class_num): 129 | npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=0.5, use_07_metric=args.use_voc_07_metric) 130 | rec_total.update(rec, npos) 131 | prec_total.update(prec, nd) 132 | ap_total.update(ap, 1) 133 | print('Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}'.format(ii, rec, prec, ap)) 134 | 135 | mAP = ap_total.average 136 | print('final mAP: {:.4f}'.format(mAP)) 137 | print("recall: {:.3f}, precision: {:.3f}".format(rec_total.average, prec_total.average)) 138 | print("total_loss: {:.3f}, loss_xy: {:.3f}, loss_wh: {:.3f}, loss_conf: {:.3f}, loss_class: {:.3f}".format( 139 | val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average 140 | )) 141 | -------------------------------------------------------------------------------- /get_kmeans.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # This script is modified from https://github.com/lars76/kmeans-anchor-boxes 3 | 4 | from __future__ import division, print_function 5 | 6 | import numpy as np 7 | 8 | def iou(box, clusters): 9 | """ 10 | Calculates the Intersection over Union (IoU) between a box and k clusters. 11 | param: 12 | box: tuple or array, shifted to the origin (i. e. width and height) 13 | clusters: numpy array of shape (k, 2) where k is the number of clusters 14 | return: 15 | numpy array of shape (k, 0) where k is the number of clusters 16 | """ 17 | x = np.minimum(clusters[:, 0], box[0]) 18 | y = np.minimum(clusters[:, 1], box[1]) 19 | if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0: 20 | raise ValueError("Box has no area") 21 | 22 | intersection = x * y 23 | box_area = box[0] * box[1] 24 | cluster_area = clusters[:, 0] * clusters[:, 1] 25 | 26 | iou_ = np.true_divide(intersection, box_area + cluster_area - intersection + 1e-10) 27 | # iou_ = intersection / (box_area + cluster_area - intersection + 1e-10) 28 | 29 | return iou_ 30 | 31 | 32 | def avg_iou(boxes, clusters): 33 | """ 34 | Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters. 35 | param: 36 | boxes: numpy array of shape (r, 2), where r is the number of rows 37 | clusters: numpy array of shape (k, 2) where k is the number of clusters 38 | return: 39 | average IoU as a single float 40 | """ 41 | return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])]) 42 | 43 | 44 | def translate_boxes(boxes): 45 | """ 46 | Translates all the boxes to the origin. 47 | param: 48 | boxes: numpy array of shape (r, 4) 49 | return: 50 | numpy array of shape (r, 2) 51 | """ 52 | new_boxes = boxes.copy() 53 | for row in range(new_boxes.shape[0]): 54 | new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0]) 55 | new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1]) 56 | return np.delete(new_boxes, [0, 1], axis=1) 57 | 58 | 59 | def kmeans(boxes, k, dist=np.median): 60 | """ 61 | Calculates k-means clustering with the Intersection over Union (IoU) metric. 62 | param: 63 | boxes: numpy array of shape (r, 2), where r is the number of rows 64 | k: number of clusters 65 | dist: distance function 66 | return: 67 | numpy array of shape (k, 2) 68 | """ 69 | rows = boxes.shape[0] 70 | 71 | distances = np.empty((rows, k)) 72 | last_clusters = np.zeros((rows,)) 73 | 74 | np.random.seed() 75 | 76 | # the Forgy method will fail if the whole array contains the same rows 77 | clusters = boxes[np.random.choice(rows, k, replace=False)] 78 | 79 | while True: 80 | for row in range(rows): 81 | distances[row] = 1 - iou(boxes[row], clusters) 82 | 83 | nearest_clusters = np.argmin(distances, axis=1) 84 | 85 | if (last_clusters == nearest_clusters).all(): 86 | break 87 | 88 | for cluster in range(k): 89 | clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0) 90 | 91 | last_clusters = nearest_clusters 92 | 93 | return clusters 94 | 95 | 96 | def parse_anno(annotation_path, target_size=None): 97 | anno = open(annotation_path, 'r') 98 | result = [] 99 | for line in anno: 100 | s = line.strip().split(' ') 101 | img_w = int(s[2]) 102 | img_h = int(s[3]) 103 | s = s[4:] 104 | box_cnt = len(s) // 5 105 | for i in range(box_cnt): 106 | x_min, y_min, x_max, y_max = float(s[i*5+1]), float(s[i*5+2]), float(s[i*5+3]), float(s[i*5+4]) 107 | width = x_max - x_min 108 | height = y_max - y_min 109 | assert width > 0 110 | assert height > 0 111 | # use letterbox resize, i.e. keep the original aspect ratio 112 | # get k-means anchors on the resized target image size 113 | if target_size is not None: 114 | resize_ratio = min(target_size[0] / img_w, target_size[1] / img_h) 115 | width *= resize_ratio 116 | height *= resize_ratio 117 | result.append([width, height]) 118 | # get k-means anchors on the original image size 119 | else: 120 | result.append([width, height]) 121 | result = np.asarray(result) 122 | return result 123 | 124 | 125 | def get_kmeans(anno, cluster_num=9): 126 | 127 | anchors = kmeans(anno, cluster_num) 128 | ave_iou = avg_iou(anno, anchors) 129 | 130 | anchors = anchors.astype('int').tolist() 131 | 132 | anchors = sorted(anchors, key=lambda x: x[0] * x[1]) 133 | 134 | return anchors, ave_iou 135 | 136 | 137 | if __name__ == '__main__': 138 | # target resize format: [width, height] 139 | # if target_resize is speficied, the anchors are on the resized image scale 140 | # if target_resize is set to None, the anchors are on the original image scale 141 | target_size = [416, 416] 142 | annotation_path = "train.txt" 143 | anno_result = parse_anno(annotation_path, target_size=target_size) 144 | anchors, ave_iou = get_kmeans(anno_result, 9) 145 | 146 | anchor_string = '' 147 | for anchor in anchors: 148 | anchor_string += '{},{}, '.format(anchor[0], anchor[1]) 149 | anchor_string = anchor_string[:-2] 150 | 151 | print('anchors are:') 152 | print(anchor_string) 153 | print('the average iou is:') 154 | print(ave_iou) 155 | 156 | -------------------------------------------------------------------------------- /misc/experiments_on_voc/args_voc.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # This file contains the parameter used in train.py 3 | 4 | from __future__ import division, print_function 5 | 6 | from utils.misc_utils import parse_anchors, read_class_names 7 | import math 8 | 9 | ### Some paths 10 | train_file = './data/my_data/train.txt' # The path of the training txt file. 11 | val_file = './data/my_data/val.txt' # The path of the validation txt file. 12 | restore_path = './data/darknet_weights/yolov3.ckpt' # The path of the weights to restore. 13 | save_dir = './checkpoint/' # The directory of the weights to save. 14 | log_dir = './data/logs/' # The directory to store the tensorboard log files. 15 | progress_log_path = './data/progress.log' # The path to record the training progress. 16 | anchor_path = './data/yolo_anchors.txt' # The path of the anchor txt file. 17 | class_name_path = './data/voc.names' # The path of the class names. 18 | 19 | ### Training releated numbers 20 | batch_size = 6 21 | img_size = [416, 416] # Images will be resized to `img_size` and fed to the network, size format: [width, height] 22 | letterbox_resize = False # Whether to use the letterbox resize, i.e., keep the original aspect ratio in the resized image. 23 | total_epoches = 100 24 | train_evaluation_step = 100 # Evaluate on the training batch after some steps. 25 | val_evaluation_epoch = 1 # Evaluate on the whole validation dataset after some steps. Set to None to evaluate every epoch. 26 | save_epoch = 10 # Save the model after some epochs. 27 | batch_norm_decay = 0.99 # decay in bn ops 28 | weight_decay = 5e-4 # l2 weight decay 29 | global_step = 0 # used when resuming training 30 | 31 | ### tf.data parameters 32 | num_threads = 10 # Number of threads for image processing used in tf.data pipeline. 33 | prefetech_buffer = 5 # Prefetech_buffer used in tf.data pipeline. 34 | 35 | ### Learning rate and optimizer 36 | optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop] 37 | save_optimizer = False # Whether to save the optimizer parameters into the checkpoint file. 38 | learning_rate_init = 1e-4 39 | lr_type = 'piecewise' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise] 40 | lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type. 41 | lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type. 42 | lr_lower_bound = 1e-6 # The minimum learning rate. 43 | # piecewise params 44 | pw_boundaries = [25, 40] # epoch based boundaries 45 | pw_values = [learning_rate_init, 3e-5, 1e-4] 46 | 47 | ### Load and finetune 48 | # Choose the parts you want to restore the weights. List form. 49 | # restore_include: None, restore_exclude: None => restore the whole model 50 | # restore_include: None, restore_exclude: scope => restore the whole model except `scope` 51 | # restore_include: scope1, restore_exclude: scope2 => if scope1 contains scope2, restore scope1 and not restore scope2 (scope1 - scope2) 52 | # choise 1: only restore the darknet body 53 | # restore_include = ['yolov3/darknet53_body'] 54 | # restore_exclude = None 55 | # choise 2: restore all layers except the last 3 conv2d layers in 3 scale 56 | restore_include = None 57 | restore_exclude = ['yolov3/yolov3_head/Conv_14', 'yolov3/yolov3_head/Conv_6', 'yolov3/yolov3_head/Conv_22'] 58 | # Choose the parts you want to finetune. List form. 59 | # Set to None to train the whole model. 60 | update_part = None 61 | 62 | ### other training strategies 63 | multi_scale_train = True # Whether to apply multi-scale training strategy. Image size varies from [320, 320] to [640, 640] by default. 64 | use_label_smooth = True # Whether to use class label smoothing strategy. 65 | use_focal_loss = True # Whether to apply focal loss on the conf loss. 66 | use_mix_up = True # Whether to use mix up data augmentation strategy. 67 | use_warm_up = True # whether to use warm up strategy to prevent from gradient exploding. 68 | warm_up_epoch = 3 # Warm up training epoches. Set to a larger value if gradient explodes. 69 | 70 | ### some constants in validation 71 | # nms 72 | nms_threshold = 0.45 # iou threshold in nms operation 73 | score_threshold = 0.01 # threshold of the probability of the classes in nms operation, i.e. score = pred_confs * pred_probs. set lower for higher recall. 74 | nms_topk = 150 # keep at most nms_topk outputs after nms 75 | # mAP eval 76 | eval_threshold = 0.5 # the iou threshold applied in mAP evaluation 77 | use_voc_07_metric = False # whether to use voc 2007 evaluation metric, i.e. the 11-point metric 78 | 79 | ### parse some params 80 | anchors = parse_anchors(anchor_path) 81 | classes = read_class_names(class_name_path) 82 | class_num = len(classes) 83 | train_img_cnt = len(open(train_file, 'r').readlines()) 84 | val_img_cnt = len(open(val_file, 'r').readlines()) 85 | train_batch_num = int(math.ceil(float(train_img_cnt) / batch_size)) 86 | 87 | lr_decay_freq = int(train_batch_num * lr_decay_epoch) 88 | pw_boundaries = [float(i) * train_batch_num + global_step for i in pw_boundaries] -------------------------------------------------------------------------------- /misc/experiments_on_voc/eval_voc.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import argparse 8 | from tqdm import trange 9 | 10 | from utils.data_utils import get_batch_data 11 | from utils.misc_utils import parse_anchors, read_class_names, AverageMeter 12 | from utils.eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec 13 | from utils.nms_utils import gpu_nms 14 | 15 | from model import yolov3 16 | 17 | ################# 18 | # ArgumentParser 19 | ################# 20 | parser = argparse.ArgumentParser(description="YOLO-V3 eval procedure.") 21 | # some paths 22 | parser.add_argument("--eval_file", type=str, default="./data/my_data/val.txt", 23 | help="The path of the validation or test txt file.") 24 | 25 | parser.add_argument("--restore_path", type=str, default="./data/checkpoint_whole_finetune_no_letterbox/best_model_Epoch_32_step_91046_mAP_0.8754_loss_2.2147_lr_3e-05", 26 | help="The path of the weights to restore.") 27 | 28 | parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", 29 | help="The path of the anchor txt file.") 30 | 31 | parser.add_argument("--class_name_path", type=str, default="./data/voc.names", 32 | help="The path of the class names.") 33 | 34 | # some numbers 35 | parser.add_argument("--img_size", nargs='*', type=int, default=[416, 416], 36 | help="Resize the input image to `img_size`, size format: [width, height]") 37 | 38 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=False, 39 | help="Whether to use the letterbox resize.") 40 | 41 | parser.add_argument("--num_threads", type=int, default=10, 42 | help="Number of threads for image processing used in tf.data pipeline.") 43 | 44 | parser.add_argument("--prefetech_buffer", type=int, default=5, 45 | help="Prefetech_buffer used in tf.data pipeline.") 46 | 47 | parser.add_argument("--nms_threshold", type=float, default=0.45, 48 | help="IOU threshold in nms operation.") 49 | 50 | parser.add_argument("--score_threshold", type=float, default=0.01, 51 | help="Threshold of the probability of the classes in nms operation.") 52 | 53 | parser.add_argument("--nms_topk", type=int, default=150, 54 | help="Keep at most nms_topk outputs after nms.") 55 | 56 | parser.add_argument("--use_voc_07_metric", type=lambda x: (str(x).lower() == 'true'), default=False, 57 | help="Whether to use the voc 2007 mAP metrics.") 58 | 59 | args = parser.parse_args() 60 | 61 | # args params 62 | args.anchors = parse_anchors(args.anchor_path) 63 | args.classes = read_class_names(args.class_name_path) 64 | args.class_num = len(args.classes) 65 | args.img_cnt = len(open(args.eval_file, 'r').readlines()) 66 | 67 | # setting placeholders 68 | is_training = tf.placeholder(dtype=tf.bool, name="phase_train") 69 | handle_flag = tf.placeholder(tf.string, [], name='iterator_handle_flag') 70 | pred_boxes_flag = tf.placeholder(tf.float32, [1, None, None]) 71 | pred_scores_flag = tf.placeholder(tf.float32, [1, None, None]) 72 | gpu_nms_op = gpu_nms(pred_boxes_flag, pred_scores_flag, args.class_num, args.nms_topk, args.score_threshold, args.nms_threshold) 73 | 74 | ################## 75 | # tf.data pipeline 76 | ################## 77 | val_dataset = tf.data.TextLineDataset(args.eval_file) 78 | val_dataset = val_dataset.batch(1) 79 | val_dataset = val_dataset.map( 80 | lambda x: tf.py_func(get_batch_data, [x, args.class_num, args.img_size, args.anchors, 'val', False, False, args.letterbox_resize], [tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), 81 | num_parallel_calls=args.num_threads 82 | ) 83 | val_dataset.prefetch(args.prefetech_buffer) 84 | iterator = val_dataset.make_one_shot_iterator() 85 | 86 | image_ids, image, y_true_13, y_true_26, y_true_52 = iterator.get_next() 87 | image_ids.set_shape([None]) 88 | y_true = [y_true_13, y_true_26, y_true_52] 89 | image.set_shape([None, args.img_size[1], args.img_size[0], 3]) 90 | for y in y_true: 91 | y.set_shape([None, None, None, None, None]) 92 | 93 | ################## 94 | # Model definition 95 | ################## 96 | yolo_model = yolov3(args.class_num, args.anchors) 97 | with tf.variable_scope('yolov3'): 98 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) 99 | loss = yolo_model.compute_loss(pred_feature_maps, y_true) 100 | y_pred = yolo_model.predict(pred_feature_maps) 101 | 102 | saver_to_restore = tf.train.Saver() 103 | 104 | with tf.Session() as sess: 105 | sess.run([tf.global_variables_initializer()]) 106 | saver_to_restore.restore(sess, args.restore_path) 107 | 108 | print('\n----------- start to eval -----------\n') 109 | 110 | val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = \ 111 | AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 112 | val_preds = [] 113 | 114 | for j in trange(args.img_cnt): 115 | __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], feed_dict={is_training: False}) 116 | pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) 117 | 118 | val_preds.extend(pred_content) 119 | val_loss_total.update(__loss[0]) 120 | val_loss_xy.update(__loss[1]) 121 | val_loss_wh.update(__loss[2]) 122 | val_loss_conf.update(__loss[3]) 123 | val_loss_class.update(__loss[4]) 124 | 125 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() 126 | gt_dict = parse_gt_rec(args.eval_file, args.img_size, args.letterbox_resize) 127 | print('mAP eval:') 128 | for ii in range(args.class_num): 129 | npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=0.5, use_07_metric=args.use_voc_07_metric) 130 | rec_total.update(rec, npos) 131 | prec_total.update(prec, nd) 132 | ap_total.update(ap, 1) 133 | print('Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}'.format(ii, rec, prec, ap)) 134 | 135 | mAP = ap_total.average 136 | print('final mAP: {:.4f}'.format(mAP)) 137 | print("recall: {:.3f}, precision: {:.3f}".format(rec_total.average, prec_total.average)) 138 | print("total_loss: {:.3f}, loss_xy: {:.3f}, loss_wh: {:.3f}, loss_conf: {:.3f}, loss_class: {:.3f}".format( 139 | val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average 140 | )) 141 | -------------------------------------------------------------------------------- /misc/experiments_on_voc/voc.names: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /misc/parse_voc_xml.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import xml.etree.ElementTree as ET 4 | import os 5 | 6 | names_dict = {} 7 | cnt = 0 8 | f = open('./voc_names.txt', 'r').readlines() 9 | for line in f: 10 | line = line.strip() 11 | names_dict[line] = cnt 12 | cnt += 1 13 | 14 | voc_07 = '/data/VOCdevkit/VOC2007' 15 | voc_12 = '/data/VOCdevkit/VOC2012' 16 | 17 | anno_path = [os.path.join(voc_07, 'Annotations'), os.path.join(voc_12, 'Annotations')] 18 | img_path = [os.path.join(voc_07, 'JPEGImages'), os.path.join(voc_12, 'JPEGImages')] 19 | 20 | trainval_path = [os.path.join(voc_07, 'ImageSets/Main/trainval.txt'), 21 | os.path.join(voc_12, 'ImageSets/Main/trainval.txt')] 22 | test_path = [os.path.join(voc_07, 'ImageSets/Main/test.txt')] 23 | 24 | 25 | def parse_xml(path): 26 | tree = ET.parse(path) 27 | img_name = path.split('/')[-1][:-4] 28 | 29 | height = tree.findtext("./size/height") 30 | width = tree.findtext("./size/width") 31 | 32 | objects = [img_name, width, height] 33 | 34 | for obj in tree.findall('object'): 35 | difficult = obj.find('difficult').text 36 | if difficult == '1': 37 | continue 38 | name = obj.find('name').text 39 | bbox = obj.find('bndbox') 40 | xmin = bbox.find('xmin').text 41 | ymin = bbox.find('ymin').text 42 | xmax = bbox.find('xmax').text 43 | ymax = bbox.find('ymax').text 44 | 45 | name = str(names_dict[name]) 46 | objects.extend([name, xmin, ymin, xmax, ymax]) 47 | if len(objects) > 1: 48 | return objects 49 | else: 50 | return None 51 | 52 | test_cnt = 0 53 | def gen_test_txt(txt_path): 54 | global test_cnt 55 | f = open(txt_path, 'w') 56 | 57 | for i, path in enumerate(test_path): 58 | img_names = open(path, 'r').readlines() 59 | for img_name in img_names: 60 | img_name = img_name.strip() 61 | xml_path = anno_path[i] + '/' + img_name + '.xml' 62 | objects = parse_xml(xml_path) 63 | if objects: 64 | objects[0] = img_path[i] + '/' + img_name + '.jpg' 65 | if os.path.exists(objects[0]): 66 | objects.insert(0, str(test_cnt)) 67 | test_cnt += 1 68 | objects = ' '.join(objects) + '\n' 69 | f.write(objects) 70 | f.close() 71 | 72 | 73 | train_cnt = 0 74 | def gen_train_txt(txt_path): 75 | global train_cnt 76 | f = open(txt_path, 'w') 77 | 78 | for i, path in enumerate(trainval_path): 79 | img_names = open(path, 'r').readlines() 80 | for img_name in img_names: 81 | img_name = img_name.strip() 82 | xml_path = anno_path[i] + '/' + img_name + '.xml' 83 | objects = parse_xml(xml_path) 84 | if objects: 85 | objects[0] = img_path[i] + '/' + img_name + '.jpg' 86 | if os.path.exists(objects[0]): 87 | objects.insert(0, str(train_cnt)) 88 | train_cnt += 1 89 | objects = ' '.join(objects) + '\n' 90 | f.write(objects) 91 | f.close() 92 | 93 | 94 | gen_train_txt('train.txt') 95 | gen_test_txt('val.txt') 96 | 97 | -------------------------------------------------------------------------------- /misc/remove_optimizers_params_in_ckpt.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # This script is used to remove the optimizer parameters in the saved checkpoint files. 4 | # These parameters are useless in the forward process. 5 | # Removing them will shrink the checkpoint size a lot. 6 | 7 | import sys 8 | sys.path.append('..') 9 | 10 | import os 11 | import tensorflow as tf 12 | from model import yolov3 13 | 14 | # params 15 | ckpt_path = '' 16 | class_num = 20 17 | save_dir = 'shrinked_ckpt' 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | 21 | image = tf.placeholder(tf.float32, [1, 416, 416, 3]) 22 | yolo_model = yolov3(class_num, None) 23 | with tf.variable_scope('yolov3'): 24 | pred_feature_maps = yolo_model.forward(image) 25 | 26 | saver_to_restore = tf.train.Saver() 27 | saver_to_save = tf.train.Saver() 28 | 29 | with tf.Session() as sess: 30 | sess.run(tf.global_variables_initializer()) 31 | saver_to_restore.restore(sess, ckpt_path) 32 | saver_to_save.save(sess, save_dir + '/shrinked') -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # for better understanding about yolov3 architecture, refer to this website (in Chinese): 3 | # https://blog.csdn.net/leviopku/article/details/82660381 4 | 5 | from __future__ import division, print_function 6 | 7 | import tensorflow as tf 8 | slim = tf.contrib.slim 9 | 10 | from utils.layer_utils import conv2d, darknet53_body, yolo_block, upsample_layer 11 | 12 | class yolov3(object): 13 | 14 | def __init__(self, class_num, anchors, use_label_smooth=False, use_focal_loss=False, batch_norm_decay=0.999, weight_decay=5e-4, use_static_shape=True): 15 | 16 | # self.anchors = [[10, 13], [16, 30], [33, 23], 17 | # [30, 61], [62, 45], [59, 119], 18 | # [116, 90], [156, 198], [373,326]] 19 | self.class_num = class_num 20 | self.anchors = anchors 21 | self.batch_norm_decay = batch_norm_decay 22 | self.use_label_smooth = use_label_smooth 23 | self.use_focal_loss = use_focal_loss 24 | self.weight_decay = weight_decay 25 | # inference speed optimization 26 | # if `use_static_shape` is True, use tensor.get_shape(), otherwise use tf.shape(tensor) 27 | # static_shape is slightly faster 28 | self.use_static_shape = use_static_shape 29 | 30 | def forward(self, inputs, is_training=False, reuse=False): 31 | # the input img_size, form: [height, weight] 32 | # it will be used later 33 | self.img_size = tf.shape(inputs)[1:3] 34 | # set batch norm params 35 | batch_norm_params = { 36 | 'decay': self.batch_norm_decay, 37 | 'epsilon': 1e-05, 38 | 'scale': True, 39 | 'is_training': is_training, 40 | 'fused': None, # Use fused batch norm if possible. 41 | } 42 | 43 | with slim.arg_scope([slim.conv2d, slim.batch_norm], reuse=reuse): 44 | with slim.arg_scope([slim.conv2d], 45 | normalizer_fn=slim.batch_norm, 46 | normalizer_params=batch_norm_params, 47 | biases_initializer=None, 48 | activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=0.1), 49 | weights_regularizer=slim.l2_regularizer(self.weight_decay)): 50 | with tf.variable_scope('darknet53_body'): 51 | route_1, route_2, route_3 = darknet53_body(inputs) 52 | 53 | with tf.variable_scope('yolov3_head'): 54 | inter1, net = yolo_block(route_3, 512) 55 | feature_map_1 = slim.conv2d(net, 3 * (5 + self.class_num), 1, 56 | stride=1, normalizer_fn=None, 57 | activation_fn=None, biases_initializer=tf.zeros_initializer()) 58 | feature_map_1 = tf.identity(feature_map_1, name='feature_map_1') 59 | 60 | inter1 = conv2d(inter1, 256, 1) 61 | inter1 = upsample_layer(inter1, route_2.get_shape().as_list() if self.use_static_shape else tf.shape(route_2)) 62 | concat1 = tf.concat([inter1, route_2], axis=3) 63 | 64 | inter2, net = yolo_block(concat1, 256) 65 | feature_map_2 = slim.conv2d(net, 3 * (5 + self.class_num), 1, 66 | stride=1, normalizer_fn=None, 67 | activation_fn=None, biases_initializer=tf.zeros_initializer()) 68 | feature_map_2 = tf.identity(feature_map_2, name='feature_map_2') 69 | 70 | inter2 = conv2d(inter2, 128, 1) 71 | inter2 = upsample_layer(inter2, route_1.get_shape().as_list() if self.use_static_shape else tf.shape(route_1)) 72 | concat2 = tf.concat([inter2, route_1], axis=3) 73 | 74 | _, feature_map_3 = yolo_block(concat2, 128) 75 | feature_map_3 = slim.conv2d(feature_map_3, 3 * (5 + self.class_num), 1, 76 | stride=1, normalizer_fn=None, 77 | activation_fn=None, biases_initializer=tf.zeros_initializer()) 78 | feature_map_3 = tf.identity(feature_map_3, name='feature_map_3') 79 | 80 | return feature_map_1, feature_map_2, feature_map_3 81 | 82 | def reorg_layer(self, feature_map, anchors): 83 | ''' 84 | feature_map: a feature_map from [feature_map_1, feature_map_2, feature_map_3] returned 85 | from `forward` function 86 | anchors: shape: [3, 2] 87 | ''' 88 | # NOTE: size in [h, w] format! don't get messed up! 89 | grid_size = feature_map.get_shape().as_list()[1:3] if self.use_static_shape else tf.shape(feature_map)[1:3] # [13, 13] 90 | # the downscale ratio in height and weight 91 | ratio = tf.cast(self.img_size / grid_size, tf.float32) 92 | # rescale the anchors to the feature_map 93 | # NOTE: the anchor is in [w, h] format! 94 | rescaled_anchors = [(anchor[0] / ratio[1], anchor[1] / ratio[0]) for anchor in anchors] 95 | 96 | feature_map = tf.reshape(feature_map, [-1, grid_size[0], grid_size[1], 3, 5 + self.class_num]) 97 | 98 | # split the feature_map along the last dimension 99 | # shape info: take 416x416 input image and the 13*13 feature_map for example: 100 | # box_centers: [N, 13, 13, 3, 2] last_dimension: [center_x, center_y] 101 | # box_sizes: [N, 13, 13, 3, 2] last_dimension: [width, height] 102 | # conf_logits: [N, 13, 13, 3, 1] 103 | # prob_logits: [N, 13, 13, 3, class_num] 104 | box_centers, box_sizes, conf_logits, prob_logits = tf.split(feature_map, [2, 2, 1, self.class_num], axis=-1) 105 | box_centers = tf.nn.sigmoid(box_centers) 106 | 107 | # use some broadcast tricks to get the mesh coordinates 108 | grid_x = tf.range(grid_size[1], dtype=tf.int32) 109 | grid_y = tf.range(grid_size[0], dtype=tf.int32) 110 | grid_x, grid_y = tf.meshgrid(grid_x, grid_y) 111 | x_offset = tf.reshape(grid_x, (-1, 1)) 112 | y_offset = tf.reshape(grid_y, (-1, 1)) 113 | x_y_offset = tf.concat([x_offset, y_offset], axis=-1) 114 | # shape: [13, 13, 1, 2] 115 | x_y_offset = tf.cast(tf.reshape(x_y_offset, [grid_size[0], grid_size[1], 1, 2]), tf.float32) 116 | 117 | # get the absolute box coordinates on the feature_map 118 | box_centers = box_centers + x_y_offset 119 | # rescale to the original image scale 120 | box_centers = box_centers * ratio[::-1] 121 | 122 | # avoid getting possible nan value with tf.clip_by_value 123 | box_sizes = tf.exp(box_sizes) * rescaled_anchors 124 | # box_sizes = tf.clip_by_value(tf.exp(box_sizes), 1e-9, 100) * rescaled_anchors 125 | # rescale to the original image scale 126 | box_sizes = box_sizes * ratio[::-1] 127 | 128 | # shape: [N, 13, 13, 3, 4] 129 | # last dimension: (center_x, center_y, w, h) 130 | boxes = tf.concat([box_centers, box_sizes], axis=-1) 131 | 132 | # shape: 133 | # x_y_offset: [13, 13, 1, 2] 134 | # boxes: [N, 13, 13, 3, 4], rescaled to the original image scale 135 | # conf_logits: [N, 13, 13, 3, 1] 136 | # prob_logits: [N, 13, 13, 3, class_num] 137 | return x_y_offset, boxes, conf_logits, prob_logits 138 | 139 | 140 | def predict(self, feature_maps): 141 | ''' 142 | Receive the returned feature_maps from `forward` function, 143 | the produce the output predictions at the test stage. 144 | ''' 145 | feature_map_1, feature_map_2, feature_map_3 = feature_maps 146 | 147 | feature_map_anchors = [(feature_map_1, self.anchors[6:9]), 148 | (feature_map_2, self.anchors[3:6]), 149 | (feature_map_3, self.anchors[0:3])] 150 | reorg_results = [self.reorg_layer(feature_map, anchors) for (feature_map, anchors) in feature_map_anchors] 151 | 152 | def _reshape(result): 153 | x_y_offset, boxes, conf_logits, prob_logits = result 154 | grid_size = x_y_offset.get_shape().as_list()[:2] if self.use_static_shape else tf.shape(x_y_offset)[:2] 155 | boxes = tf.reshape(boxes, [-1, grid_size[0] * grid_size[1] * 3, 4]) 156 | conf_logits = tf.reshape(conf_logits, [-1, grid_size[0] * grid_size[1] * 3, 1]) 157 | prob_logits = tf.reshape(prob_logits, [-1, grid_size[0] * grid_size[1] * 3, self.class_num]) 158 | # shape: (take 416*416 input image and feature_map_1 for example) 159 | # boxes: [N, 13*13*3, 4] 160 | # conf_logits: [N, 13*13*3, 1] 161 | # prob_logits: [N, 13*13*3, class_num] 162 | return boxes, conf_logits, prob_logits 163 | 164 | boxes_list, confs_list, probs_list = [], [], [] 165 | for result in reorg_results: 166 | boxes, conf_logits, prob_logits = _reshape(result) 167 | confs = tf.sigmoid(conf_logits) 168 | probs = tf.sigmoid(prob_logits) 169 | boxes_list.append(boxes) 170 | confs_list.append(confs) 171 | probs_list.append(probs) 172 | 173 | # collect results on three scales 174 | # take 416*416 input image for example: 175 | # shape: [N, (13*13+26*26+52*52)*3, 4] 176 | boxes = tf.concat(boxes_list, axis=1) 177 | # shape: [N, (13*13+26*26+52*52)*3, 1] 178 | confs = tf.concat(confs_list, axis=1) 179 | # shape: [N, (13*13+26*26+52*52)*3, class_num] 180 | probs = tf.concat(probs_list, axis=1) 181 | 182 | center_x, center_y, width, height = tf.split(boxes, [1, 1, 1, 1], axis=-1) 183 | x_min = center_x - width / 2 184 | y_min = center_y - height / 2 185 | x_max = center_x + width / 2 186 | y_max = center_y + height / 2 187 | 188 | boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1) 189 | 190 | return boxes, confs, probs 191 | 192 | def loss_layer(self, feature_map_i, y_true, anchors): 193 | ''' 194 | calc loss function from a certain scale 195 | input: 196 | feature_map_i: feature maps of a certain scale. shape: [N, 13, 13, 3*(5 + num_class)] etc. 197 | y_true: y_ture from a certain scale. shape: [N, 13, 13, 3, 5 + num_class + 1] etc. 198 | anchors: shape [9, 2] 199 | ''' 200 | 201 | # size in [h, w] format! don't get messed up! 202 | grid_size = tf.shape(feature_map_i)[1:3] 203 | # the downscale ratio in height and weight 204 | ratio = tf.cast(self.img_size / grid_size, tf.float32) 205 | # N: batch_size 206 | N = tf.cast(tf.shape(feature_map_i)[0], tf.float32) 207 | 208 | x_y_offset, pred_boxes, pred_conf_logits, pred_prob_logits = self.reorg_layer(feature_map_i, anchors) 209 | 210 | ########### 211 | # get mask 212 | ########### 213 | 214 | # shape: take 416x416 input image and 13*13 feature_map for example: 215 | # [N, 13, 13, 3, 1] 216 | object_mask = y_true[..., 4:5] 217 | 218 | # the calculation of ignore mask if referred from 219 | # https://github.com/pjreddie/darknet/blob/master/src/yolo_layer.c#L179 220 | ignore_mask = tf.TensorArray(tf.float32, size=0, dynamic_size=True) 221 | def loop_cond(idx, ignore_mask): 222 | return tf.less(idx, tf.cast(N, tf.int32)) 223 | def loop_body(idx, ignore_mask): 224 | # shape: [13, 13, 3, 4] & [13, 13, 3] ==> [V, 4] 225 | # V: num of true gt box of each image in a batch 226 | valid_true_boxes = tf.boolean_mask(y_true[idx, ..., 0:4], tf.cast(object_mask[idx, ..., 0], 'bool')) 227 | # shape: [13, 13, 3, 4] & [V, 4] ==> [13, 13, 3, V] 228 | iou = self.box_iou(pred_boxes[idx], valid_true_boxes) 229 | # shape: [13, 13, 3] 230 | best_iou = tf.reduce_max(iou, axis=-1) 231 | # shape: [13, 13, 3] 232 | ignore_mask_tmp = tf.cast(best_iou < 0.5, tf.float32) 233 | # finally will be shape: [N, 13, 13, 3] 234 | ignore_mask = ignore_mask.write(idx, ignore_mask_tmp) 235 | return idx + 1, ignore_mask 236 | _, ignore_mask = tf.while_loop(cond=loop_cond, body=loop_body, loop_vars=[0, ignore_mask]) 237 | ignore_mask = ignore_mask.stack() 238 | # shape: [N, 13, 13, 3, 1] 239 | ignore_mask = tf.expand_dims(ignore_mask, -1) 240 | 241 | # shape: [N, 13, 13, 3, 2] 242 | pred_box_xy = pred_boxes[..., 0:2] 243 | pred_box_wh = pred_boxes[..., 2:4] 244 | 245 | # get xy coordinates in one cell from the feature_map 246 | # numerical range: 0 ~ 1 247 | # shape: [N, 13, 13, 3, 2] 248 | true_xy = y_true[..., 0:2] / ratio[::-1] - x_y_offset 249 | pred_xy = pred_box_xy / ratio[::-1] - x_y_offset 250 | 251 | # get_tw_th 252 | # numerical range: 0 ~ 1 253 | # shape: [N, 13, 13, 3, 2] 254 | true_tw_th = y_true[..., 2:4] / anchors 255 | pred_tw_th = pred_box_wh / anchors 256 | # for numerical stability 257 | true_tw_th = tf.where(condition=tf.equal(true_tw_th, 0), 258 | x=tf.ones_like(true_tw_th), y=true_tw_th) 259 | pred_tw_th = tf.where(condition=tf.equal(pred_tw_th, 0), 260 | x=tf.ones_like(pred_tw_th), y=pred_tw_th) 261 | true_tw_th = tf.log(tf.clip_by_value(true_tw_th, 1e-9, 1e9)) 262 | pred_tw_th = tf.log(tf.clip_by_value(pred_tw_th, 1e-9, 1e9)) 263 | 264 | # box size punishment: 265 | # box with smaller area has bigger weight. This is taken from the yolo darknet C source code. 266 | # shape: [N, 13, 13, 3, 1] 267 | box_loss_scale = 2. - (y_true[..., 2:3] / tf.cast(self.img_size[1], tf.float32)) * (y_true[..., 3:4] / tf.cast(self.img_size[0], tf.float32)) 268 | 269 | ############ 270 | # loss_part 271 | ############ 272 | # mix_up weight 273 | # [N, 13, 13, 3, 1] 274 | mix_w = y_true[..., -1:] 275 | # shape: [N, 13, 13, 3, 1] 276 | xy_loss = tf.reduce_sum(tf.square(true_xy - pred_xy) * object_mask * box_loss_scale * mix_w) / N 277 | wh_loss = tf.reduce_sum(tf.square(true_tw_th - pred_tw_th) * object_mask * box_loss_scale * mix_w) / N 278 | 279 | # shape: [N, 13, 13, 3, 1] 280 | conf_pos_mask = object_mask 281 | conf_neg_mask = (1 - object_mask) * ignore_mask 282 | conf_loss_pos = conf_pos_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask, logits=pred_conf_logits) 283 | conf_loss_neg = conf_neg_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask, logits=pred_conf_logits) 284 | # TODO: may need to balance the pos-neg by multiplying some weights 285 | conf_loss = conf_loss_pos + conf_loss_neg 286 | if self.use_focal_loss: 287 | alpha = 1.0 288 | gamma = 2.0 289 | # TODO: alpha should be a mask array if needed 290 | focal_mask = alpha * tf.pow(tf.abs(object_mask - tf.sigmoid(pred_conf_logits)), gamma) 291 | conf_loss *= focal_mask 292 | conf_loss = tf.reduce_sum(conf_loss * mix_w) / N 293 | 294 | # shape: [N, 13, 13, 3, 1] 295 | # whether to use label smooth 296 | if self.use_label_smooth: 297 | delta = 0.01 298 | label_target = (1 - delta) * y_true[..., 5:-1] + delta * 1. / self.class_num 299 | else: 300 | label_target = y_true[..., 5:-1] 301 | class_loss = object_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_target, logits=pred_prob_logits) * mix_w 302 | class_loss = tf.reduce_sum(class_loss) / N 303 | 304 | return xy_loss, wh_loss, conf_loss, class_loss 305 | 306 | 307 | def box_iou(self, pred_boxes, valid_true_boxes): 308 | ''' 309 | param: 310 | pred_boxes: [13, 13, 3, 4], (center_x, center_y, w, h) 311 | valid_true: [V, 4] 312 | ''' 313 | 314 | # [13, 13, 3, 2] 315 | pred_box_xy = pred_boxes[..., 0:2] 316 | pred_box_wh = pred_boxes[..., 2:4] 317 | 318 | # shape: [13, 13, 3, 1, 2] 319 | pred_box_xy = tf.expand_dims(pred_box_xy, -2) 320 | pred_box_wh = tf.expand_dims(pred_box_wh, -2) 321 | 322 | # [V, 2] 323 | true_box_xy = valid_true_boxes[:, 0:2] 324 | true_box_wh = valid_true_boxes[:, 2:4] 325 | 326 | # [13, 13, 3, 1, 2] & [V, 2] ==> [13, 13, 3, V, 2] 327 | intersect_mins = tf.maximum(pred_box_xy - pred_box_wh / 2., 328 | true_box_xy - true_box_wh / 2.) 329 | intersect_maxs = tf.minimum(pred_box_xy + pred_box_wh / 2., 330 | true_box_xy + true_box_wh / 2.) 331 | intersect_wh = tf.maximum(intersect_maxs - intersect_mins, 0.) 332 | 333 | # shape: [13, 13, 3, V] 334 | intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 335 | # shape: [13, 13, 3, 1] 336 | pred_box_area = pred_box_wh[..., 0] * pred_box_wh[..., 1] 337 | # shape: [V] 338 | true_box_area = true_box_wh[..., 0] * true_box_wh[..., 1] 339 | # shape: [1, V] 340 | true_box_area = tf.expand_dims(true_box_area, axis=0) 341 | 342 | # [13, 13, 3, V] 343 | iou = intersect_area / (pred_box_area + true_box_area - intersect_area + 1e-10) 344 | 345 | return iou 346 | 347 | 348 | def compute_loss(self, y_pred, y_true): 349 | ''' 350 | param: 351 | y_pred: returned feature_map list by `forward` function: [feature_map_1, feature_map_2, feature_map_3] 352 | y_true: input y_true by the tf.data pipeline 353 | ''' 354 | loss_xy, loss_wh, loss_conf, loss_class = 0., 0., 0., 0. 355 | anchor_group = [self.anchors[6:9], self.anchors[3:6], self.anchors[0:3]] 356 | 357 | # calc loss in 3 scales 358 | for i in range(len(y_pred)): 359 | result = self.loss_layer(y_pred[i], y_true[i], anchor_group[i]) 360 | loss_xy += result[0] 361 | loss_wh += result[1] 362 | loss_conf += result[2] 363 | loss_class += result[3] 364 | total_loss = loss_xy + loss_wh + loss_conf + loss_class 365 | return [total_loss, loss_xy, loss_wh, loss_conf, loss_class] 366 | -------------------------------------------------------------------------------- /test_single_image.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import argparse 8 | import cv2 9 | 10 | from utils.misc_utils import parse_anchors, read_class_names 11 | from utils.nms_utils import gpu_nms 12 | from utils.plot_utils import get_color_table, plot_one_box 13 | from utils.data_aug import letterbox_resize 14 | 15 | from model import yolov3 16 | 17 | parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.") 18 | parser.add_argument("input_image", type=str, 19 | help="The path of the input image.") 20 | parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", 21 | help="The path of the anchor txt file.") 22 | parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], 23 | help="Resize the input image with `new_size`, size format: [width, height]") 24 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, 25 | help="Whether to use the letterbox resize.") 26 | parser.add_argument("--class_name_path", type=str, default="./data/coco.names", 27 | help="The path of the class names.") 28 | parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt", 29 | help="The path of the weights to restore.") 30 | args = parser.parse_args() 31 | 32 | args.anchors = parse_anchors(args.anchor_path) 33 | args.classes = read_class_names(args.class_name_path) 34 | args.num_class = len(args.classes) 35 | 36 | color_table = get_color_table(args.num_class) 37 | 38 | img_ori = cv2.imread(args.input_image) 39 | if args.letterbox_resize: 40 | img, resize_ratio, dw, dh = letterbox_resize(img_ori, args.new_size[0], args.new_size[1]) 41 | else: 42 | height_ori, width_ori = img_ori.shape[:2] 43 | img = cv2.resize(img_ori, tuple(args.new_size)) 44 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 45 | img = np.asarray(img, np.float32) 46 | img = img[np.newaxis, :] / 255. 47 | 48 | with tf.Session() as sess: 49 | input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data') 50 | yolo_model = yolov3(args.num_class, args.anchors) 51 | with tf.variable_scope('yolov3'): 52 | pred_feature_maps = yolo_model.forward(input_data, False) 53 | pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps) 54 | 55 | pred_scores = pred_confs * pred_probs 56 | 57 | boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=200, score_thresh=0.3, nms_thresh=0.45) 58 | 59 | saver = tf.train.Saver() 60 | saver.restore(sess, args.restore_path) 61 | 62 | boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img}) 63 | 64 | # rescale the coordinates to the original image 65 | if args.letterbox_resize: 66 | boxes_[:, [0, 2]] = (boxes_[:, [0, 2]] - dw) / resize_ratio 67 | boxes_[:, [1, 3]] = (boxes_[:, [1, 3]] - dh) / resize_ratio 68 | else: 69 | boxes_[:, [0, 2]] *= (width_ori/float(args.new_size[0])) 70 | boxes_[:, [1, 3]] *= (height_ori/float(args.new_size[1])) 71 | 72 | print("box coords:") 73 | print(boxes_) 74 | print('*' * 30) 75 | print("scores:") 76 | print(scores_) 77 | print('*' * 30) 78 | print("labels:") 79 | print(labels_) 80 | 81 | for i in range(len(boxes_)): 82 | x0, y0, x1, y1 = boxes_[i] 83 | plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]] + ', {:.2f}%'.format(scores_[i] * 100), color=color_table[labels_[i]]) 84 | cv2.imshow('Detection result', img_ori) 85 | cv2.imwrite('detection_result.jpg', img_ori) 86 | cv2.waitKey(0) 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import logging 8 | from tqdm import trange 9 | 10 | import args 11 | 12 | from utils.data_utils import get_batch_data 13 | from utils.misc_utils import shuffle_and_overwrite, make_summary, config_learning_rate, config_optimizer, AverageMeter 14 | from utils.eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec 15 | from utils.nms_utils import gpu_nms 16 | 17 | from model import yolov3 18 | 19 | # setting loggers 20 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s', 21 | datefmt='%a, %d %b %Y %H:%M:%S', filename=args.progress_log_path, filemode='w') 22 | 23 | # setting placeholders 24 | is_training = tf.placeholder(tf.bool, name="phase_train") 25 | handle_flag = tf.placeholder(tf.string, [], name='iterator_handle_flag') 26 | # register the gpu nms operation here for the following evaluation scheme 27 | pred_boxes_flag = tf.placeholder(tf.float32, [1, None, None]) 28 | pred_scores_flag = tf.placeholder(tf.float32, [1, None, None]) 29 | gpu_nms_op = gpu_nms(pred_boxes_flag, pred_scores_flag, args.class_num, args.nms_topk, args.score_threshold, args.nms_threshold) 30 | 31 | ################## 32 | # tf.data pipeline 33 | ################## 34 | train_dataset = tf.data.TextLineDataset(args.train_file) 35 | train_dataset = train_dataset.shuffle(args.train_img_cnt) 36 | train_dataset = train_dataset.batch(args.batch_size) 37 | train_dataset = train_dataset.map( 38 | lambda x: tf.py_func(get_batch_data, 39 | inp=[x, args.class_num, args.img_size, args.anchors, 'train', args.multi_scale_train, args.use_mix_up, args.letterbox_resize], 40 | Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), 41 | num_parallel_calls=args.num_threads 42 | ) 43 | train_dataset = train_dataset.prefetch(args.prefetech_buffer) 44 | 45 | val_dataset = tf.data.TextLineDataset(args.val_file) 46 | val_dataset = val_dataset.batch(1) 47 | val_dataset = val_dataset.map( 48 | lambda x: tf.py_func(get_batch_data, 49 | inp=[x, args.class_num, args.img_size, args.anchors, 'val', False, False, args.letterbox_resize], 50 | Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]), 51 | num_parallel_calls=args.num_threads 52 | ) 53 | val_dataset.prefetch(args.prefetech_buffer) 54 | 55 | iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes) 56 | train_init_op = iterator.make_initializer(train_dataset) 57 | val_init_op = iterator.make_initializer(val_dataset) 58 | 59 | # get an element from the chosen dataset iterator 60 | image_ids, image, y_true_13, y_true_26, y_true_52 = iterator.get_next() 61 | y_true = [y_true_13, y_true_26, y_true_52] 62 | 63 | # tf.data pipeline will lose the data `static` shape, so we need to set it manually 64 | image_ids.set_shape([None]) 65 | image.set_shape([None, None, None, 3]) 66 | for y in y_true: 67 | y.set_shape([None, None, None, None, None]) 68 | 69 | ################## 70 | # Model definition 71 | ################## 72 | yolo_model = yolov3(args.class_num, args.anchors, args.use_label_smooth, args.use_focal_loss, args.batch_norm_decay, args.weight_decay, use_static_shape=False) 73 | with tf.variable_scope('yolov3'): 74 | pred_feature_maps = yolo_model.forward(image, is_training=is_training) 75 | loss = yolo_model.compute_loss(pred_feature_maps, y_true) 76 | y_pred = yolo_model.predict(pred_feature_maps) 77 | 78 | l2_loss = tf.losses.get_regularization_loss() 79 | 80 | # setting restore parts and vars to update 81 | saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=args.restore_include, exclude=args.restore_exclude)) 82 | update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part) 83 | 84 | tf.summary.scalar('train_batch_statistics/total_loss', loss[0]) 85 | tf.summary.scalar('train_batch_statistics/loss_xy', loss[1]) 86 | tf.summary.scalar('train_batch_statistics/loss_wh', loss[2]) 87 | tf.summary.scalar('train_batch_statistics/loss_conf', loss[3]) 88 | tf.summary.scalar('train_batch_statistics/loss_class', loss[4]) 89 | tf.summary.scalar('train_batch_statistics/loss_l2', l2_loss) 90 | tf.summary.scalar('train_batch_statistics/loss_ratio', l2_loss / loss[0]) 91 | 92 | global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) 93 | if args.use_warm_up: 94 | learning_rate = tf.cond(tf.less(global_step, args.train_batch_num * args.warm_up_epoch), 95 | lambda: args.learning_rate_init * global_step / (args.train_batch_num * args.warm_up_epoch), 96 | lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch)) 97 | else: 98 | learning_rate = config_learning_rate(args, global_step) 99 | tf.summary.scalar('learning_rate', learning_rate) 100 | 101 | if not args.save_optimizer: 102 | saver_to_save = tf.train.Saver() 103 | saver_best = tf.train.Saver() 104 | 105 | optimizer = config_optimizer(args.optimizer_name, learning_rate) 106 | 107 | # set dependencies for BN ops 108 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 109 | with tf.control_dependencies(update_ops): 110 | # train_op = optimizer.minimize(loss[0] + l2_loss, var_list=update_vars, global_step=global_step) 111 | # apply gradient clip to avoid gradient exploding 112 | gvs = optimizer.compute_gradients(loss[0] + l2_loss, var_list=update_vars) 113 | clip_grad_var = [gv if gv[0] is None else [ 114 | tf.clip_by_norm(gv[0], 100.), gv[1]] for gv in gvs] 115 | train_op = optimizer.apply_gradients(clip_grad_var, global_step=global_step) 116 | 117 | if args.save_optimizer: 118 | print('Saving optimizer parameters to checkpoint! Remember to restore the global_step in the fine-tuning afterwards.') 119 | saver_to_save = tf.train.Saver() 120 | saver_best = tf.train.Saver() 121 | 122 | with tf.Session() as sess: 123 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 124 | saver_to_restore.restore(sess, args.restore_path) 125 | merged = tf.summary.merge_all() 126 | writer = tf.summary.FileWriter(args.log_dir, sess.graph) 127 | 128 | print('\n----------- start to train -----------\n') 129 | 130 | best_mAP = -np.Inf 131 | 132 | for epoch in range(args.total_epoches): 133 | 134 | sess.run(train_init_op) 135 | loss_total, loss_xy, loss_wh, loss_conf, loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 136 | 137 | for i in trange(args.train_batch_num): 138 | _, summary, __y_pred, __y_true, __loss, __global_step, __lr = sess.run( 139 | [train_op, merged, y_pred, y_true, loss, global_step, learning_rate], 140 | feed_dict={is_training: True}) 141 | 142 | writer.add_summary(summary, global_step=__global_step) 143 | 144 | loss_total.update(__loss[0], len(__y_pred[0])) 145 | loss_xy.update(__loss[1], len(__y_pred[0])) 146 | loss_wh.update(__loss[2], len(__y_pred[0])) 147 | loss_conf.update(__loss[3], len(__y_pred[0])) 148 | loss_class.update(__loss[4], len(__y_pred[0])) 149 | 150 | if __global_step % args.train_evaluation_step == 0 and __global_step > 0: 151 | # recall, precision = evaluate_on_cpu(__y_pred, __y_true, args.class_num, args.nms_topk, args.score_threshold, args.nms_threshold) 152 | recall, precision = evaluate_on_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __y_pred, __y_true, args.class_num, args.nms_threshold) 153 | 154 | info = "Epoch: {}, global_step: {} | loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f} | ".format( 155 | epoch, int(__global_step), loss_total.average, loss_xy.average, loss_wh.average, loss_conf.average, loss_class.average) 156 | info += 'Last batch: rec: {:.3f}, prec: {:.3f} | lr: {:.5g}'.format(recall, precision, __lr) 157 | print(info) 158 | logging.info(info) 159 | 160 | writer.add_summary(make_summary('evaluation/train_batch_recall', recall), global_step=__global_step) 161 | writer.add_summary(make_summary('evaluation/train_batch_precision', precision), global_step=__global_step) 162 | 163 | if np.isnan(loss_total.average): 164 | print('****' * 10) 165 | raise ArithmeticError( 166 | 'Gradient exploded! Please train again and you may need modify some parameters.') 167 | 168 | # NOTE: this is just demo. You can set the conditions when to save the weights. 169 | if epoch % args.save_epoch == 0 and epoch > 0: 170 | if loss_total.average <= 2.: 171 | saver_to_save.save(sess, args.save_dir + 'model-epoch_{}_step_{}_loss_{:.4f}_lr_{:.5g}'.format(epoch, int(__global_step), loss_total.average, __lr)) 172 | 173 | # switch to validation dataset for evaluation 174 | if epoch % args.val_evaluation_epoch == 0 and epoch >= args.warm_up_epoch: 175 | sess.run(val_init_op) 176 | 177 | val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = \ 178 | AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 179 | 180 | val_preds = [] 181 | 182 | for j in trange(args.val_img_cnt): 183 | __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], 184 | feed_dict={is_training: False}) 185 | pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) 186 | val_preds.extend(pred_content) 187 | val_loss_total.update(__loss[0]) 188 | val_loss_xy.update(__loss[1]) 189 | val_loss_wh.update(__loss[2]) 190 | val_loss_conf.update(__loss[3]) 191 | val_loss_class.update(__loss[4]) 192 | 193 | # calc mAP 194 | rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() 195 | gt_dict = parse_gt_rec(args.val_file, args.img_size, args.letterbox_resize) 196 | 197 | info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) 198 | 199 | for ii in range(args.class_num): 200 | npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=args.use_voc_07_metric) 201 | info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) 202 | rec_total.update(rec, npos) 203 | prec_total.update(prec, nd) 204 | ap_total.update(ap, 1) 205 | 206 | mAP = ap_total.average 207 | info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP) 208 | info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format( 209 | val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) 210 | print(info) 211 | logging.info(info) 212 | 213 | if mAP > best_mAP: 214 | best_mAP = mAP 215 | saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( 216 | epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) 217 | 218 | writer.add_summary(make_summary('evaluation/val_mAP', mAP), global_step=epoch) 219 | writer.add_summary(make_summary('evaluation/val_recall', rec_total.average), global_step=epoch) 220 | writer.add_summary(make_summary('evaluation/val_precision', prec_total.average), global_step=epoch) 221 | writer.add_summary(make_summary('validation_statistics/total_loss', val_loss_total.average), global_step=epoch) 222 | writer.add_summary(make_summary('validation_statistics/loss_xy', val_loss_xy.average), global_step=epoch) 223 | writer.add_summary(make_summary('validation_statistics/loss_wh', val_loss_wh.average), global_step=epoch) 224 | writer.add_summary(make_summary('validation_statistics/loss_conf', val_loss_conf.average), global_step=epoch) 225 | writer.add_summary(make_summary('validation_statistics/loss_class', val_loss_class.average), global_step=epoch) 226 | 227 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wizyoung/YOLOv3_TensorFlow/8776cf7b2531cae83f5fc730f3c70ae97919bfd6/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_aug.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # part of this is take from Gluon's repo: 3 | # https://github.com/dmlc/gluon-cv/blob/master/gluoncv/data/transforms/presets/yolo.py 4 | 5 | from __future__ import division, print_function 6 | 7 | import random 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | def mix_up(img1, img2, bbox1, bbox2): 13 | ''' 14 | return: 15 | mix_img: HWC format mix up image 16 | mix_bbox: [N, 5] shape mix up bbox, i.e. `x_min, y_min, x_max, y_mix, mixup_weight`. 17 | ''' 18 | height = max(img1.shape[0], img2.shape[0]) 19 | width = max(img1.shape[1], img2.shape[1]) 20 | 21 | mix_img = np.zeros(shape=(height, width, 3), dtype='float32') 22 | 23 | # rand_num = np.random.random() 24 | rand_num = np.random.beta(1.5, 1.5) 25 | rand_num = max(0, min(1, rand_num)) 26 | mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32') * rand_num 27 | mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype('float32') * (1. - rand_num) 28 | 29 | mix_img = mix_img.astype('uint8') 30 | 31 | # the last element of the 2nd dimention is the mix up weight 32 | bbox1 = np.concatenate((bbox1, np.full(shape=(bbox1.shape[0], 1), fill_value=rand_num)), axis=-1) 33 | bbox2 = np.concatenate((bbox2, np.full(shape=(bbox2.shape[0], 1), fill_value=1. - rand_num)), axis=-1) 34 | mix_bbox = np.concatenate((bbox1, bbox2), axis=0) 35 | 36 | return mix_img, mix_bbox 37 | 38 | 39 | def bbox_crop(bbox, crop_box=None, allow_outside_center=True): 40 | """Crop bounding boxes according to slice area. 41 | This method is mainly used with image cropping to ensure bonding boxes fit 42 | within the cropped image. 43 | Parameters 44 | ---------- 45 | bbox : numpy.ndarray 46 | Numpy.ndarray with shape (N, 4+) where N is the number of bounding boxes. 47 | The second axis represents attributes of the bounding box. 48 | Specifically, these are :math:`(x_{min}, y_{min}, x_{max}, y_{max})`, 49 | we allow additional attributes other than coordinates, which stay intact 50 | during bounding box transformations. 51 | crop_box : tuple 52 | Tuple of length 4. :math:`(x_{min}, y_{min}, width, height)` 53 | allow_outside_center : bool 54 | If `False`, remove bounding boxes which have centers outside cropping area. 55 | Returns 56 | ------- 57 | numpy.ndarray 58 | Cropped bounding boxes with shape (M, 4+) where M <= N. 59 | """ 60 | bbox = bbox.copy() 61 | if crop_box is None: 62 | return bbox 63 | if not len(crop_box) == 4: 64 | raise ValueError( 65 | "Invalid crop_box parameter, requires length 4, given {}".format(str(crop_box))) 66 | if sum([int(c is None) for c in crop_box]) == 4: 67 | return bbox 68 | 69 | l, t, w, h = crop_box 70 | 71 | left = l if l else 0 72 | top = t if t else 0 73 | right = left + (w if w else np.inf) 74 | bottom = top + (h if h else np.inf) 75 | crop_bbox = np.array((left, top, right, bottom)) 76 | 77 | if allow_outside_center: 78 | mask = np.ones(bbox.shape[0], dtype=bool) 79 | else: 80 | centers = (bbox[:, :2] + bbox[:, 2:4]) / 2 81 | mask = np.logical_and(crop_bbox[:2] <= centers, centers < crop_bbox[2:]).all(axis=1) 82 | 83 | # transform borders 84 | bbox[:, :2] = np.maximum(bbox[:, :2], crop_bbox[:2]) 85 | bbox[:, 2:4] = np.minimum(bbox[:, 2:4], crop_bbox[2:4]) 86 | bbox[:, :2] -= crop_bbox[:2] 87 | bbox[:, 2:4] -= crop_bbox[:2] 88 | 89 | mask = np.logical_and(mask, (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)) 90 | bbox = bbox[mask] 91 | return bbox 92 | 93 | def bbox_iou(bbox_a, bbox_b, offset=0): 94 | """Calculate Intersection-Over-Union(IOU) of two bounding boxes. 95 | Parameters 96 | ---------- 97 | bbox_a : numpy.ndarray 98 | An ndarray with shape :math:`(N, 4)`. 99 | bbox_b : numpy.ndarray 100 | An ndarray with shape :math:`(M, 4)`. 101 | offset : float or int, default is 0 102 | The ``offset`` is used to control the whether the width(or height) is computed as 103 | (right - left + ``offset``). 104 | Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. 105 | Returns 106 | ------- 107 | numpy.ndarray 108 | An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of 109 | bounding boxes in `bbox_a` and `bbox_b`. 110 | """ 111 | if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: 112 | raise IndexError("Bounding boxes axis 1 must have at least length 4") 113 | 114 | tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) 115 | br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) 116 | 117 | area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) 118 | area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) 119 | area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) 120 | return area_i / (area_a[:, None] + area_b - area_i) 121 | 122 | 123 | def random_crop_with_constraints(bbox, size, min_scale=0.3, max_scale=1, 124 | max_aspect_ratio=2, constraints=None, 125 | max_trial=50): 126 | """Crop an image randomly with bounding box constraints. 127 | This data augmentation is used in training of 128 | Single Shot Multibox Detector [#]_. More details can be found in 129 | data augmentation section of the original paper. 130 | .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, 131 | Scott Reed, Cheng-Yang Fu, Alexander C. Berg. 132 | SSD: Single Shot MultiBox Detector. ECCV 2016. 133 | Parameters 134 | ---------- 135 | bbox : numpy.ndarray 136 | Numpy.ndarray with shape (N, 4+) where N is the number of bounding boxes. 137 | The second axis represents attributes of the bounding box. 138 | Specifically, these are :math:`(x_{min}, y_{min}, x_{max}, y_{max})`, 139 | we allow additional attributes other than coordinates, which stay intact 140 | during bounding box transformations. 141 | size : tuple 142 | Tuple of length 2 of image shape as (width, height). 143 | min_scale : float 144 | The minimum ratio between a cropped region and the original image. 145 | The default value is :obj:`0.3`. 146 | max_scale : float 147 | The maximum ratio between a cropped region and the original image. 148 | The default value is :obj:`1`. 149 | max_aspect_ratio : float 150 | The maximum aspect ratio of cropped region. 151 | The default value is :obj:`2`. 152 | constraints : iterable of tuples 153 | An iterable of constraints. 154 | Each constraint should be :obj:`(min_iou, max_iou)` format. 155 | If means no constraint if set :obj:`min_iou` or :obj:`max_iou` to :obj:`None`. 156 | If this argument defaults to :obj:`None`, :obj:`((0.1, None), (0.3, None), 157 | (0.5, None), (0.7, None), (0.9, None), (None, 1))` will be used. 158 | max_trial : int 159 | Maximum number of trials for each constraint before exit no matter what. 160 | Returns 161 | ------- 162 | numpy.ndarray 163 | Cropped bounding boxes with shape :obj:`(M, 4+)` where M <= N. 164 | tuple 165 | Tuple of length 4 as (x_offset, y_offset, new_width, new_height). 166 | """ 167 | # default params in paper 168 | if constraints is None: 169 | constraints = ( 170 | (0.1, None), 171 | (0.3, None), 172 | (0.5, None), 173 | (0.7, None), 174 | (0.9, None), 175 | (None, 1), 176 | ) 177 | 178 | w, h = size 179 | 180 | candidates = [(0, 0, w, h)] 181 | for min_iou, max_iou in constraints: 182 | min_iou = -np.inf if min_iou is None else min_iou 183 | max_iou = np.inf if max_iou is None else max_iou 184 | 185 | for _ in range(max_trial): 186 | scale = random.uniform(min_scale, max_scale) 187 | aspect_ratio = random.uniform( 188 | max(1 / max_aspect_ratio, scale * scale), 189 | min(max_aspect_ratio, 1 / (scale * scale))) 190 | crop_h = int(h * scale / np.sqrt(aspect_ratio)) 191 | crop_w = int(w * scale * np.sqrt(aspect_ratio)) 192 | 193 | crop_t = random.randrange(h - crop_h) 194 | crop_l = random.randrange(w - crop_w) 195 | crop_bb = np.array((crop_l, crop_t, crop_l + crop_w, crop_t + crop_h)) 196 | 197 | if len(bbox) == 0: 198 | top, bottom = crop_t, crop_t + crop_h 199 | left, right = crop_l, crop_l + crop_w 200 | return bbox, (left, top, right-left, bottom-top) 201 | 202 | iou = bbox_iou(bbox, crop_bb[np.newaxis]) 203 | if min_iou <= iou.min() and iou.max() <= max_iou: 204 | top, bottom = crop_t, crop_t + crop_h 205 | left, right = crop_l, crop_l + crop_w 206 | candidates.append((left, top, right-left, bottom-top)) 207 | break 208 | 209 | # random select one 210 | while candidates: 211 | crop = candidates.pop(np.random.randint(0, len(candidates))) 212 | new_bbox = bbox_crop(bbox, crop, allow_outside_center=False) 213 | if new_bbox.size < 1: 214 | continue 215 | new_crop = (crop[0], crop[1], crop[2], crop[3]) 216 | return new_bbox, new_crop 217 | return bbox, (0, 0, w, h) 218 | 219 | 220 | def random_color_distort(img, brightness_delta=32, hue_vari=18, sat_vari=0.5, val_vari=0.5): 221 | ''' 222 | randomly distort image color. Adjust brightness, hue, saturation, value. 223 | param: 224 | img: a BGR uint8 format OpenCV image. HWC format. 225 | ''' 226 | 227 | def random_hue(img_hsv, hue_vari, p=0.5): 228 | if np.random.uniform(0, 1) > p: 229 | hue_delta = np.random.randint(-hue_vari, hue_vari) 230 | img_hsv[:, :, 0] = (img_hsv[:, :, 0] + hue_delta) % 180 231 | return img_hsv 232 | 233 | def random_saturation(img_hsv, sat_vari, p=0.5): 234 | if np.random.uniform(0, 1) > p: 235 | sat_mult = 1 + np.random.uniform(-sat_vari, sat_vari) 236 | img_hsv[:, :, 1] *= sat_mult 237 | return img_hsv 238 | 239 | def random_value(img_hsv, val_vari, p=0.5): 240 | if np.random.uniform(0, 1) > p: 241 | val_mult = 1 + np.random.uniform(-val_vari, val_vari) 242 | img_hsv[:, :, 2] *= val_mult 243 | return img_hsv 244 | 245 | def random_brightness(img, brightness_delta, p=0.5): 246 | if np.random.uniform(0, 1) > p: 247 | img = img.astype(np.float32) 248 | brightness_delta = int(np.random.uniform(-brightness_delta, brightness_delta)) 249 | img = img + brightness_delta 250 | return np.clip(img, 0, 255) 251 | 252 | # brightness 253 | img = random_brightness(img, brightness_delta) 254 | img = img.astype(np.uint8) 255 | 256 | # color jitter 257 | img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32) 258 | 259 | if np.random.randint(0, 2): 260 | img_hsv = random_value(img_hsv, val_vari) 261 | img_hsv = random_saturation(img_hsv, sat_vari) 262 | img_hsv = random_hue(img_hsv, hue_vari) 263 | else: 264 | img_hsv = random_saturation(img_hsv, sat_vari) 265 | img_hsv = random_hue(img_hsv, hue_vari) 266 | img_hsv = random_value(img_hsv, val_vari) 267 | 268 | img_hsv = np.clip(img_hsv, 0, 255) 269 | img = cv2.cvtColor(img_hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) 270 | 271 | return img 272 | 273 | 274 | def letterbox_resize(img, new_width, new_height, interp=0): 275 | ''' 276 | Letterbox resize. keep the original aspect ratio in the resized image. 277 | ''' 278 | ori_height, ori_width = img.shape[:2] 279 | 280 | resize_ratio = min(new_width / ori_width, new_height / ori_height) 281 | 282 | resize_w = int(resize_ratio * ori_width) 283 | resize_h = int(resize_ratio * ori_height) 284 | 285 | img = cv2.resize(img, (resize_w, resize_h), interpolation=interp) 286 | image_padded = np.full((new_height, new_width, 3), 128, np.uint8) 287 | 288 | dw = int((new_width - resize_w) / 2) 289 | dh = int((new_height - resize_h) / 2) 290 | 291 | image_padded[dh: resize_h + dh, dw: resize_w + dw, :] = img 292 | 293 | return image_padded, resize_ratio, dw, dh 294 | 295 | 296 | def resize_with_bbox(img, bbox, new_width, new_height, interp=0, letterbox=False): 297 | ''' 298 | Resize the image and correct the bbox accordingly. 299 | ''' 300 | 301 | if letterbox: 302 | image_padded, resize_ratio, dw, dh = letterbox_resize(img, new_width, new_height, interp) 303 | 304 | # xmin, xmax 305 | bbox[:, [0, 2]] = bbox[:, [0, 2]] * resize_ratio + dw 306 | # ymin, ymax 307 | bbox[:, [1, 3]] = bbox[:, [1, 3]] * resize_ratio + dh 308 | 309 | return image_padded, bbox 310 | else: 311 | ori_height, ori_width = img.shape[:2] 312 | 313 | img = cv2.resize(img, (new_width, new_height), interpolation=interp) 314 | 315 | # xmin, xmax 316 | bbox[:, [0, 2]] = bbox[:, [0, 2]] / ori_width * new_width 317 | # ymin, ymax 318 | bbox[:, [1, 3]] = bbox[:, [1, 3]] / ori_height * new_height 319 | 320 | return img, bbox 321 | 322 | 323 | def random_flip(img, bbox, px=0, py=0): 324 | ''' 325 | Randomly flip the image and correct the bbox. 326 | param: 327 | px: 328 | the probability of horizontal flip 329 | py: 330 | the probability of vertical flip 331 | ''' 332 | height, width = img.shape[:2] 333 | if np.random.uniform(0, 1) < px: 334 | img = cv2.flip(img, 1) 335 | xmax = width - bbox[:, 0] 336 | xmin = width - bbox[:, 2] 337 | bbox[:, 0] = xmin 338 | bbox[:, 2] = xmax 339 | 340 | if np.random.uniform(0, 1) < py: 341 | img = cv2.flip(img, 0) 342 | ymax = height - bbox[:, 1] 343 | ymin = height - bbox[:, 3] 344 | bbox[:, 1] = ymin 345 | bbox[:, 3] = ymax 346 | return img, bbox 347 | 348 | 349 | def random_expand(img, bbox, max_ratio=4, fill=0, keep_ratio=True): 350 | ''' 351 | Random expand original image with borders, this is identical to placing 352 | the original image on a larger canvas. 353 | param: 354 | max_ratio : 355 | Maximum ratio of the output image on both direction(vertical and horizontal) 356 | fill : 357 | The value(s) for padded borders. 358 | keep_ratio : bool 359 | If `True`, will keep output image the same aspect ratio as input. 360 | ''' 361 | h, w, c = img.shape 362 | ratio_x = random.uniform(1, max_ratio) 363 | if keep_ratio: 364 | ratio_y = ratio_x 365 | else: 366 | ratio_y = random.uniform(1, max_ratio) 367 | 368 | oh, ow = int(h * ratio_y), int(w * ratio_x) 369 | off_y = random.randint(0, oh - h) 370 | off_x = random.randint(0, ow - w) 371 | 372 | dst = np.full(shape=(oh, ow, c), fill_value=fill, dtype=img.dtype) 373 | 374 | dst[off_y:off_y + h, off_x:off_x + w, :] = img 375 | 376 | # correct bbox 377 | bbox[:, :2] += (off_x, off_y) 378 | bbox[:, 2:4] += (off_x, off_y) 379 | 380 | return dst, bbox 381 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import numpy as np 6 | import cv2 7 | import sys 8 | from utils.data_aug import * 9 | import random 10 | 11 | PY_VERSION = sys.version_info[0] 12 | iter_cnt = 0 13 | 14 | 15 | def parse_line(line): 16 | ''' 17 | Given a line from the training/test txt file, return parsed info. 18 | line format: line_index, img_path, img_width, img_height, [box_info_1 (5 number)], ... 19 | return: 20 | line_idx: int64 21 | pic_path: string. 22 | boxes: shape [N, 4], N is the ground truth count, elements in the second 23 | dimension are [x_min, y_min, x_max, y_max] 24 | labels: shape [N]. class index. 25 | img_width: int. 26 | img_height: int 27 | ''' 28 | if 'str' not in str(type(line)): 29 | line = line.decode() 30 | s = line.strip().split(' ') 31 | assert len(s) > 8, 'Annotation error! Please check your annotation file. Make sure there is at least one target object in each image.' 32 | line_idx = int(s[0]) 33 | pic_path = s[1] 34 | img_width = int(s[2]) 35 | img_height = int(s[3]) 36 | s = s[4:] 37 | assert len(s) % 5 == 0, 'Annotation error! Please check your annotation file. Maybe partially missing some coordinates?' 38 | box_cnt = len(s) // 5 39 | boxes = [] 40 | labels = [] 41 | for i in range(box_cnt): 42 | label, x_min, y_min, x_max, y_max = int(s[i * 5]), float(s[i * 5 + 1]), float(s[i * 5 + 2]), float( 43 | s[i * 5 + 3]), float(s[i * 5 + 4]) 44 | boxes.append([x_min, y_min, x_max, y_max]) 45 | labels.append(label) 46 | boxes = np.asarray(boxes, np.float32) 47 | labels = np.asarray(labels, np.int64) 48 | return line_idx, pic_path, boxes, labels, img_width, img_height 49 | 50 | 51 | def process_box(boxes, labels, img_size, class_num, anchors): 52 | ''' 53 | Generate the y_true label, i.e. the ground truth feature_maps in 3 different scales. 54 | params: 55 | boxes: [N, 5] shape, float32 dtype. `x_min, y_min, x_max, y_mix, mixup_weight`. 56 | labels: [N] shape, int64 dtype. 57 | class_num: int64 num. 58 | anchors: [9, 4] shape, float32 dtype. 59 | ''' 60 | anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 61 | 62 | # convert boxes form: 63 | # shape: [N, 2] 64 | # (x_center, y_center) 65 | box_centers = (boxes[:, 0:2] + boxes[:, 2:4]) / 2 66 | # (width, height) 67 | box_sizes = boxes[:, 2:4] - boxes[:, 0:2] 68 | 69 | # [13, 13, 3, 5+num_class+1] `5` means coords and labels. `1` means mix up weight. 70 | y_true_13 = np.zeros((img_size[1] // 32, img_size[0] // 32, 3, 6 + class_num), np.float32) 71 | y_true_26 = np.zeros((img_size[1] // 16, img_size[0] // 16, 3, 6 + class_num), np.float32) 72 | y_true_52 = np.zeros((img_size[1] // 8, img_size[0] // 8, 3, 6 + class_num), np.float32) 73 | 74 | # mix up weight default to 1. 75 | y_true_13[..., -1] = 1. 76 | y_true_26[..., -1] = 1. 77 | y_true_52[..., -1] = 1. 78 | 79 | y_true = [y_true_13, y_true_26, y_true_52] 80 | 81 | # [N, 1, 2] 82 | box_sizes = np.expand_dims(box_sizes, 1) 83 | # broadcast tricks 84 | # [N, 1, 2] & [9, 2] ==> [N, 9, 2] 85 | mins = np.maximum(- box_sizes / 2, - anchors / 2) 86 | maxs = np.minimum(box_sizes / 2, anchors / 2) 87 | # [N, 9, 2] 88 | whs = maxs - mins 89 | 90 | # [N, 9] 91 | iou = (whs[:, :, 0] * whs[:, :, 1]) / ( 92 | box_sizes[:, :, 0] * box_sizes[:, :, 1] + anchors[:, 0] * anchors[:, 1] - whs[:, :, 0] * whs[:, :, 93 | 1] + 1e-10) 94 | # [N] 95 | best_match_idx = np.argmax(iou, axis=1) 96 | 97 | ratio_dict = {1.: 8., 2.: 16., 3.: 32.} 98 | for i, idx in enumerate(best_match_idx): 99 | # idx: 0,1,2 ==> 2; 3,4,5 ==> 1; 6,7,8 ==> 0 100 | feature_map_group = 2 - idx // 3 101 | # scale ratio: 0,1,2 ==> 8; 3,4,5 ==> 16; 6,7,8 ==> 32 102 | ratio = ratio_dict[np.ceil((idx + 1) / 3.)] 103 | x = int(np.floor(box_centers[i, 0] / ratio)) 104 | y = int(np.floor(box_centers[i, 1] / ratio)) 105 | k = anchors_mask[feature_map_group].index(idx) 106 | c = labels[i] 107 | # print(feature_map_group, '|', y,x,k,c) 108 | 109 | y_true[feature_map_group][y, x, k, :2] = box_centers[i] 110 | y_true[feature_map_group][y, x, k, 2:4] = box_sizes[i] 111 | y_true[feature_map_group][y, x, k, 4] = 1. 112 | y_true[feature_map_group][y, x, k, 5 + c] = 1. 113 | y_true[feature_map_group][y, x, k, -1] = boxes[i, -1] 114 | 115 | return y_true_13, y_true_26, y_true_52 116 | 117 | 118 | def parse_data(line, class_num, img_size, anchors, mode, letterbox_resize): 119 | ''' 120 | param: 121 | line: a line from the training/test txt file 122 | class_num: totol class nums. 123 | img_size: the size of image to be resized to. [width, height] format. 124 | anchors: anchors. 125 | mode: 'train' or 'val'. When set to 'train', data_augmentation will be applied. 126 | letterbox_resize: whether to use the letterbox resize, i.e., keep the original aspect ratio in the resized image. 127 | ''' 128 | if not isinstance(line, list): 129 | img_idx, pic_path, boxes, labels, _, _ = parse_line(line) 130 | img = cv2.imread(pic_path) 131 | # expand the 2nd dimension, mix up weight default to 1. 132 | boxes = np.concatenate((boxes, np.full(shape=(boxes.shape[0], 1), fill_value=1., dtype=np.float32)), axis=-1) 133 | else: 134 | # the mix up case 135 | _, pic_path1, boxes1, labels1, _, _ = parse_line(line[0]) 136 | img1 = cv2.imread(pic_path1) 137 | img_idx, pic_path2, boxes2, labels2, _, _ = parse_line(line[1]) 138 | img2 = cv2.imread(pic_path2) 139 | 140 | img, boxes = mix_up(img1, img2, boxes1, boxes2) 141 | labels = np.concatenate((labels1, labels2)) 142 | 143 | if str(mode) == 'train': 144 | # random color jittering 145 | # NOTE: applying color distort may lead to bad performance sometimes 146 | img = random_color_distort(img) 147 | 148 | # random expansion with prob 0.5 149 | if np.random.uniform(0, 1) > 0.5: 150 | img, boxes = random_expand(img, boxes, 4) 151 | 152 | # random cropping 153 | h, w, _ = img.shape 154 | boxes, crop = random_crop_with_constraints(boxes, (w, h)) 155 | x0, y0, w, h = crop 156 | img = img[y0: y0+h, x0: x0+w] 157 | 158 | # resize with random interpolation 159 | h, w, _ = img.shape 160 | interp = np.random.randint(0, 5) 161 | img, boxes = resize_with_bbox(img, boxes, img_size[0], img_size[1], interp=interp, letterbox=letterbox_resize) 162 | 163 | # random horizontal flip 164 | h, w, _ = img.shape 165 | img, boxes = random_flip(img, boxes, px=0.5) 166 | else: 167 | img, boxes = resize_with_bbox(img, boxes, img_size[0], img_size[1], interp=1, letterbox=letterbox_resize) 168 | 169 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) 170 | 171 | # the input of yolo_v3 should be in range 0~1 172 | img = img / 255. 173 | 174 | y_true_13, y_true_26, y_true_52 = process_box(boxes, labels, img_size, class_num, anchors) 175 | 176 | return img_idx, img, y_true_13, y_true_26, y_true_52 177 | 178 | 179 | def get_batch_data(batch_line, class_num, img_size, anchors, mode, multi_scale=False, mix_up=False, letterbox_resize=True, interval=10): 180 | ''' 181 | generate a batch of imgs and labels 182 | param: 183 | batch_line: a batch of lines from train/val.txt files 184 | class_num: num of total classes. 185 | img_size: the image size to be resized to. format: [width, height]. 186 | anchors: anchors. shape: [9, 2]. 187 | mode: 'train' or 'val'. if set to 'train', data augmentation will be applied. 188 | multi_scale: whether to use multi_scale training, img_size varies from [320, 320] to [640, 640] by default. Note that it will take effect only when mode is set to 'train'. 189 | letterbox_resize: whether to use the letterbox resize, i.e., keep the original aspect ratio in the resized image. 190 | interval: change the scale of image every interval batches. Note that it's indeterministic because of the multi threading. 191 | ''' 192 | global iter_cnt 193 | # multi_scale training 194 | if multi_scale and mode == 'train': 195 | random.seed(iter_cnt // interval) 196 | random_img_size = [[x * 32, x * 32] for x in range(10, 20)] 197 | img_size = random.sample(random_img_size, 1)[0] 198 | iter_cnt += 1 199 | 200 | img_idx_batch, img_batch, y_true_13_batch, y_true_26_batch, y_true_52_batch = [], [], [], [], [] 201 | 202 | # mix up strategy 203 | if mix_up and mode == 'train': 204 | mix_lines = [] 205 | batch_line = batch_line.tolist() 206 | for idx, line in enumerate(batch_line): 207 | if np.random.uniform(0, 1) < 0.5: 208 | mix_lines.append([line, random.sample(batch_line[:idx] + batch_line[idx+1:], 1)[0]]) 209 | else: 210 | mix_lines.append(line) 211 | batch_line = mix_lines 212 | 213 | for line in batch_line: 214 | img_idx, img, y_true_13, y_true_26, y_true_52 = parse_data(line, class_num, img_size, anchors, mode, letterbox_resize) 215 | 216 | img_idx_batch.append(img_idx) 217 | img_batch.append(img) 218 | y_true_13_batch.append(y_true_13) 219 | y_true_26_batch.append(y_true_26) 220 | y_true_52_batch.append(y_true_52) 221 | 222 | img_idx_batch, img_batch, y_true_13_batch, y_true_26_batch, y_true_52_batch = np.asarray(img_idx_batch, np.int64), np.asarray(img_batch), np.asarray(y_true_13_batch), np.asarray(y_true_26_batch), np.asarray(y_true_52_batch) 223 | 224 | return img_idx_batch, img_batch, y_true_13_batch, y_true_26_batch, y_true_52_batch 225 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import numpy as np 6 | import cv2 7 | from collections import Counter 8 | 9 | from utils.nms_utils import cpu_nms, gpu_nms 10 | from utils.data_utils import parse_line 11 | 12 | 13 | def calc_iou(pred_boxes, true_boxes): 14 | ''' 15 | Maintain an efficient way to calculate the ios matrix using the numpy broadcast tricks. 16 | shape_info: pred_boxes: [N, 4] 17 | true_boxes: [V, 4] 18 | return: IoU matrix: shape: [N, V] 19 | ''' 20 | 21 | # [N, 1, 4] 22 | pred_boxes = np.expand_dims(pred_boxes, -2) 23 | # [1, V, 4] 24 | true_boxes = np.expand_dims(true_boxes, 0) 25 | 26 | # [N, 1, 2] & [1, V, 2] ==> [N, V, 2] 27 | intersect_mins = np.maximum(pred_boxes[..., :2], true_boxes[..., :2]) 28 | intersect_maxs = np.minimum(pred_boxes[..., 2:], true_boxes[..., 2:]) 29 | intersect_wh = np.maximum(intersect_maxs - intersect_mins, 0.) 30 | 31 | # shape: [N, V] 32 | intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 33 | # shape: [N, 1, 2] 34 | pred_box_wh = pred_boxes[..., 2:] - pred_boxes[..., :2] 35 | # shape: [N, 1] 36 | pred_box_area = pred_box_wh[..., 0] * pred_box_wh[..., 1] 37 | # [1, V, 2] 38 | true_boxes_wh = true_boxes[..., 2:] - true_boxes[..., :2] 39 | # [1, V] 40 | true_boxes_area = true_boxes_wh[..., 0] * true_boxes_wh[..., 1] 41 | 42 | # shape: [N, V] 43 | iou = intersect_area / (pred_box_area + true_boxes_area - intersect_area + 1e-10) 44 | 45 | return iou 46 | 47 | 48 | def evaluate_on_cpu(y_pred, y_true, num_classes, calc_now=True, max_boxes=50, score_thresh=0.5, iou_thresh=0.5): 49 | ''' 50 | Given y_pred and y_true of a batch of data, get the recall and precision of the current batch. 51 | ''' 52 | 53 | num_images = y_true[0].shape[0] 54 | true_labels_dict = {i: 0 for i in range(num_classes)} # {class: count} 55 | pred_labels_dict = {i: 0 for i in range(num_classes)} 56 | true_positive_dict = {i: 0 for i in range(num_classes)} 57 | 58 | for i in range(num_images): 59 | true_labels_list, true_boxes_list = [], [] 60 | for j in range(3): # three feature maps 61 | # shape: [13, 13, 3, 80] 62 | true_probs_temp = y_true[j][i][..., 5:-1] 63 | # shape: [13, 13, 3, 4] (x_center, y_center, w, h) 64 | true_boxes_temp = y_true[j][i][..., 0:4] 65 | 66 | # [13, 13, 3] 67 | object_mask = true_probs_temp.sum(axis=-1) > 0 68 | 69 | # [V, 3] V: Ground truth number of the current image 70 | true_probs_temp = true_probs_temp[object_mask] 71 | # [V, 4] 72 | true_boxes_temp = true_boxes_temp[object_mask] 73 | 74 | # [V], labels 75 | true_labels_list += np.argmax(true_probs_temp, axis=-1).tolist() 76 | # [V, 4] (x_center, y_center, w, h) 77 | true_boxes_list += true_boxes_temp.tolist() 78 | 79 | if len(true_labels_list) != 0: 80 | for cls, count in Counter(true_labels_list).items(): 81 | true_labels_dict[cls] += count 82 | 83 | # [V, 4] (xmin, ymin, xmax, ymax) 84 | true_boxes = np.array(true_boxes_list) 85 | box_centers, box_sizes = true_boxes[:, 0:2], true_boxes[:, 2:4] 86 | true_boxes[:, 0:2] = box_centers - box_sizes / 2. 87 | true_boxes[:, 2:4] = true_boxes[:, 0:2] + box_sizes 88 | 89 | # [1, xxx, 4] 90 | pred_boxes = y_pred[0][i:i + 1] 91 | pred_confs = y_pred[1][i:i + 1] 92 | pred_probs = y_pred[2][i:i + 1] 93 | 94 | # pred_boxes: [N, 4] 95 | # pred_confs: [N] 96 | # pred_labels: [N] 97 | # N: Detected box number of the current image 98 | pred_boxes, pred_confs, pred_labels = cpu_nms(pred_boxes, pred_confs * pred_probs, num_classes, 99 | max_boxes=max_boxes, score_thresh=score_thresh, iou_thresh=iou_thresh) 100 | 101 | # len: N 102 | pred_labels_list = [] if pred_labels is None else pred_labels.tolist() 103 | if pred_labels_list == []: 104 | continue 105 | 106 | # calc iou 107 | # [N, V] 108 | iou_matrix = calc_iou(pred_boxes, true_boxes) 109 | # [N] 110 | max_iou_idx = np.argmax(iou_matrix, axis=-1) 111 | 112 | correct_idx = [] 113 | correct_conf = [] 114 | for k in range(max_iou_idx.shape[0]): 115 | pred_labels_dict[pred_labels_list[k]] += 1 116 | match_idx = max_iou_idx[k] # V level 117 | if iou_matrix[k, match_idx] > iou_thresh and true_labels_list[match_idx] == pred_labels_list[k]: 118 | if match_idx not in correct_idx: 119 | correct_idx.append(match_idx) 120 | correct_conf.append(pred_confs[k]) 121 | else: 122 | same_idx = correct_idx.index(match_idx) 123 | if pred_confs[k] > correct_conf[same_idx]: 124 | correct_idx.pop(same_idx) 125 | correct_conf.pop(same_idx) 126 | correct_idx.append(match_idx) 127 | correct_conf.append(pred_confs[k]) 128 | 129 | for t in correct_idx: 130 | true_positive_dict[true_labels_list[t]] += 1 131 | 132 | if calc_now: 133 | # avoid divided by 0 134 | recall = sum(true_positive_dict.values()) / (sum(true_labels_dict.values()) + 1e-6) 135 | precision = sum(true_positive_dict.values()) / (sum(pred_labels_dict.values()) + 1e-6) 136 | 137 | return recall, precision 138 | else: 139 | return true_positive_dict, true_labels_dict, pred_labels_dict 140 | 141 | 142 | def evaluate_on_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, y_pred, y_true, num_classes, iou_thresh=0.5, calc_now=True): 143 | ''' 144 | Given y_pred and y_true of a batch of data, get the recall and precision of the current batch. 145 | This function will perform gpu operation on the GPU. 146 | ''' 147 | 148 | num_images = y_true[0].shape[0] 149 | true_labels_dict = {i: 0 for i in range(num_classes)} # {class: count} 150 | pred_labels_dict = {i: 0 for i in range(num_classes)} 151 | true_positive_dict = {i: 0 for i in range(num_classes)} 152 | 153 | for i in range(num_images): 154 | true_labels_list, true_boxes_list = [], [] 155 | for j in range(3): # three feature maps 156 | # shape: [13, 13, 3, 80] 157 | true_probs_temp = y_true[j][i][..., 5:-1] 158 | # shape: [13, 13, 3, 4] (x_center, y_center, w, h) 159 | true_boxes_temp = y_true[j][i][..., 0:4] 160 | 161 | # [13, 13, 3] 162 | object_mask = true_probs_temp.sum(axis=-1) > 0 163 | 164 | # [V, 80] V: Ground truth number of the current image 165 | true_probs_temp = true_probs_temp[object_mask] 166 | # [V, 4] 167 | true_boxes_temp = true_boxes_temp[object_mask] 168 | 169 | # [V], labels, each from 0 to 79 170 | true_labels_list += np.argmax(true_probs_temp, axis=-1).tolist() 171 | # [V, 4] (x_center, y_center, w, h) 172 | true_boxes_list += true_boxes_temp.tolist() 173 | 174 | if len(true_labels_list) != 0: 175 | for cls, count in Counter(true_labels_list).items(): 176 | true_labels_dict[cls] += count 177 | 178 | # [V, 4] (xmin, ymin, xmax, ymax) 179 | true_boxes = np.array(true_boxes_list) 180 | box_centers, box_sizes = true_boxes[:, 0:2], true_boxes[:, 2:4] 181 | true_boxes[:, 0:2] = box_centers - box_sizes / 2. 182 | true_boxes[:, 2:4] = true_boxes[:, 0:2] + box_sizes 183 | 184 | # [1, xxx, 4] 185 | pred_boxes = y_pred[0][i:i + 1] 186 | pred_confs = y_pred[1][i:i + 1] 187 | pred_probs = y_pred[2][i:i + 1] 188 | 189 | # pred_boxes: [N, 4] 190 | # pred_confs: [N] 191 | # pred_labels: [N] 192 | # N: Detected box number of the current image 193 | pred_boxes, pred_confs, pred_labels = sess.run(gpu_nms_op, 194 | feed_dict={pred_boxes_flag: pred_boxes, 195 | pred_scores_flag: pred_confs * pred_probs}) 196 | # len: N 197 | pred_labels_list = [] if pred_labels is None else pred_labels.tolist() 198 | if pred_labels_list == []: 199 | continue 200 | 201 | # calc iou 202 | # [N, V] 203 | iou_matrix = calc_iou(pred_boxes, true_boxes) 204 | # [N] 205 | max_iou_idx = np.argmax(iou_matrix, axis=-1) 206 | 207 | correct_idx = [] 208 | correct_conf = [] 209 | for k in range(max_iou_idx.shape[0]): 210 | pred_labels_dict[pred_labels_list[k]] += 1 211 | match_idx = max_iou_idx[k] # V level 212 | if iou_matrix[k, match_idx] > iou_thresh and true_labels_list[match_idx] == pred_labels_list[k]: 213 | if match_idx not in correct_idx: 214 | correct_idx.append(match_idx) 215 | correct_conf.append(pred_confs[k]) 216 | else: 217 | same_idx = correct_idx.index(match_idx) 218 | if pred_confs[k] > correct_conf[same_idx]: 219 | correct_idx.pop(same_idx) 220 | correct_conf.pop(same_idx) 221 | correct_idx.append(match_idx) 222 | correct_conf.append(pred_confs[k]) 223 | 224 | for t in correct_idx: 225 | true_positive_dict[true_labels_list[t]] += 1 226 | 227 | if calc_now: 228 | # avoid divided by 0 229 | recall = sum(true_positive_dict.values()) / (sum(true_labels_dict.values()) + 1e-6) 230 | precision = sum(true_positive_dict.values()) / (sum(pred_labels_dict.values()) + 1e-6) 231 | 232 | return recall, precision 233 | else: 234 | return true_positive_dict, true_labels_dict, pred_labels_dict 235 | 236 | 237 | def get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, image_ids, y_pred): 238 | ''' 239 | Given the y_pred of an input image, get the predicted bbox and label info. 240 | return: 241 | pred_content: 2d list. 242 | ''' 243 | image_id = image_ids[0] 244 | 245 | # keep the first dimension 1 246 | pred_boxes = y_pred[0][0:1] 247 | pred_confs = y_pred[1][0:1] 248 | pred_probs = y_pred[2][0:1] 249 | 250 | boxes, scores, labels = sess.run(gpu_nms_op, 251 | feed_dict={pred_boxes_flag: pred_boxes, 252 | pred_scores_flag: pred_confs * pred_probs}) 253 | 254 | pred_content = [] 255 | for i in range(len(labels)): 256 | x_min, y_min, x_max, y_max = boxes[i] 257 | score = scores[i] 258 | label = labels[i] 259 | pred_content.append([image_id, x_min, y_min, x_max, y_max, score, label]) 260 | 261 | return pred_content 262 | 263 | 264 | gt_dict = {} # key: img_id, value: gt object list 265 | def parse_gt_rec(gt_filename, target_img_size, letterbox_resize=True): 266 | ''' 267 | parse and re-organize the gt info. 268 | return: 269 | gt_dict: dict. Each key is a img_id, the value is the gt bboxes in the corresponding img. 270 | ''' 271 | 272 | global gt_dict 273 | 274 | if not gt_dict: 275 | new_width, new_height = target_img_size 276 | with open(gt_filename, 'r') as f: 277 | for line in f: 278 | img_id, pic_path, boxes, labels, ori_width, ori_height = parse_line(line) 279 | 280 | objects = [] 281 | for i in range(len(labels)): 282 | x_min, y_min, x_max, y_max = boxes[i] 283 | label = labels[i] 284 | 285 | if letterbox_resize: 286 | resize_ratio = min(new_width / ori_width, new_height / ori_height) 287 | 288 | resize_w = int(resize_ratio * ori_width) 289 | resize_h = int(resize_ratio * ori_height) 290 | 291 | dw = int((new_width - resize_w) / 2) 292 | dh = int((new_height - resize_h) / 2) 293 | 294 | objects.append([x_min * resize_ratio + dw, 295 | y_min * resize_ratio + dh, 296 | x_max * resize_ratio + dw, 297 | y_max * resize_ratio + dh, 298 | label]) 299 | else: 300 | objects.append([x_min * new_width / ori_width, 301 | y_min * new_height / ori_height, 302 | x_max * new_width / ori_width, 303 | y_max * new_height / ori_height, 304 | label]) 305 | gt_dict[img_id] = objects 306 | return gt_dict 307 | 308 | 309 | # The following two functions are modified from FAIR's Detectron repo to calculate mAP: 310 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/voc_eval.py 311 | def voc_ap(rec, prec, use_07_metric=False): 312 | """Compute VOC AP given precision and recall. If use_07_metric is true, uses 313 | the VOC 07 11-point method (default:False). 314 | """ 315 | if use_07_metric: 316 | # 11 point metric 317 | ap = 0. 318 | for t in np.arange(0., 1.1, 0.1): 319 | if np.sum(rec >= t) == 0: 320 | p = 0 321 | else: 322 | p = np.max(prec[rec >= t]) 323 | ap = ap + p / 11. 324 | else: 325 | # correct AP calculation 326 | # first append sentinel values at the end 327 | mrec = np.concatenate(([0.], rec, [1.])) 328 | mpre = np.concatenate(([0.], prec, [0.])) 329 | 330 | # compute the precision envelope 331 | for i in range(mpre.size - 1, 0, -1): 332 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 333 | 334 | # to calculate area under PR curve, look for points 335 | # where X axis (recall) changes value 336 | i = np.where(mrec[1:] != mrec[:-1])[0] 337 | 338 | # and sum (\Delta recall) * prec 339 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 340 | return ap 341 | 342 | 343 | def voc_eval(gt_dict, val_preds, classidx, iou_thres=0.5, use_07_metric=False): 344 | ''' 345 | Top level function that does the PASCAL VOC evaluation. 346 | ''' 347 | # 1.obtain gt: extract all gt objects for this class 348 | class_recs = {} 349 | npos = 0 350 | for img_id in gt_dict: 351 | R = [obj for obj in gt_dict[img_id] if obj[-1] == classidx] 352 | bbox = np.array([x[:4] for x in R]) 353 | det = [False] * len(R) 354 | npos += len(R) 355 | class_recs[img_id] = {'bbox': bbox, 'det': det} 356 | 357 | # 2. obtain pred results 358 | pred = [x for x in val_preds if x[-1] == classidx] 359 | img_ids = [x[0] for x in pred] 360 | confidence = np.array([x[-2] for x in pred]) 361 | BB = np.array([[x[1], x[2], x[3], x[4]] for x in pred]) 362 | 363 | # 3. sort by confidence 364 | sorted_ind = np.argsort(-confidence) 365 | try: 366 | BB = BB[sorted_ind, :] 367 | except: 368 | print('no box, ignore') 369 | return 1e-6, 1e-6, 0, 0, 0 370 | img_ids = [img_ids[x] for x in sorted_ind] 371 | 372 | # 4. mark TPs and FPs 373 | nd = len(img_ids) 374 | tp = np.zeros(nd) 375 | fp = np.zeros(nd) 376 | 377 | for d in range(nd): 378 | # all the gt info in some image 379 | R = class_recs[img_ids[d]] 380 | bb = BB[d, :] 381 | ovmax = -np.Inf 382 | BBGT = R['bbox'] 383 | 384 | if BBGT.size > 0: 385 | # calc iou 386 | # intersection 387 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 388 | iymin = np.maximum(BBGT[:, 1], bb[1]) 389 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 390 | iymax = np.minimum(BBGT[:, 3], bb[3]) 391 | iw = np.maximum(ixmax - ixmin + 1., 0.) 392 | ih = np.maximum(iymax - iymin + 1., 0.) 393 | inters = iw * ih 394 | 395 | # union 396 | uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + (BBGT[:, 2] - BBGT[:, 0] + 1.) * ( 397 | BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) 398 | 399 | overlaps = inters / uni 400 | ovmax = np.max(overlaps) 401 | jmax = np.argmax(overlaps) 402 | 403 | if ovmax > iou_thres: 404 | # gt not matched yet 405 | if not R['det'][jmax]: 406 | tp[d] = 1. 407 | R['det'][jmax] = 1 408 | else: 409 | fp[d] = 1. 410 | else: 411 | fp[d] = 1. 412 | 413 | # compute precision recall 414 | fp = np.cumsum(fp) 415 | tp = np.cumsum(tp) 416 | rec = tp / float(npos) 417 | # avoid divide by zero in case the first detection matches a difficult 418 | # ground truth 419 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 420 | ap = voc_ap(rec, prec, use_07_metric) 421 | 422 | # return rec, prec, ap 423 | return npos, nd, tp[-1] / float(npos), tp[-1] / float(nd), ap 424 | -------------------------------------------------------------------------------- /utils/layer_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | slim = tf.contrib.slim 8 | 9 | def conv2d(inputs, filters, kernel_size, strides=1): 10 | def _fixed_padding(inputs, kernel_size): 11 | pad_total = kernel_size - 1 12 | pad_beg = pad_total // 2 13 | pad_end = pad_total - pad_beg 14 | 15 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 16 | [pad_beg, pad_end], [0, 0]], mode='CONSTANT') 17 | return padded_inputs 18 | if strides > 1: 19 | inputs = _fixed_padding(inputs, kernel_size) 20 | inputs = slim.conv2d(inputs, filters, kernel_size, stride=strides, 21 | padding=('SAME' if strides == 1 else 'VALID')) 22 | return inputs 23 | 24 | def darknet53_body(inputs): 25 | def res_block(inputs, filters): 26 | shortcut = inputs 27 | net = conv2d(inputs, filters * 1, 1) 28 | net = conv2d(net, filters * 2, 3) 29 | 30 | net = net + shortcut 31 | 32 | return net 33 | 34 | # first two conv2d layers 35 | net = conv2d(inputs, 32, 3, strides=1) 36 | net = conv2d(net, 64, 3, strides=2) 37 | 38 | # res_block * 1 39 | net = res_block(net, 32) 40 | 41 | net = conv2d(net, 128, 3, strides=2) 42 | 43 | # res_block * 2 44 | for i in range(2): 45 | net = res_block(net, 64) 46 | 47 | net = conv2d(net, 256, 3, strides=2) 48 | 49 | # res_block * 8 50 | for i in range(8): 51 | net = res_block(net, 128) 52 | 53 | route_1 = net 54 | net = conv2d(net, 512, 3, strides=2) 55 | 56 | # res_block * 8 57 | for i in range(8): 58 | net = res_block(net, 256) 59 | 60 | route_2 = net 61 | net = conv2d(net, 1024, 3, strides=2) 62 | 63 | # res_block * 4 64 | for i in range(4): 65 | net = res_block(net, 512) 66 | route_3 = net 67 | 68 | return route_1, route_2, route_3 69 | 70 | 71 | def yolo_block(inputs, filters): 72 | net = conv2d(inputs, filters * 1, 1) 73 | net = conv2d(net, filters * 2, 3) 74 | net = conv2d(net, filters * 1, 1) 75 | net = conv2d(net, filters * 2, 3) 76 | net = conv2d(net, filters * 1, 1) 77 | route = net 78 | net = conv2d(net, filters * 2, 3) 79 | return route, net 80 | 81 | 82 | def upsample_layer(inputs, out_shape): 83 | new_height, new_width = out_shape[1], out_shape[2] 84 | # NOTE: here height is the first 85 | # TODO: Do we need to set `align_corners` as True? 86 | inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width), name='upsampled') 87 | return inputs 88 | 89 | 90 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import random 6 | 7 | from tensorflow.core.framework import summary_pb2 8 | 9 | 10 | def make_summary(name, val): 11 | return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)]) 12 | 13 | 14 | class AverageMeter(object): 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.average = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.average = self.sum / float(self.count) 29 | 30 | 31 | def parse_anchors(anchor_path): 32 | ''' 33 | parse anchors. 34 | returned data: shape [N, 2], dtype float32 35 | ''' 36 | anchors = np.reshape(np.asarray(open(anchor_path, 'r').read().split(','), np.float32), [-1, 2]) 37 | return anchors 38 | 39 | 40 | def read_class_names(class_name_path): 41 | names = {} 42 | with open(class_name_path, 'r') as data: 43 | for ID, name in enumerate(data): 44 | names[ID] = name.strip('\n') 45 | return names 46 | 47 | 48 | def shuffle_and_overwrite(file_name): 49 | content = open(file_name, 'r').readlines() 50 | random.shuffle(content) 51 | with open(file_name, 'w') as f: 52 | for line in content: 53 | f.write(line) 54 | 55 | 56 | def update_dict(ori_dict, new_dict): 57 | if not ori_dict: 58 | return new_dict 59 | for key in ori_dict: 60 | ori_dict[key] += new_dict[key] 61 | return ori_dict 62 | 63 | 64 | def list_add(ori_list, new_list): 65 | for i in range(len(ori_list)): 66 | ori_list[i] += new_list[i] 67 | return ori_list 68 | 69 | 70 | def load_weights(var_list, weights_file): 71 | """ 72 | Loads and converts pre-trained weights. 73 | param: 74 | var_list: list of network variables. 75 | weights_file: name of the binary file. 76 | """ 77 | with open(weights_file, "rb") as fp: 78 | np.fromfile(fp, dtype=np.int32, count=5) 79 | weights = np.fromfile(fp, dtype=np.float32) 80 | 81 | ptr = 0 82 | i = 0 83 | assign_ops = [] 84 | while i < len(var_list) - 1: 85 | var1 = var_list[i] 86 | var2 = var_list[i + 1] 87 | # do something only if we process conv layer 88 | if 'Conv' in var1.name.split('/')[-2]: 89 | # check type of next layer 90 | if 'BatchNorm' in var2.name.split('/')[-2]: 91 | # load batch norm params 92 | gamma, beta, mean, var = var_list[i + 1:i + 5] 93 | batch_norm_vars = [beta, gamma, mean, var] 94 | for var in batch_norm_vars: 95 | shape = var.shape.as_list() 96 | num_params = np.prod(shape) 97 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 98 | ptr += num_params 99 | assign_ops.append(tf.assign(var, var_weights, validate_shape=True)) 100 | # we move the pointer by 4, because we loaded 4 variables 101 | i += 4 102 | elif 'Conv' in var2.name.split('/')[-2]: 103 | # load biases 104 | bias = var2 105 | bias_shape = bias.shape.as_list() 106 | bias_params = np.prod(bias_shape) 107 | bias_weights = weights[ptr:ptr + 108 | bias_params].reshape(bias_shape) 109 | ptr += bias_params 110 | assign_ops.append(tf.assign(bias, bias_weights, validate_shape=True)) 111 | # we loaded 1 variable 112 | i += 1 113 | # we can load weights of conv layer 114 | shape = var1.shape.as_list() 115 | num_params = np.prod(shape) 116 | 117 | var_weights = weights[ptr:ptr + num_params].reshape( 118 | (shape[3], shape[2], shape[0], shape[1])) 119 | # remember to transpose to column-major 120 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 121 | ptr += num_params 122 | assign_ops.append( 123 | tf.assign(var1, var_weights, validate_shape=True)) 124 | i += 1 125 | 126 | return assign_ops 127 | 128 | 129 | def config_learning_rate(args, global_step): 130 | if args.lr_type == 'exponential': 131 | lr_tmp = tf.train.exponential_decay(args.learning_rate_init, global_step, args.lr_decay_freq, 132 | args.lr_decay_factor, staircase=True, name='exponential_learning_rate') 133 | return tf.maximum(lr_tmp, args.lr_lower_bound) 134 | elif args.lr_type == 'cosine_decay': 135 | train_steps = (args.total_epoches - float(args.use_warm_up) * args.warm_up_epoch) * args.train_batch_num 136 | return args.lr_lower_bound + 0.5 * (args.learning_rate_init - args.lr_lower_bound) * \ 137 | (1 + tf.cos(global_step / train_steps * np.pi)) 138 | elif args.lr_type == 'cosine_decay_restart': 139 | return tf.train.cosine_decay_restarts(args.learning_rate_init, global_step, 140 | args.lr_decay_freq, t_mul=2.0, m_mul=1.0, 141 | name='cosine_decay_learning_rate_restart') 142 | elif args.lr_type == 'fixed': 143 | return tf.convert_to_tensor(args.learning_rate_init, name='fixed_learning_rate') 144 | elif args.lr_type == 'piecewise': 145 | return tf.train.piecewise_constant(global_step, boundaries=args.pw_boundaries, values=args.pw_values, 146 | name='piecewise_learning_rate') 147 | else: 148 | raise ValueError('Unsupported learning rate type!') 149 | 150 | 151 | def config_optimizer(optimizer_name, learning_rate, decay=0.9, momentum=0.9): 152 | if optimizer_name == 'momentum': 153 | return tf.train.MomentumOptimizer(learning_rate, momentum=momentum) 154 | elif optimizer_name == 'rmsprop': 155 | return tf.train.RMSPropOptimizer(learning_rate, decay=decay, momentum=momentum) 156 | elif optimizer_name == 'adam': 157 | return tf.train.AdamOptimizer(learning_rate) 158 | elif optimizer_name == 'sgd': 159 | return tf.train.GradientDescentOptimizer(learning_rate) 160 | else: 161 | raise ValueError('Unsupported optimizer type!') -------------------------------------------------------------------------------- /utils/nms_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | def gpu_nms(boxes, scores, num_classes, max_boxes=50, score_thresh=0.5, nms_thresh=0.5): 9 | """ 10 | Perform NMS on GPU using TensorFlow. 11 | 12 | params: 13 | boxes: tensor of shape [1, 10647, 4] # 10647=(13*13+26*26+52*52)*3, for input 416*416 image 14 | scores: tensor of shape [1, 10647, num_classes], score=conf*prob 15 | num_classes: total number of classes 16 | max_boxes: integer, maximum number of predicted boxes you'd like, default is 50 17 | score_thresh: if [ highest class probability score < score_threshold] 18 | then get rid of the corresponding box 19 | nms_thresh: real value, "intersection over union" threshold used for NMS filtering 20 | """ 21 | 22 | boxes_list, label_list, score_list = [], [], [] 23 | max_boxes = tf.constant(max_boxes, dtype='int32') 24 | 25 | # since we do nms for single image, then reshape it 26 | boxes = tf.reshape(boxes, [-1, 4]) # '-1' means we don't konw the exact number of boxes 27 | score = tf.reshape(scores, [-1, num_classes]) 28 | 29 | # Step 1: Create a filtering mask based on "box_class_scores" by using "threshold". 30 | mask = tf.greater_equal(score, tf.constant(score_thresh)) 31 | # Step 2: Do non_max_suppression for each class 32 | for i in range(num_classes): 33 | # Step 3: Apply the mask to scores, boxes and pick them out 34 | filter_boxes = tf.boolean_mask(boxes, mask[:,i]) 35 | filter_score = tf.boolean_mask(score[:,i], mask[:,i]) 36 | nms_indices = tf.image.non_max_suppression(boxes=filter_boxes, 37 | scores=filter_score, 38 | max_output_size=max_boxes, 39 | iou_threshold=nms_thresh, name='nms_indices') 40 | label_list.append(tf.ones_like(tf.gather(filter_score, nms_indices), 'int32')*i) 41 | boxes_list.append(tf.gather(filter_boxes, nms_indices)) 42 | score_list.append(tf.gather(filter_score, nms_indices)) 43 | 44 | boxes = tf.concat(boxes_list, axis=0) 45 | score = tf.concat(score_list, axis=0) 46 | label = tf.concat(label_list, axis=0) 47 | 48 | return boxes, score, label 49 | 50 | 51 | def py_nms(boxes, scores, max_boxes=50, iou_thresh=0.5): 52 | """ 53 | Pure Python NMS baseline. 54 | 55 | Arguments: boxes: shape of [-1, 4], the value of '-1' means that dont know the 56 | exact number of boxes 57 | scores: shape of [-1,] 58 | max_boxes: representing the maximum of boxes to be selected by non_max_suppression 59 | iou_thresh: representing iou_threshold for deciding to keep boxes 60 | """ 61 | assert boxes.shape[1] == 4 and len(scores.shape) == 1 62 | 63 | x1 = boxes[:, 0] 64 | y1 = boxes[:, 1] 65 | x2 = boxes[:, 2] 66 | y2 = boxes[:, 3] 67 | 68 | areas = (x2 - x1) * (y2 - y1) 69 | order = scores.argsort()[::-1] 70 | 71 | keep = [] 72 | while order.size > 0: 73 | i = order[0] 74 | keep.append(i) 75 | xx1 = np.maximum(x1[i], x1[order[1:]]) 76 | yy1 = np.maximum(y1[i], y1[order[1:]]) 77 | xx2 = np.minimum(x2[i], x2[order[1:]]) 78 | yy2 = np.minimum(y2[i], y2[order[1:]]) 79 | 80 | w = np.maximum(0.0, xx2 - xx1 + 1) 81 | h = np.maximum(0.0, yy2 - yy1 + 1) 82 | inter = w * h 83 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 84 | 85 | inds = np.where(ovr <= iou_thresh)[0] 86 | order = order[inds + 1] 87 | 88 | return keep[:max_boxes] 89 | 90 | 91 | def cpu_nms(boxes, scores, num_classes, max_boxes=50, score_thresh=0.5, iou_thresh=0.5): 92 | """ 93 | Perform NMS on CPU. 94 | Arguments: 95 | boxes: shape [1, 10647, 4] 96 | scores: shape [1, 10647, num_classes] 97 | """ 98 | 99 | boxes = boxes.reshape(-1, 4) 100 | scores = scores.reshape(-1, num_classes) 101 | # Picked bounding boxes 102 | picked_boxes, picked_score, picked_label = [], [], [] 103 | 104 | for i in range(num_classes): 105 | indices = np.where(scores[:,i] >= score_thresh) 106 | filter_boxes = boxes[indices] 107 | filter_scores = scores[:,i][indices] 108 | if len(filter_boxes) == 0: 109 | continue 110 | # do non_max_suppression on the cpu 111 | indices = py_nms(filter_boxes, filter_scores, 112 | max_boxes=max_boxes, iou_thresh=iou_thresh) 113 | picked_boxes.append(filter_boxes[indices]) 114 | picked_score.append(filter_scores[indices]) 115 | picked_label.append(np.ones(len(indices), dtype='int32')*i) 116 | if len(picked_boxes) == 0: 117 | return None, None, None 118 | 119 | boxes = np.concatenate(picked_boxes, axis=0) 120 | score = np.concatenate(picked_score, axis=0) 121 | label = np.concatenate(picked_label, axis=0) 122 | 123 | return boxes, score, label -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import cv2 6 | import random 7 | 8 | 9 | def get_color_table(class_num, seed=2): 10 | random.seed(seed) 11 | color_table = {} 12 | for i in range(class_num): 13 | color_table[i] = [random.randint(0, 255) for _ in range(3)] 14 | return color_table 15 | 16 | 17 | def plot_one_box(img, coord, label=None, color=None, line_thickness=None): 18 | ''' 19 | coord: [x_min, y_min, x_max, y_max] format coordinates. 20 | img: img to plot on. 21 | label: str. The label name. 22 | color: int. color index. 23 | line_thickness: int. rectangle line thickness. 24 | ''' 25 | tl = line_thickness or int(round(0.002 * max(img.shape[0:2]))) # line thickness 26 | color = color or [random.randint(0, 255) for _ in range(3)] 27 | c1, c2 = (int(coord[0]), int(coord[1])), (int(coord[2]), int(coord[3])) 28 | cv2.rectangle(img, c1, c2, color, thickness=tl) 29 | if label: 30 | tf = max(tl - 1, 1) # font thickness 31 | t_size = cv2.getTextSize(label, 0, fontScale=float(tl) / 3, thickness=tf)[0] 32 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 33 | cv2.rectangle(img, c1, c2, color, -1) # filled 34 | cv2.putText(img, label, (c1[0], c1[1] - 2), 0, float(tl) / 3, [0, 0, 0], thickness=tf, lineType=cv2.LINE_AA) 35 | 36 | -------------------------------------------------------------------------------- /video_test.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import division, print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import argparse 8 | import cv2 9 | import time 10 | 11 | from utils.misc_utils import parse_anchors, read_class_names 12 | from utils.nms_utils import gpu_nms 13 | from utils.plot_utils import get_color_table, plot_one_box 14 | from utils.data_aug import letterbox_resize 15 | 16 | from model import yolov3 17 | 18 | parser = argparse.ArgumentParser(description="YOLO-V3 video test procedure.") 19 | parser.add_argument("input_video", type=str, 20 | help="The path of the input video.") 21 | parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", 22 | help="The path of the anchor txt file.") 23 | parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], 24 | help="Resize the input image with `new_size`, size format: [width, height]") 25 | parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, 26 | help="Whether to use the letterbox resize.") 27 | parser.add_argument("--class_name_path", type=str, default="./data/coco.names", 28 | help="The path of the class names.") 29 | parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt", 30 | help="The path of the weights to restore.") 31 | parser.add_argument("--save_video", type=lambda x: (str(x).lower() == 'true'), default=False, 32 | help="Whether to save the video detection results.") 33 | args = parser.parse_args() 34 | 35 | args.anchors = parse_anchors(args.anchor_path) 36 | args.classes = read_class_names(args.class_name_path) 37 | args.num_class = len(args.classes) 38 | 39 | color_table = get_color_table(args.num_class) 40 | 41 | vid = cv2.VideoCapture(args.input_video) 42 | video_frame_cnt = int(vid.get(7)) 43 | video_width = int(vid.get(3)) 44 | video_height = int(vid.get(4)) 45 | video_fps = int(vid.get(5)) 46 | 47 | if args.save_video: 48 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 49 | videoWriter = cv2.VideoWriter('video_result.mp4', fourcc, video_fps, (video_width, video_height)) 50 | 51 | with tf.Session() as sess: 52 | input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data') 53 | yolo_model = yolov3(args.num_class, args.anchors) 54 | with tf.variable_scope('yolov3'): 55 | pred_feature_maps = yolo_model.forward(input_data, False) 56 | pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps) 57 | 58 | pred_scores = pred_confs * pred_probs 59 | 60 | boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=200, score_thresh=0.3, nms_thresh=0.45) 61 | 62 | saver = tf.train.Saver() 63 | saver.restore(sess, args.restore_path) 64 | 65 | for i in range(video_frame_cnt): 66 | ret, img_ori = vid.read() 67 | if args.letterbox_resize: 68 | img, resize_ratio, dw, dh = letterbox_resize(img_ori, args.new_size[0], args.new_size[1]) 69 | else: 70 | height_ori, width_ori = img_ori.shape[:2] 71 | img = cv2.resize(img_ori, tuple(args.new_size)) 72 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 73 | img = np.asarray(img, np.float32) 74 | img = img[np.newaxis, :] / 255. 75 | 76 | start_time = time.time() 77 | boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img}) 78 | end_time = time.time() 79 | 80 | # rescale the coordinates to the original image 81 | if args.letterbox_resize: 82 | boxes_[:, [0, 2]] = (boxes_[:, [0, 2]] - dw) / resize_ratio 83 | boxes_[:, [1, 3]] = (boxes_[:, [1, 3]] - dh) / resize_ratio 84 | else: 85 | boxes_[:, [0, 2]] *= (width_ori/float(args.new_size[0])) 86 | boxes_[:, [1, 3]] *= (height_ori/float(args.new_size[1])) 87 | 88 | 89 | for i in range(len(boxes_)): 90 | x0, y0, x1, y1 = boxes_[i] 91 | plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]] + ', {:.2f}%'.format(scores_[i] * 100), color=color_table[labels_[i]]) 92 | cv2.putText(img_ori, '{:.2f}ms'.format((end_time - start_time) * 1000), (40, 40), 0, 93 | fontScale=1, color=(0, 255, 0), thickness=2) 94 | cv2.imshow('image', img_ori) 95 | if args.save_video: 96 | videoWriter.write(img_ori) 97 | if cv2.waitKey(1) & 0xFF == ord('q'): 98 | break 99 | 100 | vid.release() 101 | if args.save_video: 102 | videoWriter.release() 103 | --------------------------------------------------------------------------------