├── .gitignore
├── LICENSE
├── README.md
├── cache.py
├── config.ini
├── config
├── cache
│ ├── coco.tsv
│ └── voc.tsv
├── names
│ ├── 20
│ └── 80
├── yolo
│ ├── darknet-20.ini
│ ├── darknet-80.ini
│ ├── tiny-20.ini
│ └── tiny-80.ini
└── yolo2
│ ├── anchors
│ ├── coco.tsv
│ └── voc.tsv
│ ├── darknet-20.ini
│ ├── darknet-80.ini
│ ├── tiny-20.ini
│ └── tiny-80.ini
├── demo_data_augmentation.py
├── demo_detect.py
├── detect.py
├── detect_camera.py
├── model
├── __init__.py
├── yolo
│ ├── __init__.py
│ ├── function.py
│ └── inference.py
└── yolo2
│ ├── __init__.py
│ ├── function.py
│ └── inference.py
├── parse_darknet_yolo2.py
├── train.py
└── utils
├── __init__.py
├── data
├── __init__.py
├── cache.py
└── voc.py
├── postprocess.py
├── preprocess.py
├── verify.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 | .project
4 | .pydevproject
5 | .settings/
6 | .idea/
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This project is deprecated. Please see [yolo2-pytorch](https://github.com/ruiminshen/yolo2-pytorch)
2 |
3 | # TensorFlow implementation of the [YOLO (You Only Look Once)](https://arxiv.org/pdf/1506.02640.pdf) and [YOLOv2](https://arxiv.org/pdf/1612.08242.pdf)
4 |
5 | ## Dependencies
6 |
7 | * [Python 3](https://www.python.org/)
8 | * [TensorFlow 1.0](https://www.tensorflow.org/)
9 | * [NumPy](www.numpy.org/)
10 | * [SciPy](https://www.scipy.org/)
11 | * [Pandas](pandas.pydata.org/)
12 | * [Matplotlib](https://matplotlib.org/)
13 | * [BeautifulSoup4](https://www.crummy.com/software/BeautifulSoup/)
14 | * [OpenCV](https://github.com/opencv/opencv)
15 | * [PIL](http://www.pythonware.com/products/pil/)
16 | * [tqdm](https://github.com/tqdm/tqdm)
17 | * [COCO](https://github.com/pdollar/coco) (optional)
18 |
19 | ## Configuration
20 |
21 | Configurations are mainly defined in the "config.ini" file. Such as the detection model (config/model), base directory (config/basedir, which identifies the cache files (.tfrecord), the model data files (.ckpt), and summary data for TensorBoard), and the inference function ([model]/inference). *Notability the configurations can be extended using the "-c" command-line argument*.
22 |
23 | ## Basic Usage
24 |
25 | - Download the [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 2007 ([training, validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) and [test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar)) and 2012 ([training and validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar)) dataset. Extract these tars into one directory (such as "~/Documents/Database/").
26 |
27 | - Download the [COCO](http://mscoco.org/) 2014 ([training](http://msvocds.blob.core.windows.net/coco2014/train2014.zip), [validation](http://msvocds.blob.core.windows.net/coco2014/val2014.zip), and [test](http://msvocds.blob.core.windows.net/coco2014/test2014.zip)) dataset. Extract these zip files into one directory (such as "~/Documents/Database/coco/").
28 |
29 | - Run "cache.py" to create the cache file for the training program. **A verify command-line argument "-v" is strongly recommended to check the training data and drop the corrupted examples**, such as the image "COCO_val2014_000000320612.jpg" in the COCO dataset.
30 |
31 | - Run "train.py" to start the training process (the model data saved previously will be loaded if it exists). Multiple command-line arguments can be defined to control the training process. Such as the batch size, the learning rate, the optimization algorithm and the maximum number of steps.
32 |
33 | - Run "detect.py" to detect objects in an image. Run "export CUDA_VISIBLE_DEVICES=" to avoid out of GPU memory error while the training process is running.
34 |
35 | ## Examples
36 |
37 | ### Training a 20 classes Darknet YOLOv2 model from a pretrained 80 classes model
38 |
39 | - Cache the 20 classes data using the customized config file argument. Cache files (.tfrecord) in "~/Documents/Database/yolo-tf/cache/20" will be created.
40 |
41 | ```
42 | python3 cache.py -c config.ini config/yolo2/darknet-20.ini -v
43 | ```
44 |
45 | - Download a 80 classes Darknet YOLOv2 model (the original file name is "yolo.weights", a [version](https://drive.google.com/drive/folders/0B1tW_VtY7onidEwyQ2FtQVplWEU) from Darkflow is recommanded). In this tutorial I put it in "~/Downloads/yolo.weights".
46 |
47 | - Parse the 80 classes Darknet YOLOv2 model into Tensorflow format (~/Documents/Database/yolo-tf/yolo2/darknet/80/model.ckpt). A warning like "xxx bytes remaining" indicates the file "yolo.weights" is not compatiable with the original Darknet YOLOv2 model (defined in the function `model.yolo2.inference.darknet`). **Make sure the 80 classes data is cached before parsing**.
48 |
49 | ```
50 | python3 parse_darknet_yolo2.py ~/Downloads/yolo.weights -c config.ini config/yolo2/darknet-80.ini -d
51 | ```
52 |
53 | - Transferring the 80 classes Darknet YOLOv2 model into a 20 classes model (~/Documents/Database/yolo-tf/yolo2/darknet/20) except the final convolutional layer. **Be ware the "-d" command-line argument will delete the model files and should be used only once when initializing the model**.
54 |
55 | ```
56 | python3 train.py -c config.ini config/yolo2/darknet-20.ini -t ~/Documents/Database/yolo-tf/yolo2/darknet/80/model.ckpt -e yolo2_darknet/conv -d
57 | ```
58 |
59 | - Using the following command in another terminal and opening the address "localhost:6006" in a web browser to monitor the training process.
60 |
61 | ```
62 | tensorboard --logdir ~/Documents/Database/yolo-tf/yolo2/darknet/20
63 | ```
64 |
65 | - If you think your model is stabilized, press Ctrl+C to cancel and restart the training with a greater batch size.
66 |
67 | ```
68 | python3 train.py -c config.ini config/yolo2/darknet-20.ini -b 16
69 | ```
70 |
71 | - Detect objects from an image file.
72 |
73 | ```
74 | python3 detect.py $IMAGE_FILE -c config.ini config/yolo2/darknet-20.ini
75 | ```
76 |
77 | - Detect objects with a camera.
78 |
79 | ```
80 | python3 detect_camera.py -c config.ini config/yolo2/darknet-20.ini
81 | ```
82 |
83 | ## Checklist
84 |
85 | - [x] Batch normalization
86 | - [x] Passthrough layer
87 | - [ ] Multi-scale training
88 | - [ ] Dimension cluster
89 | - [x] Extendable configuration (via "-c" command-line argument)
90 | - [x] PASCAL VOC dataset supporting
91 | - [x] MS COCO dataset supporting
92 | - [x] Data augmentation: random crop
93 | - [x] Data augmentation: random flip horizontally
94 | - [x] Multi-thread data batch queue
95 | - [x] Darknet model file (.weights) parser
96 | - [x] Partial model transferring before training
97 | - [x] Detection from image
98 | - [x] Detection from camera
99 | - [ ] Multi-GPU supporting
100 | - [ ] Faster NMS using C/C++ or GPU
101 | - [ ] Performance evaluation
102 |
103 | ## License
104 |
105 | This project is released as the open source software with the GNU Lesser General Public License version 3 ([LGPL v3](http://www.gnu.org/licenses/lgpl-3.0.html)).
106 |
107 | # Acknowledgements
108 |
109 | This project is mainly inspired by the following projects:
110 |
111 | * [YOLO (Darknet)](https://pjreddie.com/darknet/yolo/).
112 | * [Darkflow](https://github.com/thtrieu/darkflow).
113 |
--------------------------------------------------------------------------------
/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import shutil
22 | import importlib
23 | import pandas as pd
24 | import tensorflow as tf
25 | import utils
26 |
27 |
28 | def main():
29 | cachedir = utils.get_cachedir(config)
30 | os.makedirs(cachedir, exist_ok=True)
31 | path = os.path.join(cachedir, 'names')
32 | shutil.copyfile(os.path.expanduser(os.path.expandvars(config.get('cache', 'names'))), path)
33 | with open(path, 'r') as f:
34 | names = [line.strip() for line in f]
35 | name_index = dict([(name, i) for i, name in enumerate(names)])
36 | datasets = [(os.path.basename(os.path.splitext(path)[0]), pd.read_csv(os.path.expanduser(os.path.expandvars(path)), sep='\t')) for path in config.get('cache', 'datasets').split(':')]
37 | module = importlib.import_module('utils.data.cache')
38 | for profile in args.profile:
39 | path = os.path.join(cachedir, profile + '.tfrecord')
40 | tf.logging.info('write tfrecords file: ' + path)
41 | with tf.python_io.TFRecordWriter(path) as writer:
42 | for name, dataset in datasets:
43 | tf.logging.info('loading %s %s dataset' % (name, profile))
44 | func = getattr(module, name)
45 | for i, row in dataset.iterrows():
46 | tf.logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()])))
47 | func(writer, name_index, profile, row, args.verify)
48 | tf.logging.info('%s data are saved into %s' % (str(args.profile), cachedir))
49 |
50 |
51 | def make_args():
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
54 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val', 'test'])
55 | parser.add_argument('-v', '--verify', action='store_true')
56 | parser.add_argument('--level', default='info', help='logging level')
57 | return parser.parse_args()
58 |
59 | if __name__ == '__main__':
60 | args = make_args()
61 | config = configparser.ConfigParser()
62 | utils.load_config(config, args.config)
63 | if args.level:
64 | tf.logging.set_verbosity(args.level.upper())
65 | with tf.Session() as sess:
66 | main()
67 |
--------------------------------------------------------------------------------
/config.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | ; yolo yolo2
3 | model = yolo2
4 | basedir = ~/Documents/Database/yolo-tf
5 |
6 | [queue]
7 | capacity = 320
8 | min_after_dequeue=160
9 |
10 | [cache]
11 | names = config/names/80
12 | datasets = config/cache/coco.tsv:config/cache/voc.tsv
13 |
14 | [data_augmentation_full]
15 | enable = 1
16 | enable_probability = 0.5
17 | random_crop = 0.9
18 |
19 | [data_augmentation_resized]
20 | enable = 1
21 | enable_probability = 0.5
22 | random_flip_horizontally = 1
23 | random_brightness = 1
24 | random_contrast = 1
25 | random_saturation = 1
26 | random_hue = 1
27 | noise = 1
28 | grayscale_probability = 0.05
29 |
30 | [exponential_decay]
31 | decay_steps = 100000
32 | decay_rate = 0.96
33 | staircase = 1
34 |
35 | [optimizer_adam]
36 | beta1 = 0.9
37 | beta2 = 0.999
38 | epsilon = 1e-8
39 |
40 | [optimizer_adadelta]
41 | rho = 0.95
42 | epsilon = 1e-8
43 |
44 | [optimizer_adagrad]
45 | initial_accumulator_value = 0.1
46 |
47 | [optimizer_momentum]
48 | momentum = 0.9
49 |
50 | [optimizer_rmsprop]
51 | decay = 0.9
52 | momentum = 0
53 | epsilon = 1e-10
54 |
55 | [optimizer_ftrl]
56 | learning_rate_power = -0.5
57 | initial_accumulator_value = 0.1
58 | l1_regularization_strength = 0
59 | l2_regularization_strength = 0
60 |
61 | [summary]
62 | ; (total_loss\/objectives\/(iou_best|iou_normal|coords|prob)|total_loss)$
63 | scalar = (total_loss\/objectives\/(iou_best|iou_normal|coords|prob)|total_loss)$
64 | scalar_reduce = tf.reduce_mean
65 |
66 | ; [_\w\d]+\/(input|conv\d*\/(convolution|leaky_relu\/data))$
67 | ; [_\w\d]+\/(passthrough|reorg)$
68 | image_ = [_\w\d]+\/(input|conv\d*\/(convolution|leaky_relu\/data))$
69 | image_max = 1
70 |
71 | ; [_\w\d]+\/(conv|fc)\d*\/(weights|biases)$
72 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/(gamma|beta)$
73 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/moments\/normalize\/(mean|variance)$
74 | ; [_\w\d]+\/(conv|fc)\d*\/BatchNorm\/(moving_mean|moving_variance)$
75 | ; [_\w\d]+\/(conv|fc)\d*\/(convolution|leaky_relu\/data)$
76 | ; [_\w\d]+\/(input|conv0\/convolution)$
77 | histogram_ = [_\w\d]+\/(input|conv0\/convolution)$
78 | gradients = 0
79 |
80 | [yolo]
81 | inference = tiny
82 | width = 448
83 | height = 448
84 | boxes_per_cell = 2
85 |
86 | [yolo_hparam]
87 | prob = 1
88 | iou_best = 1
89 | iou_normal = .5
90 | coords = 5
91 |
92 | [yolo2]
93 | inference = darknet
94 | width = 416
95 | height = 416
96 | anchors = config/yolo2/anchors/coco.tsv
97 |
98 | [yolo2_hparam]
99 | prob = 1
100 | iou_best = 5
101 | iou_normal = 1
102 | coords = 1
103 |
--------------------------------------------------------------------------------
/config/cache/coco.tsv:
--------------------------------------------------------------------------------
1 | root year
2 | ~/Documents/Database/coco 2014
3 |
--------------------------------------------------------------------------------
/config/cache/voc.tsv:
--------------------------------------------------------------------------------
1 | root
2 | ~/Documents/Database/VOCdevkit/VOC2007
3 | ~/Documents/Database/VOCdevkit/VOC2012
4 |
--------------------------------------------------------------------------------
/config/names/20:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | boat
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | diningtable
12 | dog
13 | horse
14 | motorbike
15 | person
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | tvmonitor
--------------------------------------------------------------------------------
/config/names/80:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
--------------------------------------------------------------------------------
/config/yolo/darknet-20.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo
3 |
4 | [cache]
5 | names = config/names/20
6 | datasets = config/cache/voc.tsv
7 |
8 | [yolo]
9 | inference = darknet
10 | width = 448
11 | height = 448
12 | boxes_per_cell = 2
13 | hparam = 5e-4
14 |
15 | [yolo_hparam]
16 | prob = 1
17 | iou_best = 1
18 | iou_normal = .5
19 | coords = 5
20 |
--------------------------------------------------------------------------------
/config/yolo/darknet-80.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo
3 |
4 | [cache]
5 | names = config/names/80
6 | datasets = config/cache/coco.tsv
7 |
8 | [yolo]
9 | inference = darknet
10 | width = 448
11 | height = 448
12 | boxes_per_cell = 2
13 | hparam = 5e-4
14 |
15 | [yolo_hparam]
16 | prob = 1
17 | iou_best = 1
18 | iou_normal = .5
19 | coords = 5
20 |
--------------------------------------------------------------------------------
/config/yolo/tiny-20.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo
3 |
4 | [cache]
5 | names = config/names/20
6 | datasets = config/cache/voc.tsv
7 |
8 | [yolo]
9 | inference = tiny
10 | width = 448
11 | height = 448
12 | boxes_per_cell = 2
13 | hparam = 5e-4
14 |
15 | [yolo_hparam]
16 | prob = 1
17 | iou_best = 1
18 | iou_normal = .5
19 | coords = 5
20 |
--------------------------------------------------------------------------------
/config/yolo/tiny-80.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo
3 |
4 | [cache]
5 | names = config/names/80
6 | datasets = config/cache/coco.tsv
7 |
8 | [yolo]
9 | inference = tiny
10 | width = 448
11 | height = 448
12 | boxes_per_cell = 2
13 | hparam = 5e-4
14 |
15 | [yolo_hparam]
16 | prob = 1
17 | iou_best = 1
18 | iou_normal = .5
19 | coords = 5
20 |
--------------------------------------------------------------------------------
/config/yolo2/anchors/coco.tsv:
--------------------------------------------------------------------------------
1 | w h
2 | 0.738768 0.874946
3 | 2.42204 2.65704
4 | 4.30971 7.04493
5 | 10.246 4.59428
6 | 12.6868 11.8741
7 |
--------------------------------------------------------------------------------
/config/yolo2/anchors/voc.tsv:
--------------------------------------------------------------------------------
1 | w h
2 | 1.08 1.19
3 | 3.42 4.41
4 | 6.63 11.38
5 | 9.42 5.11
6 | 16.62 10.52
7 |
--------------------------------------------------------------------------------
/config/yolo2/darknet-20.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo2
3 |
4 | [cache]
5 | names = config/names/20
6 |
7 | [yolo2]
8 | inference = darknet
9 | width = 416
10 | height = 416
11 | anchors = config/yolo2/anchors/voc.tsv
12 |
13 | [yolo2_hparam]
14 | prob = 1
15 | iou_best = 5
16 | iou_normal = 1
17 | coords = 1
18 |
--------------------------------------------------------------------------------
/config/yolo2/darknet-80.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo2
3 |
4 | [cache]
5 | names = config/names/80
6 |
7 | [yolo2]
8 | inference = darknet
9 | width = 416
10 | height = 416
11 | anchors = config/yolo2/anchors/coco.tsv
12 |
13 | [yolo2_hparam]
14 | prob = 1
15 | iou_best = 5
16 | iou_normal = 1
17 | coords = 1
18 |
--------------------------------------------------------------------------------
/config/yolo2/tiny-20.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo2
3 |
4 | [cache]
5 | names = config/names/20
6 |
7 | [yolo2]
8 | inference = tiny
9 | width = 416
10 | height = 416
11 | anchors = config/yolo2/anchors/voc.tsv
12 |
13 | [yolo2_hparam]
14 | prob = 1
15 | iou_best = 5
16 | iou_normal = 1
17 | coords = 1
18 |
--------------------------------------------------------------------------------
/config/yolo2/tiny-80.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | model = yolo2
3 |
4 | [cache]
5 | names = config/names/80
6 |
7 | [yolo2]
8 | inference = tiny
9 | width = 416
10 | height = 416
11 | anchors = config/yolo2/anchors/coco.tsv
12 |
13 | [yolo2_hparam]
14 | prob = 1
15 | iou_best = 5
16 | iou_normal = 1
17 | coords = 1
18 |
--------------------------------------------------------------------------------
/demo_data_augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import multiprocessing
22 | import numpy as np
23 | import matplotlib.pyplot as plt
24 | import tensorflow as tf
25 | import utils.data
26 | import utils.visualize
27 |
28 |
29 | def main():
30 | model = config.get('config', 'model')
31 | cachedir = utils.get_cachedir(config)
32 | with open(os.path.join(cachedir, 'names'), 'r') as f:
33 | names = [line.strip() for line in f]
34 | width = config.getint(model, 'width')
35 | height = config.getint(model, 'height')
36 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
37 | tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
38 | batch_size = args.rows * args.cols
39 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
40 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
41 | tf.logging.warn('num_examples=%d' % num_examples)
42 | with tf.Session() as sess:
43 | with tf.name_scope('batch'):
44 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
45 | batch = tf.train.shuffle_batch((tf.cast(image_rgb, tf.uint8),) + labels, batch_size=batch_size,
46 | capacity=config.getint('queue', 'capacity'), min_after_dequeue=config.getint('queue', 'min_after_dequeue'), num_threads=multiprocessing.cpu_count()
47 | )
48 | tf.global_variables_initializer().run()
49 | coord = tf.train.Coordinator()
50 | threads = tf.train.start_queue_runners(sess, coord)
51 | batch_image, batch_labels = sess.run([batch[0], batch[1:]])
52 | coord.request_stop()
53 | coord.join(threads)
54 | batch_image = batch_image.astype(np.uint8)
55 | fig, axes = plt.subplots(args.rows, args.cols)
56 | for b, (ax, image) in enumerate(zip(axes.flat, batch_image)):
57 | ax.imshow(image)
58 | utils.visualize.draw_labels(ax, names, width, height, cell_width, cell_height, *[l[b] for l in batch_labels])
59 | if args.grid:
60 | ax.set_xticks(np.arange(0, width, width / cell_width))
61 | ax.set_yticks(np.arange(0, height, height / cell_height))
62 | ax.grid(which='both')
63 | ax.tick_params(labelbottom='off', labelleft='off')
64 | else:
65 | ax.set_xticks([])
66 | ax.set_yticks([])
67 | fig.tight_layout()
68 | plt.show()
69 |
70 |
71 | def make_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
74 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val'])
75 | parser.add_argument('-g', '--grid', action='store_true')
76 | parser.add_argument('--rows', default=5, type=int)
77 | parser.add_argument('--cols', default=5, type=int)
78 | parser.add_argument('--level', default='info', help='logging level')
79 | return parser.parse_args()
80 |
81 | if __name__ == '__main__':
82 | args = make_args()
83 | config = configparser.ConfigParser()
84 | utils.load_config(config, args.config)
85 | if args.level:
86 | tf.logging.set_verbosity(args.level.upper())
87 | main()
88 |
--------------------------------------------------------------------------------
/demo_detect.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import importlib
22 | import itertools
23 | import numpy as np
24 | import matplotlib.pyplot as plt
25 | import matplotlib.patches as patches
26 | import tensorflow as tf
27 | import tensorflow.contrib.slim as slim
28 | import utils.data
29 | import utils.visualize
30 |
31 |
32 | class Drawer(object):
33 | def __init__(self, sess, names, cell_width, cell_height, image, labels, model, feed_dict):
34 | self.sess = sess
35 | self.names = names
36 | self.cell_width, self.cell_height = cell_width, cell_height
37 | self.image, self.labels = image, labels
38 | self.model = model
39 | self.feed_dict = feed_dict
40 | self.fig = plt.figure()
41 | self.ax = self.fig.gca()
42 | height, width, _ = image.shape
43 | self.scale = [width / self.cell_width, height / self.cell_height]
44 | self.ax.imshow(image)
45 | self.plots = utils.visualize.draw_labels(self.ax, names, width, height, cell_width, cell_height, *labels)
46 | self.ax.set_xticks(np.arange(0, width, width / cell_width))
47 | self.ax.set_yticks(np.arange(0, height, height / cell_height))
48 | self.ax.grid(which='both')
49 | self.ax.tick_params(labelbottom='off', labelleft='off')
50 | self.fig.canvas.mpl_connect('button_press_event', self.onclick)
51 | self.colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))]
52 |
53 | def onclick(self, event):
54 | for p in self.plots:
55 | p.remove()
56 | self.plots = []
57 | height, width, _ = self.image.shape
58 | ix = int(event.xdata * self.cell_width / width)
59 | iy = int(event.ydata * self.cell_height / height)
60 | self.plots.append(self.ax.add_patch(patches.Rectangle((ix * width / self.cell_width, iy * height / self.cell_height), width / self.cell_width, height / self.cell_height, linewidth=0, facecolor='black', alpha=.2)))
61 | index = iy * self.cell_width + ix
62 | prob, iou, xy_min, wh = self.sess.run([self.model.prob[0][index], self.model.iou[0][index], self.model.xy_min[0][index], self.model.wh[0][index]], feed_dict=self.feed_dict)
63 | xy_min = xy_min * self.scale
64 | wh = wh * self.scale
65 | for _prob, _iou, (x, y), (w, h), color in zip(prob, iou, xy_min, wh, self.colors):
66 | index = np.argmax(_prob)
67 | name = self.names[index]
68 | _prob = _prob[index]
69 | _conf = _prob * _iou
70 | linewidth = min(_conf * 10, 3)
71 | self.plots.append(self.ax.add_patch(patches.Rectangle((x, y), w, h, linewidth=linewidth, edgecolor=color, facecolor='none')))
72 | self.plots.append(self.ax.annotate(name + ' (%.1f%%, %.1f%%)' % (_iou * 100, _prob * 100), (x, y), color=color))
73 | self.fig.canvas.draw()
74 |
75 |
76 | def main():
77 | model = config.get('config', 'model')
78 | cachedir = utils.get_cachedir(config)
79 | with open(os.path.join(cachedir, 'names'), 'r') as f:
80 | names = [line.strip() for line in f]
81 | width = config.getint(model, 'width')
82 | height = config.getint(model, 'height')
83 | yolo = importlib.import_module('model.' + model)
84 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
85 | tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
86 | with tf.Session() as sess:
87 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
88 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
89 | tf.logging.warn('num_examples=%d' % num_examples)
90 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
91 | image_std = tf.image.per_image_standardization(image_rgb)
92 | image_rgb = tf.cast(image_rgb, tf.uint8)
93 | ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image')
94 | global_step = tf.contrib.framework.get_or_create_global_step()
95 | builder = yolo.Builder(args, config)
96 | builder(ph_image)
97 | variables_to_restore = slim.get_variables_to_restore()
98 | ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels]
99 | with tf.name_scope('total_loss') as name:
100 | builder.create_objectives(ph_labels)
101 | total_loss = tf.losses.get_total_loss(name=name)
102 | tf.global_variables_initializer().run()
103 | coord = tf.train.Coordinator()
104 | threads = tf.train.start_queue_runners(sess, coord)
105 | _image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels])
106 | coord.request_stop()
107 | coord.join(threads)
108 | feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)])
109 | feed_dict[ph_image] = np.expand_dims(_image_std, 0)
110 | logdir = utils.get_logdir(config)
111 | assert os.path.exists(logdir)
112 | model_path = tf.train.latest_checkpoint(logdir)
113 | tf.logging.info('load ' + model_path)
114 | slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess)
115 | tf.logging.info('global_step=%d' % sess.run(global_step))
116 | tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict))
117 | _ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict)
118 | plt.show()
119 |
120 |
121 | def make_args():
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
124 | parser.add_argument('-p', '--profile', nargs='+', default=['train'])
125 | parser.add_argument('--level', default='info', help='logging level')
126 | return parser.parse_args()
127 |
128 | if __name__ == '__main__':
129 | args = make_args()
130 | config = configparser.ConfigParser()
131 | utils.load_config(config, args.config)
132 | if args.level:
133 | tf.logging.set_verbosity(args.level.upper())
134 | main()
135 |
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import importlib
22 | import itertools
23 | from PIL import Image, ExifTags
24 | import numpy as np
25 | import matplotlib.pyplot as plt
26 | import matplotlib.patches as patches
27 | import tensorflow as tf
28 | import tensorflow.contrib.slim as slim
29 | import utils.preprocess
30 | import utils.postprocess
31 |
32 |
33 | def std(image):
34 | return utils.preprocess.per_image_standardization(image)
35 |
36 |
37 | def darknet(image):
38 | return image / 255.
39 |
40 |
41 | def read_image(path):
42 | image = Image.open(path)
43 | for key in ExifTags.TAGS.keys():
44 | if ExifTags.TAGS[key] == 'Orientation':
45 | break
46 | try:
47 | exif = dict(image._getexif().items())
48 | except AttributeError:
49 | return image
50 | if exif[key] == 3:
51 | image = image.rotate(180, expand=True)
52 | elif exif[key] == 6:
53 | image = image.rotate(270, expand=True)
54 | elif exif[key] == 8:
55 | image = image.rotate(90, expand=True)
56 | return image
57 |
58 |
59 | def detect(sess, model, names, image, path):
60 | preprocess = eval(args.preprocess)
61 | _, height, width, _ = image.get_shape().as_list()
62 | _image = read_image(path)
63 | image_original = np.array(np.uint8(_image))
64 | if len(image_original.shape) == 2:
65 | image_original = np.repeat(np.expand_dims(image_original, -1), 3, 2)
66 | image_height, image_width, _ = image_original.shape
67 | image_std = preprocess(np.array(np.uint8(_image.resize((width, height)))).astype(np.float32))
68 | feed_dict = {image: np.expand_dims(image_std, 0)}
69 | tensors = [model.conf, model.xy_min, model.xy_max]
70 | conf, xy_min, xy_max = sess.run([tf.check_numerics(t, t.op.name) for t in tensors], feed_dict=feed_dict)
71 | boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
72 | scale = [image_width / model.cell_width, image_height / model.cell_height]
73 | fig = plt.figure()
74 | ax = fig.gca()
75 | ax.imshow(image_original)
76 | colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))]
77 | cnt = 0
78 | for _conf, _xy_min, _xy_max in boxes:
79 | index = np.argmax(_conf)
80 | if _conf[index] > args.threshold:
81 | wh = _xy_max - _xy_min
82 | _xy_min = _xy_min * scale
83 | _wh = wh * scale
84 | linewidth = min(_conf[index] * 10, 3)
85 | ax.add_patch(patches.Rectangle(_xy_min, _wh[0], _wh[1], linewidth=linewidth, edgecolor=colors[index], facecolor='none'))
86 | ax.annotate(names[index] + ' (%.1f%%)' % (_conf[index] * 100), _xy_min, color=colors[index])
87 | cnt += 1
88 | fig.canvas.set_window_title('%d objects detected' % cnt)
89 | ax.set_xticks([])
90 | ax.set_yticks([])
91 | return fig
92 |
93 |
94 | def main():
95 | model = config.get('config', 'model')
96 | yolo = importlib.import_module('model.' + model)
97 | width = config.getint(model, 'width')
98 | height = config.getint(model, 'height')
99 | with tf.Session() as sess:
100 | image = tf.placeholder(tf.float32, [1, height, width, 3], name='image')
101 | builder = yolo.Builder(args, config)
102 | builder(image)
103 | global_step = tf.contrib.framework.get_or_create_global_step()
104 | model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
105 | tf.logging.info('load ' + model_path)
106 | slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
107 | tf.logging.info('global_step=%d' % sess.run(global_step))
108 | path = os.path.expanduser(os.path.expandvars(args.path))
109 | if os.path.isfile(path):
110 | detect(sess, builder.model, builder.names, image, path)
111 | plt.show()
112 | else:
113 | for dirpath, _, filenames in os.walk(path):
114 | for filename in filenames:
115 | if os.path.splitext(filename)[-1].lower() in args.exts:
116 | _path = os.path.join(dirpath, filename)
117 | print(_path)
118 | detect(sess, builder.model, builder.names, image, _path)
119 | plt.show()
120 |
121 |
122 | def make_args():
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument('path', help='input image path')
125 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
126 | parser.add_argument('-p', '--preprocess', default='std', help='the preprocess function')
127 | parser.add_argument('-t', '--threshold', type=float, default=0.3)
128 | parser.add_argument('--threshold_iou', type=float, default=0.4, help='IoU threshold')
129 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpg', '.png'])
130 | parser.add_argument('--level', default='info', help='logging level')
131 | return parser.parse_args()
132 |
133 | if __name__ == '__main__':
134 | args = make_args()
135 | config = configparser.ConfigParser()
136 | utils.load_config(config, args.config)
137 | if args.level:
138 | tf.logging.set_verbosity(args.level.upper())
139 | main()
140 |
--------------------------------------------------------------------------------
/detect_camera.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import argparse
19 | import configparser
20 | import importlib
21 | import cv2
22 | import numpy as np
23 | import tensorflow as tf
24 | import tensorflow.contrib.slim as slim
25 | import utils.postprocess
26 |
27 |
28 | def main():
29 | model = config.get('config', 'model')
30 | yolo = importlib.import_module('model.' + model)
31 | width = config.getint(model, 'width')
32 | height = config.getint(model, 'height')
33 | preprocess = getattr(importlib.import_module('detect'), args.preprocess)
34 | with tf.Session() as sess:
35 | ph_image = tf.placeholder(tf.float32, [1, height, width, 3], name='ph_image')
36 | builder = yolo.Builder(args, config)
37 | builder(ph_image)
38 | global_step = tf.contrib.framework.get_or_create_global_step()
39 | model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
40 | tf.logging.info('load ' + model_path)
41 | slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
42 | tf.logging.info('global_step=%d' % sess.run(global_step))
43 | tensors = [builder.model.conf, builder.model.xy_min, builder.model.xy_max]
44 | tensors = [tf.check_numerics(t, t.op.name) for t in tensors]
45 | cap = cv2.VideoCapture(0)
46 | try:
47 | while True:
48 | ret, image_bgr = cap.read()
49 | assert ret
50 | image_height, image_width, _ = image_bgr.shape
51 | scale = [image_width / builder.model.cell_width, image_height / builder.model.cell_height]
52 | image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
53 | image_std = np.expand_dims(preprocess(cv2.resize(image_rgb, (width, height))).astype(np.float32), 0)
54 | feed_dict = {ph_image: image_std}
55 | conf, xy_min, xy_max = sess.run(tensors, feed_dict)
56 | boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
57 | for _conf, _xy_min, _xy_max in boxes:
58 | index = np.argmax(_conf)
59 | if _conf[index] > args.threshold:
60 | _xy_min = (_xy_min * scale).astype(np.int)
61 | _xy_max = (_xy_max * scale).astype(np.int)
62 | cv2.rectangle(image_bgr, tuple(_xy_min), tuple(_xy_max), (255, 0, 255), 3)
63 | cv2.putText(image_bgr, builder.names[index] + ' (%.1f%%)' % (_conf[index] * 100), tuple(_xy_min), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
64 | cv2.imshow('detection', image_bgr)
65 | cv2.waitKey(1)
66 | finally:
67 | cv2.destroyAllWindows()
68 | cap.release()
69 |
70 |
71 | def make_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
74 | parser.add_argument('-p', '--preprocess', default='std', help='the preprocess function')
75 | parser.add_argument('-t', '--threshold', type=float, default=0.3)
76 | parser.add_argument('--threshold_iou', type=float, default=0.4, help='IoU threshold')
77 | parser.add_argument('--level', default='info', help='logging level')
78 | return parser.parse_args()
79 |
80 | if __name__ == '__main__':
81 | args = make_args()
82 | config = configparser.ConfigParser()
83 | utils.load_config(config, args.config)
84 | if args.level:
85 | tf.logging.set_verbosity(args.level.upper())
86 | main()
87 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo-tf/eae65c8071fe5069f5e3bb1e26f19a761b1b68bc/model/__init__.py
--------------------------------------------------------------------------------
/model/yolo/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import configparser
19 | import os
20 | import re
21 | import math
22 | import numpy as np
23 | import pandas as pd
24 | import tensorflow as tf
25 | import utils
26 | from . import inference
27 |
28 |
29 | def calc_cell_xy(cell_height, cell_width, dtype=np.float32):
30 | cell_base = np.zeros([cell_height, cell_width, 2], dtype=dtype)
31 | for y in range(cell_height):
32 | for x in range(cell_width):
33 | cell_base[y, x, :] = [x, y]
34 | return cell_base
35 |
36 |
37 | class Model(object):
38 | def __init__(self, net, scope, classes, boxes_per_cell, training=False):
39 | _, self.cell_height, self.cell_width, _ = tf.get_default_graph().get_tensor_by_name(scope + '/conv:0').get_shape().as_list()
40 | cells = self.cell_height * self.cell_width
41 | with tf.name_scope('regress'):
42 | with tf.name_scope('inputs'):
43 | end = cells * classes
44 | self.prob = tf.reshape(net[:, :end], [-1, cells, 1, classes], name='prob')
45 | inputs_remaining = tf.reshape(net[:, end:], [-1, cells, boxes_per_cell, 5], name='inputs_remaining')
46 | self.iou = tf.identity(inputs_remaining[:, :, :, 0], name='iou')
47 | self.offset_xy = tf.identity(inputs_remaining[:, :, :, 1:3], name='offset_xy')
48 | wh01_sqrt_base = tf.identity(inputs_remaining[:, :, :, 3:], name='wh01_sqrt_base')
49 | wh01 = tf.square(wh01_sqrt_base, name='wh01')
50 | wh01_sqrt = tf.abs(wh01_sqrt_base, name='wh01_sqrt')
51 | self.coords = tf.concat([self.offset_xy, wh01_sqrt], -1, name='coords')
52 | self.wh = tf.identity(wh01 * [self.cell_width, self.cell_height], name='wh')
53 | _wh = self.wh / 2
54 | self.offset_xy_min = tf.identity(self.offset_xy - _wh, name='offset_xy_min')
55 | self.offset_xy_max = tf.identity(self.offset_xy + _wh, name='offset_xy_max')
56 | self.areas = tf.reduce_prod(self.wh, -1, name='areas')
57 | if not training:
58 | with tf.name_scope('detection'):
59 | cell_xy = calc_cell_xy(self.cell_height, self.cell_width).reshape([1, cells, 1, 2])
60 | self.xy = tf.identity(cell_xy + self.offset_xy, name='xy')
61 | self.xy_min = tf.identity(cell_xy + self.offset_xy_min, name='xy_min')
62 | self.xy_max = tf.identity(cell_xy + self.offset_xy_max, name='xy_max')
63 | self.conf = tf.identity(tf.expand_dims(self.iou, -1) * self.prob, name='conf')
64 | self.inputs = net
65 | self.classes = classes
66 | self.boxes_per_cell = boxes_per_cell
67 |
68 |
69 | class Objectives(dict):
70 | def __init__(self, model, mask, prob, coords, offset_xy_min, offset_xy_max, areas):
71 | self.model = model
72 | with tf.name_scope('true'):
73 | self.mask = tf.identity(mask, name='mask')
74 | self.prob = tf.identity(prob, name='prob')
75 | self.coords = tf.identity(coords, name='coords')
76 | self.offset_xy_min = tf.identity(offset_xy_min, name='offset_xy_min')
77 | self.offset_xy_max = tf.identity(offset_xy_max, name='offset_xy_max')
78 | self.areas = tf.identity(areas, name='areas')
79 | with tf.name_scope('iou') as name:
80 | _offset_xy_min = tf.maximum(model.offset_xy_min, self.offset_xy_min, name='_offset_xy_min')
81 | _offset_xy_max = tf.minimum(model.offset_xy_max, self.offset_xy_max, name='_offset_xy_max')
82 | _wh = tf.maximum(_offset_xy_max - _offset_xy_min, 0.0, name='_wh')
83 | _areas = tf.reduce_prod(_wh, -1, name='_areas')
84 | areas = tf.maximum(self.areas + model.areas - _areas, 1e-10, name='areas')
85 | iou = tf.truediv(_areas, areas, name=name)
86 | with tf.name_scope('mask'):
87 | best_box_iou = tf.reduce_max(iou, 2, True, name='best_box_iou')
88 | best_box = tf.to_float(tf.equal(iou, best_box_iou), name='best_box')
89 | mask_best = tf.identity(self.mask * best_box, name='mask_best')
90 | mask_normal = tf.identity(1 - mask_best, name='mask_normal')
91 | with tf.name_scope('dist'):
92 | iou_dist = tf.square(model.iou - mask_best, name='iou_dist')
93 | coords_dist = tf.square(model.coords - self.coords, name='coords_dist')
94 | prob_dist = tf.square(model.prob - self.prob, name='prob_dist')
95 | with tf.name_scope('objectives'):
96 | cnt = np.multiply.reduce(iou_dist.get_shape().as_list())
97 | self['iou_best'] = tf.identity(tf.reduce_sum(mask_best * iou_dist) / cnt, name='iou_best')
98 | self['iou_normal'] = tf.identity(tf.reduce_sum(mask_normal * iou_dist) / cnt, name='iou_normal')
99 | self['coords'] = tf.identity(tf.reduce_sum(tf.expand_dims(mask_best, -1) * coords_dist) / cnt, name='coords')
100 | self['prob'] = tf.identity(tf.reduce_sum(tf.expand_dims(self.mask, -1) * prob_dist) / cnt, name='prob')
101 |
102 |
103 | class Builder(object):
104 | def __init__(self, args, config):
105 | section = __name__.split('.')[-1]
106 | self.args = args
107 | self.config = config
108 | with open(os.path.join(utils.get_cachedir(config), 'names'), 'r') as f:
109 | self.names = [line.strip() for line in f]
110 | self.boxes_per_cell = config.getint(section, 'boxes_per_cell')
111 | self.func = getattr(inference, config.get(section, 'inference'))
112 |
113 | def __call__(self, data, training=False):
114 | _scope, self.output = self.func(data, len(self.names), self.boxes_per_cell, training=training)
115 | with tf.name_scope(__name__.split('.')[-1]):
116 | self.model = Model(self.output, _scope, len(self.names), self.boxes_per_cell)
117 |
118 | def create_objectives(self, labels):
119 | section = __name__.split('.')[-1]
120 | self.objectives = Objectives(self.model, *labels)
121 | with tf.name_scope('weighted_objectives'):
122 | for key in self.objectives:
123 | tf.add_to_collection(tf.GraphKeys.LOSSES, tf.multiply(self.objectives[key], self.config.getfloat(section + '_hparam', key), name='weighted_' + key))
124 |
--------------------------------------------------------------------------------
/model/yolo/function.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def leaky_relu(inputs, alpha=.1):
22 | with tf.name_scope('leaky_relu') as name:
23 | data = tf.identity(inputs, name='data')
24 | return tf.maximum(data, alpha * data, name=name)
25 |
--------------------------------------------------------------------------------
/model/yolo/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 | import tensorflow as tf
20 | import tensorflow.contrib.slim as slim
21 | from .function import leaky_relu
22 |
23 |
24 | def tiny(net, classes, boxes_per_cell, training=False):
25 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3]
26 | net = tf.identity(net, name='%s/input' % scope)
27 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'):
28 | index = 0
29 | net = slim.layers.conv2d(net, 16, scope='%s/conv%d' % (scope, index))
30 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
31 | index += 1
32 | net = slim.layers.conv2d(net, 32, scope='%s/conv%d' % (scope, index))
33 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
34 | index += 1
35 | net = slim.layers.conv2d(net, 64, scope='%s/conv%d' % (scope, index))
36 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
37 | index += 1
38 | net = slim.layers.conv2d(net, 128, scope='%s/conv%d' % (scope, index))
39 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
40 | index += 1
41 | net = slim.layers.conv2d(net, 256, scope='%s/conv%d' % (scope, index))
42 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
43 | index += 1
44 | net = slim.layers.conv2d(net, 512, scope='%s/conv%d' % (scope, index))
45 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
46 | index += 1
47 | net = slim.layers.conv2d(net, 512, scope='%s/conv%d' % (scope, index))
48 | index += 1
49 | net = slim.layers.conv2d(net, 1024, scope='%s/conv%d' % (scope, index))
50 | index += 1
51 | net = slim.layers.conv2d(net, 256, scope='%s/conv%d' % (scope, index))
52 | net = tf.identity(net, name='%s/conv' % scope)
53 | _, cell_height, cell_width, _ = net.get_shape().as_list()
54 | net = slim.layers.flatten(net, scope='%s/flatten' % scope)
55 | with slim.arg_scope([slim.layers.fully_connected], activation_fn=leaky_relu, weights_regularizer=slim.l2_regularizer(0.001)), slim.arg_scope([slim.layers.dropout], keep_prob=.5, is_training=training):
56 | index = 0
57 | net = slim.layers.fully_connected(net, 256, scope='%s/fc%d' % (scope, index))
58 | net = slim.layers.dropout(net, scope='%s/dropout%d' % (scope, index))
59 | index += 1
60 | net = slim.layers.fully_connected(net, 4096, scope='%s/fc%d' % (scope, index))
61 | net = slim.layers.dropout(net, scope='%s/dropout%d' % (scope, index))
62 | net = slim.layers.fully_connected(net, cell_width * cell_height * (classes + boxes_per_cell * 5), activation_fn=None, scope='%s/fc' % scope)
63 | net = tf.identity(net, name='%s/output' % scope)
64 | return scope, net
65 |
66 | TINY_DOWNSAMPLING = (2 ** 6, 2 ** 6)
67 |
--------------------------------------------------------------------------------
/model/yolo2/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import configparser
19 | import os
20 | import numpy as np
21 | import pandas as pd
22 | import tensorflow as tf
23 | import utils
24 | from . import inference
25 | from .. import yolo
26 |
27 |
28 | class Model(object):
29 | def __init__(self, net, classes, anchors, training=False):
30 | _, self.cell_height, self.cell_width, _ = net.get_shape().as_list()
31 | cells = self.cell_height * self.cell_width
32 | inputs = tf.reshape(net, [-1, cells, len(anchors), 5 + classes], name='inputs')
33 | with tf.name_scope('regress'):
34 | with tf.name_scope('inputs'):
35 | with tf.name_scope('inputs_sigmoid') as name:
36 | inputs_sigmoid = tf.nn.sigmoid(inputs[:, :, :, :3], name=name)
37 | self.iou = tf.identity(inputs_sigmoid[:, :, :, 0], name='iou')
38 | self.offset_xy = tf.identity(inputs_sigmoid[:, :, :, 1:3], name='offset_xy')
39 | with tf.name_scope('wh') as name:
40 | self.wh = tf.identity(tf.exp(inputs[:, :, :, 3:5]) * np.reshape(anchors, [1, 1, len(anchors), -1]), name=name)
41 | with tf.name_scope('prob') as name:
42 | self.prob = tf.identity(tf.nn.softmax(inputs[:, :, :, 5:]), name=name)
43 | self.areas = tf.reduce_prod(self.wh, -1, name='areas')
44 | _wh = self.wh / 2
45 | self.offset_xy_min = tf.identity(self.offset_xy - _wh, name='offset_xy_min')
46 | self.offset_xy_max = tf.identity(self.offset_xy + _wh, name='offset_xy_max')
47 | self.wh01 = tf.identity(self.wh / np.reshape([self.cell_width, self.cell_height], [1, 1, 1, 2]), name='wh01')
48 | self.wh01_sqrt = tf.sqrt(self.wh01, name='wh01_sqrt')
49 | self.coords = tf.concat([self.offset_xy, self.wh01_sqrt], -1, name='coords')
50 | if not training:
51 | with tf.name_scope('detection'):
52 | cell_xy = yolo.calc_cell_xy(self.cell_height, self.cell_width).reshape([1, cells, 1, 2])
53 | self.xy = tf.identity(cell_xy + self.offset_xy, name='xy')
54 | self.xy_min = tf.identity(cell_xy + self.offset_xy_min, name='xy_min')
55 | self.xy_max = tf.identity(cell_xy + self.offset_xy_max, name='xy_max')
56 | self.conf = tf.identity(tf.expand_dims(self.iou, -1) * self.prob, name='conf')
57 | self.inputs = net
58 | self.classes = classes
59 | self.anchors = anchors
60 |
61 |
62 | class Objectives(dict):
63 | def __init__(self, model, mask, prob, coords, offset_xy_min, offset_xy_max, areas):
64 | self.model = model
65 | with tf.name_scope('true'):
66 | self.mask = tf.identity(mask, name='mask')
67 | self.prob = tf.identity(prob, name='prob')
68 | self.coords = tf.identity(coords, name='coords')
69 | self.offset_xy_min = tf.identity(offset_xy_min, name='offset_xy_min')
70 | self.offset_xy_max = tf.identity(offset_xy_max, name='offset_xy_max')
71 | self.areas = tf.identity(areas, name='areas')
72 | with tf.name_scope('iou') as name:
73 | _offset_xy_min = tf.maximum(model.offset_xy_min, self.offset_xy_min, name='_offset_xy_min')
74 | _offset_xy_max = tf.minimum(model.offset_xy_max, self.offset_xy_max, name='_offset_xy_max')
75 | _wh = tf.maximum(_offset_xy_max - _offset_xy_min, 0.0, name='_wh')
76 | _areas = tf.reduce_prod(_wh, -1, name='_areas')
77 | areas = tf.maximum(self.areas + model.areas - _areas, 1e-10, name='areas')
78 | iou = tf.truediv(_areas, areas, name=name)
79 | with tf.name_scope('mask'):
80 | best_box_iou = tf.reduce_max(iou, 2, True, name='best_box_iou')
81 | best_box = tf.to_float(tf.equal(iou, best_box_iou), name='best_box')
82 | mask_best = tf.identity(self.mask * best_box, name='mask_best')
83 | mask_normal = tf.identity(1 - mask_best, name='mask_normal')
84 | with tf.name_scope('dist'):
85 | iou_dist = tf.square(model.iou - mask_best, name='iou_dist')
86 | coords_dist = tf.square(model.coords - self.coords, name='coords_dist')
87 | prob_dist = tf.square(model.prob - self.prob, name='prob_dist')
88 | with tf.name_scope('objectives'):
89 | cnt = np.multiply.reduce(iou_dist.get_shape().as_list())
90 | self['iou_best'] = tf.identity(tf.reduce_sum(mask_best * iou_dist) / cnt, name='iou_best')
91 | self['iou_normal'] = tf.identity(tf.reduce_sum(mask_normal * iou_dist) / cnt, name='iou_normal')
92 | _mask_best = tf.expand_dims(mask_best, -1)
93 | self['coords'] = tf.identity(tf.reduce_sum(_mask_best * coords_dist) / cnt, name='coords')
94 | self['prob'] = tf.identity(tf.reduce_sum(_mask_best * prob_dist) / cnt, name='prob')
95 |
96 |
97 | class Builder(yolo.Builder):
98 | def __init__(self, args, config):
99 | section = __name__.split('.')[-1]
100 | self.args = args
101 | self.config = config
102 | with open(os.path.join(utils.get_cachedir(config), 'names'), 'r') as f:
103 | self.names = [line.strip() for line in f]
104 | self.width = config.getint(section, 'width')
105 | self.height = config.getint(section, 'height')
106 | self.anchors = pd.read_csv(os.path.expanduser(os.path.expandvars(config.get(section, 'anchors'))), sep='\t').values
107 | self.func = getattr(inference, config.get(section, 'inference'))
108 |
109 | def __call__(self, data, training=False):
110 | _, self.output = self.func(data, len(self.names), len(self.anchors), training=training)
111 | with tf.name_scope(__name__.split('.')[-1]):
112 | self.model = Model(self.output, len(self.names), self.anchors, training=training)
113 |
114 | def create_objectives(self, labels):
115 | section = __name__.split('.')[-1]
116 | self.objectives = Objectives(self.model, *labels)
117 | with tf.name_scope('weighted_objectives'):
118 | for key in self.objectives:
119 | tf.add_to_collection(tf.GraphKeys.LOSSES, tf.multiply(self.objectives[key], self.config.getfloat(section + '_hparam', key), name='weighted_' + key))
120 |
--------------------------------------------------------------------------------
/model/yolo2/function.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 |
22 | def reorg(net, stride=2, name='reorg'):
23 | batch_size, height, width, channels = net.get_shape().as_list()
24 | _height, _width, _channel = height // stride, width // stride, channels * stride * stride
25 | with tf.name_scope(name) as name:
26 | net = tf.reshape(net, [batch_size, _height, stride, _width, stride, channels])
27 | net = tf.transpose(net, [0, 1, 3, 2, 4, 5]) # batch_size, _height, _width, stride, stride, channels
28 | net = tf.reshape(net, [batch_size, _height, _width, -1], name=name)
29 | return net
30 |
31 |
32 | def main():
33 | image = [
34 | (0, 1, 0, 1),
35 | (2, 3, 2, 3),
36 | (0, 1, 0, 1),
37 | (2, 3, 2, 3),
38 | ]
39 | image = np.expand_dims(image, 0)
40 | image = np.expand_dims(image, -1)
41 | with tf.Session() as sess:
42 | ph_image = tf.placeholder(tf.uint8, image.shape)
43 | images = sess.run(reorg(ph_image), feed_dict={ph_image: image})
44 | for i, image in enumerate(np.transpose(images[0], [2, 0, 1])):
45 | data = np.unique(image)
46 | assert len(data) == 1
47 | assert data[0] == i
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/model/yolo2/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 | import tensorflow as tf
20 | import tensorflow.contrib.slim as slim
21 | from ..yolo.function import leaky_relu
22 | from .function import reorg
23 |
24 |
25 | def tiny(net, classes, num_anchors, training=False, center=True):
26 | def batch_norm(net):
27 | net = slim.batch_norm(net, center=center, scale=True, epsilon=1e-5, is_training=training)
28 | if not center:
29 | net = tf.nn.bias_add(net, slim.variable('biases', shape=[tf.shape(net)[-1]], initializer=tf.zeros_initializer()))
30 | return net
31 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3]
32 | net = tf.identity(net, name='%s/input' % scope)
33 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), normalizer_fn=batch_norm, activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'):
34 | index = 0
35 | channels = 16
36 | for _ in range(5):
37 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
38 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
39 | index += 1
40 | channels *= 2
41 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
42 | net = slim.layers.max_pool2d(net, stride=1, scope='%s/max_pool%d' % (scope, index))
43 | index += 1
44 | channels *= 2
45 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
46 | index += 1
47 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
48 | net = slim.layers.conv2d(net, num_anchors * (5 + classes), kernel_size=[1, 1], activation_fn=None, scope='%s/conv' % scope)
49 | net = tf.identity(net, name='%s/output' % scope)
50 | return scope, net
51 |
52 | TINY_DOWNSAMPLING = (2 ** 5, 2 ** 5)
53 |
54 |
55 | def _tiny(net, classes, num_anchors, training=False):
56 | return tiny(net, classes, num_anchors, training, False)
57 |
58 | _TINY_DOWNSAMPLING = (2 ** 5, 2 ** 5)
59 |
60 |
61 | def darknet(net, classes, num_anchors, training=False, center=True):
62 | def batch_norm(net):
63 | net = slim.batch_norm(net, center=center, scale=True, epsilon=1e-5, is_training=training)
64 | if not center:
65 | net = tf.nn.bias_add(net, slim.variable('biases', shape=[tf.shape(net)[-1]], initializer=tf.zeros_initializer()))
66 | return net
67 | scope = __name__.split('.')[-2] + '_' + inspect.stack()[0][3]
68 | net = tf.identity(net, name='%s/input' % scope)
69 | with slim.arg_scope([slim.layers.conv2d], kernel_size=[3, 3], normalizer_fn=batch_norm, activation_fn=leaky_relu), slim.arg_scope([slim.layers.max_pool2d], kernel_size=[2, 2], padding='SAME'):
70 | index = 0
71 | channels = 32
72 | for _ in range(2):
73 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
74 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
75 | index += 1
76 | channels *= 2
77 | for _ in range(2):
78 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
79 | index += 1
80 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index))
81 | index += 1
82 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
83 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
84 | index += 1
85 | channels *= 2
86 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
87 | index += 1
88 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index))
89 | index += 1
90 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
91 | index += 1
92 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index))
93 | index += 1
94 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
95 | passthrough = tf.identity(net, name=scope + '/passthrough')
96 | net = slim.layers.max_pool2d(net, scope='%s/max_pool%d' % (scope, index))
97 | index += 1
98 | channels *= 2
99 | # downsampling finished
100 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
101 | index += 1
102 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index))
103 | index += 1
104 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
105 | index += 1
106 | net = slim.layers.conv2d(net, channels / 2, kernel_size=[1, 1], scope='%s/conv%d' % (scope, index))
107 | index += 1
108 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
109 | index += 1
110 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
111 | index += 1
112 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
113 | index += 1
114 | with tf.name_scope(scope):
115 | _net = reorg(passthrough)
116 | net = tf.concat([_net, net], 3, name='%s/concat%d' % (scope, index))
117 | net = slim.layers.conv2d(net, channels, scope='%s/conv%d' % (scope, index))
118 | net = slim.layers.conv2d(net, num_anchors * (5 + classes), kernel_size=[1, 1], activation_fn=None, scope='%s/conv' % scope)
119 | net = tf.identity(net, name='%s/output' % scope)
120 | return scope, net
121 |
122 | DARKNET_DOWNSAMPLING = (2 ** 5, 2 ** 5)
123 |
124 |
125 | def _darknet(net, classes, num_anchors, training=False):
126 | return darknet(net, classes, num_anchors, training, False)
127 |
128 | _DARKNET_DOWNSAMPLING = (2 ** 5, 2 ** 5)
129 |
--------------------------------------------------------------------------------
/parse_darknet_yolo2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import re
20 | import time
21 | import shutil
22 | import argparse
23 | import configparser
24 | import operator
25 | import itertools
26 | import struct
27 | import numpy as np
28 | import pandas as pd
29 | import tensorflow as tf
30 | import model.yolo2.inference as inference
31 | import utils
32 |
33 |
34 | def transpose_weights(weights, num_anchors):
35 | ksize1, ksize2, channels_in, _ = weights.shape
36 | weights = weights.reshape([ksize1, ksize2, channels_in, num_anchors, -1])
37 | coords = weights[:, :, :, :, 0:4]
38 | iou = np.expand_dims(weights[:, :, :, :, 4], -1)
39 | classes = weights[:, :, :, :, 5:]
40 | return np.concatenate([iou, coords, classes], -1).reshape([ksize1, ksize2, channels_in, -1])
41 |
42 |
43 | def transpose_biases(biases, num_anchors):
44 | biases = biases.reshape([num_anchors, -1])
45 | coords = biases[:, 0:4]
46 | iou = np.expand_dims(biases[:, 4], -1)
47 | classes = biases[:, 5:]
48 | return np.concatenate([iou, coords, classes], -1).reshape([-1])
49 |
50 |
51 | def transpose(sess, layer, num_anchors):
52 | v = next(filter(lambda v: v.op.name.endswith('weights'), layer))
53 | sess.run(v.assign(transpose_weights(sess.run(v), num_anchors)))
54 | v = next(filter(lambda v: v.op.name.endswith('biases'), layer))
55 | sess.run(v.assign(transpose_biases(sess.run(v), num_anchors)))
56 |
57 |
58 | def main():
59 | model = config.get('config', 'model')
60 | cachedir = utils.get_cachedir(config)
61 | with open(os.path.join(cachedir, 'names'), 'r') as f:
62 | names = [line.strip() for line in f]
63 | width, height = np.array(utils.get_downsampling(config)) * 13
64 | anchors = pd.read_csv(os.path.expanduser(os.path.expandvars(config.get(model, 'anchors'))), sep='\t').values
65 | func = getattr(inference, config.get(model, 'inference'))
66 | with tf.Session() as sess:
67 | image = tf.placeholder(tf.float32, [1, height, width, 3], name='image')
68 | func(image, len(names), len(anchors))
69 | tf.contrib.framework.get_or_create_global_step()
70 | tf.global_variables_initializer().run()
71 | prog = re.compile(r'[_\w\d]+\/conv(\d*)\/(weights|biases|(BatchNorm\/(gamma|beta|moving_mean|moving_variance)))$')
72 | variables = [(prog.match(v.op.name).group(1), v) for v in tf.global_variables() if prog.match(v.op.name)]
73 | variables = sorted([[int(k) if k else -1, [v for _, v in g]] for k, g in itertools.groupby(variables, operator.itemgetter(0))], key=operator.itemgetter(0))
74 | assert variables[0][0] == -1
75 | variables[0][0] = len(variables) - 1
76 | variables.insert(len(variables), variables.pop(0))
77 | with tf.name_scope('assign'):
78 | with open(os.path.expanduser(os.path.expandvars(args.file)), 'rb') as f:
79 | major, minor, revision, seen = struct.unpack('4i', f.read(16))
80 | tf.logging.info('major=%d, minor=%d, revision=%d, seen=%d' % (major, minor, revision, seen))
81 | for i, layer in variables:
82 | tf.logging.info('processing layer %d' % i)
83 | total = 0
84 | for suffix in ['biases', 'beta', 'gamma', 'moving_mean', 'moving_variance', 'weights']:
85 | try:
86 | v = next(filter(lambda v: v.op.name.endswith(suffix), layer))
87 | except StopIteration:
88 | continue
89 | shape = v.get_shape().as_list()
90 | cnt = np.multiply.reduce(shape)
91 | total += cnt
92 | tf.logging.info('%s: %s=%d' % (v.op.name, str(shape), cnt))
93 | p = struct.unpack('%df' % cnt, f.read(4 * cnt))
94 | if suffix == 'weights':
95 | ksize1, ksize2, channels_in, channels_out = shape
96 | p = np.reshape(p, [channels_out, channels_in, ksize1, ksize2]) # Darknet format
97 | p = np.transpose(p, [2, 3, 1, 0]) # TensorFlow format (ksize1, ksize2, channels_in, channels_out)
98 | sess.run(v.assign(p))
99 | tf.logging.info('%d parameters assigned' % total)
100 | remaining = os.fstat(f.fileno()).st_size - f.tell()
101 | transpose(sess, layer, len(anchors))
102 | saver = tf.train.Saver()
103 | logdir = utils.get_logdir(config)
104 | if args.delete:
105 | tf.logging.warn('delete logging directory: ' + logdir)
106 | shutil.rmtree(logdir, ignore_errors=True)
107 | os.makedirs(logdir, exist_ok=True)
108 | model_path = os.path.join(logdir, 'model.ckpt')
109 | tf.logging.info('save model into ' + model_path)
110 | saver.save(sess, model_path)
111 | if args.summary:
112 | path = os.path.join(logdir, args.logname)
113 | summary_writer = tf.summary.FileWriter(path)
114 | summary_writer.add_graph(sess.graph)
115 | tf.logging.info('tensorboard --logdir ' + logdir)
116 | if remaining > 0:
117 | tf.logging.warn('%d bytes remaining' % remaining)
118 |
119 |
120 | def make_args():
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument('file', help='Darknet .weights file')
123 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
124 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir')
125 | parser.add_argument('-s', '--summary', action='store_true')
126 | parser.add_argument('--logname', default=time.strftime('%Y-%m-%d_%H-%M-%S'), help='the name of TensorBoard log')
127 | parser.add_argument('--level', default='info', help='logging level')
128 | return parser.parse_args()
129 |
130 | if __name__ == '__main__':
131 | args = make_args()
132 | config = configparser.ConfigParser()
133 | utils.load_config(config, args.config)
134 | if args.level:
135 | tf.logging.set_verbosity(args.level.upper())
136 | main()
137 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import importlib
22 | import shutil
23 | import time
24 | import inspect
25 | import multiprocessing
26 | import tensorflow as tf
27 | import tensorflow.contrib.slim as slim
28 | import utils.data
29 |
30 |
31 | def summary_scalar(config):
32 | try:
33 | reduce = eval(config.get('summary', 'scalar_reduce'))
34 | for t in utils.match_tensor(config.get('summary', 'scalar')):
35 | name = t.op.name
36 | if len(t.get_shape()) > 0:
37 | t = reduce(t)
38 | tf.logging.warn(name + ' is not a scalar tensor, reducing by ' + reduce.__name__)
39 | tf.summary.scalar(name, t)
40 | except (configparser.NoSectionError, configparser.NoOptionError):
41 | tf.logging.warn(inspect.stack()[0][3] + ' disabled')
42 |
43 |
44 | def summary_image(config):
45 | try:
46 | for t in utils.match_tensor(config.get('summary', 'image')):
47 | name = t.op.name
48 | channels = t.get_shape()[-1].value
49 | if channels not in (1, 3, 4):
50 | t = tf.expand_dims(tf.reduce_sum(t, -1), -1)
51 | tf.summary.image(name, t, config.getint('summary', 'image_max'))
52 | except (configparser.NoSectionError, configparser.NoOptionError):
53 | tf.logging.warn(inspect.stack()[0][3] + ' disabled')
54 |
55 |
56 | def summary_histogram(config):
57 | try:
58 | for t in utils.match_tensor(config.get('summary', 'histogram')):
59 | tf.summary.histogram(t.op.name, t)
60 | except (configparser.NoSectionError, configparser.NoOptionError):
61 | tf.logging.warn(inspect.stack()[0][3] + ' disabled')
62 |
63 |
64 | def summary(config):
65 | summary_scalar(config)
66 | summary_image(config)
67 | summary_histogram(config)
68 |
69 |
70 | def get_optimizer(config, name):
71 | section = 'optimizer_' + name
72 | return {
73 | 'adam': lambda learning_rate: tf.train.AdamOptimizer(learning_rate, config.getfloat(section, 'beta1'), config.getfloat(section, 'beta2'), config.getfloat(section, 'epsilon')),
74 | 'adadelta': lambda learning_rate: tf.train.AdadeltaOptimizer(learning_rate, config.getfloat(section, 'rho'), config.getfloat(section, 'epsilon')),
75 | 'adagrad': lambda learning_rate: tf.train.AdagradOptimizer(learning_rate, config.getfloat(section, 'initial_accumulator_value')),
76 | 'momentum': lambda learning_rate: tf.train.MomentumOptimizer(learning_rate, config.getfloat(section, 'momentum')),
77 | 'rmsprop': lambda learning_rate: tf.train.RMSPropOptimizer(learning_rate, config.getfloat(section, 'decay'), config.getfloat(section, 'momentum'), config.getfloat(section, 'epsilon')),
78 | 'ftrl': lambda learning_rate: tf.train.FtrlOptimizer(learning_rate, config.getfloat(section, 'learning_rate_power'), config.getfloat(section, 'initial_accumulator_value'), config.getfloat(section, 'l1_regularization_strength'), config.getfloat(section, 'l2_regularization_strength')),
79 | 'gd': lambda learning_rate: tf.train.GradientDescentOptimizer(learning_rate),
80 | }[name]
81 |
82 |
83 | def main():
84 | model = config.get('config', 'model')
85 | logdir = utils.get_logdir(config)
86 | if args.delete:
87 | tf.logging.warn('delete logging directory: ' + logdir)
88 | shutil.rmtree(logdir, ignore_errors=True)
89 | cachedir = utils.get_cachedir(config)
90 | with open(os.path.join(cachedir, 'names'), 'r') as f:
91 | names = [line.strip() for line in f]
92 | width = config.getint(model, 'width')
93 | height = config.getint(model, 'height')
94 | cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
95 | tf.logging.warn('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
96 | yolo = importlib.import_module('model.' + model)
97 | paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
98 | num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
99 | tf.logging.warn('num_examples=%d' % num_examples)
100 | with tf.name_scope('batch'):
101 | image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
102 | with tf.name_scope('per_image_standardization'):
103 | image_std = tf.image.per_image_standardization(image_rgb)
104 | batch = tf.train.shuffle_batch((image_std,) + labels, batch_size=args.batch_size,
105 | capacity=config.getint('queue', 'capacity'), min_after_dequeue=config.getint('queue', 'min_after_dequeue'),
106 | num_threads=multiprocessing.cpu_count()
107 | )
108 | global_step = tf.contrib.framework.get_or_create_global_step()
109 | builder = yolo.Builder(args, config)
110 | builder(batch[0], training=True)
111 | with tf.name_scope('total_loss') as name:
112 | builder.create_objectives(batch[1:])
113 | total_loss = tf.losses.get_total_loss(name=name)
114 | variables_to_restore = slim.get_variables_to_restore(exclude=args.exclude)
115 | with tf.name_scope('optimizer'):
116 | try:
117 | decay_steps = config.getint('exponential_decay', 'decay_steps')
118 | decay_rate = config.getfloat('exponential_decay', 'decay_rate')
119 | staircase = config.getboolean('exponential_decay', 'staircase')
120 | learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps, decay_rate, staircase=staircase)
121 | tf.logging.warn('using a learning rate start from %f with exponential decay (decay_steps=%d, decay_rate=%f, staircase=%d)' % (args.learning_rate, decay_steps, decay_rate, staircase))
122 | except (configparser.NoSectionError, configparser.NoOptionError):
123 | learning_rate = args.learning_rate
124 | tf.logging.warn('using a staionary learning rate %f' % args.learning_rate)
125 | optimizer = get_optimizer(config, args.optimizer)(learning_rate)
126 | tf.logging.warn('optimizer=' + args.optimizer)
127 | train_op = slim.learning.create_train_op(total_loss, optimizer, global_step,
128 | clip_gradient_norm=args.gradient_clip, summarize_gradients=config.getboolean('summary', 'gradients'),
129 | )
130 | if args.transfer:
131 | path = os.path.expanduser(os.path.expandvars(args.transfer))
132 | tf.logging.warn('transferring from ' + path)
133 | init_assign_op, init_feed_dict = slim.assign_from_checkpoint(path, variables_to_restore)
134 | def init_fn(sess):
135 | sess.run(init_assign_op, init_feed_dict)
136 | tf.logging.warn('transferring from global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate)))
137 | else:
138 | init_fn = lambda sess: tf.logging.warn('global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate)))
139 | summary(config)
140 | tf.logging.warn('tensorboard --logdir ' + logdir)
141 | slim.learning.train(train_op, logdir, master=args.master, is_chief=(args.task == 0),
142 | global_step=global_step, number_of_steps=args.steps, init_fn=init_fn,
143 | summary_writer=tf.summary.FileWriter(os.path.join(logdir, args.logname)),
144 | save_summaries_secs=args.summary_secs, save_interval_secs=args.save_secs
145 | )
146 |
147 |
148 | def make_args():
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
151 | parser.add_argument('-t', '--transfer', help='transferring model from a .ckpt file')
152 | parser.add_argument('-e', '--exclude', nargs='+', help='exclude variables while transferring')
153 | parser.add_argument('-p', '--profile', nargs='+', default=['train', 'val'])
154 | parser.add_argument('-s', '--steps', type=int, default=None, help='max number of steps')
155 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir')
156 | parser.add_argument('-b', '--batch_size', default=8, type=int, help='batch size')
157 | parser.add_argument('-o', '--optimizer', default='adam')
158 | parser.add_argument('-n', '--logname', default=time.strftime('%Y-%m-%d_%H-%M-%S'), help='the name for TensorBoard')
159 | parser.add_argument('-g', '--gradient_clip', default=0, type=float, help='gradient clip')
160 | parser.add_argument('-lr', '--learning_rate', default=1e-6, type=float, help='learning rate')
161 | parser.add_argument('--seed', type=int, default=None)
162 | parser.add_argument('--summary_secs', default=30, type=int, help='seconds to save summaries')
163 | parser.add_argument('--save_secs', default=600, type=int, help='seconds to save model')
164 | parser.add_argument('--level', help='logging level')
165 | parser.add_argument('--master', default='', help='master address')
166 | parser.add_argument('--task', type=int, default=0, help='task ID')
167 | return parser.parse_args()
168 |
169 | if __name__ == '__main__':
170 | args = make_args()
171 | config = configparser.ConfigParser()
172 | utils.load_config(config, args.config)
173 | if args.level:
174 | tf.logging.set_verbosity(args.level.upper())
175 | main()
176 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import re
20 | import importlib
21 | import inspect
22 | import numpy as np
23 | import matplotlib.patches as patches
24 | import tensorflow as tf
25 | from tensorflow.python.client import device_lib
26 |
27 |
28 | def get_cachedir(config):
29 | basedir = os.path.expanduser(os.path.expandvars(config.get('config', 'basedir')))
30 | name = os.path.basename(config.get('cache', 'names'))
31 | return os.path.join(basedir, 'cache', name)
32 |
33 |
34 | def get_logdir(config):
35 | basedir = os.path.expanduser(os.path.expandvars(config.get('config', 'basedir')))
36 | model = config.get('config', 'model')
37 | inference = config.get(model, 'inference')
38 | name = os.path.basename(config.get('cache', 'names'))
39 | return os.path.join(basedir, model, inference, name)
40 |
41 |
42 | def get_inference(config):
43 | model = config.get('config', 'model')
44 | return getattr(importlib.import_module('.'.join(['model', model, 'inference'])), config.get(model, 'inference'))
45 |
46 |
47 | def get_downsampling(config):
48 | model = config.get('config', 'model')
49 | return getattr(importlib.import_module('.'.join(['model', model, 'inference'])), config.get(model, 'inference').upper() + '_DOWNSAMPLING')
50 |
51 |
52 | def calc_cell_width_height(config, width, height):
53 | downsampling_width, downsampling_height = get_downsampling(config)
54 | assert width % downsampling_width == 0
55 | assert height % downsampling_height == 0
56 | return width // downsampling_width, height // downsampling_height
57 |
58 |
59 | def match_trainable_variables(pattern):
60 | prog = re.compile(pattern)
61 | return [v for v in tf.trainable_variables() if prog.match(v.op.name)]
62 |
63 |
64 | def match_tensor(pattern):
65 | prog = re.compile(pattern)
66 | return [op.values()[0] for op in tf.get_default_graph().get_operations() if op.values() and prog.match(op.name)]
67 |
68 |
69 | def load_config(config, paths):
70 | for path in paths:
71 | path = os.path.expanduser(os.path.expandvars(path))
72 | assert os.path.exists(path)
73 | config.read(path)
74 |
75 | def get_available_gpus():
76 | local_device_protos = device_lib.list_local_devices()
77 | return [x.name for x in local_device_protos if x.device_type == 'GPU']
78 |
--------------------------------------------------------------------------------
/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import re
20 | import importlib
21 | import inspect
22 | import numpy as np
23 | import matplotlib.patches as patches
24 | import tensorflow as tf
25 | from .. import preprocess
26 |
27 |
28 | def decode_image_objects(paths):
29 | with tf.name_scope(inspect.stack()[0][3]):
30 | with tf.name_scope('parse_example'):
31 | reader = tf.TFRecordReader()
32 | _, serialized = reader.read(tf.train.string_input_producer(paths))
33 | example = tf.parse_single_example(serialized, features={
34 | 'imagepath': tf.FixedLenFeature([], tf.string),
35 | 'imageshape': tf.FixedLenFeature([3], tf.int64),
36 | 'objects': tf.FixedLenFeature([2], tf.string),
37 | })
38 | imagepath = example['imagepath']
39 | objects = example['objects']
40 | with tf.name_scope('decode_objects'):
41 | objects_class = tf.decode_raw(objects[0], tf.int64, name='objects_class')
42 | objects_coord = tf.decode_raw(objects[1], tf.float32)
43 | objects_coord = tf.reshape(objects_coord, [-1, 4], name='objects_coord')
44 | with tf.name_scope('load_image'):
45 | imagefile = tf.read_file(imagepath)
46 | image = tf.image.decode_jpeg(imagefile, channels=3)
47 | return image, example['imageshape'], objects_class, objects_coord
48 |
49 |
50 | def data_augmentation_full(image, objects_coord, width_height, config):
51 | section = inspect.stack()[0][3]
52 | with tf.name_scope(section):
53 | random_crop = config.getfloat(section, 'random_crop')
54 | if random_crop > 0:
55 | image, objects_coord, width_height = tf.cond(
56 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
57 | lambda: preprocess.random_crop(image, objects_coord, width_height, random_crop),
58 | lambda: (image, objects_coord, width_height)
59 | )
60 | return image, objects_coord, width_height
61 |
62 |
63 | def resize_image_objects(image, objects_coord, width_height, width, height):
64 | with tf.name_scope(inspect.stack()[0][3]):
65 | image = tf.image.resize_images(image, [height, width])
66 | factor = [width, height] / width_height
67 | objects_coord = objects_coord * tf.tile(factor, [2])
68 | return image, objects_coord
69 |
70 |
71 | def data_augmentation_resized(image, objects_coord, width, height, config):
72 | section = inspect.stack()[0][3]
73 | with tf.name_scope(section):
74 | if config.getboolean(section, 'random_flip_horizontally'):
75 | image, objects_coord = preprocess.random_flip_horizontally(image, objects_coord, width)
76 | if config.getboolean(section, 'random_brightness'):
77 | image = tf.cond(
78 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
79 | lambda: tf.image.random_brightness(image, max_delta=63),
80 | lambda: image
81 | )
82 | if config.getboolean(section, 'random_saturation'):
83 | image = tf.cond(
84 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
85 | lambda: tf.image.random_saturation(image, lower=0.5, upper=1.5),
86 | lambda: image
87 | )
88 | if config.getboolean(section, 'random_hue'):
89 | image = tf.cond(
90 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
91 | lambda: tf.image.random_hue(image, max_delta=0.032),
92 | lambda: image
93 | )
94 | if config.getboolean(section, 'random_contrast'):
95 | image = tf.cond(
96 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
97 | lambda: tf.image.random_contrast(image, lower=0.5, upper=1.5),
98 | lambda: image
99 | )
100 | if config.getboolean(section, 'noise'):
101 | image = tf.cond(
102 | tf.random_uniform([]) < config.getfloat(section, 'enable_probability'),
103 | lambda: image + tf.truncated_normal(tf.shape(image)) * tf.random_uniform([], 5, 15),
104 | lambda: image
105 | )
106 | grayscale_probability = config.getfloat(section, 'grayscale_probability')
107 | if grayscale_probability > 0:
108 | image = preprocess.random_grayscale(image, grayscale_probability)
109 | return image, objects_coord
110 |
111 |
112 | def transform_labels(objects_class, objects_coord, classes, cell_width, cell_height, dtype=np.float32):
113 | cells = cell_height * cell_width
114 | mask = np.zeros([cells, 1], dtype=dtype)
115 | prob = np.zeros([cells, 1, classes], dtype=dtype)
116 | coords = np.zeros([cells, 1, 4], dtype=dtype)
117 | offset_xy_min = np.zeros([cells, 1, 2], dtype=dtype)
118 | offset_xy_max = np.zeros([cells, 1, 2], dtype=dtype)
119 | assert len(objects_class) == len(objects_coord)
120 | xmin, ymin, xmax, ymax = objects_coord.T
121 | x = cell_width * (xmin + xmax) / 2
122 | y = cell_height * (ymin + ymax) / 2
123 | ix = np.floor(x)
124 | iy = np.floor(y)
125 | offset_x = x - ix
126 | offset_y = y - iy
127 | w = xmax - xmin
128 | h = ymax - ymin
129 | index = (iy * cell_width + ix).astype(np.int)
130 | mask[index, :] = 1
131 | prob[index, :, objects_class] = 1
132 | coords[index, 0, 0] = offset_x
133 | coords[index, 0, 1] = offset_y
134 | coords[index, 0, 2] = np.sqrt(w)
135 | coords[index, 0, 3] = np.sqrt(h)
136 | _w = w / 2 * cell_width
137 | _h = h / 2 * cell_height
138 | offset_xy_min[index, 0, 0] = offset_x - _w
139 | offset_xy_min[index, 0, 1] = offset_y - _h
140 | offset_xy_max[index, 0, 0] = offset_x + _w
141 | offset_xy_max[index, 0, 1] = offset_y + _h
142 | wh = offset_xy_max - offset_xy_min
143 | assert np.all(wh >= 0)
144 | areas = np.multiply.reduce(wh, -1)
145 | return mask, prob, coords, offset_xy_min, offset_xy_max, areas
146 |
147 |
148 | def decode_labels(objects_class, objects_coord, classes, cell_width, cell_height):
149 | with tf.name_scope(inspect.stack()[0][3]):
150 | mask, prob, coords, offset_xy_min, offset_xy_max, areas = tf.py_func(transform_labels, [objects_class, objects_coord, classes, cell_width, cell_height], [tf.float32] * 6)
151 | cells = cell_height * cell_width
152 | with tf.name_scope('reshape_labels'):
153 | mask = tf.reshape(mask, [cells, 1], name='mask')
154 | prob = tf.reshape(prob, [cells, 1, classes], name='prob')
155 | coords = tf.reshape(coords, [cells, 1, 4], name='coords')
156 | offset_xy_min = tf.reshape(offset_xy_min, [cells, 1, 2], name='offset_xy_min')
157 | offset_xy_max = tf.reshape(offset_xy_max, [cells, 1, 2], name='offset_xy_max')
158 | areas = tf.reshape(areas, [cells, 1], name='areas')
159 | return mask, prob, coords, offset_xy_min, offset_xy_max, areas
160 |
161 |
162 | def load_image_labels(paths, classes, width, height, cell_width, cell_height, config):
163 | with tf.name_scope('batch'):
164 | image, imageshape, objects_class, objects_coord = decode_image_objects(paths)
165 | image = tf.cast(image, tf.float32)
166 | width_height = tf.cast(imageshape[1::-1], tf.float32)
167 | if config.getboolean('data_augmentation_full', 'enable'):
168 | image, objects_coord, width_height = data_augmentation_full(image, objects_coord, width_height, config)
169 | image, objects_coord = resize_image_objects(image, objects_coord, width_height, width, height)
170 | if config.getboolean('data_augmentation_resized', 'enable'):
171 | image, objects_coord = data_augmentation_resized(image, objects_coord, width, height, config)
172 | image = tf.clip_by_value(image, 0, 255)
173 | objects_coord = objects_coord / [width, height, width, height]
174 | labels = decode_labels(objects_class, objects_coord, classes, cell_width, cell_height)
175 | return image, labels
176 |
--------------------------------------------------------------------------------
/utils/data/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import inspect
20 | from PIL import Image
21 | import tqdm
22 | import numpy as np
23 | import tensorflow as tf
24 | import utils.data.voc
25 |
26 |
27 | def verify_imageshape(imagepath, imageshape):
28 | with Image.open(imagepath) as image:
29 | return np.all(np.equal(image.size, imageshape[1::-1]))
30 |
31 |
32 | def verify_image_jpeg(imagepath, imageshape):
33 | scope = inspect.stack()[0][3]
34 | try:
35 | graph = tf.get_default_graph()
36 | path = graph.get_tensor_by_name(scope + '/path:0')
37 | decode = graph.get_tensor_by_name(scope + '/decode_jpeg:0')
38 | except KeyError:
39 | tf.logging.debug('creating decode_jpeg tensor')
40 | path = tf.placeholder(tf.string, name=scope + '/path')
41 | imagefile = tf.read_file(path, name=scope + '/read_file')
42 | decode = tf.image.decode_jpeg(imagefile, channels=3, name=scope + '/decode_jpeg')
43 | try:
44 | image = tf.get_default_session().run(decode, {path: imagepath})
45 | except:
46 | return False
47 | return np.all(np.equal(image.shape[:2], imageshape[:2]))
48 |
49 |
50 | def check_coords(objects_coord):
51 | return np.all(objects_coord[:, 0] <= objects_coord[:, 2]) and np.all(objects_coord[:, 1] <= objects_coord[:, 3])
52 |
53 |
54 | def verify_coords(objects_coord, imageshape):
55 | assert check_coords(objects_coord)
56 | return np.all(objects_coord >= 0) and np.all(objects_coord <= np.tile(imageshape[1::-1], [2]))
57 |
58 |
59 | def fix_coords(objects_coord, imageshape):
60 | assert check_coords(objects_coord)
61 | objects_coord = np.maximum(objects_coord, np.zeros([4], dtype=objects_coord.dtype))
62 | objects_coord = np.minimum(objects_coord, np.tile(np.asanyarray(imageshape[1::-1], objects_coord.dtype), [2]))
63 | return objects_coord
64 |
65 |
66 | def voc(writer, name_index, profile, row, verify=False):
67 | root = os.path.expanduser(os.path.expandvars(row['root']))
68 | path = os.path.join(root, 'ImageSets', 'Main', profile) + '.txt'
69 | if not os.path.exists(path):
70 | tf.logging.warn(path + ' not exists')
71 | return False
72 | with open(path, 'r') as f:
73 | filenames = [line.strip() for line in f]
74 | annotations = [os.path.join(root, 'Annotations', filename + '.xml') for filename in filenames]
75 | _annotations = list(filter(os.path.exists, annotations))
76 | if len(annotations) > len(_annotations):
77 | tf.logging.warn('%d of %d images not exists' % (len(annotations) - len(_annotations), len(annotations)))
78 | cnt_noobj = 0
79 | for path in tqdm.tqdm(_annotations):
80 | imagename, imageshape, objects_class, objects_coord = utils.data.voc.load_dataset(path, name_index)
81 | if len(objects_class) <= 0:
82 | cnt_noobj += 1
83 | continue
84 | objects_class = np.array(objects_class, dtype=np.int64)
85 | objects_coord = np.array(objects_coord, dtype=np.float32)
86 | imagepath = os.path.join(root, 'JPEGImages', imagename)
87 | if verify:
88 | if not verify_coords(objects_coord, imageshape):
89 | tf.logging.error('failed to verify coordinates of ' + imagepath)
90 | continue
91 | if not verify_image_jpeg(imagepath, imageshape):
92 | tf.logging.error('failed to decode ' + imagepath)
93 | continue
94 | assert len(objects_class) == len(objects_coord)
95 | example = tf.train.Example(features=tf.train.Features(feature={
96 | 'imagepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(imagepath)])),
97 | 'imageshape': tf.train.Feature(int64_list=tf.train.Int64List(value=imageshape)),
98 | 'objects': tf.train.Feature(bytes_list=tf.train.BytesList(value=[objects_class.tostring(), objects_coord.tostring()])),
99 | }))
100 | writer.write(example.SerializeToString())
101 | if cnt_noobj > 0:
102 | tf.logging.warn('%d of %d images have no object' % (cnt_noobj, len(filenames)))
103 | return True
104 |
105 |
106 | def coco(writer, name_index, profile, row, verify=False):
107 | root = os.path.expanduser(os.path.expandvars(row['root']))
108 | year = str(row['year'])
109 | name = profile + year
110 | path = os.path.join(root, 'annotations', 'instances_%s.json' % name)
111 | if not os.path.exists(path):
112 | tf.logging.warn(path + ' not exists')
113 | return False
114 | import pycocotools.coco
115 | coco = pycocotools.coco.COCO(path)
116 | catIds = coco.getCatIds(catNms=list(name_index.keys()))
117 | cats = coco.loadCats(catIds)
118 | id_index = dict((cat['id'], name_index[cat['name']]) for cat in cats)
119 | imgIds = coco.getImgIds()
120 | path = os.path.join(root, name)
121 | imgs = coco.loadImgs(imgIds)
122 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(path, img['file_name'])), imgs))
123 | if len(imgs) > len(_imgs):
124 | tf.logging.warn('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs)))
125 | cnt_noobj = 0
126 | for img in tqdm.tqdm(_imgs):
127 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
128 | anns = coco.loadAnns(annIds)
129 | if len(anns) <= 0:
130 | cnt_noobj += 1
131 | continue
132 | imagepath = os.path.join(path, img['file_name'])
133 | width, height = img['width'], img['height']
134 | imageshape = [height, width, 3]
135 | objects_class = np.array([id_index[ann['category_id']] for ann in anns], dtype=np.int64)
136 | objects_coord = [ann['bbox'] for ann in anns]
137 | objects_coord = [(x, y, x + w, y + h) for x, y, w, h in objects_coord]
138 | objects_coord = np.array(objects_coord, dtype=np.float32)
139 | if verify:
140 | if not verify_coords(objects_coord, imageshape):
141 | tf.logging.error('failed to verify coordinates of ' + imagepath)
142 | continue
143 | if not verify_image_jpeg(imagepath, imageshape):
144 | tf.logging.error('failed to decode ' + imagepath)
145 | continue
146 | assert len(objects_class) == len(objects_coord)
147 | example = tf.train.Example(features=tf.train.Features(feature={
148 | 'imagepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(imagepath)])),
149 | 'imageshape': tf.train.Feature(int64_list=tf.train.Int64List(value=imageshape)),
150 | 'objects': tf.train.Feature(bytes_list=tf.train.BytesList(value=[objects_class.tostring(), objects_coord.tostring()])),
151 | }))
152 | writer.write(example.SerializeToString())
153 | if cnt_noobj > 0:
154 | tf.logging.warn('%d of %d images have no object' % (cnt_noobj, len(_imgs)))
155 | return True
156 |
--------------------------------------------------------------------------------
/utils/data/voc.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import sys
19 | import bs4
20 |
21 |
22 | def load_dataset(path, name_index):
23 | with open(path, 'r') as f:
24 | anno = bs4.BeautifulSoup(f.read(), 'xml').find('annotation')
25 | objects_class = []
26 | objects_coord = []
27 | for obj in anno.find_all('object', recursive=False):
28 | for bndbox, name in zip(obj.find_all('bndbox', recursive=False), obj.find_all('name', recursive=False)):
29 | if name.text in name_index:
30 | objects_class.append(name_index[name.text])
31 | xmin = float(bndbox.find('xmin').text) - 1
32 | ymin = float(bndbox.find('ymin').text) - 1
33 | xmax = float(bndbox.find('xmax').text) - 1
34 | ymax = float(bndbox.find('ymax').text) - 1
35 | objects_coord.append((xmin, ymin, xmax, ymax))
36 | else:
37 | sys.stderr.write(name.text + ' not in names\n')
38 | size = anno.find('size')
39 | return anno.find('filename').text, (int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)), objects_class, objects_coord
40 |
--------------------------------------------------------------------------------
/utils/postprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import numpy as np
19 |
20 |
21 | def iou(xy_min1, xy_max1, xy_min2, xy_max2):
22 | assert(not np.isnan(xy_min1).any())
23 | assert(not np.isnan(xy_max1).any())
24 | assert(not np.isnan(xy_min2).any())
25 | assert(not np.isnan(xy_max2).any())
26 | assert np.all(xy_min1 <= xy_max1)
27 | assert np.all(xy_min2 <= xy_max2)
28 | areas1 = np.multiply.reduce(xy_max1 - xy_min1)
29 | areas2 = np.multiply.reduce(xy_max2 - xy_min2)
30 | _xy_min = np.maximum(xy_min1, xy_min2)
31 | _xy_max = np.minimum(xy_max1, xy_max2)
32 | _wh = np.maximum(_xy_max - _xy_min, 0)
33 | _areas = np.multiply.reduce(_wh)
34 | assert _areas <= areas1
35 | assert _areas <= areas2
36 | return _areas / np.maximum(areas1 + areas2 - _areas, 1e-10)
37 |
38 |
39 | def non_max_suppress(conf, xy_min, xy_max, threshold, threshold_iou):
40 | _, _, classes = conf.shape
41 | boxes = [(_conf, _xy_min, _xy_max) for _conf, _xy_min, _xy_max in zip(conf.reshape(-1, classes), xy_min.reshape(-1, 2), xy_max.reshape(-1, 2))]
42 | for c in range(classes):
43 | boxes.sort(key=lambda box: box[0][c], reverse=True)
44 | for i in range(len(boxes) - 1):
45 | box = boxes[i]
46 | if box[0][c] <= threshold:
47 | continue
48 | for _box in boxes[i + 1:]:
49 | if iou(box[1], box[2], _box[1], _box[2]) >= threshold_iou:
50 | _box[0][c] = 0
51 | return boxes
52 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 | import numpy as np
20 | import tensorflow as tf
21 |
22 |
23 | def per_image_standardization(image):
24 | stddev = np.std(image)
25 | return (image - np.mean(image)) / max(stddev, 1.0 / np.sqrt(np.multiply.reduce(image.shape)))
26 |
27 |
28 | def random_crop(image, objects_coord, width_height, scale=1):
29 | assert 0 < scale <= 1
30 | section = inspect.stack()[0][3]
31 | with tf.name_scope(section):
32 | xy_min = tf.reduce_min(objects_coord[:, :2], 0)
33 | xy_max = tf.reduce_max(objects_coord[:, 2:], 0)
34 | margin = width_height - xy_max
35 | shrink = tf.random_uniform([4], maxval=scale) * tf.concat([xy_min, margin], 0)
36 | _xy_min = shrink[:2]
37 | _wh = width_height - shrink[2:] - _xy_min
38 | objects_coord = objects_coord - tf.tile(_xy_min, [2])
39 | _xy_min_ = tf.cast(_xy_min, tf.int32)
40 | _wh_ = tf.cast(_wh, tf.int32)
41 | image = tf.image.crop_to_bounding_box(image, _xy_min_[1], _xy_min_[0], _wh_[1], _wh_[0])
42 | return image, objects_coord, _wh
43 |
44 |
45 | def flip_horizontally(image, objects_coord, width):
46 | section = inspect.stack()[0][3]
47 | with tf.name_scope(section):
48 | image = tf.image.flip_left_right(image)
49 | xmin, ymin, xmax, ymax = objects_coord[:, 0:1], objects_coord[:, 1:2], objects_coord[:, 2:3], objects_coord[:, 3:4]
50 | objects_coord = tf.concat([width - xmax, ymin, width - xmin, ymax], 1)
51 | return image, objects_coord
52 |
53 |
54 | def random_flip_horizontally(image, objects_coord, width, probability=0.5):
55 | section = inspect.stack()[0][3]
56 | with tf.name_scope(section):
57 | pred = tf.random_uniform([]) < probability
58 | fn1 = lambda: flip_horizontally(image, objects_coord, width)
59 | fn2 = lambda: (image, objects_coord)
60 | return tf.cond(pred, fn1, fn2)
61 |
62 |
63 | def random_grayscale(image, probability=0.5):
64 | if probability <= 0:
65 | return image
66 | section = inspect.stack()[0][3]
67 | with tf.name_scope(section):
68 | pred = tf.random_uniform([]) < probability
69 | fn1 = lambda: tf.tile(tf.image.rgb_to_grayscale(image), [1] * (len(image.get_shape()) - 1) + [3])
70 | fn2 = lambda: image
71 | return tf.cond(pred, fn1, fn2)
72 |
--------------------------------------------------------------------------------
/utils/verify.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import numpy as np
19 |
20 |
21 | def abs_mean(data):
22 | return np.sum(np.abs(data)) / np.float32(data.size)
23 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import itertools
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | import matplotlib.patches as patches
22 |
23 |
24 | def draw_labels(ax, names, width, height, cell_width, cell_height, mask, prob, coords, xy_min, xy_max, areas, rtol=1e-3):
25 | colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))]
26 | plots = []
27 | for i, (_mask, _prob, _coords, _xy_min, _xy_max, _areas) in enumerate(zip(mask, prob, coords, xy_min, xy_max, areas)):
28 | _mask = _mask.reshape([])
29 | _coords = _coords.reshape([-1])
30 | if np.any(_mask) > 0:
31 | index = np.argmax(_prob)
32 | iy = i // cell_width
33 | ix = i % cell_width
34 | plots.append(ax.add_patch(patches.Rectangle((ix * width / cell_width, iy * height / cell_height), width / cell_width, height / cell_height, linewidth=0, facecolor=colors[index], alpha=.2)))
35 | #check coords
36 | offset_x, offset_y, _w_sqrt, _h_sqrt = _coords
37 | cell_x, cell_y = ix + offset_x, iy + offset_y
38 | x, y = cell_x * width / cell_width, cell_y * height / cell_height
39 | _w, _h = _w_sqrt * _w_sqrt, _h_sqrt * _h_sqrt
40 | w, h = _w * width, _h * height
41 | x_min, y_min = x - w / 2, y - h / 2
42 | plots.append(ax.add_patch(patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor=colors[index], facecolor='none')))
43 | plots.append(ax.annotate(names[index], (x_min, y_min), color=colors[index]))
44 | #check offset_xy_min and xy_max
45 | wh = _xy_max - _xy_min
46 | assert np.all(wh >= 0)
47 | np.testing.assert_allclose(wh / [cell_width, cell_height], [[_w, _h]], rtol=rtol)
48 | np.testing.assert_allclose(_xy_min + wh / 2, [[offset_x, offset_y]], rtol=rtol)
49 | return plots
50 |
--------------------------------------------------------------------------------