├── .gitignore ├── LICENSE ├── README.md ├── cache.py ├── config.ini ├── config ├── cache │ ├── coco.tsv │ └── voc.tsv ├── names │ ├── 20 │ └── 80 ├── yolo │ ├── darknet-20.ini │ ├── darknet-80.ini │ ├── tiny-20.ini │ └── tiny-80.ini └── yolo2 │ ├── anchors │ ├── coco.tsv │ └── voc.tsv │ ├── darknet-20.ini │ ├── darknet-80.ini │ ├── tiny-20.ini │ └── tiny-80.ini ├── demo_data_augmentation.py ├── demo_detect.py ├── detect.py ├── detect_camera.py ├── model ├── __init__.py ├── yolo │ ├── __init__.py │ ├── function.py │ └── inference.py └── yolo2 │ ├── __init__.py │ ├── function.py │ └── inference.py ├── parse_darknet_yolo2.py ├── train.py └── utils ├── __init__.py ├── data ├── __init__.py ├── cache.py └── voc.py ├── postprocess.py ├── preprocess.py ├── verify.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .project 4 | .pydevproject 5 | .settings/ 6 | .idea/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project is deprecated. Please see [yolo2-pytorch](https://github.com/ruiminshen/yolo2-pytorch) 2 | 3 | # TensorFlow implementation of the [YOLO (You Only Look Once)](https://arxiv.org/pdf/1506.02640.pdf) and [YOLOv2](https://arxiv.org/pdf/1612.08242.pdf) 4 | 5 | ## Dependencies 6 | 7 | * [Python 3](https://www.python.org/) 8 | * [TensorFlow 1.0](https://www.tensorflow.org/) 9 | * [NumPy](www.numpy.org/) 10 | * [SciPy](https://www.scipy.org/) 11 | * [Pandas](pandas.pydata.org/) 12 | * [Matplotlib](https://matplotlib.org/) 13 | * [BeautifulSoup4](https://www.crummy.com/software/BeautifulSoup/) 14 | * [OpenCV](https://github.com/opencv/opencv) 15 | * [PIL](http://www.pythonware.com/products/pil/) 16 | * [tqdm](https://github.com/tqdm/tqdm) 17 | * [COCO](https://github.com/pdollar/coco) (optional) 18 | 19 | ## Configuration 20 | 21 | Configurations are mainly defined in the "config.ini" file. Such as the detection model (config/model), base directory (config/basedir, which identifies the cache files (.tfrecord), the model data files (.ckpt), and summary data for TensorBoard), and the inference function ([model]/inference). *Notability the configurations can be extended using the "-c" command-line argument*. 22 | 23 | ## Basic Usage 24 | 25 | - Download the [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 2007 ([training, validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) and [test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar)) and 2012 ([training and validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar)) dataset. Extract these tars into one directory (such as "~/Documents/Database/"). 26 | 27 | - Download the [COCO](http://mscoco.org/) 2014 ([training](http://msvocds.blob.core.windows.net/coco2014/train2014.zip), [validation](http://msvocds.blob.core.windows.net/coco2014/val2014.zip), and [test](http://msvocds.blob.core.windows.net/coco2014/test2014.zip)) dataset. Extract these zip files into one directory (such as "~/Documents/Database/coco/"). 28 | 29 | - Run "cache.py" to create the cache file for the training program. **A verify command-line argument "-v" is strongly recommended to check the training data and drop the corrupted examples**, such as the image "COCO_val2014_000000320612.jpg" in the COCO dataset. 30 | 31 | - Run "train.py" to start the training process (the model data saved previously will be loaded if it exists). Multiple command-line arguments can be defined to control the training process. Such as the batch size, the learning rate, the optimization algorithm and the maximum number of steps. 32 | 33 | - Run "detect.py" to detect objects in an image. Run "export CUDA_VISIBLE_DEVICES=" to avoid out of GPU memory error while the training process is running. 34 | 35 | ## Examples 36 | 37 | ### Training a 20 classes Darknet YOLOv2 model from a pretrained 80 classes model 38 | 39 | - Cache the 20 classes data using the customized config file argument. Cache files (.tfrecord) in "~/Documents/Database/yolo-tf/cache/20" will be created. 40 | 41 | ``` 42 | python3 cache.py -c config.ini config/yolo2/darknet-20.ini -v 43 | ``` 44 | 45 | - Download a 80 classes Darknet YOLOv2 model (the original file name is "yolo.weights", a [version](https://drive.google.com/drive/folders/0B1tW_VtY7onidEwyQ2FtQVplWEU) from Darkflow is recommanded). In this tutorial I put it in "~/Downloads/yolo.weights". 46 | 47 | - Parse the 80 classes Darknet YOLOv2 model into Tensorflow format (~/Documents/Database/yolo-tf/yolo2/darknet/80/model.ckpt). A warning like "xxx bytes remaining" indicates the file "yolo.weights" is not compatiable with the original Darknet YOLOv2 model (defined in the function `model.yolo2.inference.darknet`). **Make sure the 80 classes data is cached before parsing**. 48 | 49 | ``` 50 | python3 parse_darknet_yolo2.py ~/Downloads/yolo.weights -c config.ini config/yolo2/darknet-80.ini -d 51 | ``` 52 | 53 | - Transferring the 80 classes Darknet YOLOv2 model into a 20 classes model (~/Documents/Database/yolo-tf/yolo2/darknet/20) except the final convolutional layer. **Be ware the "-d" command-line argument will delete the model files and should be used only once when initializing the model**. 54 | 55 | ``` 56 | python3 train.py -c config.ini config/yolo2/darknet-20.ini -t ~/Documents/Database/yolo-tf/yolo2/darknet/80/model.ckpt -e yolo2_darknet/conv -d 57 | ``` 58 | 59 | - Using the following command in another terminal and opening the address "localhost:6006" in a web browser to monitor the training process. 60 | 61 | ``` 62 | tensorboard --logdir ~/Documents/Database/yolo-tf/yolo2/darknet/20 63 | ``` 64 | 65 | - If you think your model is stabilized, press Ctrl+C to cancel and restart the training with a greater batch size. 66 | 67 | ``` 68 | python3 train.py -c config.ini config/yolo2/darknet-20.ini -b 16 69 | ``` 70 | 71 | - Detect objects from an image file. 72 | 73 | ``` 74 | python3 detect.py $IMAGE_FILE -c config.ini config/yolo2/darknet-20.ini 75 | ``` 76 | 77 | - Detect objects with a camera. 78 | 79 | ``` 80 | python3 detect_camera.py -c config.ini config/yolo2/darknet-20.ini 81 | ``` 82 | 83 | ## Checklist 84 | 85 | - [x] Batch normalization 86 | - [x] Passthrough layer 87 | - [ ] Multi-scale training 88 | - [ ] Dimension cluster 89 | - [x] Extendable configuration (via "-c" command-line argument) 90 | - [x] PASCAL VOC dataset supporting 91 | - [x] MS COCO dataset supporting 92 | - [x] Data augmentation: random crop 93 | - [x] Data augmentation: random flip horizontally 94 | - [x] Multi-thread data batch queue 95 | - [x] Darknet model file (.weights) parser 96 | - [x] Partial model transferring before training 97 | - [x] Detection from image 98 | - [x] Detection from camera 99 | - [ ] Multi-GPU supporting 100 | - [ ] Faster NMS using C/C++ or GPU 101 | - [ ] Performance evaluation 102 | 103 | ## License 104 | 105 | This project is released as the open source software with the GNU Lesser General Public License version 3 ([LGPL v3](http://www.gnu.org/licenses/lgpl-3.0.html)). 106 | 107 | # Acknowledgements 108 | 109 | This project is mainly inspired by the following projects: 110 | 111 | * [YOLO (Darknet)](https://pjreddie.com/darknet/yolo/). 112 | * [Darkflow](https://github.com/thtrieu/darkflow). 113 | -------------------------------------------------------------------------------- /cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import shutil 22 | import importlib 23 | import pandas as pd 24 | import tensorflow as tf 25 | import utils 26 | 27 | 28 | def main(): 29 | cachedir = utils.get_cachedir(config) 30 | os.makedirs(cachedir, exist_ok=True) 31 | path = os.path.join(cachedir, 'names') 32 | shutil.copyfile(os.path.expanduser(os.path.expandvars(config.get('cache', 'names'))), path) 33 | with open(path, 'r') as f: 34 | names = [line.strip() for line in f] 35 | name_index = dict([(name, i) for i, name in enumerate(names)]) 36 | datasets = [(os.path.basename(os.path.splitext(path)[0]), pd.read_csv(os.path.expanduser(os.path.expandvars(path)), sep='\t')) for path in config.get('cache', 'datasets').split(':')] 37 | module = importlib.import_module('utils.data.cache') 38 | for profile in args.profile: 39 | path = os.path.join(cachedir, profile + '.tfrecord') 40 | tf.logging.info('write tfrecords file: ' + path) 41 | with tf.python_io.TFRecordWriter(path) as writer: 42 | for name, dataset in datasets: 43 | tf.logging.info('loading %s %s dataset' % (name, profile)) 44 | func = getattr(module, name) 45 | for i, row in dataset.iterrows(): 46 | tf.logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()]))) 47 | func(writer, name_index, profile, row, args.verify) 48 | tf.logging.info('%s data are saved into %s' % (str(args.profile), cachedir)) 49 | 50 | 51 | def make_args(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 54 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val', 'test']) 55 | parser.add_argument('-v', '--verify', action='store_true') 56 | parser.add_argument('--level', default='info', help='logging level') 57 | return parser.parse_args() 58 | 59 | if __name__ == '__main__': 60 | args = make_args() 61 | config = configparser.ConfigParser() 62 | utils.load_config(config, args.config) 63 | if args.level: 64 | tf.logging.set_verbosity(args.level.upper()) 65 | with tf.Session() as sess: 66 | main() 67 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | ; yolo yolo2 3 | model = yolo2 4 | basedir = ~/Documents/Database/yolo-tf 5 | 6 | [queue] 7 | capacity = 320 8 | min_after_dequeue=160 9 | 10 | [cache] 11 | names = config/names/80 12 | datasets = config/cache/coco.tsv:config/cache/voc.tsv 13 | 14 | [data_augmentation_full] 15 | enable = 1 16 | enable_probability = 0.5 17 | random_crop = 0.9 18 | 19 | [data_augmentation_resized] 20 | enable = 1 21 | enable_probability = 0.5 22 | random_flip_horizontally = 1 23 | random_brightness = 1 24 | random_contrast = 1 25 | random_saturation = 1 26 | random_hue = 1 27 | noise = 1 28 | grayscale_probability = 0.05 29 | 30 | [exponential_decay] 31 | decay_steps = 100000 32 | decay_rate = 0.96 33 | staircase = 1 34 | 35 | [optimizer_adam] 36 | beta1 = 0.9 37 | beta2 = 0.999 38 | epsilon = 1e-8 39 | 40 | [optimizer_adadelta] 41 | rho = 0.95 42 | epsilon = 1e-8 43 | 44 | [optimizer_adagrad] 45 | initial_accumulator_value = 0.1 46 | 47 | [optimizer_momentum] 48 | momentum = 0.9 49 | 50 | [optimizer_rmsprop] 51 | decay = 0.9 52 | momentum = 0 53 | epsilon = 1e-10 54 | 55 | [optimizer_ftrl] 56 | learning_rate_power = -0.5 57 | initial_accumulator_value = 0.1 58 | l1_regularization_strength = 0 59 | l2_regularization_strength = 0 60 | 61 | [summary] 62 | ; (total_loss\/objectives\/(iou_best|iou_normal|coords|prob)|total_loss)$ 63 | scalar = (total_loss\/objectives\/(iou_best|iou_normal|coords|prob)|total_loss)$ 64 | scalar_reduce = tf.reduce_mean 65 | 66 | ; [_\w\d]+\/(input|conv\d*\/(convolution|leaky_relu\/data))$ 67 | ; [_\w\d]+\/(passthrough|reorg)$ 68 | image_ = [_\w\d]+\/(input|conv\d*\/(convolution|leaky_relu\/data))$ 69 | image_max = 1 70 | 71 | ; [_\w\d]+\/(conv|fc)\d*\/(weights|biases)$ 72 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/(gamma|beta)$ 73 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/moments\/normalize\/(mean|variance)$ 74 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/(moving_mean|moving_variance)$ 75 | ; [_\w\d]+\/(conv|fc)\d*\/(convolution|leaky_relu\/data)$ 76 | ; [_\w\d]+\/(input|conv0\/convolution)$ 77 | histogram_ = [_\w\d]+\/(input|conv0\/convolution)$ 78 | gradients = 0 79 | 80 | [yolo] 81 | inference = tiny 82 | width = 448 83 | height = 448 84 | boxes_per_cell = 2 85 | 86 | [yolo_hparam] 87 | prob = 1 88 | iou_best = 1 89 | iou_normal = .5 90 | coords = 5 91 | 92 | [yolo2] 93 | inference = darknet 94 | width = 416 95 | height = 416 96 | anchors = config/yolo2/anchors/coco.tsv 97 | 98 | [yolo2_hparam] 99 | prob = 1 100 | iou_best = 5 101 | iou_normal = 1 102 | coords = 1 103 | -------------------------------------------------------------------------------- /config/cache/coco.tsv: -------------------------------------------------------------------------------- 1 | root year 2 | ~/Documents/Database/coco 2014 3 | -------------------------------------------------------------------------------- /config/cache/voc.tsv: -------------------------------------------------------------------------------- 1 | root 2 | ~/Documents/Database/VOCdevkit/VOC2007 3 | ~/Documents/Database/VOCdevkit/VOC2012 4 | -------------------------------------------------------------------------------- /config/names/20: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /config/names/80: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /config/yolo/darknet-20.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo 3 | 4 | [cache] 5 | names = config/names/20 6 | datasets = config/cache/voc.tsv 7 | 8 | [yolo] 9 | inference = darknet 10 | width = 448 11 | height = 448 12 | boxes_per_cell = 2 13 | hparam = 5e-4 14 | 15 | [yolo_hparam] 16 | prob = 1 17 | iou_best = 1 18 | iou_normal = .5 19 | coords = 5 20 | -------------------------------------------------------------------------------- /config/yolo/darknet-80.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo 3 | 4 | [cache] 5 | names = config/names/80 6 | datasets = config/cache/coco.tsv 7 | 8 | [yolo] 9 | inference = darknet 10 | width = 448 11 | height = 448 12 | boxes_per_cell = 2 13 | hparam = 5e-4 14 | 15 | [yolo_hparam] 16 | prob = 1 17 | iou_best = 1 18 | iou_normal = .5 19 | coords = 5 20 | -------------------------------------------------------------------------------- /config/yolo/tiny-20.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo 3 | 4 | [cache] 5 | names = config/names/20 6 | datasets = config/cache/voc.tsv 7 | 8 | [yolo] 9 | inference = tiny 10 | width = 448 11 | height = 448 12 | boxes_per_cell = 2 13 | hparam = 5e-4 14 | 15 | [yolo_hparam] 16 | prob = 1 17 | iou_best = 1 18 | iou_normal = .5 19 | coords = 5 20 | -------------------------------------------------------------------------------- /config/yolo/tiny-80.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo 3 | 4 | [cache] 5 | names = config/names/80 6 | datasets = config/cache/coco.tsv 7 | 8 | [yolo] 9 | inference = tiny 10 | width = 448 11 | height = 448 12 | boxes_per_cell = 2 13 | hparam = 5e-4 14 | 15 | [yolo_hparam] 16 | prob = 1 17 | iou_best = 1 18 | iou_normal = .5 19 | coords = 5 20 | -------------------------------------------------------------------------------- /config/yolo2/anchors/coco.tsv: -------------------------------------------------------------------------------- 1 | w h 2 | 0.738768 0.874946 3 | 2.42204 2.65704 4 | 4.30971 7.04493 5 | 10.246 4.59428 6 | 12.6868 11.8741 7 | -------------------------------------------------------------------------------- /config/yolo2/anchors/voc.tsv: -------------------------------------------------------------------------------- 1 | w h 2 | 1.08 1.19 3 | 3.42 4.41 4 | 6.63 11.38 5 | 9.42 5.11 6 | 16.62 10.52 7 | -------------------------------------------------------------------------------- /config/yolo2/darknet-20.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo2 3 | 4 | [cache] 5 | names = config/names/20 6 | 7 | [yolo2] 8 | inference = darknet 9 | width = 416 10 | height = 416 11 | anchors = config/yolo2/anchors/voc.tsv 12 | 13 | [yolo2_hparam] 14 | prob = 1 15 | iou_best = 5 16 | iou_normal = 1 17 | coords = 1 18 | -------------------------------------------------------------------------------- /config/yolo2/darknet-80.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo2 3 | 4 | [cache] 5 | names = config/names/80 6 | 7 | [yolo2] 8 | inference = darknet 9 | width = 416 10 | height = 416 11 | anchors = config/yolo2/anchors/coco.tsv 12 | 13 | [yolo2_hparam] 14 | prob = 1 15 | iou_best = 5 16 | iou_normal = 1 17 | coords = 1 18 | -------------------------------------------------------------------------------- /config/yolo2/tiny-20.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo2 3 | 4 | [cache] 5 | names = config/names/20 6 | 7 | [yolo2] 8 | inference = tiny 9 | width = 416 10 | height = 416 11 | anchors = config/yolo2/anchors/voc.tsv 12 | 13 | [yolo2_hparam] 14 | prob = 1 15 | iou_best = 5 16 | iou_normal = 1 17 | coords = 1 18 | -------------------------------------------------------------------------------- /config/yolo2/tiny-80.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | model = yolo2 3 | 4 | [cache] 5 | names = config/names/80 6 | 7 | [yolo2] 8 | inference = tiny 9 | width = 416 10 | height = 416 11 | anchors = config/yolo2/anchors/coco.tsv 12 | 13 | [yolo2_hparam] 14 | prob = 1 15 | iou_best = 5 16 | iou_normal = 1 17 | coords = 1 18 | -------------------------------------------------------------------------------- /demo_data_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import multiprocessing 22 | import numpy as np 23 | import matplotlib.pyplot as plt 24 | import tensorflow as tf 25 | import utils.data 26 | import utils.visualize 27 | 28 | 29 | def main(): 30 | model = config.get('config', 'model') 31 | cachedir = utils.get_cachedir(config) 32 | with open(os.path.join(cachedir, 'names'), 'r') as f: 33 | names = [line.strip() for line in f] 34 | width = config.getint(model, 'width') 35 | height = config.getint(model, 'height') 36 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height) 37 | tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height)) 38 | batch_size = args.rows * args.cols 39 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile] 40 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths) 41 | tf.logging.warn('num_examples=%d' % num_examples) 42 | with tf.Session() as sess: 43 | with tf.name_scope('batch'): 44 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config) 45 | batch = tf.train.shuffle_batch((tf.cast(image_rgb, tf.uint8),) + labels, batch_size=batch_size, 46 | capacity=config.getint('queue', 'capacity'), min_after_dequeue=config.getint('queue', 'min_after_dequeue'), num_threads=multiprocessing.cpu_count() 47 | ) 48 | tf.global_variables_initializer().run() 49 | coord = tf.train.Coordinator() 50 | threads = tf.train.start_queue_runners(sess, coord) 51 | batch_image, batch_labels = sess.run([batch[0], batch[1:]]) 52 | coord.request_stop() 53 | coord.join(threads) 54 | batch_image = batch_image.astype(np.uint8) 55 | fig, axes = plt.subplots(args.rows, args.cols) 56 | for b, (ax, image) in enumerate(zip(axes.flat, batch_image)): 57 | ax.imshow(image) 58 | utils.visualize.draw_labels(ax, names, width, height, cell_width, cell_height, *[l[b] for l in batch_labels]) 59 | if args.grid: 60 | ax.set_xticks(np.arange(0, width, width / cell_width)) 61 | ax.set_yticks(np.arange(0, height, height / cell_height)) 62 | ax.grid(which='both') 63 | ax.tick_params(labelbottom='off', labelleft='off') 64 | else: 65 | ax.set_xticks([]) 66 | ax.set_yticks([]) 67 | fig.tight_layout() 68 | plt.show() 69 | 70 | 71 | def make_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 74 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val']) 75 | parser.add_argument('-g', '--grid', action='store_true') 76 | parser.add_argument('--rows', default=5, type=int) 77 | parser.add_argument('--cols', default=5, type=int) 78 | parser.add_argument('--level', default='info', help='logging level') 79 | return parser.parse_args() 80 | 81 | if __name__ == '__main__': 82 | args = make_args() 83 | config = configparser.ConfigParser() 84 | utils.load_config(config, args.config) 85 | if args.level: 86 | tf.logging.set_verbosity(args.level.upper()) 87 | main() 88 | -------------------------------------------------------------------------------- /demo_detect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import importlib 22 | import itertools 23 | import numpy as np 24 | import matplotlib.pyplot as plt 25 | import matplotlib.patches as patches 26 | import tensorflow as tf 27 | import tensorflow.contrib.slim as slim 28 | import utils.data 29 | import utils.visualize 30 | 31 | 32 | class Drawer(object): 33 | def __init__(self, sess, names, cell_width, cell_height, image, labels, model, feed_dict): 34 | self.sess = sess 35 | self.names = names 36 | self.cell_width, self.cell_height = cell_width, cell_height 37 | self.image, self.labels = image, labels 38 | self.model = model 39 | self.feed_dict = feed_dict 40 | self.fig = plt.figure() 41 | self.ax = self.fig.gca() 42 | height, width, _ = image.shape 43 | self.scale = [width / self.cell_width, height / self.cell_height] 44 | self.ax.imshow(image) 45 | self.plots = utils.visualize.draw_labels(self.ax, names, width, height, cell_width, cell_height, *labels) 46 | self.ax.set_xticks(np.arange(0, width, width / cell_width)) 47 | self.ax.set_yticks(np.arange(0, height, height / cell_height)) 48 | self.ax.grid(which='both') 49 | self.ax.tick_params(labelbottom='off', labelleft='off') 50 | self.fig.canvas.mpl_connect('button_press_event', self.onclick) 51 | self.colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))] 52 | 53 | def onclick(self, event): 54 | for p in self.plots: 55 | p.remove() 56 | self.plots = [] 57 | height, width, _ = self.image.shape 58 | ix = int(event.xdata * self.cell_width / width) 59 | iy = int(event.ydata * self.cell_height / height) 60 | self.plots.append(self.ax.add_patch(patches.Rectangle((ix * width / self.cell_width, iy * height / self.cell_height), width / self.cell_width, height / self.cell_height, linewidth=0, facecolor='black', alpha=.2))) 61 | index = iy * self.cell_width + ix 62 | prob, iou, xy_min, wh = self.sess.run([self.model.prob[0][index], self.model.iou[0][index], self.model.xy_min[0][index], self.model.wh[0][index]], feed_dict=self.feed_dict) 63 | xy_min = xy_min * self.scale 64 | wh = wh * self.scale 65 | for _prob, _iou, (x, y), (w, h), color in zip(prob, iou, xy_min, wh, self.colors): 66 | index = np.argmax(_prob) 67 | name = self.names[index] 68 | _prob = _prob[index] 69 | _conf = _prob * _iou 70 | linewidth = min(_conf * 10, 3) 71 | self.plots.append(self.ax.add_patch(patches.Rectangle((x, y), w, h, linewidth=linewidth, edgecolor=color, facecolor='none'))) 72 | self.plots.append(self.ax.annotate(name + ' (%.1f%%, %.1f%%)' % (_iou * 100, _prob * 100), (x, y), color=color)) 73 | self.fig.canvas.draw() 74 | 75 | 76 | def main(): 77 | model = config.get('config', 'model') 78 | cachedir = utils.get_cachedir(config) 79 | with open(os.path.join(cachedir, 'names'), 'r') as f: 80 | names = [line.strip() for line in f] 81 | width = config.getint(model, 'width') 82 | height = config.getint(model, 'height') 83 | yolo = importlib.import_module('model.' + model) 84 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height) 85 | tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height)) 86 | with tf.Session() as sess: 87 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile] 88 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths) 89 | tf.logging.warn('num_examples=%d' % num_examples) 90 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config) 91 | image_std = tf.image.per_image_standardization(image_rgb) 92 | image_rgb = tf.cast(image_rgb, tf.uint8) 93 | ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image') 94 | global_step = tf.contrib.framework.get_or_create_global_step() 95 | builder = yolo.Builder(args, config) 96 | builder(ph_image) 97 | variables_to_restore = slim.get_variables_to_restore() 98 | ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels] 99 | with tf.name_scope('total_loss') as name: 100 | builder.create_objectives(ph_labels) 101 | total_loss = tf.losses.get_total_loss(name=name) 102 | tf.global_variables_initializer().run() 103 | coord = tf.train.Coordinator() 104 | threads = tf.train.start_queue_runners(sess, coord) 105 | _image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels]) 106 | coord.request_stop() 107 | coord.join(threads) 108 | feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)]) 109 | feed_dict[ph_image] = np.expand_dims(_image_std, 0) 110 | logdir = utils.get_logdir(config) 111 | assert os.path.exists(logdir) 112 | model_path = tf.train.latest_checkpoint(logdir) 113 | tf.logging.info('load ' + model_path) 114 | slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess) 115 | tf.logging.info('global_step=%d' % sess.run(global_step)) 116 | tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict)) 117 | _ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict) 118 | plt.show() 119 | 120 | 121 | def make_args(): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 124 | parser.add_argument('-p', '--profile', nargs='+', default=['train']) 125 | parser.add_argument('--level', default='info', help='logging level') 126 | return parser.parse_args() 127 | 128 | if __name__ == '__main__': 129 | args = make_args() 130 | config = configparser.ConfigParser() 131 | utils.load_config(config, args.config) 132 | if args.level: 133 | tf.logging.set_verbosity(args.level.upper()) 134 | main() 135 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import importlib 22 | import itertools 23 | from PIL import Image, ExifTags 24 | import numpy as np 25 | import matplotlib.pyplot as plt 26 | import matplotlib.patches as patches 27 | import tensorflow as tf 28 | import tensorflow.contrib.slim as slim 29 | import utils.preprocess 30 | import utils.postprocess 31 | 32 | 33 | def std(image): 34 | return utils.preprocess.per_image_standardization(image) 35 | 36 | 37 | def darknet(image): 38 | return image / 255. 39 | 40 | 41 | def read_image(path): 42 | image = Image.open(path) 43 | for key in ExifTags.TAGS.keys(): 44 | if ExifTags.TAGS[key] == 'Orientation': 45 | break 46 | try: 47 | exif = dict(image._getexif().items()) 48 | except AttributeError: 49 | return image 50 | if exif[key] == 3: 51 | image = image.rotate(180, expand=True) 52 | elif exif[key] == 6: 53 | image = image.rotate(270, expand=True) 54 | elif exif[key] == 8: 55 | image = image.rotate(90, expand=True) 56 | return image 57 | 58 | 59 | def detect(sess, model, names, image, path): 60 | preprocess = eval(args.preprocess) 61 | _, height, width, _ = image.get_shape().as_list() 62 | _image = read_image(path) 63 | image_original = np.array(np.uint8(_image)) 64 | if len(image_original.shape) == 2: 65 | image_original = np.repeat(np.expand_dims(image_original, -1), 3, 2) 66 | image_height, image_width, _ = image_original.shape 67 | image_std = preprocess(np.array(np.uint8(_image.resize((width, height)))).astype(np.float32)) 68 | feed_dict = {image: np.expand_dims(image_std, 0)} 69 | tensors = [model.conf, model.xy_min, model.xy_max] 70 | conf, xy_min, xy_max = sess.run([tf.check_numerics(t, t.op.name) for t in tensors], feed_dict=feed_dict) 71 | boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou) 72 | scale = [image_width / model.cell_width, image_height / model.cell_height] 73 | fig = plt.figure() 74 | ax = fig.gca() 75 | ax.imshow(image_original) 76 | colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))] 77 | cnt = 0 78 | for _conf, _xy_min, _xy_max in boxes: 79 | index = np.argmax(_conf) 80 | if _conf[index] > args.threshold: 81 | wh = _xy_max - _xy_min 82 | _xy_min = _xy_min * scale 83 | _wh = wh * scale 84 | linewidth = min(_conf[index] * 10, 3) 85 | ax.add_patch(patches.Rectangle(_xy_min, _wh[0], _wh[1], linewidth=linewidth, edgecolor=colors[index], facecolor='none')) 86 | ax.annotate(names[index] + ' (%.1f%%)' % (_conf[index] * 100), _xy_min, color=colors[index]) 87 | cnt += 1 88 | fig.canvas.set_window_title('%d objects detected' % cnt) 89 | ax.set_xticks([]) 90 | ax.set_yticks([]) 91 | return fig 92 | 93 | 94 | def main(): 95 | model = config.get('config', 'model') 96 | yolo = importlib.import_module('model.' + model) 97 | width = config.getint(model, 'width') 98 | height = config.getint(model, 'height') 99 | with tf.Session() as sess: 100 | image = tf.placeholder(tf.float32, [1, height, width, 3], name='image') 101 | builder = yolo.Builder(args, config) 102 | builder(image) 103 | global_step = tf.contrib.framework.get_or_create_global_step() 104 | model_path = tf.train.latest_checkpoint(utils.get_logdir(config)) 105 | tf.logging.info('load ' + model_path) 106 | slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess) 107 | tf.logging.info('global_step=%d' % sess.run(global_step)) 108 | path = os.path.expanduser(os.path.expandvars(args.path)) 109 | if os.path.isfile(path): 110 | detect(sess, builder.model, builder.names, image, path) 111 | plt.show() 112 | else: 113 | for dirpath, _, filenames in os.walk(path): 114 | for filename in filenames: 115 | if os.path.splitext(filename)[-1].lower() in args.exts: 116 | _path = os.path.join(dirpath, filename) 117 | print(_path) 118 | detect(sess, builder.model, builder.names, image, _path) 119 | plt.show() 120 | 121 | 122 | def make_args(): 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('path', help='input image path') 125 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 126 | parser.add_argument('-p', '--preprocess', default='std', help='the preprocess function') 127 | parser.add_argument('-t', '--threshold', type=float, default=0.3) 128 | parser.add_argument('--threshold_iou', type=float, default=0.4, help='IoU threshold') 129 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpg', '.png']) 130 | parser.add_argument('--level', default='info', help='logging level') 131 | return parser.parse_args() 132 | 133 | if __name__ == '__main__': 134 | args = make_args() 135 | config = configparser.ConfigParser() 136 | utils.load_config(config, args.config) 137 | if args.level: 138 | tf.logging.set_verbosity(args.level.upper()) 139 | main() 140 | -------------------------------------------------------------------------------- /detect_camera.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import argparse 19 | import configparser 20 | import importlib 21 | import cv2 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow.contrib.slim as slim 25 | import utils.postprocess 26 | 27 | 28 | def main(): 29 | model = config.get('config', 'model') 30 | yolo = importlib.import_module('model.' + model) 31 | width = config.getint(model, 'width') 32 | height = config.getint(model, 'height') 33 | preprocess = getattr(importlib.import_module('detect'), args.preprocess) 34 | with tf.Session() as sess: 35 | ph_image = tf.placeholder(tf.float32, [1, height, width, 3], name='ph_image') 36 | builder = yolo.Builder(args, config) 37 | builder(ph_image) 38 | global_step = tf.contrib.framework.get_or_create_global_step() 39 | model_path = tf.train.latest_checkpoint(utils.get_logdir(config)) 40 | tf.logging.info('load ' + model_path) 41 | slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess) 42 | tf.logging.info('global_step=%d' % sess.run(global_step)) 43 | tensors = [builder.model.conf, builder.model.xy_min, builder.model.xy_max] 44 | tensors = [tf.check_numerics(t, t.op.name) for t in tensors] 45 | cap = cv2.VideoCapture(0) 46 | try: 47 | while True: 48 | ret, image_bgr = cap.read() 49 | assert ret 50 | image_height, image_width, _ = image_bgr.shape 51 | scale = [image_width / builder.model.cell_width, image_height / builder.model.cell_height] 52 | image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 53 | image_std = np.expand_dims(preprocess(cv2.resize(image_rgb, (width, height))).astype(np.float32), 0) 54 | feed_dict = {ph_image: image_std} 55 | conf, xy_min, xy_max = sess.run(tensors, feed_dict) 56 | boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou) 57 | for _conf, _xy_min, _xy_max in boxes: 58 | index = np.argmax(_conf) 59 | if _conf[index] > args.threshold: 60 | _xy_min = (_xy_min * scale).astype(np.int) 61 | _xy_max = (_xy_max * scale).astype(np.int) 62 | cv2.rectangle(image_bgr, tuple(_xy_min), tuple(_xy_max), (255, 0, 255), 3) 63 | cv2.putText(image_bgr, builder.names[index] + ' (%.1f%%)' % (_conf[index] * 100), tuple(_xy_min), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) 64 | cv2.imshow('detection', image_bgr) 65 | cv2.waitKey(1) 66 | finally: 67 | cv2.destroyAllWindows() 68 | cap.release() 69 | 70 | 71 | def make_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 74 | parser.add_argument('-p', '--preprocess', default='std', help='the preprocess function') 75 | parser.add_argument('-t', '--threshold', type=float, default=0.3) 76 | parser.add_argument('--threshold_iou', type=float, default=0.4, help='IoU threshold') 77 | parser.add_argument('--level', default='info', help='logging level') 78 | return parser.parse_args() 79 | 80 | if __name__ == '__main__': 81 | args = make_args() 82 | config = configparser.ConfigParser() 83 | utils.load_config(config, args.config) 84 | if args.level: 85 | tf.logging.set_verbosity(args.level.upper()) 86 | main() 87 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo-tf/eae65c8071fe5069f5e3bb1e26f19a761b1b68bc/model/__init__.py -------------------------------------------------------------------------------- /model/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import configparser 19 | import os 20 | import re 21 | import math 22 | import numpy as np 23 | import pandas as pd 24 | import tensorflow as tf 25 | import utils 26 | from . import inference 27 | 28 | 29 | def calc_cell_xy(cell_height, cell_width, dtype=np.float32): 30 | cell_base = np.zeros([cell_height, cell_width, 2], dtype=dtype) 31 | for y in range(cell_height): 32 | for x in range(cell_width): 33 | cell_base[y, x, :] = [x, y] 34 | return cell_base 35 | 36 | 37 | class Model(object): 38 | def __init__(self, net, scope, classes, boxes_per_cell, training=False): 39 | _, self.cell_height, self.cell_width, _ = tf.get_default_graph().get_tensor_by_name(scope + '/conv:0').get_shape().as_list() 40 | cells = self.cell_height * self.cell_width 41 | with tf.name_scope('regress'): 42 | with tf.name_scope('inputs'): 43 | end = cells * classes 44 | self.prob = tf.reshape(net[:, :end], [-1, cells, 1, classes], name='prob') 45 | inputs_remaining = tf.reshape(net[:, end:], [-1, cells, boxes_per_cell, 5], name='inputs_remaining') 46 | self.iou = tf.identity(inputs_remaining[:, :, :, 0], name='iou') 47 | self.offset_xy = tf.identity(inputs_remaining[:, :, :, 1:3], name='offset_xy') 48 | wh01_sqrt_base = tf.identity(inputs_remaining[:, :, :, 3:], name='wh01_sqrt_base') 49 | wh01 = tf.square(wh01_sqrt_base, name='wh01') 50 | wh01_sqrt = tf.abs(wh01_sqrt_base, name='wh01_sqrt') 51 | self.coords = tf.concat([self.offset_xy, wh01_sqrt], -1, name='coords') 52 | self.wh = tf.identity(wh01 * [self.cell_width, self.cell_height], name='wh') 53 | _wh = self.wh / 2 54 | self.offset_xy_min = tf.identity(self.offset_xy - _wh, name='offset_xy_min') 55 | self.offset_xy_max = tf.identity(self.offset_xy + _wh, name='offset_xy_max') 56 | self.areas = tf.reduce_prod(self.wh, -1, name='areas') 57 | if not training: 58 | with tf.name_scope('detection'): 59 | cell_xy = calc_cell_xy(self.cell_height, self.cell_width).reshape([1, cells, 1, 2]) 60 | self.xy = tf.identity(cell_xy + self.offset_xy, name='xy') 61 | self.xy_min = tf.identity(cell_xy + self.offset_xy_min, name='xy_min') 62 | self.xy_max = tf.identity(cell_xy + self.offset_xy_max, name='xy_max') 63 | self.conf = tf.identity(tf.expand_dims(self.iou, -1) * self.prob, name='conf') 64 | self.inputs = net 65 | self.classes = classes 66 | self.boxes_per_cell = boxes_per_cell 67 | 68 | 69 | class Objectives(dict): 70 | def __init__(self, model, mask, prob, coords, offset_xy_min, offset_xy_max, areas): 71 | self.model = model 72 | with tf.name_scope('true'): 73 | self.mask = tf.identity(mask, name='mask') 74 | self.prob = tf.identity(prob, name='prob') 75 | self.coords = tf.identity(coords, name='coords') 76 | self.offset_xy_min = tf.identity(offset_xy_min, name='offset_xy_min') 77 | self.offset_xy_max = tf.identity(offset_xy_max, name='offset_xy_max') 78 | self.areas = tf.identity(areas, name='areas') 79 | with tf.name_scope('iou') as name: 80 | _offset_xy_min = tf.maximum(model.offset_xy_min, self.offset_xy_min, name='_offset_xy_min') 81 | _offset_xy_max = tf.minimum(model.offset_xy_max, self.offset_xy_max, name='_offset_xy_max') 82 | _wh = tf.maximum(_offset_xy_max - _offset_xy_min, 0.0, name='_wh') 83 | _areas = tf.reduce_prod(_wh, -1, name='_areas') 84 | areas = tf.maximum(self.areas + model.areas - _areas, 1e-10, name='areas') 85 | iou = tf.truediv(_areas, areas, name=name) 86 | with tf.name_scope('mask'): 87 | best_box_iou = tf.reduce_max(iou, 2, True, name='best_box_iou') 88 | best_box = tf.to_float(tf.equal(iou, best_box_iou), name='best_box') 89 | mask_best = tf.identity(self.mask * best_box, name='mask_best') 90 | mask_normal = tf.identity(1 - mask_best, name='mask_normal') 91 | with tf.name_scope('dist'): 92 | iou_dist = tf.square(model.iou - mask_best, name='iou_dist') 93 | coords_dist = tf.square(model.coords - self.coords, name='coords_dist') 94 | prob_dist = tf.square(model.prob - self.prob, name='prob_dist') 95 | with tf.name_scope('objectives'): 96 | cnt = np.multiply.reduce(iou_dist.get_shape().as_list()) 97 | self['iou_best'] = tf.identity(tf.reduce_sum(mask_best * iou_dist) / cnt, name='iou_best') 98 | self['iou_normal'] = tf.identity(tf.reduce_sum(mask_normal * iou_dist) / cnt, name='iou_normal') 99 | self['coords'] = tf.identity(tf.reduce_sum(tf.expand_dims(mask_best, -1) * coords_dist) / cnt, name='coords') 100 | self['prob'] = tf.identity(tf.reduce_sum(tf.expand_dims(self.mask, -1) * prob_dist) / cnt, name='prob') 101 | 102 | 103 | class Builder(object): 104 | def __init__(self, args, config): 105 | section = __name__.split('.')[-1] 106 | self.args = args 107 | self.config = config 108 | with open(os.path.join(utils.get_cachedir(config), 'names'), 'r') as f: 109 | self.names = [line.strip() for line in f] 110 | self.boxes_per_cell = config.getint(section, 'boxes_per_cell') 111 | self.func = getattr(inference, config.get(section, 'inference')) 112 | 113 | def __call__(self, data, training=False): 114 | _scope, self.output = self.func(data, len(self.names), self.boxes_per_cell, training=training) 115 | with tf.name_scope(__name__.split('.')[-1]): 116 | self.model = Model(self.output, _scope, len(self.names), self.boxes_per_cell) 117 | 118 | def create_objectives(self, labels): 119 | section = __name__.split('.')[-1] 120 | self.objectives = Objectives(self.model, *labels) 121 | with tf.name_scope('weighted_objectives'): 122 | for key in self.objectives: 123 | tf.add_to_collection(tf.GraphKeys.LOSSES, tf.multiply(self.objectives[key], self.config.getfloat(section + '_hparam', key), name='weighted_' + key)) 124 | -------------------------------------------------------------------------------- /model/yolo/function.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def leaky_relu(inputs, alpha=.1): 22 | with tf.name_scope('leaky_relu') as name: 23 | data = tf.identity(inputs, name='data') 24 | return tf.maximum(data, alpha * data, name=name) 25 | -------------------------------------------------------------------------------- /model/yolo/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | import tensorflow as tf 20 | import tensorflow.contrib.slim as slim 21 | from .function import leaky_relu 22 | 23 | 24 | def tiny(net, classes, boxes_per_cell, training=False): 25 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3] 26 | net = tf.identity(net, name='%s/input' % scope) 27 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'): 28 | index = 0 29 | net = slim.layers.conv2d(net, 16, scope='%s/conv%d' % (scope, index)) 30 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 31 | index += 1 32 | net = slim.layers.conv2d(net, 32, scope='%s/conv%d' % (scope, index)) 33 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 34 | index += 1 35 | net = slim.layers.conv2d(net, 64, scope='%s/conv%d' % (scope, index)) 36 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 37 | index += 1 38 | net = slim.layers.conv2d(net, 128, scope='%s/conv%d' % (scope, index)) 39 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 40 | index += 1 41 | net = slim.layers.conv2d(net, 256, scope='%s/conv%d' % (scope, index)) 42 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 43 | index += 1 44 | net = slim.layers.conv2d(net, 512, scope='%s/conv%d' % (scope, index)) 45 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 46 | index += 1 47 | net = slim.layers.conv2d(net, 512, scope='%s/conv%d' % (scope, index)) 48 | index += 1 49 | net = slim.layers.conv2d(net, 1024, scope='%s/conv%d' % (scope, index)) 50 | index += 1 51 | net = slim.layers.conv2d(net, 256, scope='%s/conv%d' % (scope, index)) 52 | net = tf.identity(net, name='%s/conv' % scope) 53 | _, cell_height, cell_width, _ = net.get_shape().as_list() 54 | net = slim.layers.flatten(net, scope='%s/flatten' % scope) 55 | with slim.arg_scope([slim.layers.fully_connected], activation_fn=leaky_relu, weights_regularizer=slim.l2_regularizer(0.001)), slim.arg_scope([slim.layers.dropout], keep_prob=.5, is_training=training): 56 | index = 0 57 | net = slim.layers.fully_connected(net, 256, scope='%s/fc%d' % (scope, index)) 58 | net = slim.layers.dropout(net, scope='%s/dropout%d' % (scope, index)) 59 | index += 1 60 | net = slim.layers.fully_connected(net, 4096, scope='%s/fc%d' % (scope, index)) 61 | net = slim.layers.dropout(net, scope='%s/dropout%d' % (scope, index)) 62 | net = slim.layers.fully_connected(net, cell_width * cell_height * (classes + boxes_per_cell * 5), activation_fn=None, scope='%s/fc' % scope) 63 | net = tf.identity(net, name='%s/output' % scope) 64 | return scope, net 65 | 66 | TINY_DOWNSAMPLING = (2 ** 6, 2 ** 6) 67 | -------------------------------------------------------------------------------- /model/yolo2/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import configparser 19 | import os 20 | import numpy as np 21 | import pandas as pd 22 | import tensorflow as tf 23 | import utils 24 | from . import inference 25 | from .. import yolo 26 | 27 | 28 | class Model(object): 29 | def __init__(self, net, classes, anchors, training=False): 30 | _, self.cell_height, self.cell_width, _ = net.get_shape().as_list() 31 | cells = self.cell_height * self.cell_width 32 | inputs = tf.reshape(net, [-1, cells, len(anchors), 5 + classes], name='inputs') 33 | with tf.name_scope('regress'): 34 | with tf.name_scope('inputs'): 35 | with tf.name_scope('inputs_sigmoid') as name: 36 | inputs_sigmoid = tf.nn.sigmoid(inputs[:, :, :, :3], name=name) 37 | self.iou = tf.identity(inputs_sigmoid[:, :, :, 0], name='iou') 38 | self.offset_xy = tf.identity(inputs_sigmoid[:, :, :, 1:3], name='offset_xy') 39 | with tf.name_scope('wh') as name: 40 | self.wh = tf.identity(tf.exp(inputs[:, :, :, 3:5]) * np.reshape(anchors, [1, 1, len(anchors), -1]), name=name) 41 | with tf.name_scope('prob') as name: 42 | self.prob = tf.identity(tf.nn.softmax(inputs[:, :, :, 5:]), name=name) 43 | self.areas = tf.reduce_prod(self.wh, -1, name='areas') 44 | _wh = self.wh / 2 45 | self.offset_xy_min = tf.identity(self.offset_xy - _wh, name='offset_xy_min') 46 | self.offset_xy_max = tf.identity(self.offset_xy + _wh, name='offset_xy_max') 47 | self.wh01 = tf.identity(self.wh / np.reshape([self.cell_width, self.cell_height], [1, 1, 1, 2]), name='wh01') 48 | self.wh01_sqrt = tf.sqrt(self.wh01, name='wh01_sqrt') 49 | self.coords = tf.concat([self.offset_xy, self.wh01_sqrt], -1, name='coords') 50 | if not training: 51 | with tf.name_scope('detection'): 52 | cell_xy = yolo.calc_cell_xy(self.cell_height, self.cell_width).reshape([1, cells, 1, 2]) 53 | self.xy = tf.identity(cell_xy + self.offset_xy, name='xy') 54 | self.xy_min = tf.identity(cell_xy + self.offset_xy_min, name='xy_min') 55 | self.xy_max = tf.identity(cell_xy + self.offset_xy_max, name='xy_max') 56 | self.conf = tf.identity(tf.expand_dims(self.iou, -1) * self.prob, name='conf') 57 | self.inputs = net 58 | self.classes = classes 59 | self.anchors = anchors 60 | 61 | 62 | class Objectives(dict): 63 | def __init__(self, model, mask, prob, coords, offset_xy_min, offset_xy_max, areas): 64 | self.model = model 65 | with tf.name_scope('true'): 66 | self.mask = tf.identity(mask, name='mask') 67 | self.prob = tf.identity(prob, name='prob') 68 | self.coords = tf.identity(coords, name='coords') 69 | self.offset_xy_min = tf.identity(offset_xy_min, name='offset_xy_min') 70 | self.offset_xy_max = tf.identity(offset_xy_max, name='offset_xy_max') 71 | self.areas = tf.identity(areas, name='areas') 72 | with tf.name_scope('iou') as name: 73 | _offset_xy_min = tf.maximum(model.offset_xy_min, self.offset_xy_min, name='_offset_xy_min') 74 | _offset_xy_max = tf.minimum(model.offset_xy_max, self.offset_xy_max, name='_offset_xy_max') 75 | _wh = tf.maximum(_offset_xy_max - _offset_xy_min, 0.0, name='_wh') 76 | _areas = tf.reduce_prod(_wh, -1, name='_areas') 77 | areas = tf.maximum(self.areas + model.areas - _areas, 1e-10, name='areas') 78 | iou = tf.truediv(_areas, areas, name=name) 79 | with tf.name_scope('mask'): 80 | best_box_iou = tf.reduce_max(iou, 2, True, name='best_box_iou') 81 | best_box = tf.to_float(tf.equal(iou, best_box_iou), name='best_box') 82 | mask_best = tf.identity(self.mask * best_box, name='mask_best') 83 | mask_normal = tf.identity(1 - mask_best, name='mask_normal') 84 | with tf.name_scope('dist'): 85 | iou_dist = tf.square(model.iou - mask_best, name='iou_dist') 86 | coords_dist = tf.square(model.coords - self.coords, name='coords_dist') 87 | prob_dist = tf.square(model.prob - self.prob, name='prob_dist') 88 | with tf.name_scope('objectives'): 89 | cnt = np.multiply.reduce(iou_dist.get_shape().as_list()) 90 | self['iou_best'] = tf.identity(tf.reduce_sum(mask_best * iou_dist) / cnt, name='iou_best') 91 | self['iou_normal'] = tf.identity(tf.reduce_sum(mask_normal * iou_dist) / cnt, name='iou_normal') 92 | _mask_best = tf.expand_dims(mask_best, -1) 93 | self['coords'] = tf.identity(tf.reduce_sum(_mask_best * coords_dist) / cnt, name='coords') 94 | self['prob'] = tf.identity(tf.reduce_sum(_mask_best * prob_dist) / cnt, name='prob') 95 | 96 | 97 | class Builder(yolo.Builder): 98 | def __init__(self, args, config): 99 | section = __name__.split('.')[-1] 100 | self.args = args 101 | self.config = config 102 | with open(os.path.join(utils.get_cachedir(config), 'names'), 'r') as f: 103 | self.names = [line.strip() for line in f] 104 | self.width = config.getint(section, 'width') 105 | self.height = config.getint(section, 'height') 106 | self.anchors = pd.read_csv(os.path.expanduser(os.path.expandvars(config.get(section, 'anchors'))), sep='\t').values 107 | self.func = getattr(inference, config.get(section, 'inference')) 108 | 109 | def __call__(self, data, training=False): 110 | _, self.output = self.func(data, len(self.names), len(self.anchors), training=training) 111 | with tf.name_scope(__name__.split('.')[-1]): 112 | self.model = Model(self.output, len(self.names), self.anchors, training=training) 113 | 114 | def create_objectives(self, labels): 115 | section = __name__.split('.')[-1] 116 | self.objectives = Objectives(self.model, *labels) 117 | with tf.name_scope('weighted_objectives'): 118 | for key in self.objectives: 119 | tf.add_to_collection(tf.GraphKeys.LOSSES, tf.multiply(self.objectives[key], self.config.getfloat(section + '_hparam', key), name='weighted_' + key)) 120 | -------------------------------------------------------------------------------- /model/yolo2/function.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | 22 | def reorg(net, stride=2, name='reorg'): 23 | batch_size, height, width, channels = net.get_shape().as_list() 24 | _height, _width, _channel = height // stride, width // stride, channels * stride * stride 25 | with tf.name_scope(name) as name: 26 | net = tf.reshape(net, [batch_size, _height, stride, _width, stride, channels]) 27 | net = tf.transpose(net, [0, 1, 3, 2, 4, 5]) # batch_size, _height, _width, stride, stride, channels 28 | net = tf.reshape(net, [batch_size, _height, _width, -1], name=name) 29 | return net 30 | 31 | 32 | def main(): 33 | image = [ 34 | (0, 1, 0, 1), 35 | (2, 3, 2, 3), 36 | (0, 1, 0, 1), 37 | (2, 3, 2, 3), 38 | ] 39 | image = np.expand_dims(image, 0) 40 | image = np.expand_dims(image, -1) 41 | with tf.Session() as sess: 42 | ph_image = tf.placeholder(tf.uint8, image.shape) 43 | images = sess.run(reorg(ph_image), feed_dict={ph_image: image}) 44 | for i, image in enumerate(np.transpose(images[0], [2, 0, 1])): 45 | data = np.unique(image) 46 | assert len(data) == 1 47 | assert data[0] == i 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /model/yolo2/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | import tensorflow as tf 20 | import tensorflow.contrib.slim as slim 21 | from ..yolo.function import leaky_relu 22 | from .function import reorg 23 | 24 | 25 | def tiny(net, classes, num_anchors, training=False, center=True): 26 | def batch_norm(net): 27 | net = slim.batch_norm(net, center=center, scale=True, epsilon=1e-5, is_training=training) 28 | if not center: 29 | net = tf.nn.bias_add(net, slim.variable('biases', shape=[tf.shape(net)[-1]], initializer=tf.zeros_initializer())) 30 | return net 31 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3] 32 | net = tf.identity(net, name='%s/input' % scope) 33 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), normalizer_fn=batch_norm, activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'): 34 | index = 0 35 | channels = 16 36 | for _ in range(5): 37 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 38 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 39 | index += 1 40 | channels *= 2 41 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 42 | net = slim.layers.max_pool2d(net, stride=1, scope='%s/max_pool%d' % (scope, index)) 43 | index += 1 44 | channels *= 2 45 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 46 | index += 1 47 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 48 | net = slim.layers.conv2d(net, num_anchors * (5 + classes), kernel_size=[1, 1], activation_fn=None, scope='%s/conv' % scope) 49 | net = tf.identity(net, name='%s/output' % scope) 50 | return scope, net 51 | 52 | TINY_DOWNSAMPLING = (2 ** 5, 2 ** 5) 53 | 54 | 55 | def _tiny(net, classes, num_anchors, training=False): 56 | return tiny(net, classes, num_anchors, training, False) 57 | 58 | _TINY_DOWNSAMPLING = (2 ** 5, 2 ** 5) 59 | 60 | 61 | def darknet(net, classes, num_anchors, training=False, center=True): 62 | def batch_norm(net): 63 | net = slim.batch_norm(net, center=center, scale=True, epsilon=1e-5, is_training=training) 64 | if not center: 65 | net = tf.nn.bias_add(net, slim.variable('biases', shape=[tf.shape(net)[-1]], initializer=tf.zeros_initializer())) 66 | return net 67 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3] 68 | net = tf.identity(net, name='%s/input' % scope) 69 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], normalizer_fn=batch_norm, activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'): 70 | index = 0 71 | channels = 32 72 | for _ in range(2): 73 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 74 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 75 | index += 1 76 | channels *= 2 77 | for _ in range(2): 78 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 79 | index += 1 80 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index)) 81 | index += 1 82 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 83 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 84 | index += 1 85 | channels *= 2 86 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 87 | index += 1 88 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index)) 89 | index += 1 90 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 91 | index += 1 92 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index)) 93 | index += 1 94 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 95 | passthrough = tf.identity(net, name=scope + '/passthrough') 96 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index)) 97 | index += 1 98 | channels *= 2 99 | # downsampling finished 100 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 101 | index += 1 102 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index)) 103 | index += 1 104 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 105 | index += 1 106 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index)) 107 | index += 1 108 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 109 | index += 1 110 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 111 | index += 1 112 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 113 | index += 1 114 | with tf.name_scope(scope): 115 | _net = reorg(passthrough) 116 | net = tf.concat([_net, net], 3, name='%s/concat%d' % (scope, index)) 117 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index)) 118 | net = slim.layers.conv2d(net, num_anchors * (5 + classes), kernel_size=[1, 1], activation_fn=None, scope='%s/conv' % scope) 119 | net = tf.identity(net, name='%s/output' % scope) 120 | return scope, net 121 | 122 | DARKNET_DOWNSAMPLING = (2 ** 5, 2 ** 5) 123 | 124 | 125 | def _darknet(net, classes, num_anchors, training=False): 126 | return darknet(net, classes, num_anchors, training, False) 127 | 128 | _DARKNET_DOWNSAMPLING = (2 ** 5, 2 ** 5) 129 | -------------------------------------------------------------------------------- /parse_darknet_yolo2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import re 20 | import time 21 | import shutil 22 | import argparse 23 | import configparser 24 | import operator 25 | import itertools 26 | import struct 27 | import numpy as np 28 | import pandas as pd 29 | import tensorflow as tf 30 | import model.yolo2.inference as inference 31 | import utils 32 | 33 | 34 | def transpose_weights(weights, num_anchors): 35 | ksize1, ksize2, channels_in, _ = weights.shape 36 | weights = weights.reshape([ksize1, ksize2, channels_in, num_anchors, -1]) 37 | coords = weights[:, :, :, :, 0:4] 38 | iou = np.expand_dims(weights[:, :, :, :, 4], -1) 39 | classes = weights[:, :, :, :, 5:] 40 | return np.concatenate([iou, coords, classes], -1).reshape([ksize1, ksize2, channels_in, -1]) 41 | 42 | 43 | def transpose_biases(biases, num_anchors): 44 | biases = biases.reshape([num_anchors, -1]) 45 | coords = biases[:, 0:4] 46 | iou = np.expand_dims(biases[:, 4], -1) 47 | classes = biases[:, 5:] 48 | return np.concatenate([iou, coords, classes], -1).reshape([-1]) 49 | 50 | 51 | def transpose(sess, layer, num_anchors): 52 | v = next(filter(lambda v: v.op.name.endswith('weights'), layer)) 53 | sess.run(v.assign(transpose_weights(sess.run(v), num_anchors))) 54 | v = next(filter(lambda v: v.op.name.endswith('biases'), layer)) 55 | sess.run(v.assign(transpose_biases(sess.run(v), num_anchors))) 56 | 57 | 58 | def main(): 59 | model = config.get('config', 'model') 60 | cachedir = utils.get_cachedir(config) 61 | with open(os.path.join(cachedir, 'names'), 'r') as f: 62 | names = [line.strip() for line in f] 63 | width, height = np.array(utils.get_downsampling(config)) * 13 64 | anchors = pd.read_csv(os.path.expanduser(os.path.expandvars(config.get(model, 'anchors'))), sep='\t').values 65 | func = getattr(inference, config.get(model, 'inference')) 66 | with tf.Session() as sess: 67 | image = tf.placeholder(tf.float32, [1, height, width, 3], name='image') 68 | func(image, len(names), len(anchors)) 69 | tf.contrib.framework.get_or_create_global_step() 70 | tf.global_variables_initializer().run() 71 | prog = re.compile(r'[_\w\d]+\/conv(\d*)\/(weights|biases|(BatchNorm\/(gamma|beta|moving_mean|moving_variance)))$') 72 | variables = [(prog.match(v.op.name).group(1), v) for v in tf.global_variables() if prog.match(v.op.name)] 73 | variables = sorted([[int(k) if k else -1, [v for _, v in g]] for k, g in itertools.groupby(variables, operator.itemgetter(0))], key=operator.itemgetter(0)) 74 | assert variables[0][0] == -1 75 | variables[0][0] = len(variables) - 1 76 | variables.insert(len(variables), variables.pop(0)) 77 | with tf.name_scope('assign'): 78 | with open(os.path.expanduser(os.path.expandvars(args.file)), 'rb') as f: 79 | major, minor, revision, seen = struct.unpack('4i', f.read(16)) 80 | tf.logging.info('major=%d, minor=%d, revision=%d, seen=%d' % (major, minor, revision, seen)) 81 | for i, layer in variables: 82 | tf.logging.info('processing layer %d' % i) 83 | total = 0 84 | for suffix in ['biases', 'beta', 'gamma', 'moving_mean', 'moving_variance', 'weights']: 85 | try: 86 | v = next(filter(lambda v: v.op.name.endswith(suffix), layer)) 87 | except StopIteration: 88 | continue 89 | shape = v.get_shape().as_list() 90 | cnt = np.multiply.reduce(shape) 91 | total += cnt 92 | tf.logging.info('%s: %s=%d' % (v.op.name, str(shape), cnt)) 93 | p = struct.unpack('%df' % cnt, f.read(4 * cnt)) 94 | if suffix == 'weights': 95 | ksize1, ksize2, channels_in, channels_out = shape 96 | p = np.reshape(p, [channels_out, channels_in, ksize1, ksize2]) # Darknet format 97 | p = np.transpose(p, [2, 3, 1, 0]) # TensorFlow format (ksize1, ksize2, channels_in, channels_out) 98 | sess.run(v.assign(p)) 99 | tf.logging.info('%d parameters assigned' % total) 100 | remaining = os.fstat(f.fileno()).st_size - f.tell() 101 | transpose(sess, layer, len(anchors)) 102 | saver = tf.train.Saver() 103 | logdir = utils.get_logdir(config) 104 | if args.delete: 105 | tf.logging.warn('delete logging directory: ' + logdir) 106 | shutil.rmtree(logdir, ignore_errors=True) 107 | os.makedirs(logdir, exist_ok=True) 108 | model_path = os.path.join(logdir, 'model.ckpt') 109 | tf.logging.info('save model into ' + model_path) 110 | saver.save(sess, model_path) 111 | if args.summary: 112 | path = os.path.join(logdir, args.logname) 113 | summary_writer = tf.summary.FileWriter(path) 114 | summary_writer.add_graph(sess.graph) 115 | tf.logging.info('tensorboard --logdir ' + logdir) 116 | if remaining > 0: 117 | tf.logging.warn('%d bytes remaining' % remaining) 118 | 119 | 120 | def make_args(): 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('file', help='Darknet .weights file') 123 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 124 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir') 125 | parser.add_argument('-s', '--summary', action='store_true') 126 | parser.add_argument('--logname', default=time.strftime('%Y-%m-%d_%H-%M-%S'), help='the name of TensorBoard log') 127 | parser.add_argument('--level', default='info', help='logging level') 128 | return parser.parse_args() 129 | 130 | if __name__ == '__main__': 131 | args = make_args() 132 | config = configparser.ConfigParser() 133 | utils.load_config(config, args.config) 134 | if args.level: 135 | tf.logging.set_verbosity(args.level.upper()) 136 | main() 137 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import importlib 22 | import shutil 23 | import time 24 | import inspect 25 | import multiprocessing 26 | import tensorflow as tf 27 | import tensorflow.contrib.slim as slim 28 | import utils.data 29 | 30 | 31 | def summary_scalar(config): 32 | try: 33 | reduce = eval(config.get('summary', 'scalar_reduce')) 34 | for t in utils.match_tensor(config.get('summary', 'scalar')): 35 | name = t.op.name 36 | if len(t.get_shape()) > 0: 37 | t = reduce(t) 38 | tf.logging.warn(name + ' is not a scalar tensor, reducing by ' + reduce.__name__) 39 | tf.summary.scalar(name, t) 40 | except (configparser.NoSectionError, configparser.NoOptionError): 41 | tf.logging.warn(inspect.stack()[0][3] + ' disabled') 42 | 43 | 44 | def summary_image(config): 45 | try: 46 | for t in utils.match_tensor(config.get('summary', 'image')): 47 | name = t.op.name 48 | channels = t.get_shape()[-1].value 49 | if channels not in (1, 3, 4): 50 | t = tf.expand_dims(tf.reduce_sum(t, -1), -1) 51 | tf.summary.image(name, t, config.getint('summary', 'image_max')) 52 | except (configparser.NoSectionError, configparser.NoOptionError): 53 | tf.logging.warn(inspect.stack()[0][3] + ' disabled') 54 | 55 | 56 | def summary_histogram(config): 57 | try: 58 | for t in utils.match_tensor(config.get('summary', 'histogram')): 59 | tf.summary.histogram(t.op.name, t) 60 | except (configparser.NoSectionError, configparser.NoOptionError): 61 | tf.logging.warn(inspect.stack()[0][3] + ' disabled') 62 | 63 | 64 | def summary(config): 65 | summary_scalar(config) 66 | summary_image(config) 67 | summary_histogram(config) 68 | 69 | 70 | def get_optimizer(config, name): 71 | section = 'optimizer_' + name 72 | return { 73 | 'adam': lambda learning_rate: tf.train.AdamOptimizer(learning_rate, config.getfloat(section, 'beta1'), config.getfloat(section, 'beta2'), config.getfloat(section, 'epsilon')), 74 | 'adadelta': lambda learning_rate: tf.train.AdadeltaOptimizer(learning_rate, config.getfloat(section, 'rho'), config.getfloat(section, 'epsilon')), 75 | 'adagrad': lambda learning_rate: tf.train.AdagradOptimizer(learning_rate, config.getfloat(section, 'initial_accumulator_value')), 76 | 'momentum': lambda learning_rate: tf.train.MomentumOptimizer(learning_rate, config.getfloat(section, 'momentum')), 77 | 'rmsprop': lambda learning_rate: tf.train.RMSPropOptimizer(learning_rate, config.getfloat(section, 'decay'), config.getfloat(section, 'momentum'), config.getfloat(section, 'epsilon')), 78 | 'ftrl': lambda learning_rate: tf.train.FtrlOptimizer(learning_rate, config.getfloat(section, 'learning_rate_power'), config.getfloat(section, 'initial_accumulator_value'), config.getfloat(section, 'l1_regularization_strength'), config.getfloat(section, 'l2_regularization_strength')), 79 | 'gd': lambda learning_rate: tf.train.GradientDescentOptimizer(learning_rate), 80 | }[name] 81 | 82 | 83 | def main(): 84 | model = config.get('config', 'model') 85 | logdir = utils.get_logdir(config) 86 | if args.delete: 87 | tf.logging.warn('delete logging directory: ' + logdir) 88 | shutil.rmtree(logdir, ignore_errors=True) 89 | cachedir = utils.get_cachedir(config) 90 | with open(os.path.join(cachedir, 'names'), 'r') as f: 91 | names = [line.strip() for line in f] 92 | width = config.getint(model, 'width') 93 | height = config.getint(model, 'height') 94 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height) 95 | tf.logging.warn('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height)) 96 | yolo = importlib.import_module('model.' + model) 97 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile] 98 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths) 99 | tf.logging.warn('num_examples=%d' % num_examples) 100 | with tf.name_scope('batch'): 101 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config) 102 | with tf.name_scope('per_image_standardization'): 103 | image_std = tf.image.per_image_standardization(image_rgb) 104 | batch = tf.train.shuffle_batch((image_std,) + labels, batch_size=args.batch_size, 105 | capacity=config.getint('queue', 'capacity'), min_after_dequeue=config.getint('queue', 'min_after_dequeue'), 106 | num_threads=multiprocessing.cpu_count() 107 | ) 108 | global_step = tf.contrib.framework.get_or_create_global_step() 109 | builder = yolo.Builder(args, config) 110 | builder(batch[0], training=True) 111 | with tf.name_scope('total_loss') as name: 112 | builder.create_objectives(batch[1:]) 113 | total_loss = tf.losses.get_total_loss(name=name) 114 | variables_to_restore = slim.get_variables_to_restore(exclude=args.exclude) 115 | with tf.name_scope('optimizer'): 116 | try: 117 | decay_steps = config.getint('exponential_decay', 'decay_steps') 118 | decay_rate = config.getfloat('exponential_decay', 'decay_rate') 119 | staircase = config.getboolean('exponential_decay', 'staircase') 120 | learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps, decay_rate, staircase=staircase) 121 | tf.logging.warn('using a learning rate start from %f with exponential decay (decay_steps=%d, decay_rate=%f, staircase=%d)' % (args.learning_rate, decay_steps, decay_rate, staircase)) 122 | except (configparser.NoSectionError, configparser.NoOptionError): 123 | learning_rate = args.learning_rate 124 | tf.logging.warn('using a staionary learning rate %f' % args.learning_rate) 125 | optimizer = get_optimizer(config, args.optimizer)(learning_rate) 126 | tf.logging.warn('optimizer=' + args.optimizer) 127 | train_op = slim.learning.create_train_op(total_loss, optimizer, global_step, 128 | clip_gradient_norm=args.gradient_clip, summarize_gradients=config.getboolean('summary', 'gradients'), 129 | ) 130 | if args.transfer: 131 | path = os.path.expanduser(os.path.expandvars(args.transfer)) 132 | tf.logging.warn('transferring from ' + path) 133 | init_assign_op, init_feed_dict = slim.assign_from_checkpoint(path, variables_to_restore) 134 | def init_fn(sess): 135 | sess.run(init_assign_op, init_feed_dict) 136 | tf.logging.warn('transferring from global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate))) 137 | else: 138 | init_fn = lambda sess: tf.logging.warn('global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate))) 139 | summary(config) 140 | tf.logging.warn('tensorboard --logdir ' + logdir) 141 | slim.learning.train(train_op, logdir, master=args.master, is_chief=(args.task == 0), 142 | global_step=global_step, number_of_steps=args.steps, init_fn=init_fn, 143 | summary_writer=tf.summary.FileWriter(os.path.join(logdir, args.logname)), 144 | save_summaries_secs=args.summary_secs, save_interval_secs=args.save_secs 145 | ) 146 | 147 | 148 | def make_args(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 151 | parser.add_argument('-t', '--transfer', help='transferring model from a .ckpt file') 152 | parser.add_argument('-e', '--exclude', nargs='+', help='exclude variables while transferring') 153 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val']) 154 | parser.add_argument('-s', '--steps', type=int, default=None, help='max number of steps') 155 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir') 156 | parser.add_argument('-b', '--batch_size', default=8, type=int, help='batch size') 157 | parser.add_argument('-o', '--optimizer', default='adam') 158 | parser.add_argument('-n', '--logname', default=time.strftime('%Y-%m-%d_%H-%M-%S'), help='the name for TensorBoard') 159 | parser.add_argument('-g', '--gradient_clip', default=0, type=float, help='gradient clip') 160 | parser.add_argument('-lr', '--learning_rate', default=1e-6, type=float, help='learning rate') 161 | parser.add_argument('--seed', type=int, default=None) 162 | parser.add_argument('--summary_secs', default=30, type=int, help='seconds to save summaries') 163 | parser.add_argument('--save_secs', default=600, type=int, help='seconds to save model') 164 | parser.add_argument('--level', help='logging level') 165 | parser.add_argument('--master', default='', help='master address') 166 | parser.add_argument('--task', type=int, default=0, help='task ID') 167 | return parser.parse_args() 168 | 169 | if __name__ == '__main__': 170 | args = make_args() 171 | config = configparser.ConfigParser() 172 | utils.load_config(config, args.config) 173 | if args.level: 174 | tf.logging.set_verbosity(args.level.upper()) 175 | main() 176 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import re 20 | import importlib 21 | import inspect 22 | import numpy as np 23 | import matplotlib.patches as patches 24 | import tensorflow as tf 25 | from tensorflow.python.client import device_lib 26 | 27 | 28 | def get_cachedir(config): 29 | basedir = os.path.expanduser(os.path.expandvars(config.get('config', 'basedir'))) 30 | name = os.path.basename(config.get('cache', 'names')) 31 | return os.path.join(basedir, 'cache', name) 32 | 33 | 34 | def get_logdir(config): 35 | basedir = os.path.expanduser(os.path.expandvars(config.get('config', 'basedir'))) 36 | model = config.get('config', 'model') 37 | inference = config.get(model, 'inference') 38 | name = os.path.basename(config.get('cache', 'names')) 39 | return os.path.join(basedir, model, inference, name) 40 | 41 | 42 | def get_inference(config): 43 | model = config.get('config', 'model') 44 | return getattr(importlib.import_module('.'.join(['model', model, 'inference'])), config.get(model, 'inference')) 45 | 46 | 47 | def get_downsampling(config): 48 | model = config.get('config', 'model') 49 | return getattr(importlib.import_module('.'.join(['model', model, 'inference'])), config.get(model, 'inference').upper() + '_DOWNSAMPLING') 50 | 51 | 52 | def calc_cell_width_height(config, width, height): 53 | downsampling_width, downsampling_height = get_downsampling(config) 54 | assert width % downsampling_width == 0 55 | assert height % downsampling_height == 0 56 | return width // downsampling_width, height // downsampling_height 57 | 58 | 59 | def match_trainable_variables(pattern): 60 | prog = re.compile(pattern) 61 | return [v for v in tf.trainable_variables() if prog.match(v.op.name)] 62 | 63 | 64 | def match_tensor(pattern): 65 | prog = re.compile(pattern) 66 | return [op.values()[0] for op in tf.get_default_graph().get_operations() if op.values() and prog.match(op.name)] 67 | 68 | 69 | def load_config(config, paths): 70 | for path in paths: 71 | path = os.path.expanduser(os.path.expandvars(path)) 72 | assert os.path.exists(path) 73 | config.read(path) 74 | 75 | def get_available_gpus(): 76 | local_device_protos = device_lib.list_local_devices() 77 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 78 | -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import re 20 | import importlib 21 | import inspect 22 | import numpy as np 23 | import matplotlib.patches as patches 24 | import tensorflow as tf 25 | from .. import preprocess 26 | 27 | 28 | def decode_image_objects(paths): 29 | with tf.name_scope(inspect.stack()[0][3]): 30 | with tf.name_scope('parse_example'): 31 | reader = tf.TFRecordReader() 32 | _, serialized = reader.read(tf.train.string_input_producer(paths)) 33 | example = tf.parse_single_example(serialized, features={ 34 | 'imagepath': tf.FixedLenFeature([], tf.string), 35 | 'imageshape': tf.FixedLenFeature([3], tf.int64), 36 | 'objects': tf.FixedLenFeature([2], tf.string), 37 | }) 38 | imagepath = example['imagepath'] 39 | objects = example['objects'] 40 | with tf.name_scope('decode_objects'): 41 | objects_class = tf.decode_raw(objects[0], tf.int64, name='objects_class') 42 | objects_coord = tf.decode_raw(objects[1], tf.float32) 43 | objects_coord = tf.reshape(objects_coord, [-1, 4], name='objects_coord') 44 | with tf.name_scope('load_image'): 45 | imagefile = tf.read_file(imagepath) 46 | image = tf.image.decode_jpeg(imagefile, channels=3) 47 | return image, example['imageshape'], objects_class, objects_coord 48 | 49 | 50 | def data_augmentation_full(image, objects_coord, width_height, config): 51 | section = inspect.stack()[0][3] 52 | with tf.name_scope(section): 53 | random_crop = config.getfloat(section, 'random_crop') 54 | if random_crop > 0: 55 | image, objects_coord, width_height = tf.cond( 56 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 57 | lambda: preprocess.random_crop(image, objects_coord, width_height, random_crop), 58 | lambda: (image, objects_coord, width_height) 59 | ) 60 | return image, objects_coord, width_height 61 | 62 | 63 | def resize_image_objects(image, objects_coord, width_height, width, height): 64 | with tf.name_scope(inspect.stack()[0][3]): 65 | image = tf.image.resize_images(image, [height, width]) 66 | factor = [width, height] / width_height 67 | objects_coord = objects_coord * tf.tile(factor, [2]) 68 | return image, objects_coord 69 | 70 | 71 | def data_augmentation_resized(image, objects_coord, width, height, config): 72 | section = inspect.stack()[0][3] 73 | with tf.name_scope(section): 74 | if config.getboolean(section, 'random_flip_horizontally'): 75 | image, objects_coord = preprocess.random_flip_horizontally(image, objects_coord, width) 76 | if config.getboolean(section, 'random_brightness'): 77 | image = tf.cond( 78 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 79 | lambda: tf.image.random_brightness(image, max_delta=63), 80 | lambda: image 81 | ) 82 | if config.getboolean(section, 'random_saturation'): 83 | image = tf.cond( 84 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 85 | lambda: tf.image.random_saturation(image, lower=0.5, upper=1.5), 86 | lambda: image 87 | ) 88 | if config.getboolean(section, 'random_hue'): 89 | image = tf.cond( 90 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 91 | lambda: tf.image.random_hue(image, max_delta=0.032), 92 | lambda: image 93 | ) 94 | if config.getboolean(section, 'random_contrast'): 95 | image = tf.cond( 96 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 97 | lambda: tf.image.random_contrast(image, lower=0.5, upper=1.5), 98 | lambda: image 99 | ) 100 | if config.getboolean(section, 'noise'): 101 | image = tf.cond( 102 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'), 103 | lambda: image + tf.truncated_normal(tf.shape(image)) * tf.random_uniform([], 5, 15), 104 | lambda: image 105 | ) 106 | grayscale_probability = config.getfloat(section, 'grayscale_probability') 107 | if grayscale_probability > 0: 108 | image = preprocess.random_grayscale(image, grayscale_probability) 109 | return image, objects_coord 110 | 111 | 112 | def transform_labels(objects_class, objects_coord, classes, cell_width, cell_height, dtype=np.float32): 113 | cells = cell_height * cell_width 114 | mask = np.zeros([cells, 1], dtype=dtype) 115 | prob = np.zeros([cells, 1, classes], dtype=dtype) 116 | coords = np.zeros([cells, 1, 4], dtype=dtype) 117 | offset_xy_min = np.zeros([cells, 1, 2], dtype=dtype) 118 | offset_xy_max = np.zeros([cells, 1, 2], dtype=dtype) 119 | assert len(objects_class) == len(objects_coord) 120 | xmin, ymin, xmax, ymax = objects_coord.T 121 | x = cell_width * (xmin + xmax) / 2 122 | y = cell_height * (ymin + ymax) / 2 123 | ix = np.floor(x) 124 | iy = np.floor(y) 125 | offset_x = x - ix 126 | offset_y = y - iy 127 | w = xmax - xmin 128 | h = ymax - ymin 129 | index = (iy * cell_width + ix).astype(np.int) 130 | mask[index, :] = 1 131 | prob[index, :, objects_class] = 1 132 | coords[index, 0, 0] = offset_x 133 | coords[index, 0, 1] = offset_y 134 | coords[index, 0, 2] = np.sqrt(w) 135 | coords[index, 0, 3] = np.sqrt(h) 136 | _w = w / 2 * cell_width 137 | _h = h / 2 * cell_height 138 | offset_xy_min[index, 0, 0] = offset_x - _w 139 | offset_xy_min[index, 0, 1] = offset_y - _h 140 | offset_xy_max[index, 0, 0] = offset_x + _w 141 | offset_xy_max[index, 0, 1] = offset_y + _h 142 | wh = offset_xy_max - offset_xy_min 143 | assert np.all(wh >= 0) 144 | areas = np.multiply.reduce(wh, -1) 145 | return mask, prob, coords, offset_xy_min, offset_xy_max, areas 146 | 147 | 148 | def decode_labels(objects_class, objects_coord, classes, cell_width, cell_height): 149 | with tf.name_scope(inspect.stack()[0][3]): 150 | mask, prob, coords, offset_xy_min, offset_xy_max, areas = tf.py_func(transform_labels, [objects_class, objects_coord, classes, cell_width, cell_height], [tf.float32] * 6) 151 | cells = cell_height * cell_width 152 | with tf.name_scope('reshape_labels'): 153 | mask = tf.reshape(mask, [cells, 1], name='mask') 154 | prob = tf.reshape(prob, [cells, 1, classes], name='prob') 155 | coords = tf.reshape(coords, [cells, 1, 4], name='coords') 156 | offset_xy_min = tf.reshape(offset_xy_min, [cells, 1, 2], name='offset_xy_min') 157 | offset_xy_max = tf.reshape(offset_xy_max, [cells, 1, 2], name='offset_xy_max') 158 | areas = tf.reshape(areas, [cells, 1], name='areas') 159 | return mask, prob, coords, offset_xy_min, offset_xy_max, areas 160 | 161 | 162 | def load_image_labels(paths, classes, width, height, cell_width, cell_height, config): 163 | with tf.name_scope('batch'): 164 | image, imageshape, objects_class, objects_coord = decode_image_objects(paths) 165 | image = tf.cast(image, tf.float32) 166 | width_height = tf.cast(imageshape[1::-1], tf.float32) 167 | if config.getboolean('data_augmentation_full', 'enable'): 168 | image, objects_coord, width_height = data_augmentation_full(image, objects_coord, width_height, config) 169 | image, objects_coord = resize_image_objects(image, objects_coord, width_height, width, height) 170 | if config.getboolean('data_augmentation_resized', 'enable'): 171 | image, objects_coord = data_augmentation_resized(image, objects_coord, width, height, config) 172 | image = tf.clip_by_value(image, 0, 255) 173 | objects_coord = objects_coord / [width, height, width, height] 174 | labels = decode_labels(objects_class, objects_coord, classes, cell_width, cell_height) 175 | return image, labels 176 | -------------------------------------------------------------------------------- /utils/data/cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import inspect 20 | from PIL import Image 21 | import tqdm 22 | import numpy as np 23 | import tensorflow as tf 24 | import utils.data.voc 25 | 26 | 27 | def verify_imageshape(imagepath, imageshape): 28 | with Image.open(imagepath) as image: 29 | return np.all(np.equal(image.size, imageshape[1::-1])) 30 | 31 | 32 | def verify_image_jpeg(imagepath, imageshape): 33 | scope = inspect.stack()[0][3] 34 | try: 35 | graph = tf.get_default_graph() 36 | path = graph.get_tensor_by_name(scope + '/path:0') 37 | decode = graph.get_tensor_by_name(scope + '/decode_jpeg:0') 38 | except KeyError: 39 | tf.logging.debug('creating decode_jpeg tensor') 40 | path = tf.placeholder(tf.string, name=scope + '/path') 41 | imagefile = tf.read_file(path, name=scope + '/read_file') 42 | decode = tf.image.decode_jpeg(imagefile, channels=3, name=scope + '/decode_jpeg') 43 | try: 44 | image = tf.get_default_session().run(decode, {path: imagepath}) 45 | except: 46 | return False 47 | return np.all(np.equal(image.shape[:2], imageshape[:2])) 48 | 49 | 50 | def check_coords(objects_coord): 51 | return np.all(objects_coord[:, 0] <= objects_coord[:, 2]) and np.all(objects_coord[:, 1] <= objects_coord[:, 3]) 52 | 53 | 54 | def verify_coords(objects_coord, imageshape): 55 | assert check_coords(objects_coord) 56 | return np.all(objects_coord >= 0) and np.all(objects_coord <= np.tile(imageshape[1::-1], [2])) 57 | 58 | 59 | def fix_coords(objects_coord, imageshape): 60 | assert check_coords(objects_coord) 61 | objects_coord = np.maximum(objects_coord, np.zeros([4], dtype=objects_coord.dtype)) 62 | objects_coord = np.minimum(objects_coord, np.tile(np.asanyarray(imageshape[1::-1], objects_coord.dtype), [2])) 63 | return objects_coord 64 | 65 | 66 | def voc(writer, name_index, profile, row, verify=False): 67 | root = os.path.expanduser(os.path.expandvars(row['root'])) 68 | path = os.path.join(root, 'ImageSets', 'Main', profile) + '.txt' 69 | if not os.path.exists(path): 70 | tf.logging.warn(path + ' not exists') 71 | return False 72 | with open(path, 'r') as f: 73 | filenames = [line.strip() for line in f] 74 | annotations = [os.path.join(root, 'Annotations', filename + '.xml') for filename in filenames] 75 | _annotations = list(filter(os.path.exists, annotations)) 76 | if len(annotations) > len(_annotations): 77 | tf.logging.warn('%d of %d images not exists' % (len(annotations) - len(_annotations), len(annotations))) 78 | cnt_noobj = 0 79 | for path in tqdm.tqdm(_annotations): 80 | imagename, imageshape, objects_class, objects_coord = utils.data.voc.load_dataset(path, name_index) 81 | if len(objects_class) <= 0: 82 | cnt_noobj += 1 83 | continue 84 | objects_class = np.array(objects_class, dtype=np.int64) 85 | objects_coord = np.array(objects_coord, dtype=np.float32) 86 | imagepath = os.path.join(root, 'JPEGImages', imagename) 87 | if verify: 88 | if not verify_coords(objects_coord, imageshape): 89 | tf.logging.error('failed to verify coordinates of ' + imagepath) 90 | continue 91 | if not verify_image_jpeg(imagepath, imageshape): 92 | tf.logging.error('failed to decode ' + imagepath) 93 | continue 94 | assert len(objects_class) == len(objects_coord) 95 | example = tf.train.Example(features=tf.train.Features(feature={ 96 | 'imagepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(imagepath)])), 97 | 'imageshape': tf.train.Feature(int64_list=tf.train.Int64List(value=imageshape)), 98 | 'objects': tf.train.Feature(bytes_list=tf.train.BytesList(value=[objects_class.tostring(), objects_coord.tostring()])), 99 | })) 100 | writer.write(example.SerializeToString()) 101 | if cnt_noobj > 0: 102 | tf.logging.warn('%d of %d images have no object' % (cnt_noobj, len(filenames))) 103 | return True 104 | 105 | 106 | def coco(writer, name_index, profile, row, verify=False): 107 | root = os.path.expanduser(os.path.expandvars(row['root'])) 108 | year = str(row['year']) 109 | name = profile + year 110 | path = os.path.join(root, 'annotations', 'instances_%s.json' % name) 111 | if not os.path.exists(path): 112 | tf.logging.warn(path + ' not exists') 113 | return False 114 | import pycocotools.coco 115 | coco = pycocotools.coco.COCO(path) 116 | catIds = coco.getCatIds(catNms=list(name_index.keys())) 117 | cats = coco.loadCats(catIds) 118 | id_index = dict((cat['id'], name_index[cat['name']]) for cat in cats) 119 | imgIds = coco.getImgIds() 120 | path = os.path.join(root, name) 121 | imgs = coco.loadImgs(imgIds) 122 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(path, img['file_name'])), imgs)) 123 | if len(imgs) > len(_imgs): 124 | tf.logging.warn('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs))) 125 | cnt_noobj = 0 126 | for img in tqdm.tqdm(_imgs): 127 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) 128 | anns = coco.loadAnns(annIds) 129 | if len(anns) <= 0: 130 | cnt_noobj += 1 131 | continue 132 | imagepath = os.path.join(path, img['file_name']) 133 | width, height = img['width'], img['height'] 134 | imageshape = [height, width, 3] 135 | objects_class = np.array([id_index[ann['category_id']] for ann in anns], dtype=np.int64) 136 | objects_coord = [ann['bbox'] for ann in anns] 137 | objects_coord = [(x, y, x + w, y + h) for x, y, w, h in objects_coord] 138 | objects_coord = np.array(objects_coord, dtype=np.float32) 139 | if verify: 140 | if not verify_coords(objects_coord, imageshape): 141 | tf.logging.error('failed to verify coordinates of ' + imagepath) 142 | continue 143 | if not verify_image_jpeg(imagepath, imageshape): 144 | tf.logging.error('failed to decode ' + imagepath) 145 | continue 146 | assert len(objects_class) == len(objects_coord) 147 | example = tf.train.Example(features=tf.train.Features(feature={ 148 | 'imagepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(imagepath)])), 149 | 'imageshape': tf.train.Feature(int64_list=tf.train.Int64List(value=imageshape)), 150 | 'objects': tf.train.Feature(bytes_list=tf.train.BytesList(value=[objects_class.tostring(), objects_coord.tostring()])), 151 | })) 152 | writer.write(example.SerializeToString()) 153 | if cnt_noobj > 0: 154 | tf.logging.warn('%d of %d images have no object' % (cnt_noobj, len(_imgs))) 155 | return True 156 | -------------------------------------------------------------------------------- /utils/data/voc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import sys 19 | import bs4 20 | 21 | 22 | def load_dataset(path, name_index): 23 | with open(path, 'r') as f: 24 | anno = bs4.BeautifulSoup(f.read(), 'xml').find('annotation') 25 | objects_class = [] 26 | objects_coord = [] 27 | for obj in anno.find_all('object', recursive=False): 28 | for bndbox, name in zip(obj.find_all('bndbox', recursive=False), obj.find_all('name', recursive=False)): 29 | if name.text in name_index: 30 | objects_class.append(name_index[name.text]) 31 | xmin = float(bndbox.find('xmin').text) - 1 32 | ymin = float(bndbox.find('ymin').text) - 1 33 | xmax = float(bndbox.find('xmax').text) - 1 34 | ymax = float(bndbox.find('ymax').text) - 1 35 | objects_coord.append((xmin, ymin, xmax, ymax)) 36 | else: 37 | sys.stderr.write(name.text + ' not in names\n') 38 | size = anno.find('size') 39 | return anno.find('filename').text, (int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)), objects_class, objects_coord 40 | -------------------------------------------------------------------------------- /utils/postprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import numpy as np 19 | 20 | 21 | def iou(xy_min1, xy_max1, xy_min2, xy_max2): 22 | assert(not np.isnan(xy_min1).any()) 23 | assert(not np.isnan(xy_max1).any()) 24 | assert(not np.isnan(xy_min2).any()) 25 | assert(not np.isnan(xy_max2).any()) 26 | assert np.all(xy_min1 <= xy_max1) 27 | assert np.all(xy_min2 <= xy_max2) 28 | areas1 = np.multiply.reduce(xy_max1 - xy_min1) 29 | areas2 = np.multiply.reduce(xy_max2 - xy_min2) 30 | _xy_min = np.maximum(xy_min1, xy_min2) 31 | _xy_max = np.minimum(xy_max1, xy_max2) 32 | _wh = np.maximum(_xy_max - _xy_min, 0) 33 | _areas = np.multiply.reduce(_wh) 34 | assert _areas <= areas1 35 | assert _areas <= areas2 36 | return _areas / np.maximum(areas1 + areas2 - _areas, 1e-10) 37 | 38 | 39 | def non_max_suppress(conf, xy_min, xy_max, threshold, threshold_iou): 40 | _, _, classes = conf.shape 41 | boxes = [(_conf, _xy_min, _xy_max) for _conf, _xy_min, _xy_max in zip(conf.reshape(-1, classes), xy_min.reshape(-1, 2), xy_max.reshape(-1, 2))] 42 | for c in range(classes): 43 | boxes.sort(key=lambda box: box[0][c], reverse=True) 44 | for i in range(len(boxes) - 1): 45 | box = boxes[i] 46 | if box[0][c] <= threshold: 47 | continue 48 | for _box in boxes[i + 1:]: 49 | if iou(box[1], box[2], _box[1], _box[2]) >= threshold_iou: 50 | _box[0][c] = 0 51 | return boxes 52 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | 23 | def per_image_standardization(image): 24 | stddev = np.std(image) 25 | return (image - np.mean(image)) / max(stddev, 1.0 / np.sqrt(np.multiply.reduce(image.shape))) 26 | 27 | 28 | def random_crop(image, objects_coord, width_height, scale=1): 29 | assert 0 < scale <= 1 30 | section = inspect.stack()[0][3] 31 | with tf.name_scope(section): 32 | xy_min = tf.reduce_min(objects_coord[:, :2], 0) 33 | xy_max = tf.reduce_max(objects_coord[:, 2:], 0) 34 | margin = width_height - xy_max 35 | shrink = tf.random_uniform([4], maxval=scale) * tf.concat([xy_min, margin], 0) 36 | _xy_min = shrink[:2] 37 | _wh = width_height - shrink[2:] - _xy_min 38 | objects_coord = objects_coord - tf.tile(_xy_min, [2]) 39 | _xy_min_ = tf.cast(_xy_min, tf.int32) 40 | _wh_ = tf.cast(_wh, tf.int32) 41 | image = tf.image.crop_to_bounding_box(image, _xy_min_[1], _xy_min_[0], _wh_[1], _wh_[0]) 42 | return image, objects_coord, _wh 43 | 44 | 45 | def flip_horizontally(image, objects_coord, width): 46 | section = inspect.stack()[0][3] 47 | with tf.name_scope(section): 48 | image = tf.image.flip_left_right(image) 49 | xmin, ymin, xmax, ymax = objects_coord[:, 0:1], objects_coord[:, 1:2], objects_coord[:, 2:3], objects_coord[:, 3:4] 50 | objects_coord = tf.concat([width - xmax, ymin, width - xmin, ymax], 1) 51 | return image, objects_coord 52 | 53 | 54 | def random_flip_horizontally(image, objects_coord, width, probability=0.5): 55 | section = inspect.stack()[0][3] 56 | with tf.name_scope(section): 57 | pred = tf.random_uniform([]) < probability 58 | fn1 = lambda: flip_horizontally(image, objects_coord, width) 59 | fn2 = lambda: (image, objects_coord) 60 | return tf.cond(pred, fn1, fn2) 61 | 62 | 63 | def random_grayscale(image, probability=0.5): 64 | if probability <= 0: 65 | return image 66 | section = inspect.stack()[0][3] 67 | with tf.name_scope(section): 68 | pred = tf.random_uniform([]) < probability 69 | fn1 = lambda: tf.tile(tf.image.rgb_to_grayscale(image), [1] * (len(image.get_shape()) - 1) + [3]) 70 | fn2 = lambda: image 71 | return tf.cond(pred, fn1, fn2) 72 | -------------------------------------------------------------------------------- /utils/verify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import numpy as np 19 | 20 | 21 | def abs_mean(data): 22 | return np.sum(np.abs(data)) / np.float32(data.size) 23 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import itertools 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import matplotlib.patches as patches 22 | 23 | 24 | def draw_labels(ax, names, width, height, cell_width, cell_height, mask, prob, coords, xy_min, xy_max, areas, rtol=1e-3): 25 | colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))] 26 | plots = [] 27 | for i, (_mask, _prob, _coords, _xy_min, _xy_max, _areas) in enumerate(zip(mask, prob, coords, xy_min, xy_max, areas)): 28 | _mask = _mask.reshape([]) 29 | _coords = _coords.reshape([-1]) 30 | if np.any(_mask) > 0: 31 | index = np.argmax(_prob) 32 | iy = i // cell_width 33 | ix = i % cell_width 34 | plots.append(ax.add_patch(patches.Rectangle((ix * width / cell_width, iy * height / cell_height), width / cell_width, height / cell_height, linewidth=0, facecolor=colors[index], alpha=.2))) 35 | #check coords 36 | offset_x, offset_y, _w_sqrt, _h_sqrt = _coords 37 | cell_x, cell_y = ix + offset_x, iy + offset_y 38 | x, y = cell_x * width / cell_width, cell_y * height / cell_height 39 | _w, _h = _w_sqrt * _w_sqrt, _h_sqrt * _h_sqrt 40 | w, h = _w * width, _h * height 41 | x_min, y_min = x - w / 2, y - h / 2 42 | plots.append(ax.add_patch(patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor=colors[index], facecolor='none'))) 43 | plots.append(ax.annotate(names[index], (x_min, y_min), color=colors[index])) 44 | #check offset_xy_min and xy_max 45 | wh = _xy_max - _xy_min 46 | assert np.all(wh >= 0) 47 | np.testing.assert_allclose(wh / [cell_width, cell_height], [[_w, _h]], rtol=rtol) 48 | np.testing.assert_allclose(_xy_min + wh / 2, [[offset_x, offset_y]], rtol=rtol) 49 | return plots 50 | --------------------------------------------------------------------------------