├── .gitignore
├── LICENSE.md
├── README.md
├── benchmark_caffe2.py
├── cache.py
├── cache
├── __init__.py
├── coco.py
├── coco.tsv
├── voc.py
└── voc.txt
├── checksum_caffe2.py
├── checksum_torch.py
├── config.ini
├── config
├── anchors
│ ├── coco.tsv
│ ├── tiny-yolo-voc.tsv
│ ├── voc.tsv
│ ├── yolo-voc.tsv
│ └── yolo.tsv
├── category
│ ├── 20
│ ├── 80
│ └── person
├── darknet.ini
├── darknet
│ ├── tiny-yolo-voc.ini
│ ├── yolo-voc.ini
│ └── yolo.ini
├── debug.ini
├── eval.py
└── summary
│ └── histogram.txt
├── convert_darknet_torch.py
├── convert_onnx_caffe2.py
├── convert_torch_onnx.py
├── demo.gif
├── demo_data.py
├── demo_graph.py
├── demo_lr.py
├── detect.py
├── dimension_cluster.py
├── disable_bad_images.py
├── donate_alipay.jpg
├── donate_mm.jpg
├── download_url.py
├── eval.py
├── image.jpg
├── logging.yml
├── model
├── __init__.py
├── densenet.py
├── inception3.py
├── inception4.py
├── mobilenet.py
├── resnet.py
├── vgg.py
└── yolo2.py
├── pruner.py
├── quick_start.sh
├── receptive_field_analyzer.py
├── requirements.txt
├── split_data.py
├── train.py
├── transform
├── __init__.py
├── augmentation.py
├── image.py
└── resize
│ ├── __init__.py
│ ├── image.py
│ └── label.py
├── utils
├── __init__.py
├── cache.py
├── channel.py
├── data.py
├── iou
│ ├── __init__.py
│ ├── numpy.py
│ └── torch.py
├── postprocess.py
├── train.py
└── visualize.py
├── variable_stat.py
└── video2image.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 | .project
4 | .pydevproject
5 | .settings/
6 | .idea/
7 | .cache/
8 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of the [YOLO (You Only Look Once) v2](https://arxiv.org/pdf/1612.08242.pdf)
2 |
3 | The YOLOv2 is one of the most popular [one-stage](https://arxiv.org/abs/1708.02002) object detector.
4 | This project adopts [PyTorch](http://pytorch.org/) as the developing framework to increase productivity, and utilize [ONNX](https://github.com/onnx/onnx) to convert models into [Caffe 2](https://caffe2.ai/) to benefit engineering deployment.
5 | If you are benefited from this project, a donation will be appreciated (via [PayPal](https://www.paypal.me/minimumshen), [微信支付](donate_mm.jpg) or [支付宝](donate_alipay.jpg)).
6 |
7 | 
8 |
9 | ## Designs
10 |
11 | - Flexible configuration design.
12 | Program settings are configurable and can be modified (via **configure file overlaping** (-c/--config option) or **command editing** (-m/--modify option)) using command line argument.
13 |
14 | - Monitoring via [TensorBoard](https://github.com/tensorflow/tensorboard).
15 | Such as the loss values and the debugging images (such as IoU heatmap, ground truth and predict bounding boxes).
16 |
17 | - Parallel model training design.
18 | Different models are saved into different directories so that can be trained simultaneously.
19 |
20 | - Using a NoSQL database to store evaluation results with multiple dimension of information.
21 | This design is useful when analyzing a large amount of experiment results.
22 |
23 | - Time-based output design.
24 | Running information (such as the model, the summaries (produced by TensorBoard), and the evaluation results) are saved periodically via a predefined time.
25 |
26 | - Checkpoint management.
27 | Several latest checkpoint files (.pth) are preserved in the model directory and the older ones are deleted.
28 |
29 | - NaN debug.
30 | When a NaN loss is detected, the running environment (data batch) and the model will be exported to analyze the reason.
31 |
32 | - Unified data cache design.
33 | Various dataset are converted into a unified data cache via corresponding cache plugins.
34 | Some plugins are already implemented. Such as [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [MS COCO](http://cocodataset.org/).
35 |
36 | - Arbitrarily replaceable model plugin design.
37 | The main deep neural network (DNN) can be easily replaced via configuration settings.
38 | Multiple models are already provided. Such as Darknet, [ResNet](https://arxiv.org/abs/1512.03385), Inception [v3](https://arxiv.org/abs/1512.00567) and [v4](https://arxiv.org/abs/1602.07261), [MobileNet](https://arxiv.org/abs/1704.04861) and [DenseNet](https://arxiv.org/abs/1608.06993).
39 |
40 | - Extendable data preprocess plugin design.
41 | The original images (in different sizes) and labels are processed via a sequence of operations to form a training batch (images with the same size, and bounding boxes list are padded).
42 | Multiple preprocess plugins are already implemented. Such as
43 | augmentation operators to process images and labels (such as random rotate and random flip) simultaneously,
44 | operators to resize both images and labels into a fixed size in a batch (such as random crop),
45 | and operators to augment images without labels (such as random blur, random saturation and random brightness).
46 |
47 | ## Feautures
48 |
49 | - [x] Reproduce the original paper's training results.
50 | - [x] Multi-scale training.
51 | - [x] Dimension cluster.
52 | - [x] [Darknet](http://pjreddie.com) model file (`.weights`) parser.
53 | - [x] Detection from image and camera.
54 | - [x] Processing Video file.
55 | - [x] Multi-GPU supporting.
56 | - [ ] Distributed training.
57 | - [ ] [Focal loss](https://arxiv.org/abs/1708.02002).
58 | - [x] Channel-wise model parameter analyzer.
59 | - [x] Automatically change the number of channels.
60 | - [x] Receptive field analyzer.
61 |
62 | ## Quick Start
63 |
64 | This project uses [Python 3](https://www.python.org/). To install the dependent libraries, type the following command in a terminal.
65 |
66 | ```
67 | sudo pip3 install -r requirements.txt
68 | ```
69 |
70 | `quick_start.sh` contains the examples to perform detection and evaluation. Run this script.
71 | Multiple datasets and models (the original Darknet's format, will be converted into PyTorch's format) will be downloaded ([aria2](https://aria2.github.io/) is required).
72 | These datasets are cached into different data profiles, and the models are evaluated over the cached data.
73 | The models are used to detect objects in an example image, and the detection results will be shown.
74 |
75 | ## License
76 |
77 | This project is released as the open source software with the GNU Lesser General Public License version 3 ([LGPL v3](http://www.gnu.org/licenses/lgpl-3.0.html)).
78 |
--------------------------------------------------------------------------------
/benchmark_caffe2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import onnx_caffe2.helper
26 |
27 | import utils
28 |
29 |
30 | def main():
31 | args = make_args()
32 | config = configparser.ConfigParser()
33 | utils.load_config(config, args.config)
34 | for cmd in args.modify:
35 | utils.modify_config(config, cmd)
36 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
37 | logging.config.dictConfig(yaml.load(f))
38 | model_dir = utils.get_model_dir(config)
39 | init_net = onnx_caffe2.helper.load_caffe2_net(os.path.join(model_dir, 'init_net.pb'))
40 | predict_net = onnx_caffe2.helper.load_caffe2_net(os.path.join(model_dir, 'predict_net.pb'))
41 | benchmark = onnx_caffe2.helper.benchmark_caffe2_model(init_net, predict_net)
42 | logging.info('benchmark=%f(milliseconds)' % benchmark)
43 |
44 |
45 | def make_args():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
48 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
49 | parser.add_argument('-b', '--benchmark', action='store_true')
50 | parser.add_argument('--logging', default='logging.yml', help='logging config')
51 | return parser.parse_args()
52 |
53 |
54 | if __name__ == '__main__':
55 | main()
56 |
--------------------------------------------------------------------------------
/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import importlib
24 | import pickle
25 | import random
26 | import shutil
27 | import yaml
28 |
29 | import utils
30 |
31 |
32 | def main():
33 | args = make_args()
34 | config = configparser.ConfigParser()
35 | utils.load_config(config, args.config)
36 | for cmd in args.modify:
37 | utils.modify_config(config, cmd)
38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
39 | logging.config.dictConfig(yaml.load(f))
40 | cache_dir = utils.get_cache_dir(config)
41 | os.makedirs(cache_dir, exist_ok=True)
42 | shutil.copyfile(os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))), os.path.join(cache_dir, 'category'))
43 | category = utils.get_category(config)
44 | category_index = dict([(name, i) for i, name in enumerate(category)])
45 | datasets = config.get('cache', 'datasets').split()
46 | for phase in args.phase:
47 | path = os.path.join(cache_dir, phase) + '.pkl'
48 | logging.info('save cache file: ' + path)
49 | data = []
50 | for dataset in datasets:
51 | logging.info('load %s dataset' % dataset)
52 | module, func = dataset.rsplit('.', 1)
53 | module = importlib.import_module(module)
54 | func = getattr(module, func)
55 | data += func(config, path, category_index)
56 | if config.getboolean('cache', 'shuffle'):
57 | random.shuffle(data)
58 | with open(path, 'wb') as f:
59 | pickle.dump(data, f)
60 | logging.info('%s data are saved into %s' % (str(args.phase), cache_dir))
61 |
62 |
63 | def make_args():
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
66 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
67 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
68 | parser.add_argument('--logging', default='logging.yml', help='logging config')
69 | return parser.parse_args()
70 |
71 |
72 | if __name__ == '__main__':
73 | main()
74 |
--------------------------------------------------------------------------------
/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/cache/__init__.py
--------------------------------------------------------------------------------
/cache/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import logging
20 | import configparser
21 |
22 | import numpy as np
23 | import pandas as pd
24 | import tqdm
25 | import pycocotools.coco
26 | import cv2
27 |
28 | import utils.cache
29 |
30 |
31 | def cache(config, path, category_index):
32 | phase = os.path.splitext(os.path.basename(path))[0]
33 | data = []
34 | for i, row in pd.read_csv(os.path.splitext(__file__)[0] + '.tsv', sep='\t').iterrows():
35 | logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()])))
36 | root = os.path.expanduser(os.path.expandvars(row['root']))
37 | year = str(row['year'])
38 | suffix = phase + year
39 | path = os.path.join(root, 'annotations', 'instances_%s.json' % suffix)
40 | if not os.path.exists(path):
41 | logging.warning(path + ' not exists')
42 | continue
43 | coco = pycocotools.coco.COCO(path)
44 | catIds = coco.getCatIds(catNms=list(category_index.keys()))
45 | cats = coco.loadCats(catIds)
46 | id_index = dict((cat['id'], category_index[cat['name']]) for cat in cats)
47 | imgIds = coco.getImgIds()
48 | path = os.path.join(root, suffix)
49 | imgs = coco.loadImgs(imgIds)
50 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(path, img['file_name'])), imgs))
51 | if len(imgs) > len(_imgs):
52 | logging.warning('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs)))
53 | for img in tqdm.tqdm(_imgs):
54 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
55 | anns = coco.loadAnns(annIds)
56 | if len(anns) <= 0:
57 | continue
58 | path = os.path.join(path, img['file_name'])
59 | width, height = img['width'], img['height']
60 | bbox = np.array([ann['bbox'] for ann in anns], dtype=np.float32)
61 | yx_min = bbox[:, 1::-1]
62 | hw = bbox[:, -1:1:-1]
63 | yx_max = yx_min + hw
64 | cls = np.array([id_index[ann['category_id']] for ann in anns], dtype=np.int)
65 | difficult = np.zeros(cls.shape, dtype=np.uint8)
66 | try:
67 | if config.getboolean('cache', 'verify'):
68 | size = (height, width)
69 | image = cv2.imread(path)
70 | assert image is not None
71 | assert image.shape[:2] == size[:2]
72 | utils.cache.verify_coords(yx_min, yx_max, size[:2])
73 | except configparser.NoOptionError:
74 | pass
75 | assert len(yx_min) == len(cls)
76 | assert yx_min.shape == yx_max.shape
77 | assert len(yx_min.shape) == 2 and yx_min.shape[-1] == 2
78 | data.append(dict(path=path, yx_min=yx_min, yx_max=yx_max, cls=cls, difficult=difficult))
79 | logging.warning('%d of %d images are saved' % (len(data), len(_imgs)))
80 | return data
81 |
--------------------------------------------------------------------------------
/cache/coco.tsv:
--------------------------------------------------------------------------------
1 | root year
2 | ~/data/coco 2014
3 | ~/data/coco 2017
4 |
--------------------------------------------------------------------------------
/cache/voc.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import logging
20 | import configparser
21 |
22 | import numpy as np
23 | import tqdm
24 | import xml.etree.ElementTree
25 | import cv2
26 |
27 | import utils.cache
28 |
29 |
30 | def load_annotation(path, category_index):
31 | tree = xml.etree.ElementTree.parse(path)
32 | yx_min = []
33 | yx_max = []
34 | cls = []
35 | difficult = []
36 | for obj in tree.findall('object'):
37 | try:
38 | cls.append(category_index[obj.find('name').text])
39 | except KeyError:
40 | continue
41 | bbox = obj.find('bndbox')
42 | ymin = float(bbox.find('ymin').text) - 1
43 | xmin = float(bbox.find('xmin').text) - 1
44 | ymax = float(bbox.find('ymax').text) - 1
45 | xmax = float(bbox.find('xmax').text) - 1
46 | assert ymin < ymax
47 | assert xmin < xmax
48 | yx_min.append((ymin, xmin))
49 | yx_max.append((ymax, xmax))
50 | difficult.append(int(obj.find('difficult').text))
51 | size = tree.find('size')
52 | return tree.find('filename').text, (int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)), yx_min, yx_max, cls, difficult
53 |
54 |
55 | def load_root():
56 | with open(os.path.splitext(__file__)[0] + '.txt', 'r') as f:
57 | return [line.rstrip() for line in f]
58 |
59 |
60 | def cache(config, path, category_index, root=load_root()):
61 | phase = os.path.splitext(os.path.basename(path))[0]
62 | data = []
63 | for root in root:
64 | logging.info('loading ' + root)
65 | root = os.path.expanduser(os.path.expandvars(root))
66 | path = os.path.join(root, 'ImageSets', 'Main', phase) + '.txt'
67 | if not os.path.exists(path):
68 | logging.warning(path + ' not exists')
69 | continue
70 | with open(path, 'r') as f:
71 | filenames = [line.strip() for line in f]
72 | for filename in tqdm.tqdm(filenames):
73 | filename, size, yx_min, yx_max, cls, difficult = load_annotation(os.path.join(root, 'Annotations', filename + '.xml'), category_index)
74 | if len(cls) <= 0:
75 | continue
76 | path = os.path.join(root, 'JPEGImages', filename)
77 | yx_min = np.array(yx_min, dtype=np.float32)
78 | yx_max = np.array(yx_max, dtype=np.float32)
79 | cls = np.array(cls, dtype=np.int)
80 | difficult = np.array(difficult, dtype=np.uint8)
81 | assert len(yx_min) == len(cls)
82 | assert yx_min.shape == yx_max.shape
83 | assert len(yx_min.shape) == 2 and yx_min.shape[-1] == 2
84 | try:
85 | if config.getboolean('cache', 'verify'):
86 | try:
87 | image = cv2.imread(path)
88 | assert image is not None
89 | assert image.shape[:2] == size[:2]
90 | utils.cache.verify_coords(yx_min, yx_max, size[:2])
91 | except AssertionError as e:
92 | logging.error(path + ': ' + str(e))
93 | continue
94 | except configparser.NoOptionError:
95 | pass
96 | data.append(dict(path=path, yx_min=yx_min, yx_max=yx_max, cls=cls, difficult=difficult))
97 | logging.info('%d of %d images are saved' % (len(data), len(filenames)))
98 | return data
99 |
--------------------------------------------------------------------------------
/cache/voc.txt:
--------------------------------------------------------------------------------
1 | ~/data/VOCdevkit/VOC2007
2 | ~/data/VOCdevkit/VOC2012
3 |
--------------------------------------------------------------------------------
/checksum_caffe2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import hashlib
24 | import yaml
25 |
26 | import torch
27 | from caffe2.proto import caffe2_pb2
28 | from caffe2.python import workspace
29 |
30 | import utils
31 |
32 |
33 | def main():
34 | args = make_args()
35 | config = configparser.ConfigParser()
36 | utils.load_config(config, args.config)
37 | for cmd in args.modify:
38 | utils.modify_config(config, cmd)
39 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
40 | logging.config.dictConfig(yaml.load(f))
41 | torch.manual_seed(args.seed)
42 | model_dir = utils.get_model_dir(config)
43 | init_net = caffe2_pb2.NetDef()
44 | with open(os.path.join(model_dir, 'init_net.pb'), 'rb') as f:
45 | init_net.ParseFromString(f.read())
46 | predict_net = caffe2_pb2.NetDef()
47 | with open(os.path.join(model_dir, 'predict_net.pb'), 'rb') as f:
48 | predict_net.ParseFromString(f.read())
49 | p = workspace.Predictor(init_net, predict_net)
50 | height, width = tuple(map(int, config.get('image', 'size').split()))
51 | tensor = torch.randn(1, 3, height, width)
52 | # Checksum
53 | output = p.run([tensor.numpy()])
54 | for key, a in [
55 | ('tensor', tensor.cpu().numpy()),
56 | ('output', output[0]),
57 | ]:
58 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()])))
59 |
60 |
61 | def make_args():
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
64 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
65 | parser.add_argument('--logging', default='logging.yml', help='logging config')
66 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor')
67 | return parser.parse_args()
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/checksum_torch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import hashlib
24 | import yaml
25 |
26 | import torch
27 | import torch.autograd
28 | import cv2
29 |
30 | import utils
31 | import utils.train
32 | import model
33 | import transform
34 |
35 |
36 | def main():
37 | args = make_args()
38 | config = configparser.ConfigParser()
39 | utils.load_config(config, args.config)
40 | for cmd in args.modify:
41 | utils.modify_config(config, cmd)
42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
43 | logging.config.dictConfig(yaml.load(f))
44 | torch.manual_seed(args.seed)
45 | cache_dir = utils.get_cache_dir(config)
46 | model_dir = utils.get_model_dir(config)
47 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None)
48 | anchors = utils.get_anchors(config)
49 | anchors = torch.from_numpy(anchors).contiguous()
50 | path, step, epoch = utils.train.load_model(model_dir)
51 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category))
53 | dnn.load_state_dict(state_dict)
54 | height, width = tuple(map(int, config.get('image', 'size').split()))
55 | tensor = torch.randn(1, 3, height, width)
56 | # Checksum
57 | for key, var in dnn.state_dict().items():
58 | a = var.cpu().numpy()
59 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()])))
60 | output = dnn(torch.autograd.Variable(tensor, volatile=True)).data
61 | for key, a in [
62 | ('tensor', tensor.cpu().numpy()),
63 | ('output', output.cpu().numpy()),
64 | ]:
65 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()])))
66 |
67 |
68 | def make_args():
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
71 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
72 | parser.add_argument('--logging', default='logging.yml', help='logging config')
73 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor')
74 | return parser.parse_args()
75 |
76 |
77 | if __name__ == '__main__':
78 | main()
79 |
--------------------------------------------------------------------------------
/config.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | root = ~/model/yolo2-pytorch
3 |
4 | [image]
5 | size = 416 416
6 |
7 | [cache]
8 | name = cache
9 | category = config/category/20
10 | # voc coco
11 | datasets = cache.voc.cache cache.coco.cache
12 | shuffle = 1
13 |
14 | [model]
15 | name = model
16 | anchors = config/anchors/voc.tsv
17 | ; model.yolo2.Darknet
18 | ; model.yolo2.Tiny
19 | ; model.resnet.resnet18
20 | ; model.inception3.Inception3
21 | ; model.inception4.Inception4
22 | ; model.mobilenet.MobileNet
23 | ; model.densenet.densenet121
24 | ; model.vgg.vgg19
25 | dnn = model.yolo2.Tiny
26 | pretrained = 0
27 | threshold = 0.6
28 |
29 | [batch_norm]
30 | enable = 1
31 | gamma = 1
32 | beta = 1
33 |
34 | [inception4]
35 | pretrained = imagenet
36 |
37 | [data]
38 | workers = 3
39 | sizes = 320,320 352,352 384,384 416,416 448,448 480,480 512,512 544,544 576,576 608,608
40 | maintain = 10
41 | shuffle = 0
42 | # rescale padding
43 | resize = rescale
44 |
45 | [transform]
46 | ; transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally
47 | augmentation = transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally
48 | resize_train = transform.resize.label.RandomCrop
49 | resize_eval = transform.resize.label.Resize
50 | resize_test = transform.resize.image.Resize
51 | ; transform.image.RandomBlur transform.image.BGR2HSV transform.image.RandomHue transform.image.RandomSaturation transform.image.RandomBrightness transform.image.HSV2RGB transform.image.RandomGamma
52 | image_train = transform.image.BGR2RGB
53 | image_test = transform.image.BGR2RGB
54 | ; torchvision.transforms.ToTensor transform.image.Normalize
55 | tensor = torchvision.transforms.ToTensor transform.image.Normalize
56 | normalize = 0.5 1
57 |
58 | [augmentation]
59 | random_rotate = -5 5
60 | random_flip_horizontally = 0.5
61 | random_crop = 1
62 | random_blur = 5 5
63 | random_hue = 0 25
64 | random_saturation = 0.5 1.5
65 | random_brightness = 0.5 1.5
66 | random_gamma = 0.9 1.5
67 |
68 | [train]
69 | ; lambda params, lr: torch.optim.SGD(params, lr, momentum=2)
70 | ; lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8)
71 | ; lambda params, lr: torch.optim.RMSprop(params, lr, alpha=0.99, eps=1e-8)
72 | optimizer = lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8)
73 | ; lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
74 | ; lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)
75 | scheduler = lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)
76 | phase = train val
77 | cross_entropy = 1
78 | clip_ = 5
79 |
80 | [save]
81 | secs = 600
82 | keep = 5
83 |
84 | [summary]
85 | scalar = 10
86 | image = 60
87 | histogram_ = 60
88 |
89 | [summary_scalar]
90 | loss_hparam = 1
91 |
92 | [summary_image]
93 | limit = 2
94 | bbox = 1
95 | iou = 1
96 |
97 | [summary_histogram]
98 | parameters = config/summary/histogram.txt
99 |
100 | [hparam]
101 | foreground = 5
102 | background = 1
103 | center = 1
104 | size = 1
105 | cls = 1
106 |
107 | [detect]
108 | threshold = 0.3
109 | threshold_cls = 0.005
110 | fix = 0
111 | overlap = 0.45
112 |
113 | [eval]
114 | phase = test
115 | secs = 12 * 60 * 60
116 | first = 0
117 | iou = 0.5
118 | db = eval.json
119 | mapper = config/eval.py
120 | debug = 0
121 | sort = timestamp
122 | metric07 = 1
123 |
124 | [graph]
125 | metric = lambda t: np.mean(utils.dense(t))
126 | format = svg
127 |
128 | [digraph_graph_attr]
129 | size = 12, 12
130 |
131 | [digraph_node_attr]
132 | style = filled
133 | shape = box
134 | align = left
135 | fontsize = 12
136 | ranksep = 0.1
137 | height = 0.2
138 |
--------------------------------------------------------------------------------
/config/anchors/coco.tsv:
--------------------------------------------------------------------------------
1 | width height
2 | 0.57273 0.677385
3 | 1.87446 2.06253
4 | 3.33843 5.47434
5 | 7.88282 3.52778
6 | 9.77052 9.16828
7 |
--------------------------------------------------------------------------------
/config/anchors/tiny-yolo-voc.tsv:
--------------------------------------------------------------------------------
1 | width height
2 | 1.08 1.19
3 | 3.42 4.41
4 | 6.63 11.38
5 | 9.42 5.11
6 | 16.62 10.52
7 |
--------------------------------------------------------------------------------
/config/anchors/voc.tsv:
--------------------------------------------------------------------------------
1 | width height
2 | 1.08 1.19
3 | 3.42 4.41
4 | 6.63 11.38
5 | 9.42 5.11
6 | 16.62 10.52
7 |
--------------------------------------------------------------------------------
/config/anchors/yolo-voc.tsv:
--------------------------------------------------------------------------------
1 | width height
2 | 1.3221 1.73145
3 | 3.19275 4.00944
4 | 5.05587 8.09892
5 | 9.47112 4.84053
6 | 11.2364 10.0071
7 |
--------------------------------------------------------------------------------
/config/anchors/yolo.tsv:
--------------------------------------------------------------------------------
1 | width height
2 | 0.57273 0.677385
3 | 1.87446 2.06253
4 | 3.33843 5.47434
5 | 7.88282 3.52778
6 | 9.77052 9.16828
7 |
--------------------------------------------------------------------------------
/config/category/20:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | boat
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | diningtable
12 | dog
13 | horse
14 | motorbike
15 | person
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | tvmonitor
--------------------------------------------------------------------------------
/config/category/80:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
--------------------------------------------------------------------------------
/config/category/person:
--------------------------------------------------------------------------------
1 | person
--------------------------------------------------------------------------------
/config/darknet.ini:
--------------------------------------------------------------------------------
1 | [data]
2 | sizes = 416,416 448,448 480,480 512,512
3 | maintain = 10
4 |
5 | [transform]
6 | tensor = torchvision.transforms.ToTensor
7 |
--------------------------------------------------------------------------------
/config/darknet/tiny-yolo-voc.ini:
--------------------------------------------------------------------------------
1 | [image]
2 | size = 416 416
3 |
4 | [cache]
5 | name = cache_voc
6 | category = config/category/20
7 | datasets = cache.voc.cache
8 |
9 | [transform]
10 | tensor = torchvision.transforms.ToTensor
11 |
12 | [model]
13 | name = model_voc
14 | anchors = config/anchors/tiny-yolo-voc.tsv
15 | dnn = model.yolo2.Tiny
16 |
17 | [detect]
18 | fix = 1
19 |
--------------------------------------------------------------------------------
/config/darknet/yolo-voc.ini:
--------------------------------------------------------------------------------
1 | [image]
2 | size = 416 416
3 |
4 | [cache]
5 | name = cache_voc
6 | category = config/category/20
7 | datasets = cache.voc.cache cache.coco.cache
8 |
9 | [transform]
10 | tensor = torchvision.transforms.ToTensor
11 |
12 | [model]
13 | name = model_voc
14 | anchors = config/anchors/yolo-voc.tsv
15 | dnn = model.yolo2.Darknet
16 |
17 | [detect]
18 | fix = 1
19 |
--------------------------------------------------------------------------------
/config/darknet/yolo.ini:
--------------------------------------------------------------------------------
1 | [image]
2 | size = 416 416
3 |
4 | [cache]
5 | name = cache_coco
6 | category = config/category/80
7 | datasets = cache.coco.cache
8 |
9 | [transform]
10 | tensor = torchvision.transforms.ToTensor
11 |
12 | [model]
13 | name = model_coco
14 | anchors = config/anchors/yolo.tsv
15 | dnn = model.yolo2.Darknet
16 |
17 | [detect]
18 | fix = 1
19 |
--------------------------------------------------------------------------------
/config/debug.ini:
--------------------------------------------------------------------------------
1 | [data]
2 | sizes = 416,416
3 |
4 | [transform]
5 | augmentation =
6 | resize_train = transform.resize.label.Resize
7 | image_train = transform.image.BGR2RGB
8 | tensor = torchvision.transforms.ToTensor
9 |
--------------------------------------------------------------------------------
/config/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import configparser
20 |
21 | import numpy as np
22 | import humanize
23 | import pybenchmark
24 |
25 |
26 | class Timestamp(object):
27 | def __call__(self, env, **kwargs):
28 | return float(env.now.timestamp())
29 |
30 |
31 | class Time(object):
32 | def __call__(self, env, **kwargs):
33 | return env.now.strftime('%Y-%m-%d %H:%M:%S')
34 |
35 | def get_format(self, workbook, worksheet):
36 | return workbook.add_format({'num_format': 'yyyy-mm-dd hh:mm:ss'})
37 |
38 |
39 | class Step(object):
40 | def __call__(self, env, **kwargs):
41 | return env.step
42 |
43 |
44 | class Epoch(object):
45 | def __call__(self, env, **kwargs):
46 | return env.epoch
47 |
48 |
49 | class Model(object):
50 | def __call__(self, env, **kwargs):
51 | return env.config.get('model', 'dnn')
52 |
53 |
54 | class SizeDnn(object):
55 | def __call__(self, env, **kwargs):
56 | return sum(var.cpu().numpy().nbytes for var in env.inference.state_dict().values())
57 |
58 | def format(self, workbook, worksheet, num, col):
59 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
60 |
61 |
62 | class SizeDnnNature(object):
63 | def __call__(self, env, **kwargs):
64 | return humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in env.inference.state_dict().values()))
65 |
66 |
67 | class TimeInference(object):
68 | def __call__(self, env, **kwargs):
69 | return pybenchmark.stats['inference']['time']
70 |
71 | def format(self, workbook, worksheet, num, col):
72 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
73 |
74 |
75 | class Root(object):
76 | def __call__(self, env, **kwargs):
77 | return os.path.basename(env.config.get('config', 'root'))
78 |
79 |
80 | class CacheName(object):
81 | def __call__(self, env, **kwargs):
82 | return env.config.get('cache', 'name')
83 |
84 |
85 | class ModelName(object):
86 | def __call__(self, env, **kwargs):
87 | return env.config.get('model', 'name')
88 |
89 |
90 | class Category(object):
91 | def __call__(self, env, **kwargs):
92 | return env.config.get('cache', 'category')
93 |
94 |
95 | class DatasetSize(object):
96 | def __call__(self, env, **kwargs):
97 | return len(env.loader.dataset)
98 |
99 | def format(self, workbook, worksheet, num, col):
100 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
101 |
102 |
103 | class DetectThreshold(object):
104 | def __call__(self, env, **kwargs):
105 | return env.config.getfloat('detect', 'threshold')
106 |
107 | def format(self, workbook, worksheet, num, col):
108 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
109 |
110 |
111 | class DetectThresholdCls(object):
112 | def __call__(self, env, **kwargs):
113 | return env.config.getfloat('detect', 'threshold_cls')
114 |
115 | def format(self, workbook, worksheet, num, col):
116 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
117 |
118 |
119 | class DetectFix(object):
120 | def __call__(self, env, **kwargs):
121 | return env.config.getboolean('detect', 'fix')
122 |
123 | def format(self, workbook, worksheet, num, col):
124 | format_green = workbook.add_format({'bg_color': '#C6EFCE', 'font_color': '#006100'})
125 | format_red = workbook.add_format({'bg_color': '#FFC7CE', 'font_color': '#9C0006'})
126 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'cell', 'criteria': '==', 'value': '1', 'format': format_green})
127 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'cell', 'criteria': '<>', 'value': '1', 'format': format_red})
128 |
129 |
130 | class DetectOverlap(object):
131 | def __call__(self, env, **kwargs):
132 | return env.config.getfloat('detect', 'overlap')
133 |
134 | def format(self, workbook, worksheet, num, col):
135 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
136 |
137 |
138 | class EvalIou(object):
139 | def __call__(self, env, **kwargs):
140 | return env.config.getfloat('eval', 'iou')
141 |
142 | def format(self, workbook, worksheet, num, col):
143 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
144 |
145 |
146 | class EvalMeanAp(object):
147 | def __call__(self, env, **kwargs):
148 | return np.mean(list(kwargs['cls_ap'].values()))
149 |
150 | def format(self, workbook, worksheet, num, col):
151 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
152 |
153 |
154 | class EvalAp(object):
155 | def __call__(self, env, **kwargs):
156 | cls_ap = kwargs['cls_ap']
157 | return ', '.join(['%s=%f' % (env.category[c], cls_ap[c]) for c in sorted(cls_ap.keys())])
158 |
159 |
160 | class Hparam(object):
161 | def __call__(self, env, **kwargs):
162 | try:
163 | return ', '.join([option + '=' + value for option, value in env._config.items('hparam')])
164 | except AttributeError:
165 | return None
166 |
167 |
168 | class Optimizer(object):
169 | def __call__(self, env, **kwargs):
170 | try:
171 | return env._config.get('train', 'optimizer')
172 | except (AttributeError, configparser.NoOptionError):
173 | return None
174 |
175 |
176 | class Scheduler(object):
177 | def __call__(self, env, **kwargs):
178 | try:
179 | return env._config.get('train', 'scheduler')
180 | except (AttributeError, configparser.NoOptionError):
181 | return None
182 |
--------------------------------------------------------------------------------
/config/summary/histogram.txt:
--------------------------------------------------------------------------------
1 | .+\.bn\.weight$
2 |
--------------------------------------------------------------------------------
/convert_darknet_torch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import struct
24 | import collections
25 | import shutil
26 | import hashlib
27 | import yaml
28 |
29 | import numpy as np
30 | import torch
31 | import humanize
32 |
33 | import model
34 | import utils.train
35 |
36 |
37 | def transpose_weight(weight, num_anchors):
38 | _, channels_in, ksize1, ksize2 = weight.size()
39 | weight = weight.view(num_anchors, -1, channels_in, ksize1, ksize2)
40 | x = weight[:, 0:1, :, :, :]
41 | y = weight[:, 1:2, :, :, :]
42 | w = weight[:, 2:3, :, :, :]
43 | h = weight[:, 3:4, :, :, :]
44 | iou = weight[:, 4:5, :, :, :]
45 | cls = weight[:, 5:, :, :, :]
46 | return torch.cat([iou, y, x, h, w, cls], 1).view(-1, channels_in, ksize1, ksize2)
47 |
48 |
49 | def transpose_bias(bias, num_anchors):
50 | bias = bias.view([num_anchors, -1])
51 | x = bias[:, 0:1]
52 | y = bias[:, 1:2]
53 | w = bias[:, 2:3]
54 | h = bias[:, 3:4]
55 | iou = bias[:, 4:5]
56 | cls = bias[:, 5:]
57 | return torch.cat([iou, y, x, h, w, cls], 1).view(-1)
58 |
59 |
60 | def group_state(state_dict):
61 | grouped_dict = collections.OrderedDict()
62 | for key, var in state_dict.items():
63 | layer, suffix1, suffix2 = key.rsplit('.', 2)
64 | suffix = suffix1 + '.' + suffix2
65 | if layer in grouped_dict:
66 | grouped_dict[layer][suffix] = var
67 | else:
68 | grouped_dict[layer] = {suffix: var}
69 | return grouped_dict
70 |
71 |
72 | def main():
73 | args = make_args()
74 | config = configparser.ConfigParser()
75 | utils.load_config(config, args.config)
76 | for cmd in args.modify:
77 | utils.modify_config(config, cmd)
78 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
79 | logging.config.dictConfig(yaml.load(f))
80 | cache_dir = utils.get_cache_dir(config)
81 | model_dir = utils.get_model_dir(config)
82 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None)
83 | anchors = utils.get_anchors(config)
84 | anchors = torch.from_numpy(anchors).contiguous()
85 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), anchors, len(category))
86 | dnn.eval()
87 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
88 | state_dict = dnn.state_dict()
89 | grouped_dict = group_state(state_dict)
90 | try:
91 | layers = []
92 | with open(os.path.expanduser(os.path.expandvars(args.file)), 'rb') as f:
93 | major, minor, revision, seen = struct.unpack('4i', f.read(16))
94 | logging.info('major=%d, minor=%d, revision=%d, seen=%d' % (major, minor, revision, seen))
95 | total = 0
96 | filesize = os.fstat(f.fileno()).st_size
97 | for layer in grouped_dict:
98 | group = grouped_dict[layer]
99 | for suffix in ['conv.bias', 'bn.bias', 'bn.weight', 'bn.running_mean', 'bn.running_var', 'conv.weight']:
100 | if suffix in group:
101 | var = group[suffix]
102 | size = var.size()
103 | cnt = np.multiply.reduce(size)
104 | total += cnt
105 | key = layer + '.' + suffix
106 | val = np.array(struct.unpack('%df' % cnt, f.read(cnt * 4)), np.float32)
107 | val = np.reshape(val, size)
108 | remaining = filesize - f.tell()
109 | logging.info('%s.%s: %s=%f (%s), remaining=%d' % (layer, suffix, 'x'.join(list(map(str, size))), utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), remaining))
110 | layers.append([key, torch.from_numpy(val)])
111 | logging.info('%d parameters assigned' % total)
112 | layers[-1][1] = transpose_weight(layers[-1][1], len(anchors))
113 | layers[-2][1] = transpose_bias(layers[-2][1], len(anchors))
114 | finally:
115 | if remaining > 0:
116 | logging.warning('%d bytes remaining' % remaining)
117 | state_dict = collections.OrderedDict(layers)
118 | if args.delete:
119 | logging.warning('delete model directory: ' + model_dir)
120 | shutil.rmtree(model_dir, ignore_errors=True)
121 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep'), logger=None)
122 | path = saver(state_dict, 0, 0) + saver.ext
123 | if args.copy is not None:
124 | _path = os.path.expandvars(os.path.expanduser(args.copy))
125 | logging.info('copy %s to %s' % (path, _path))
126 | shutil.copy(path, _path)
127 |
128 |
129 | def make_args():
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('file', help='Darknet .weights file')
132 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
133 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
134 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir')
135 | parser.add_argument('--copy', help='copy model')
136 | parser.add_argument('--logging', default='logging.yml', help='logging config')
137 | return parser.parse_args()
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/convert_onnx_caffe2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import onnx
26 | import onnx_caffe2.backend
27 | import onnx_caffe2.helper
28 |
29 | import utils
30 |
31 |
32 | def main():
33 | args = make_args()
34 | config = configparser.ConfigParser()
35 | utils.load_config(config, args.config)
36 | for cmd in args.modify:
37 | utils.modify_config(config, cmd)
38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
39 | logging.config.dictConfig(yaml.load(f))
40 | model_dir = utils.get_model_dir(config)
41 | model = onnx.load(model_dir + '.onnx')
42 | onnx.checker.check_model(model)
43 | init_net, predict_net = onnx_caffe2.backend.Caffe2Backend.onnx_graph_to_caffe2_net(model.graph, device='CPU')
44 | onnx_caffe2.helper.save_caffe2_net(init_net, os.path.join(model_dir, 'init_net.pb'))
45 | onnx_caffe2.helper.save_caffe2_net(predict_net, os.path.join(model_dir, 'predict_net.pb'), output_txt=True)
46 | logging.info(model_dir)
47 |
48 |
49 | def make_args():
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
52 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
53 | parser.add_argument('--logging', default='logging.yml', help='logging config')
54 | return parser.parse_args()
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/convert_torch_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import torch.autograd
26 | import torch.cuda
27 | import torch.optim
28 | import torch.utils.data
29 | import torch.onnx
30 | import humanize
31 |
32 | import utils.train
33 | import model
34 |
35 |
36 | def main():
37 | args = make_args()
38 | config = configparser.ConfigParser()
39 | utils.load_config(config, args.config)
40 | for cmd in args.modify:
41 | utils.modify_config(config, cmd)
42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
43 | logging.config.dictConfig(yaml.load(f))
44 | height, width = tuple(map(int, config.get('image', 'size').split()))
45 | cache_dir = utils.get_cache_dir(config)
46 | model_dir = utils.get_model_dir(config)
47 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None)
48 | anchors = utils.get_anchors(config)
49 | anchors = torch.from_numpy(anchors).contiguous()
50 | path, step, epoch = utils.train.load_model(model_dir)
51 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category))
53 | inference = model.Inference(config, dnn, anchors)
54 | inference.eval()
55 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in inference.state_dict().values())))
56 | dnn.load_state_dict(state_dict)
57 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width), volatile=True)
58 | path = model_dir + '.onnx'
59 | logging.info('save ' + path)
60 | torch.onnx.export(dnn, image, path, export_params=True, verbose=args.verbose) # PyTorch's bug
61 |
62 |
63 | def make_args():
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
66 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
67 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
68 | parser.add_argument('-v', '--verbose', action='store_true')
69 | parser.add_argument('--logging', default='logging.yml', help='logging config')
70 | return parser.parse_args()
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
--------------------------------------------------------------------------------
/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/demo.gif
--------------------------------------------------------------------------------
/demo_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import multiprocessing
24 | import yaml
25 |
26 | import numpy as np
27 | import torch.utils.data
28 | import matplotlib.pyplot as plt
29 |
30 | import utils.data
31 | import utils.train
32 | import utils.visualize
33 | import transform.augmentation
34 |
35 |
36 | def main():
37 | args = make_args()
38 | config = configparser.ConfigParser()
39 | utils.load_config(config, args.config)
40 | for cmd in args.modify:
41 | utils.modify_config(config, cmd)
42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
43 | logging.config.dictConfig(yaml.load(f))
44 | cache_dir = utils.get_cache_dir(config)
45 | category = utils.get_category(config, cache_dir)
46 | draw_bbox = utils.visualize.DrawBBox(category)
47 | batch_size = args.rows * args.cols
48 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase]
49 | dataset = utils.data.Dataset(
50 | utils.data.load_pickles(paths),
51 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()),
52 | shuffle=config.getboolean('data', 'shuffle'),
53 | )
54 | logging.info('num_examples=%d' % len(dataset))
55 | try:
56 | workers = config.getint('data', 'workers')
57 | except configparser.NoOptionError:
58 | workers = multiprocessing.cpu_count()
59 | collate_fn = utils.data.Collate(
60 | transform.parse_transform(config, config.get('transform', 'resize_train')),
61 | utils.train.load_sizes(config),
62 | maintain=config.getint('data', 'maintain'),
63 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()),
64 | )
65 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn)
66 | for data in loader:
67 | path, size, image, yx_min, yx_max, cls = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, yx_min, yx_max, cls'.split(', ')))
68 | fig, axes = plt.subplots(args.rows, args.cols)
69 | axes = axes.flat if batch_size > 1 else [axes]
70 | for ax, path, size, image, yx_min, yx_max, cls in zip(*[axes, path, size, image, yx_min, yx_max, cls]):
71 | logging.info(path + ': ' + 'x'.join(map(str, size)))
72 | size = yx_max - yx_min
73 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)])
74 | yx_min, yx_max, cls = (a[target] for a in (yx_min, yx_max, cls))
75 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int), cls)
76 | ax.imshow(image)
77 | ax.set_title('%d objects' % np.sum(target))
78 | ax.set_xticks([])
79 | ax.set_yticks([])
80 | fig.tight_layout()
81 | mng = plt.get_current_fig_manager()
82 | mng.resize(*mng.window.maxsize())
83 | plt.show()
84 |
85 |
86 | def make_args():
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
89 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
90 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
91 | parser.add_argument('--rows', default=3, type=int)
92 | parser.add_argument('--cols', default=3, type=int)
93 | parser.add_argument('--logging', default='logging.yml', help='logging config')
94 | return parser.parse_args()
95 |
96 |
97 | if __name__ == '__main__':
98 | main()
99 |
--------------------------------------------------------------------------------
/demo_graph.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import torch.autograd
26 | import torch.cuda
27 | import torch.optim
28 | import torch.utils.data
29 | import humanize
30 |
31 | import model
32 | import utils
33 | import utils.train
34 | import utils.visualize
35 |
36 |
37 | def main():
38 | args = make_args()
39 | config = configparser.ConfigParser()
40 | utils.load_config(config, args.config)
41 | for cmd in args.modify:
42 | utils.modify_config(config, cmd)
43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
44 | logging.config.dictConfig(yaml.load(f))
45 | model_dir = utils.get_model_dir(config)
46 | category = utils.get_category(config)
47 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
48 | try:
49 | path, step, epoch = utils.train.load_model(model_dir)
50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
51 | except (FileNotFoundError, ValueError):
52 | logging.warning('model cannot be loaded')
53 | state_dict = None
54 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category))
55 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
56 | if state_dict is not None:
57 | dnn.load_state_dict(state_dict)
58 | height, width = tuple(map(int, config.get('image', 'size').split()))
59 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width))
60 | output = dnn(image)
61 | state_dict = dnn.state_dict()
62 | graph = utils.visualize.Graph(config, state_dict)
63 | graph(output.grad_fn)
64 | diff = [key for key in state_dict if key not in graph.drawn]
65 | if diff:
66 | logging.warning('variables not shown: ' + str(diff))
67 | path = graph.dot.view(os.path.basename(model_dir) + '.gv', os.path.dirname(model_dir))
68 | logging.info(path)
69 |
70 |
71 | def make_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
74 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
75 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
76 | parser.add_argument('--logging', default='logging.yml', help='logging config')
77 | return parser.parse_args()
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
--------------------------------------------------------------------------------
/demo_lr.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import argparse
19 | import configparser
20 | import logging
21 | import logging.config
22 | import os
23 | import yaml
24 |
25 | import torch.autograd
26 | import torch.cuda
27 | import torch.optim
28 | import torch.utils.data
29 |
30 | import model
31 | import utils.data
32 | import utils.postprocess
33 | import utils.train
34 | import utils.visualize
35 |
36 |
37 | def main():
38 | args = make_args()
39 | config = configparser.ConfigParser()
40 | utils.load_config(config, args.config)
41 | for cmd in args.modify:
42 | utils.modify_config(config, cmd)
43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
44 | logging.config.dictConfig(yaml.load(f))
45 | category = utils.get_category(config)
46 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
47 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), anchors, len(category))
48 | inference = model.Inference(config, dnn, anchors)
49 | inference.train()
50 | optimizer = eval(config.get('train', 'optimizer'))(filter(lambda p: p.requires_grad, inference.parameters()), args.learning_rate)
51 | scheduler = eval(config.get('train', 'scheduler'))(optimizer)
52 | for epoch in range(args.epoch):
53 | scheduler.step(epoch)
54 | lr = scheduler.get_lr()
55 | print('\t'.join(map(str, [epoch] + lr)))
56 |
57 |
58 |
59 | def make_args():
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument('epoch', type=int)
62 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
63 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
64 | parser.add_argument('-o', '--optimizer', default='adam')
65 | parser.add_argument('-lr', '--learning_rate', default=1e-3, type=float, help='learning rate')
66 | parser.add_argument('--logging', default='logging.yml', help='logging config')
67 | return parser.parse_args()
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import argparse
19 | import configparser
20 | import logging
21 | import logging.config
22 | import os
23 | import time
24 | import yaml
25 |
26 | import numpy as np
27 | import torch.autograd
28 | import torch.cuda
29 | import torch.optim
30 | import torch.utils.data
31 | import torch.nn.functional as F
32 | import humanize
33 | import pybenchmark
34 | import cv2
35 |
36 | import transform
37 | import model
38 | import utils.postprocess
39 | import utils.train
40 | import utils.visualize
41 |
42 |
43 | def get_logits(pred):
44 | if 'logits' in pred:
45 | return pred['logits'].contiguous()
46 | else:
47 | size = pred['iou'].size()
48 | return torch.autograd.Variable(utils.ensure_device(torch.ones(*size, 1)))
49 |
50 |
51 | def filter_visible(config, iou, yx_min, yx_max, prob):
52 | prob_cls, cls = torch.max(prob, -1)
53 | if config.getboolean('detect', 'fix'):
54 | mask = (iou * prob_cls) > config.getfloat('detect', 'threshold_cls')
55 | else:
56 | mask = iou > config.getfloat('detect', 'threshold')
57 | iou, prob_cls, cls = (t[mask].view(-1) for t in (iou, prob_cls, cls))
58 | _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug
59 | yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max))
60 | num = prob.size(-1)
61 | _mask = torch.unsqueeze(mask, -1).repeat(1, num) # PyTorch's bug
62 | prob = prob[_mask].view(-1, num)
63 | return iou, yx_min, yx_max, prob, prob_cls, cls
64 |
65 |
66 | def postprocess(config, iou, yx_min, yx_max, prob):
67 | iou, yx_min, yx_max, prob, prob_cls, cls = filter_visible(config, iou, yx_min, yx_max, prob)
68 | keep = pybenchmark.profile('nms')(utils.postprocess.nms)(iou, yx_min, yx_max, config.getfloat('detect', 'overlap'))
69 | if keep:
70 | keep = utils.ensure_device(torch.LongTensor(keep))
71 | iou, yx_min, yx_max, prob, prob_cls, cls = (t[keep] for t in (iou, yx_min, yx_max, prob, prob_cls, cls))
72 | if config.getboolean('detect', 'fix'):
73 | score = torch.unsqueeze(iou, -1) * prob
74 | mask = score > config.getfloat('detect', 'threshold_cls')
75 | indices, cls = torch.unbind(mask.nonzero(), -1)
76 | yx_min, yx_max = (t[indices] for t in (yx_min, yx_max))
77 | score = score[mask]
78 | else:
79 | score = iou
80 | return iou, yx_min, yx_max, cls, score
81 |
82 |
83 | class Detect(object):
84 | def __init__(self, args, config):
85 | self.args = args
86 | self.config = config
87 | self.cache_dir = utils.get_cache_dir(config)
88 | self.model_dir = utils.get_model_dir(config)
89 | self.category = utils.get_category(config, self.cache_dir if os.path.exists(self.cache_dir) else None)
90 | self.draw_bbox = utils.visualize.DrawBBox(self.category, colors=args.colors, thickness=args.thickness)
91 | self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
92 | self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
93 | self.path, self.step, self.epoch = utils.train.load_model(self.model_dir)
94 | state_dict = torch.load(self.path, map_location=lambda storage, loc: storage)
95 | self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), self.anchors, len(self.category))
96 | self.dnn.load_state_dict(state_dict)
97 | self.inference = model.Inference(config, self.dnn, self.anchors)
98 | self.inference.eval()
99 | if torch.cuda.is_available():
100 | self.inference.cuda()
101 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values())))
102 | self.cap = self.create_cap()
103 | self.keys = set(args.keys)
104 | self.resize = transform.parse_transform(config, config.get('transform', 'resize_test'))
105 | self.transform_image = transform.get_transform(config, config.get('transform', 'image_test').split())
106 | self.transform_tensor = transform.get_transform(config, config.get('transform', 'tensor').split())
107 |
108 | def __del__(self):
109 | cv2.destroyAllWindows()
110 | try:
111 | self.writer.release()
112 | except AttributeError:
113 | pass
114 | self.cap.release()
115 |
116 | def create_cap(self):
117 | try:
118 | cap = int(self.args.input)
119 | except ValueError:
120 | cap = os.path.expanduser(os.path.expandvars(self.args.input))
121 | assert os.path.exists(cap)
122 | return cv2.VideoCapture(cap)
123 |
124 | def create_writer(self, height, width):
125 | fps = self.cap.get(cv2.CAP_PROP_FPS)
126 | logging.info('cap fps=%f' % fps)
127 | path = os.path.expanduser(os.path.expandvars(self.args.output))
128 | if self.args.fourcc:
129 | fourcc = cv2.VideoWriter_fourcc(*self.args.fourcc.upper())
130 | else:
131 | fourcc = int(self.cap.get(cv2.CAP_PROP_FOURCC))
132 | os.makedirs(os.path.dirname(path), exist_ok=True)
133 | return cv2.VideoWriter(path, fourcc, fps, (width, height))
134 |
135 | def get_image(self):
136 | ret, image_bgr = self.cap.read()
137 | if self.args.crop:
138 | image_bgr = image_bgr[self.crop_ymin:self.crop_ymax, self.crop_xmin:self.crop_xmax]
139 | return image_bgr
140 |
141 | def __call__(self):
142 | image_bgr = self.get_image()
143 | image_resized = self.resize(image_bgr, self.height, self.width)
144 | image = self.transform_image(image_resized)
145 | tensor = self.transform_tensor(image)
146 | tensor = utils.ensure_device(tensor.unsqueeze(0))
147 | pred = pybenchmark.profile('inference')(model._inference)(self.inference, torch.autograd.Variable(tensor, volatile=True))
148 | rows, cols = pred['feature'].size()[-2:]
149 | iou = pred['iou'].data.contiguous().view(-1)
150 | yx_min, yx_max = (pred[key].data.view(-1, 2) for key in 'yx_min, yx_max'.split(', '))
151 | logits = get_logits(pred)
152 | prob = F.softmax(logits, -1).data.view(-1, logits.size(-1))
153 | ret = postprocess(self.config, iou, yx_min, yx_max, prob)
154 | image_result = image_bgr.copy()
155 | if ret is not None:
156 | iou, yx_min, yx_max, cls, score = ret
157 | try:
158 | scale = self.scale
159 | except AttributeError:
160 | scale = utils.ensure_device(torch.from_numpy(np.array(image_result.shape[:2], np.float32) / np.array([rows, cols], np.float32)))
161 | self.scale = scale
162 | yx_min, yx_max = ((t * scale).cpu().numpy().astype(np.int) for t in (yx_min, yx_max))
163 | image_result = self.draw_bbox(image_result, yx_min, yx_max, cls)
164 | if self.args.output:
165 | if not hasattr(self, 'writer'):
166 | self.writer = self.create_writer(*image_result.shape[:2])
167 | self.writer.write(image_result)
168 | else:
169 | cv2.imshow('detection', image_result)
170 | if cv2.waitKey(0 if self.args.pause else 1) in self.keys:
171 | root = os.path.join(self.model_dir, 'snapshot')
172 | os.makedirs(root, exist_ok=True)
173 | path = os.path.join(root, time.strftime(self.args.format))
174 | cv2.imwrite(path, image_bgr)
175 | logging.warning('image dumped into ' + path)
176 |
177 |
178 | def main():
179 | args = make_args()
180 | config = configparser.ConfigParser()
181 | utils.load_config(config, args.config)
182 | for cmd in args.modify:
183 | utils.modify_config(config, cmd)
184 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
185 | logging.config.dictConfig(yaml.load(f))
186 | detect = Detect(args, config)
187 | try:
188 | while detect.cap.isOpened():
189 | detect()
190 | except KeyboardInterrupt:
191 | logging.warning('interrupted')
192 | finally:
193 | logging.info(pybenchmark.stats)
194 |
195 |
196 | def make_args():
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
199 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
200 | parser.add_argument('-i', '--input', default=-1)
201 | parser.add_argument('-k', '--keys', nargs='+', type=int, default=[ord(' ')], help='keys to dump images')
202 | parser.add_argument('-o', '--output', help='output video file')
203 | parser.add_argument('-f', '--format', default='%Y-%m-%d_%H-%M-%S.jpg', help='dump file name format')
204 | parser.add_argument('--crop', nargs='+', type=float, default=[], help='ymin ymax xmin xmax')
205 | parser.add_argument('--pause', action='store_true')
206 | parser.add_argument('--fourcc', default='XVID', help='4-character code of codec used to compress the frames, such as XVID, MJPG')
207 | parser.add_argument('--thickness', default=3, type=int)
208 | parser.add_argument('--colors', nargs='+', default=[])
209 | parser.add_argument('--logging', default='logging.yml', help='logging config')
210 | return parser.parse_args()
211 |
212 |
213 | if __name__ == '__main__':
214 | main()
215 |
--------------------------------------------------------------------------------
/dimension_cluster.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import numpy as np
26 | import nltk.cluster.kmeans
27 |
28 | import utils.data
29 | import utils.iou.numpy
30 |
31 |
32 | def distance(a, b):
33 | return 1 - utils.iou.numpy.iou(-a, a, -b, b)
34 |
35 |
36 | def get_data(paths):
37 | dataset = utils.data.Dataset(utils.data.load_pickles(paths))
38 | return np.concatenate([(data['yx_max'] - data['yx_min']) / utils.image_size(data['path']) for data in dataset.dataset])
39 |
40 |
41 | def main():
42 | args = make_args()
43 | config = configparser.ConfigParser()
44 | utils.load_config(config, args.config)
45 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
46 | logging.config.dictConfig(yaml.load(f))
47 | cache_dir = utils.get_cache_dir(config)
48 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase]
49 | data = get_data(paths)
50 | logging.info('num_examples=%d' % len(data))
51 | clusterer = nltk.cluster.kmeans.KMeansClusterer(args.num, distance, args.repeats)
52 | try:
53 | clusterer.cluster(data)
54 | except KeyboardInterrupt:
55 | logging.warning('interrupted')
56 | for m in clusterer.means():
57 | print('\t'.join(map(str, m)))
58 |
59 |
60 | def make_args():
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('num', type=int)
63 | parser.add_argument('-r', '--repeats', type=int, default=np.iinfo(np.int).max)
64 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
65 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
66 | parser.add_argument('--logging', default='logging.yml', help='logging config')
67 | return parser.parse_args()
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/disable_bad_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import sys
20 | import argparse
21 | import shutil
22 | import tqdm
23 |
24 | import cv2
25 |
26 |
27 | def main():
28 | args = make_args()
29 | root = os.path.expanduser(os.path.expandvars(args.root))
30 | for dirpath, _, filenames in os.walk(root):
31 | for filename in tqdm.tqdm(filenames, desc=dirpath):
32 | if os.path.splitext(filename)[-1].lower() in args.exts and filename[0] != '.':
33 | path = os.path.join(dirpath, filename)
34 | image = cv2.imread(path)
35 | if image is None:
36 | sys.stderr.write('disable bad image %s\n' % path)
37 | _path = os.path.join(os.path.dirname(path), '.' + os.path.basename(path))
38 | if os.path.exists(_path):
39 | os.remove(_path)
40 | shutil.move(path, _path)
41 |
42 |
43 | def make_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('root')
46 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
47 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpe', '.jpg', '.jpeg', '.png'])
48 | parser.add_argument('--level', default='info', help='logging level')
49 | return parser.parse_args()
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/donate_alipay.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/donate_alipay.jpg
--------------------------------------------------------------------------------
/donate_mm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/donate_mm.jpg
--------------------------------------------------------------------------------
/download_url.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import sys
20 | import argparse
21 | import threading
22 |
23 | import numpy as np
24 | import tqdm
25 | import wget
26 |
27 |
28 | def _task(url, root, ext):
29 | path = wget.download(url, bar=None)
30 | with open(path + ext, 'w') as f:
31 | f.write(url)
32 |
33 |
34 | def task(urls, root, ext, pbar, lock, f):
35 | for url in urls:
36 | url = url.rstrip()
37 | try:
38 | _task(url, root, ext)
39 | except:
40 | with lock:
41 | f.write(url + '\n')
42 | pbar.update()
43 |
44 |
45 | def main():
46 | args = make_args()
47 | root = os.path.expandvars(os.path.expanduser(args.root))
48 | os.makedirs(root, exist_ok=True)
49 | os.chdir(root)
50 | workers = []
51 | urls = list(set(sys.stdin.readlines()))
52 | lock = threading.Lock()
53 | with tqdm.tqdm(total=len(urls)) as pbar, open(root + args.ext, 'w') as f:
54 | for urls in np.array_split(urls, args.workers):
55 | w = threading.Thread(target=task, args=(urls, root, args.ext, pbar, lock, f))
56 | w.start()
57 | workers.append(w)
58 | for w in workers:
59 | w.join()
60 |
61 |
62 | def make_args():
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('root')
65 | parser.add_argument('-w', '--workers', type=int, default=6)
66 | parser.add_argument('-e', '--ext', default='.url')
67 | return parser.parse_args()
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/image.jpg
--------------------------------------------------------------------------------
/logging.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | formatters:
3 | simple:
4 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
5 | handlers:
6 | console:
7 | class: logging.StreamHandler
8 | level: INFO
9 | formatter: simple
10 | stream: ext://sys.stderr
11 | root:
12 | level: INFO
13 | handlers: [console]
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 |
20 | import numpy as np
21 | import torch
22 | import torch.autograd
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 |
26 | import utils.iou.torch
27 |
28 |
29 | class ConfigChannels(object):
30 | def __init__(self, config, state_dict=None, channels=3):
31 | self.config = config
32 | self.state_dict = state_dict
33 | self.channels = channels
34 |
35 | def __call__(self, default, name, fn=lambda var: var.size(0)):
36 | if self.state_dict is None:
37 | self.channels = default
38 | else:
39 | var = self.state_dict[name]
40 | self.channels = fn(var)
41 | if self.channels != default:
42 | logging.warning('%s: change number of output channels from %d to %d' % (name, default, self.channels))
43 | return self.channels
44 |
45 |
46 | def output_channels(num_anchors, num_cls):
47 | if num_cls > 1:
48 | return num_anchors * (5 + num_cls)
49 | else:
50 | return num_anchors * 5
51 |
52 |
53 | def meshgrid(rows, cols, swap=False):
54 | i = torch.arange(0, rows).repeat(cols).view(-1, 1)
55 | j = torch.arange(0, cols).view(-1, 1).repeat(1, rows).view(-1, 1)
56 | return torch.cat([i, j], 1) if swap else torch.cat([j, i], 1)
57 |
58 |
59 | def iou_match(yx_min, yx_max, data):
60 | batch_size, cells, num_anchors, _ = yx_min.size()
61 | iou_matrix = utils.iou.torch.batch_iou_matrix(yx_min.view(batch_size, -1, 2), yx_max.view(batch_size, -1, 2), data['yx_min'], data['yx_max'])
62 | iou_matrix = iou_matrix.view(batch_size, cells, num_anchors, -1)
63 | iou, index = iou_matrix.max(-1)
64 | _index = torch.unbind(index.view(batch_size, -1))
65 | _data = {}
66 | for key in 'yx_min, yx_max, cls'.split(', '):
67 | t = data[key]
68 | if len(t.size()) == 2:
69 | t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors)
70 | elif len(t.size()) == 3:
71 | t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors, -1)
72 | _data[key] = t
73 | return iou_matrix, iou, index, _data
74 |
75 |
76 | def fit_positive(rows, cols, yx_min, yx_max, anchors):
77 | device_id = anchors.get_device() if torch.cuda.is_available() else None
78 | batch_size, num, _ = yx_min.size()
79 | num_anchors, _ = anchors.size()
80 | valid = torch.prod(yx_min < yx_max, -1)
81 | center = (yx_min + yx_max) / 2
82 | ij = torch.floor(center)
83 | i, j = torch.unbind(ij.long(), -1)
84 | index = i * cols + j
85 | anchors2 = anchors / 2
86 | iou_matrix = utils.iou.torch.iou_matrix((yx_min - center).view(-1, 2), (yx_max - center).view(-1, 2), -anchors2, anchors2).view(batch_size, -1, num_anchors)
87 | iou, index_anchor = iou_matrix.max(-1)
88 | _positive = []
89 | cells = rows * cols
90 | for valid, index, index_anchor in zip(torch.unbind(valid), torch.unbind(index), torch.unbind(index_anchor)):
91 | index, index_anchor = (t[valid] for t in (index, index_anchor))
92 | t = utils.ensure_device(torch.ByteTensor(cells, num_anchors).zero_(), device_id)
93 | t[index, index_anchor] = 1
94 | _positive.append(t)
95 | return torch.stack(_positive)
96 |
97 |
98 | def fill_norm(yx_min, yx_max, anchors):
99 | center = (yx_min + yx_max) / 2
100 | ij = torch.floor(center)
101 | center_offset = center - ij
102 | size = yx_max - yx_min
103 | return center_offset, torch.log(size / anchors.view(1, -1, 2))
104 |
105 |
106 | def square(t):
107 | return t * t
108 |
109 |
110 | class Inference(nn.Module):
111 | def __init__(self, config, dnn, anchors):
112 | nn.Module.__init__(self)
113 | self.config = config
114 | self.dnn = dnn
115 | self.anchors = anchors
116 |
117 | def forward(self, x):
118 | device_id = x.get_device() if torch.cuda.is_available() else None
119 | feature = self.dnn(x)
120 | rows, cols = feature.size()[-2:]
121 | cells = rows * cols
122 | _feature = feature.permute(0, 2, 3, 1).contiguous().view(feature.size(0), cells, self.anchors.size(0), -1)
123 | sigmoid = F.sigmoid(_feature[:, :, :, :3])
124 | iou = sigmoid[:, :, :, 0]
125 | ij = torch.autograd.Variable(utils.ensure_device(meshgrid(rows, cols).view(1, -1, 1, 2), device_id))
126 | center_offset = sigmoid[:, :, :, 1:3]
127 | center = ij + center_offset
128 | size_norm = _feature[:, :, :, 3:5]
129 | anchors = torch.autograd.Variable(utils.ensure_device(self.anchors.view(1, 1, -1, 2), device_id))
130 | size = torch.exp(size_norm) * anchors
131 | size2 = size / 2
132 | yx_min = center - size2
133 | yx_max = center + size2
134 | logits = _feature[:, :, :, 5:] if _feature.size(-1) > 5 else None
135 | return feature, iou, center_offset, size_norm, yx_min, yx_max, logits
136 |
137 |
138 | def loss(anchors, data, pred, threshold):
139 | iou = pred['iou']
140 | device_id = iou.get_device() if torch.cuda.is_available() else None
141 | rows, cols = pred['feature'].size()[-2:]
142 | iou_matrix, _iou, _, _data = iou_match(pred['yx_min'].data, pred['yx_max'].data, data)
143 | anchors = utils.ensure_device(anchors, device_id)
144 | positive = fit_positive(rows, cols, *(data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
145 | negative = ~positive & (_iou < threshold)
146 | _center_offset, _size_norm = fill_norm(*(_data[key] for key in 'yx_min, yx_max'.split(', ')), anchors)
147 | positive, negative, _iou, _center_offset, _size_norm, _cls = (torch.autograd.Variable(t) for t in (positive, negative, _iou, _center_offset, _size_norm, _data['cls']))
148 | _positive = torch.unsqueeze(positive, -1)
149 | loss = {}
150 | # iou
151 | loss['foreground'] = F.mse_loss(iou[positive], _iou[positive], size_average=False)
152 | loss['background'] = torch.sum(square(iou[negative]))
153 | # bbox
154 | loss['center'] = F.mse_loss(pred['center_offset'][_positive], _center_offset[_positive], size_average=False)
155 | loss['size'] = F.mse_loss(pred['size_norm'][_positive], _size_norm[_positive], size_average=False)
156 | # cls
157 | if 'logits' in pred:
158 | logits = pred['logits']
159 | if len(_cls.size()) > 3:
160 | loss['cls'] = F.mse_loss(F.softmax(logits, -1)[_positive], _cls[_positive], size_average=False)
161 | else:
162 | loss['cls'] = F.cross_entropy(logits[_positive].view(-1, logits.size(-1)), _cls[positive].view(-1))
163 | # normalize
164 | cnt = float(np.multiply.reduce(positive.size()))
165 | for key in loss:
166 | loss[key] /= cnt
167 | return loss, dict(iou=_iou, data=_data, positive=positive, negative=negative)
168 |
169 |
170 | def _inference(inference, tensor):
171 | feature, iou, center_offset, size_norm, yx_min, yx_max, logits = inference(tensor)
172 | pred = dict(
173 | feature=feature, iou=iou,
174 | center_offset=center_offset, size_norm=size_norm,
175 | yx_min=yx_min, yx_max=yx_max,
176 | )
177 | if logits is not None:
178 | pred['logits'] = logits.contiguous()
179 | return pred
180 |
--------------------------------------------------------------------------------
/model/densenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | from collections import OrderedDict
20 |
21 | import torch.nn as nn
22 | import torch.utils.model_zoo as model_zoo
23 | import torchvision.models.densenet as _model
24 | from torchvision.models.densenet import _DenseBlock, _Transition, model_urls
25 |
26 | import model
27 |
28 |
29 | class DenseNet(_model.DenseNet):
30 | def __init__(self, config_channels, anchors, num_cls, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0):
31 | nn.Module.__init__(self)
32 |
33 | # First convolution
34 | self.features = nn.Sequential(OrderedDict([
35 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
36 | ('norm0', nn.BatchNorm2d(num_init_features)),
37 | ('relu0', nn.ReLU(inplace=True)),
38 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
39 | ]))
40 |
41 | # Each denseblock
42 | num_features = num_init_features
43 | for i, num_layers in enumerate(block_config):
44 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
45 | self.features.add_module('denseblock%d' % (i + 1), block)
46 | num_features = num_features + num_layers * growth_rate
47 | if i != len(block_config) - 1:
48 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
49 | self.features.add_module('transition%d' % (i + 1), trans)
50 | num_features = num_features // 2
51 |
52 | # Final batch norm
53 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
54 | self.features.add_module('conv', nn.Conv2d(num_features, model.output_channels(len(anchors), num_cls), 1))
55 |
56 | # init
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d):
59 | m.weight = nn.init.kaiming_normal(m.weight)
60 | elif isinstance(m, nn.BatchNorm2d):
61 | m.weight.data.fill_(1)
62 | m.bias.data.zero_()
63 |
64 | def forward(self, x):
65 | return self.features(x)
66 |
67 |
68 | def densenet121(config_channels, anchors, num_cls, **kwargs):
69 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
70 | if config_channels.config.getboolean('model', 'pretrained'):
71 | url = model_urls['densenet121']
72 | logging.info('use pretrained model: ' + url)
73 | state_dict = model.state_dict()
74 | for key, value in model_zoo.load_url(url).items():
75 | if key in state_dict:
76 | state_dict[key] = value
77 | model.load_state_dict(state_dict)
78 | return model
79 |
80 |
81 | def densenet169(config_channels, anchors, num_cls, **kwargs):
82 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
83 | if config_channels.config.getboolean('model', 'pretrained'):
84 | url = model_urls['densenet169']
85 | logging.info('use pretrained model: ' + url)
86 | state_dict = model.state_dict()
87 | for key, value in model_zoo.load_url(url).items():
88 | if key in state_dict:
89 | state_dict[key] = value
90 | model.load_state_dict(state_dict)
91 | return model
92 |
93 |
94 | def densenet201(config_channels, anchors, num_cls, **kwargs):
95 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
96 | if config_channels.config.getboolean('model', 'pretrained'):
97 | url = model_urls['densenet201']
98 | logging.info('use pretrained model: ' + url)
99 | state_dict = model.state_dict()
100 | for key, value in model_zoo.load_url(url).items():
101 | if key in state_dict:
102 | state_dict[key] = value
103 | model.load_state_dict(state_dict)
104 | return model
105 |
106 |
107 | def densenet161(config_channels, anchors, num_cls, **kwargs):
108 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
109 | if config_channels.config.getboolean('model', 'pretrained'):
110 | url = model_urls['densenet161']
111 | logging.info('use pretrained model: ' + url)
112 | state_dict = model.state_dict()
113 | for key, value in model_zoo.load_url(url).items():
114 | if key in state_dict:
115 | state_dict[key] = value
116 | model.load_state_dict(state_dict)
117 | return model
118 |
--------------------------------------------------------------------------------
/model/inception3.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 |
20 | import scipy.stats as stats
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | import torch.utils.model_zoo
25 | import torchvision.models.inception as _model
26 | from torchvision.models.inception import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE
27 |
28 | import model
29 |
30 |
31 | class Inception3(_model.Inception3):
32 | def __init__(self, config_channels, anchors, num_cls, transform_input=False):
33 | nn.Module.__init__(self)
34 | self.transform_input = transform_input
35 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
36 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
37 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
38 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
39 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
40 | self.Mixed_5b = InceptionA(192, pool_features=32)
41 | self.Mixed_5c = InceptionA(256, pool_features=64)
42 | self.Mixed_5d = InceptionA(288, pool_features=64)
43 | self.Mixed_6a = InceptionB(288)
44 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
45 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
46 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
47 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
48 | # aux_logits
49 | self.Mixed_7a = InceptionD(768)
50 | self.Mixed_7b = InceptionE(1280)
51 | self.Mixed_7c = InceptionE(2048)
52 | self.conv = nn.Conv2d(2048, model.output_channels(len(anchors), num_cls), 1)
53 |
54 | for m in self.modules():
55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
56 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1
57 | X = stats.truncnorm(-2, 2, scale=stddev)
58 | values = torch.Tensor(X.rvs(m.weight.data.numel()))
59 | m.weight.data.copy_(values)
60 | elif isinstance(m, nn.BatchNorm2d):
61 | m.weight.data.fill_(1)
62 | m.bias.data.zero_()
63 |
64 | if config_channels.config.getboolean('model', 'pretrained'):
65 | url = _model.model_urls['inception_v3_google']
66 | logging.info('use pretrained model: ' + url)
67 | state_dict = self.state_dict()
68 | for key, value in torch.utils.model_zoo.load_url(url).items():
69 | if key in state_dict:
70 | state_dict[key] = value
71 | self.load_state_dict(state_dict)
72 |
73 | def forward(self, x):
74 | if self.transform_input:
75 | x = x.clone()
76 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
77 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
78 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
79 | # 299 x 299 x 3
80 | x = self.Conv2d_1a_3x3(x)
81 | # 149 x 149 x 32
82 | x = self.Conv2d_2a_3x3(x)
83 | # 147 x 147 x 32
84 | x = self.Conv2d_2b_3x3(x)
85 | # 147 x 147 x 64
86 | x = F.max_pool2d(x, kernel_size=3, stride=2)
87 | # 73 x 73 x 64
88 | x = self.Conv2d_3b_1x1(x)
89 | # 73 x 73 x 80
90 | x = self.Conv2d_4a_3x3(x)
91 | # 71 x 71 x 192
92 | x = F.max_pool2d(x, kernel_size=3, stride=2)
93 | # 35 x 35 x 192
94 | x = self.Mixed_5b(x)
95 | # 35 x 35 x 256
96 | x = self.Mixed_5c(x)
97 | # 35 x 35 x 288
98 | x = self.Mixed_5d(x)
99 | # 35 x 35 x 288
100 | x = self.Mixed_6a(x)
101 | # 17 x 17 x 768
102 | x = self.Mixed_6b(x)
103 | # 17 x 17 x 768
104 | x = self.Mixed_6c(x)
105 | # 17 x 17 x 768
106 | x = self.Mixed_6d(x)
107 | # 17 x 17 x 768
108 | x = self.Mixed_6e(x)
109 | # 17 x 17 x 768
110 | # aux_logits
111 | # 17 x 17 x 768
112 | x = self.Mixed_7a(x)
113 | # 8 x 8 x 1280
114 | x = self.Mixed_7b(x)
115 | # 8 x 8 x 2048
116 | x = self.Mixed_7c(x)
117 | # 8 x 8 x 2048
118 | return self.conv(x)
119 |
--------------------------------------------------------------------------------
/model/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import collections
19 |
20 | import torch.nn as nn
21 |
22 | import model
23 |
24 |
25 | def conv_bn(in_channels, out_channels, stride):
26 | return nn.Sequential(collections.OrderedDict([
27 | ('conv', nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)),
28 | ('bn', nn.BatchNorm2d(out_channels)),
29 | ('act', nn.ReLU(inplace=True)),
30 | ]))
31 |
32 |
33 | def conv_dw(in_channels, stride):
34 | return nn.Sequential(collections.OrderedDict([
35 | ('conv', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)),
36 | ('bn', nn.BatchNorm2d(in_channels)),
37 | ('act', nn.ReLU(inplace=True)),
38 | ]))
39 |
40 |
41 | def conv_pw(in_channels, out_channels):
42 | return nn.Sequential(collections.OrderedDict([
43 | ('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)),
44 | ('bn', nn.BatchNorm2d(out_channels)),
45 | ('act', nn.ReLU(inplace=True)),
46 | ]))
47 |
48 |
49 | def conv_unit(in_channels, out_channels, stride):
50 | return nn.Sequential(collections.OrderedDict([
51 | ('dw', conv_dw(in_channels, stride)),
52 | ('pw', conv_pw(in_channels, out_channels)),
53 | ]))
54 |
55 |
56 | class MobileNet(nn.Module):
57 | def __init__(self, config_channels, anchors, num_cls):
58 | nn.Module.__init__(self)
59 | layers = []
60 | layers.append(conv_bn(config_channels.channels, config_channels(32, 'layers.%d.conv.weight' % len(layers)), 2))
61 | layers.append(conv_unit(config_channels.channels, config_channels(64, 'layers.%d.pw.conv.weight' % len(layers)), 1))
62 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 2))
63 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 1))
64 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 2))
65 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 1))
66 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 2))
67 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
68 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
69 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
70 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
71 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
72 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 2))
73 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 1))
74 | layers.append(nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1))
75 | self.layers = nn.Sequential(*layers)
76 |
77 | for m in self.modules():
78 | if isinstance(m, nn.Conv2d):
79 | m.weight = nn.init.kaiming_normal(m.weight)
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 |
84 | def forward(self, x):
85 | return self.layers(x)
86 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import re
20 |
21 | import torch.nn as nn
22 | import torch.utils.model_zoo as model_zoo
23 | import torchvision.models.resnet as _model
24 | from torchvision.models.resnet import conv3x3
25 |
26 | import model
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | def __init__(self, config_channels, prefix, channels, stride=1):
31 | nn.Module.__init__(self)
32 | channels_in = config_channels.channels
33 | self.conv1 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), stride)
34 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix))
37 | self.bn2 = nn.BatchNorm2d(config_channels.channels)
38 | if stride > 1 or channels_in != config_channels.channels:
39 | downsample = []
40 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False))
41 | downsample.append(nn.BatchNorm2d(config_channels.channels))
42 | self.downsample = nn.Sequential(*downsample)
43 | else:
44 | self.downsample = None
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 |
65 | class Bottleneck(nn.Module):
66 | def __init__(self, config_channels, prefix, channels, stride=1):
67 | nn.Module.__init__(self)
68 | channels_in = config_channels.channels
69 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), kernel_size=1, bias=False)
70 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
71 | self.conv2 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix), kernel_size=3, stride=stride, padding=1, bias=False)
72 | self.bn2 = nn.BatchNorm2d(config_channels.channels)
73 | self.conv3 = nn.Conv2d(config_channels.channels, config_channels(channels * 4, '%s.conv3.weight' % prefix), kernel_size=1, bias=False)
74 | self.bn3 = nn.BatchNorm2d(config_channels.channels)
75 | self.relu = nn.ReLU(inplace=True)
76 | if stride > 1 or channels_in != config_channels.channels:
77 | downsample = []
78 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False))
79 | downsample.append(nn.BatchNorm2d(config_channels.channels))
80 | self.downsample = nn.Sequential(*downsample)
81 | else:
82 | self.downsample = None
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv3(out)
96 | out = self.bn3(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class ResNet(_model.ResNet):
108 | def __init__(self, config_channels, anchors, num_cls, block, layers):
109 | nn.Module.__init__(self)
110 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(64, 'conv1.weight'), kernel_size=7, stride=2, padding=3, bias=False)
111 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
114 | self.layer1 = self._make_layer(config_channels, 'layer1', block, 64, layers[0])
115 | self.layer2 = self._make_layer(config_channels, 'layer2', block, 128, layers[1], stride=2)
116 | self.layer3 = self._make_layer(config_channels, 'layer3', block, 256, layers[2], stride=2)
117 | self.layer4 = self._make_layer(config_channels, 'layer4', block, 512, layers[3], stride=2)
118 | self.conv = nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | m.weight = nn.init.kaiming_normal(m.weight)
123 | elif isinstance(m, nn.BatchNorm2d):
124 | m.weight.data.fill_(1)
125 | m.bias.data.zero_()
126 |
127 | def _make_layer(self, config_channels, prefix, block, channels, blocks, stride=1):
128 | layers = []
129 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels, stride))
130 | for i in range(1, blocks):
131 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels))
132 | return nn.Sequential(*layers)
133 |
134 | def forward(self, x):
135 | x = self.conv1(x)
136 | x = self.bn1(x)
137 | x = self.relu(x)
138 | x = self.maxpool(x)
139 |
140 | x = self.layer1(x)
141 | x = self.layer2(x)
142 | x = self.layer3(x)
143 | x = self.layer4(x)
144 |
145 | return self.conv(x)
146 |
147 | def scope(self, name):
148 | comp = name.split('.')[:-1]
149 | try:
150 | comp[-1] = re.search('[(conv)|(bn)](\d+)', comp[-1]).group(1)
151 | except AttributeError:
152 | if len(comp) > 1:
153 | if comp[-2] == 'downsample':
154 | comp = comp[:-1]
155 | else:
156 | assert False, name
157 | else:
158 | assert comp[-1] == 'conv', name
159 | return '.'.join(comp)
160 |
161 |
162 | def resnet18(config_channels, anchors, num_cls, **kwargs):
163 | model = ResNet(config_channels, anchors, num_cls, BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if config_channels.config.getboolean('model', 'pretrained'):
165 | url = _model.model_urls['resnet18']
166 | logging.info('use pretrained model: ' + url)
167 | state_dict = model.state_dict()
168 | for key, value in model_zoo.load_url(url).items():
169 | if key in state_dict:
170 | state_dict[key] = value
171 | model.load_state_dict(state_dict)
172 | return model
173 |
174 |
175 | def resnet34(config_channels, anchors, num_cls, **kwargs):
176 | model = ResNet(config_channels, anchors, num_cls, BasicBlock, [3, 4, 6, 3], **kwargs)
177 | if config_channels.config.getboolean('model', 'pretrained'):
178 | url = _model.model_urls['resnet34']
179 | logging.info('use pretrained model: ' + url)
180 | state_dict = model.state_dict()
181 | for key, value in model_zoo.load_url(url).items():
182 | if key in state_dict:
183 | state_dict[key] = value
184 | model.load_state_dict(state_dict)
185 | return model
186 |
187 |
188 | def resnet50(config_channels, anchors, num_cls, **kwargs):
189 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 6, 3], **kwargs)
190 | if config_channels.config.getboolean('model', 'pretrained'):
191 | url = _model.model_urls['resnet50']
192 | logging.info('use pretrained model: ' + url)
193 | state_dict = model.state_dict()
194 | for key, value in model_zoo.load_url(url).items():
195 | if key in state_dict:
196 | state_dict[key] = value
197 | model.load_state_dict(state_dict)
198 | return model
199 |
200 |
201 | def resnet101(config_channels, anchors, num_cls, **kwargs):
202 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 23, 3], **kwargs)
203 | if config_channels.config.getboolean('model', 'pretrained'):
204 | url = _model.model_urls['resnet101']
205 | logging.info('use pretrained model: ' + url)
206 | state_dict = model.state_dict()
207 | for key, value in model_zoo.load_url(url).items():
208 | if key in state_dict:
209 | state_dict[key] = value
210 | model.load_state_dict(state_dict)
211 | return model
212 |
213 |
214 | def resnet152(config_channels, anchors, num_cls, **kwargs):
215 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 8, 36, 3], **kwargs)
216 | if config_channels.config.getboolean('model', 'pretrained'):
217 | url = _model.model_urls['resnet152']
218 | logging.info('use pretrained model: ' + url)
219 | state_dict = model.state_dict()
220 | for key, value in model_zoo.load_url(url).items():
221 | if key in state_dict:
222 | state_dict[key] = value
223 | model.load_state_dict(state_dict)
224 | return model
225 |
--------------------------------------------------------------------------------
/model/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 |
20 | import torch.nn as nn
21 | import torch.utils.model_zoo as model_zoo
22 | import torchvision.models.vgg as _model
23 | from torchvision.models.vgg import model_urls, cfg
24 |
25 | import model
26 |
27 |
28 | class VGG(_model.VGG):
29 | def __init__(self, config_channels, anchors, num_cls, features):
30 | nn.Module.__init__(self)
31 | self.features = features
32 | self.conv = nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1)
33 | self._initialize_weights()
34 |
35 | def forward(self, x):
36 | x = self.features(x)
37 | return self.conv(x)
38 |
39 |
40 | def make_layers(config_channels, cfg, batch_norm=False):
41 | features = []
42 | for v in cfg:
43 | if v == 'M':
44 | features += [nn.MaxPool2d(kernel_size=2, stride=2)]
45 | else:
46 | conv2d = nn.Conv2d(config_channels.channels, config_channels(v, 'features.%d.weight' % len(features)), kernel_size=3, padding=1)
47 | if batch_norm:
48 | features += [conv2d, nn.BatchNorm2d(config_channels.channels), nn.ReLU(inplace=True)]
49 | else:
50 | features += [conv2d, nn.ReLU(inplace=True)]
51 | return nn.Sequential(*features)
52 |
53 |
54 | def vgg11(config_channels, anchors, num_cls):
55 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['A']))
56 | if config_channels.config.getboolean('model', 'pretrained'):
57 | url = model_urls['vgg11']
58 | logging.info('use pretrained model: ' + url)
59 | state_dict = model.state_dict()
60 | for key, value in model_zoo.load_url(url).items():
61 | if key in state_dict:
62 | state_dict[key] = value
63 | model.load_state_dict(state_dict)
64 | return model
65 |
66 |
67 | def vgg11_bn(config_channels, anchors, num_cls):
68 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['A'], batch_norm=True))
69 | if config_channels.config.getboolean('model', 'pretrained'):
70 | url = model_urls['vgg11_bn']
71 | logging.info('use pretrained model: ' + url)
72 | state_dict = model.state_dict()
73 | for key, value in model_zoo.load_url(url).items():
74 | if key in state_dict:
75 | state_dict[key] = value
76 | model.load_state_dict(state_dict)
77 | return model
78 |
79 |
80 | def vgg13(config_channels, anchors, num_cls):
81 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['B']))
82 | if config_channels.config.getboolean('model', 'pretrained'):
83 | url = model_urls['vgg13']
84 | logging.info('use pretrained model: ' + url)
85 | state_dict = model.state_dict()
86 | for key, value in model_zoo.load_url(url).items():
87 | if key in state_dict:
88 | state_dict[key] = value
89 | model.load_state_dict(state_dict)
90 | return model
91 |
92 |
93 | def vgg13_bn(config_channels, anchors, num_cls):
94 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['B'], batch_norm=True))
95 | if config_channels.config.getboolean('model', 'pretrained'):
96 | url = model_urls['vgg13_bn']
97 | logging.info('use pretrained model: ' + url)
98 | state_dict = model.state_dict()
99 | for key, value in model_zoo.load_url(url).items():
100 | if key in state_dict:
101 | state_dict[key] = value
102 | model.load_state_dict(state_dict)
103 | return model
104 |
105 |
106 | def vgg16(config_channels, anchors, num_cls):
107 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['D']))
108 | if config_channels.config.getboolean('model', 'pretrained'):
109 | url = model_urls['vgg16']
110 | logging.info('use pretrained model: ' + url)
111 | state_dict = model.state_dict()
112 | for key, value in model_zoo.load_url(url).items():
113 | if key in state_dict:
114 | state_dict[key] = value
115 | model.load_state_dict(state_dict)
116 | return model
117 |
118 |
119 | def vgg16_bn(config_channels, anchors, num_cls):
120 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['D'], batch_norm=True))
121 | if config_channels.config.getboolean('model', 'pretrained'):
122 | url = model_urls['vgg16_bn']
123 | logging.info('use pretrained model: ' + url)
124 | state_dict = model.state_dict()
125 | for key, value in model_zoo.load_url(url).items():
126 | if key in state_dict:
127 | state_dict[key] = value
128 | model.load_state_dict(state_dict)
129 | return model
130 |
131 |
132 | def vgg19(config_channels, anchors, num_cls):
133 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['E']))
134 | if config_channels.config.getboolean('model', 'pretrained'):
135 | url = model_urls['vgg19']
136 | logging.info('use pretrained model: ' + url)
137 | state_dict = model.state_dict()
138 | for key, value in model_zoo.load_url(url).items():
139 | if key in state_dict:
140 | state_dict[key] = value
141 | model.load_state_dict(state_dict)
142 | return model
143 |
144 |
145 | def vgg19_bn(config_channels, anchors, num_cls):
146 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['E'], batch_norm=True))
147 | if config_channels.config.getboolean('model', 'pretrained'):
148 | url = model_urls['vgg19_bn']
149 | logging.info('use pretrained model: ' + url)
150 | state_dict = model.state_dict()
151 | for key, value in model_zoo.load_url(url).items():
152 | if key in state_dict:
153 | state_dict[key] = value
154 | model.load_state_dict(state_dict)
155 | return model
156 |
--------------------------------------------------------------------------------
/model/yolo2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import collections.abc
19 |
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | import torch.autograd
24 |
25 | import model
26 |
27 |
28 | settings = {
29 | 'size': (416, 416),
30 | }
31 |
32 |
33 | def reorg(x, stride_h=2, stride_w=2):
34 | batch_size, channels, height, width = x.size()
35 | _height, _width = height // stride_h, width // stride_w
36 | if 1:
37 | x = x.view(batch_size, channels, _height, stride_h, _width, stride_w).transpose(3, 4).contiguous()
38 | x = x.view(batch_size, channels, _height * _width, stride_h * stride_w).transpose(2, 3).contiguous()
39 | x = x.view(batch_size, channels, stride_h * stride_w, _height, _width).transpose(1, 2).contiguous()
40 | x = x.view(batch_size, -1, _height, _width)
41 | else:
42 | x = x.view(batch_size, channels, _height, stride_h, _width, stride_w)
43 | x = x.permute(0, 1, 3, 5, 2, 4) # batch_size, channels, stride, stride, _height, _width
44 | x = x.contiguous()
45 | x = x.view(batch_size, -1, _height, _width)
46 | return x
47 |
48 |
49 | class Conv2d(nn.Module):
50 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=True, act=True):
51 | nn.Module.__init__(self)
52 | if isinstance(padding, bool):
53 | if isinstance(kernel_size, collections.abc.Iterable):
54 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0
55 | else:
56 | padding = (kernel_size - 1) // 2 if padding else 0
57 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn)
58 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x
59 | self.act = nn.LeakyReLU(0.1, inplace=True) if act else lambda x: x
60 |
61 | def forward(self, x):
62 | x = self.conv(x)
63 | x = self.bn(x)
64 | x = self.act(x)
65 | return x
66 |
67 |
68 | class Darknet(nn.Module):
69 | def __init__(self, config_channels, anchors, num_cls, stride=2, ratio=1):
70 | nn.Module.__init__(self)
71 | self.stride = stride
72 | channels = int(32 * ratio)
73 | layers = []
74 |
75 | bn = config_channels.config.getboolean('batch_norm', 'enable')
76 | # layers1
77 | for _ in range(2):
78 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
79 | layers.append(nn.MaxPool2d(kernel_size=2))
80 | channels *= 2
81 | # down 4
82 | for _ in range(2):
83 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
84 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers1.%d.conv.weight' % len(layers)), 1, bn=bn))
85 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
86 | layers.append(nn.MaxPool2d(kernel_size=2))
87 | channels *= 2
88 | # down 16
89 | for _ in range(2):
90 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
91 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers1.%d.conv.weight' % len(layers)), 1, bn=bn))
92 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
93 | self.layers1 = nn.Sequential(*layers)
94 |
95 | # layers2
96 | layers = []
97 | layers.append(nn.MaxPool2d(kernel_size=2))
98 | channels *= 2
99 | # down 32
100 | for _ in range(2):
101 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers2.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
102 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers2.%d.conv.weight' % len(layers)), 1, bn=bn))
103 | for _ in range(3):
104 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers2.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
105 | self.layers2 = nn.Sequential(*layers)
106 |
107 | self.passthrough = Conv2d(self.layers1[-1].conv.weight.size(0), config_channels(int(64 * ratio), 'passthrough.conv.weight'), 1, bn=bn)
108 |
109 | # layers3
110 | layers = []
111 | layers.append(Conv2d(self.passthrough.conv.weight.size(0) * self.stride * self.stride + self.layers2[-1].conv.weight.size(0), config_channels(int(1024 * ratio), 'layers3.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
112 | layers.append(Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1, bn=False, act=False))
113 | self.layers3 = nn.Sequential(*layers)
114 |
115 | self.init()
116 |
117 | def init(self):
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | m.weight = nn.init.kaiming_normal(m.weight)
121 | elif isinstance(m, nn.BatchNorm2d):
122 | m.weight.data.fill_(1)
123 | m.bias.data.zero_()
124 |
125 | def forward(self, x):
126 | x = self.layers1(x)
127 | _x = reorg(self.passthrough(x), self.stride)
128 | x = self.layers2(x)
129 | x = torch.cat([_x, x], 1)
130 | return self.layers3(x)
131 |
132 | def scope(self, name):
133 | return '.'.join(name.split('.')[:-2])
134 |
135 | def get_mapper(self, index):
136 | if index == 94:
137 | return lambda indices, channels: torch.cat([indices + i * channels for i in range(self.stride * self.stride)])
138 |
139 |
140 | class Tiny(nn.Module):
141 | def __init__(self, config_channels, anchors, num_cls, channels=16):
142 | nn.Module.__init__(self)
143 | layers = []
144 |
145 | bn = config_channels.config.getboolean('batch_norm', 'enable')
146 | for _ in range(5):
147 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
148 | layers.append(nn.MaxPool2d(kernel_size=2))
149 | channels *= 2
150 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
151 | layers.append(nn.ConstantPad2d((0, 1, 0, 1), float(np.finfo(np.float32).min)))
152 | layers.append(nn.MaxPool2d(kernel_size=2, stride=1))
153 | channels *= 2
154 | for _ in range(2):
155 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True))
156 | layers.append(Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1, bn=False, act=False))
157 | self.layers = nn.Sequential(*layers)
158 |
159 | self.init()
160 |
161 | def init(self):
162 | for m in self.modules():
163 | if isinstance(m, nn.Conv2d):
164 | m.weight = nn.init.xavier_normal(m.weight)
165 | elif isinstance(m, nn.BatchNorm2d):
166 | m.weight.data.fill_(1)
167 | m.bias.data.zero_()
168 |
169 | def forward(self, x):
170 | return self.layers(x)
171 |
172 | def scope(self, name):
173 | return '.'.join(name.split('.')[:-2])
174 |
--------------------------------------------------------------------------------
/pruner.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import numpy as np
26 | import torch.autograd
27 | import torch.cuda
28 | import torch.optim
29 | import torch.utils.data
30 | import humanize
31 |
32 | import model
33 | import utils
34 | import utils.train
35 | import utils.channel
36 |
37 |
38 | def main():
39 | args = make_args()
40 | config = configparser.ConfigParser()
41 | utils.load_config(config, args.config)
42 | for cmd in args.modify:
43 | utils.modify_config(config, cmd)
44 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
45 | logging.config.dictConfig(yaml.load(f))
46 | model_dir = utils.get_model_dir(config)
47 | category = utils.get_category(config)
48 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
49 | path, step, epoch = utils.train.load_model(model_dir)
50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
51 | _model = utils.parse_attr(config.get('model', 'dnn'))
52 | dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category))
53 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
54 | dnn.load_state_dict(state_dict)
55 | height, width = tuple(map(int, config.get('image', 'size').split()))
56 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width))
57 | output = dnn(image)
58 | state_dict = dnn.state_dict()
59 | d = utils.dense(state_dict[args.name])
60 | keep = torch.LongTensor(np.argsort(d)[:int(len(d) * args.keep)])
61 | modifier = utils.channel.Modifier(
62 | args.name, state_dict, dnn,
63 | lambda name, var: var[keep],
64 | lambda name, var, mapper: var[mapper(keep, len(d))],
65 | debug=args.debug,
66 | )
67 | modifier(output.grad_fn)
68 | if args.debug:
69 | path = modifier.dot.view('%s.%s.gv' % (os.path.basename(model_dir), os.path.basename(os.path.splitext(__file__)[0])), os.path.dirname(model_dir))
70 | logging.info(path)
71 | assert len(keep) == len(state_dict[args.name])
72 | dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category))
73 | dnn.load_state_dict(state_dict)
74 | dnn(image)
75 | if not args.debug:
76 | torch.save(state_dict, path)
77 |
78 |
79 | def make_args():
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument('name')
82 | parser.add_argument('keep', type=float)
83 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
84 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
85 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
86 | parser.add_argument('-d', '--debug', action='store_true')
87 | parser.add_argument('--logging', default='logging.yml', help='logging config')
88 | return parser.parse_args()
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
--------------------------------------------------------------------------------
/quick_start.sh:
--------------------------------------------------------------------------------
1 | echo download VOC dataset
2 | LINKS="
3 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
4 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
5 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
6 | "
7 | ROOT=~/data
8 | for LINK in $LINKS
9 | do
10 | aria2c --auto-file-renaming=false -d $ROOT $LINK
11 | tar -kxvf $ROOT/$(basename $LINK) -C $ROOT
12 | done
13 |
14 | echo download COCO dataset
15 | LINKS="
16 | http://images.cocodataset.org/zips/train2014.zip
17 | http://images.cocodataset.org/zips/val2014.zip
18 | http://images.cocodataset.org/annotations/annotations_trainval2014.zip
19 | http://images.cocodataset.org/zips/train2017.zip
20 | http://images.cocodataset.org/zips/val2017.zip
21 | http://images.cocodataset.org/annotations/annotations_trainval2017.zip
22 | "
23 | ROOT=~/data/coco
24 | for LINK in $LINKS
25 | do
26 | aria2c --auto-file-renaming=false -d $ROOT $LINK
27 | unzip -n $ROOT/$(basename $LINK) -d $ROOT
28 | done
29 | rm $ROOT/val2014/COCO_val2014_000000320612.jpg
30 |
31 | echo cache data
32 | python3 cache.py -m cache/datasets=cache.voc.cache cache/name=cache_voc cache/category=config/category/20
33 | python3 cache.py -m cache/datasets=cache.coco.cache cache/name=cache_coco cache/category=config/category/80
34 | python3 cache.py -m cache/datasets='cache.voc.cache cache.coco.cache' cache/name=cache_20 cache/category=config/category/20
35 |
36 | ROOT=~/model/darknet
37 |
38 | echo test VOC models
39 | MODELS="
40 | yolo-voc
41 | tiny-yolo-voc
42 | "
43 |
44 | for MODEL in $MODELS
45 | do
46 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/$MODEL.weights
47 | python3 convert_darknet_torch.py ~/model/darknet/$MODEL.weights -c config.ini config/darknet/$MODEL.ini -d
48 | python3 eval.py -c config.ini config/darknet/$MODEL.ini
49 | python3 detect.py -c config.ini config/darknet/$MODEL.ini -i image.jpg --pause
50 | done
51 |
52 | echo test COCO models
53 | MODELS="
54 | yolo
55 | "
56 |
57 | for MODEL in $MODELS
58 | do
59 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/$MODEL.weights
60 | python3 convert_darknet_torch.py ~/model/darknet/$MODEL.weights -c config.ini config/darknet/$MODEL.ini -d
61 | python3 eval.py -c config.ini config/darknet/$MODEL.ini
62 | python3 detect.py -c config.ini config/darknet/$MODEL.ini -i image.jpg --pause
63 | done
64 |
65 | echo convert pretrained Darknet model
66 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/darknet19_448.conv.23
67 | python3 convert_darknet_torch.py ~/model/darknet/darknet19_448.conv.23 -m model/name=model_voc model/dnn=model.yolo2.Darknet -d --copy ~/model/darknet/darknet19_448.conv.23.pth
68 |
69 | echo reproduce the training results
70 | export CACHE_NAME=cache_voc MODEL_NAME=model_voc MODEL=model.yolo2.Darknet
71 | python3 train.py -b 64 -lr 1e-3 -e 160 -m cache/name=$CACHE_NAME model/name=$MODEL_NAME model/dnn=$MODEL train/optimizer='lambda params, lr: torch.optim.SGD(params, lr, momentum=0.9)' train/scheduler='lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)' -f ~/model/darknet/darknet19_448.conv.23.pth -d
72 | python3 eval.py -m cache/name=$CACHE_NAME model/name=$MODEL_NAME model/dnn=$MODEL
73 |
--------------------------------------------------------------------------------
/receptive_field_analyzer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import multiprocessing
24 | import yaml
25 |
26 | import numpy as np
27 | import scipy.misc
28 | import torch.autograd
29 | import torch.cuda
30 | import torch.optim
31 | import torch.utils.data
32 | import tqdm
33 | import humanize
34 |
35 | import model
36 | import utils.data
37 | import utils.iou.torch
38 | import utils.postprocess
39 | import utils.train
40 | import utils.visualize
41 |
42 |
43 | class Dataset(torch.utils.data.Dataset):
44 | def __init__(self, height, width):
45 | self.points = np.array([(i, j) for i in range(height) for j in range(width)])
46 |
47 | def __len__(self):
48 | return len(self.points)
49 |
50 | def __getitem__(self, index):
51 | return self.points[index]
52 |
53 |
54 | class Analyzer(object):
55 | def __init__(self, args, config):
56 | self.args = args
57 | self.config = config
58 | self.model_dir = utils.get_model_dir(config)
59 | self.category = utils.get_category(config)
60 | self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
61 | self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), self.anchors, len(self.category))
62 | self.dnn.eval()
63 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values())))
64 | if torch.cuda.is_available():
65 | self.dnn.cuda()
66 | self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
67 | output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True))
68 | _, _, self.rows, self.cols = output.size()
69 | self.i, self.j = self.rows // 2, self.cols // 2
70 | self.output = output[:, :, self.i, self.j]
71 | dataset = Dataset(self.height, self.width)
72 | try:
73 | workers = self.config.getint('data', 'workers')
74 | except configparser.NoOptionError:
75 | workers = multiprocessing.cpu_count()
76 | self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers)
77 |
78 | def __call__(self):
79 | changed = np.zeros([self.height, self.width], np.bool)
80 | for yx in tqdm.tqdm(self.loader):
81 | batch_size = yx.size(0)
82 | tensor = torch.zeros(batch_size, 3, self.height, self.width)
83 | for i, _yx in enumerate(torch.unbind(yx)):
84 | y, x = torch.unbind(_yx)
85 | tensor[i, :, y, x] = 1
86 | tensor = utils.ensure_device(tensor)
87 | output = self.dnn(torch.autograd.Variable(tensor, volatile=True))
88 | output = output[:, :, self.i, self.j]
89 | cmp = output == self.output
90 | cmp = torch.prod(cmp, -1).data
91 | for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)):
92 | y, x = torch.unbind(_yx)
93 | changed[y, x] = c
94 | return changed
95 |
96 |
97 | def main():
98 | args = make_args()
99 | config = configparser.ConfigParser()
100 | utils.load_config(config, args.config)
101 | for cmd in args.modify:
102 | utils.modify_config(config, cmd)
103 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
104 | logging.config.dictConfig(yaml.load(f))
105 | analyzer = Analyzer(args, config)
106 | changed = analyzer()
107 | os.makedirs(analyzer.model_dir, exist_ok=True)
108 | path = os.path.join(analyzer.model_dir, args.filename)
109 | scipy.misc.imsave(path, (~changed).astype(np.uint8) * 255)
110 | logging.info(path)
111 |
112 |
113 | def make_args():
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
116 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
117 | parser.add_argument('-b', '--batch_size', default=16, type=int, help='batch size')
118 | parser.add_argument('-n', '--filename', default='receptive_field.jpg')
119 | parser.add_argument('--logging', default='logging.yml', help='logging config')
120 | return parser.parse_args()
121 |
122 |
123 | if __name__ == '__main__':
124 | main()
125 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | humanize
2 | tqdm
3 | onnx_caffe2
4 | onnx
5 | torch<=0.3.1
6 | torchvision
7 | nltk
8 | pandas
9 | pycocotools
10 | XlsxWriter
11 | filelock
12 | matplotlib
13 | scikit_image
14 | pybenchmark
15 | tinydb
16 | graphviz
17 | pretrainedmodels
18 | inflection
19 | videosequence
20 | pymediainfo
21 | Pillow
22 | scipy
23 | skimage
24 | scikit_learn
25 | tensorboardX
26 | wget
27 | PyYAML
28 |
--------------------------------------------------------------------------------
/split_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import random
21 |
22 |
23 | def main():
24 | args = make_args()
25 | root = os.path.expanduser(os.path.expandvars(args.root))
26 | realpaths = []
27 | for dirpath, _, filenames in os.walk(root):
28 | for filename in filenames:
29 | if os.path.splitext(filename)[-1].lower() in args.exts and filename[0] != '.':
30 | path = os.path.join(dirpath, filename)
31 | realpath = os.path.relpath(path, root)
32 | realpaths.append(realpath)
33 | random.shuffle(realpaths)
34 | total = args.train + args.val + args.test
35 | nval = int(len(realpaths) * args.val / total)
36 | ntest = nval + int(len(realpaths) * args.test / total)
37 | val = realpaths[:nval]
38 | test = realpaths[nval:ntest]
39 | train = realpaths[ntest:]
40 | print('train=%d, val=%d, test=%d' % (len(train), len(val), len(test)))
41 | with open(os.path.join(root, 'train' + args.ext), 'w') as f:
42 | for path in train:
43 | f.write(path + '\n')
44 | with open(os.path.join(root, 'val' + args.ext), 'w') as f:
45 | for path in val:
46 | f.write(path + '\n')
47 | with open(os.path.join(root, 'test' + args.ext), 'w') as f:
48 | for path in test:
49 | f.write(path + '\n')
50 |
51 |
52 | def make_args():
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('root')
55 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpe', '.jpg', '.jpeg', '.png'])
56 | parser.add_argument('--train', type=float, default=7)
57 | parser.add_argument('--val', type=float, default=2)
58 | parser.add_argument('--test', type=float, default=1)
59 | parser.add_argument('--ext', default='.txt')
60 | return parser.parse_args()
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/transform/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 |
20 | import torchvision
21 |
22 | import utils
23 |
24 |
25 | def parse_transform(config, method):
26 | if isinstance(method, str):
27 | attr = utils.parse_attr(method)
28 | sig = inspect.signature(attr)
29 | if len(sig.parameters) == 1:
30 | return attr(config)
31 | else:
32 | return attr()
33 | else:
34 | return method
35 |
36 |
37 | def get_transform(config, sequence, compose=torchvision.transforms.Compose):
38 | return compose([parse_transform(config, method) for method in sequence])
39 |
--------------------------------------------------------------------------------
/transform/augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 | import random
20 |
21 | import inflection
22 | import numpy as np
23 | import cv2
24 |
25 | import transform
26 |
27 |
28 | class Rotator(object):
29 | def __init__(self, y, x, height, width, angle):
30 | """
31 | A efficient tool to rotate multiple images in the same size.
32 | :author 申瑞珉 (Ruimin Shen)
33 | :param y: The y coordinate of rotation point.
34 | :param x: The x coordinate of rotation point.
35 | :param height: Image height.
36 | :param width: Image width.
37 | :param angle: Rotate angle.
38 | """
39 | self._mat = cv2.getRotationMatrix2D((x, y), angle, 1.0)
40 | r = np.abs(self._mat[0, :2])
41 | _height, _width = np.inner(r, [height, width]), np.inner(r, [width, height])
42 | fix_y, fix_x = _height / 2 - y, _width / 2 - x
43 | self._mat[:, 2] += [fix_x, fix_y]
44 | self._size = int(_width), int(_height)
45 |
46 | def __call__(self, image, flags=cv2.INTER_LINEAR, fill=None):
47 | if fill is None:
48 | fill = np.random.rand(3) * 256
49 | return cv2.warpAffine(image, self._mat, self._size, flags=flags, borderMode=cv2.BORDER_CONSTANT, borderValue=fill)
50 |
51 | def _rotate_points(self, points):
52 | _points = np.pad(points, [(0, 0), (0, 1)], 'constant')
53 | _points[:, 2] = 1
54 | _points = np.dot(self._mat, _points.T)
55 | return _points.T.astype(points.dtype)
56 |
57 | def rotate_points(self, points):
58 | return self._rotate_points(points[:, ::-1])[:, ::-1]
59 |
60 |
61 | def random_rotate(config, image, yx_min, yx_max):
62 | name = inspect.stack()[0][3]
63 | angle = random.uniform(*tuple(map(float, config.get('augmentation', name).split())))
64 | height, width = image.shape[:2]
65 | p1, p2 = np.copy(yx_min), np.copy(yx_max)
66 | p1[:, 0] = yx_max[:, 0]
67 | p2[:, 0] = yx_min[:, 0]
68 | points = np.concatenate([yx_min, yx_max, p1, p2], 0)
69 | rotator = Rotator(height / 2, width / 2, height, width, angle)
70 | image = rotator(image, fill=0)
71 | points = rotator.rotate_points(points)
72 | bbox_points = np.reshape(points, [4, -1, 2])
73 | yx_min = np.apply_along_axis(lambda points: np.min(points, 0), 0, bbox_points)
74 | yx_max = np.apply_along_axis(lambda points: np.max(points, 0), 0, bbox_points)
75 | return image, yx_min, yx_max
76 |
77 |
78 | class RandomRotate(object):
79 | def __init__(self, config):
80 | self.config = config
81 | self.fn = eval(inflection.underscore(type(self).__name__))
82 |
83 | def __call__(self, data):
84 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'])
85 | return data
86 |
87 |
88 | def flip_horizontally(image, yx_min, yx_max):
89 | assert len(image.shape) == 3
90 | image = cv2.flip(image, 1)
91 | width = image.shape[1]
92 | temp = width - yx_min[:, 1]
93 | yx_min[:, 1] = width - yx_max[:, 1]
94 | yx_max[:, 1] = temp
95 | return image, yx_min, yx_max
96 |
97 |
98 | def random_flip_horizontally(config, image, yx_min, yx_max):
99 | name = inspect.stack()[0][3]
100 | if random.random() > config.getfloat('augmentation', name):
101 | return flip_horizontally(image, yx_min, yx_max)
102 | else:
103 | return image, yx_min, yx_max
104 |
105 |
106 | class RandomFlipHorizontally(object):
107 | def __init__(self, config):
108 | self.config = config
109 | self.fn = eval(inflection.underscore(type(self).__name__))
110 |
111 | def __call__(self, data):
112 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'])
113 | return data
114 |
115 |
116 | def get_transform(config, sequence):
117 | return transform.get_transform(config, sequence)
118 |
--------------------------------------------------------------------------------
/transform/image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import random
19 |
20 | import numpy as np
21 | import torchvision
22 | import inflection
23 | import skimage.exposure
24 | import cv2
25 |
26 |
27 | class BGR2RGB(object):
28 | def __call__(self, image):
29 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30 |
31 |
32 | class BGR2HSV(object):
33 | def __call__(self, image):
34 | return cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
35 |
36 |
37 | class HSV2RGB(object):
38 | def __call__(self, image):
39 | return cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
40 |
41 |
42 | class RandomBlur(object):
43 | def __init__(self, config):
44 | name = inflection.underscore(type(self).__name__)
45 | self.adjust = tuple(map(int, config.get('augmentation', name).split()))
46 |
47 | def __call__(self, image):
48 | adjust = tuple(random.randint(1, adjust) for adjust in self.adjust)
49 | return cv2.blur(image, adjust)
50 |
51 |
52 | class RandomHue(object):
53 | def __init__(self, config):
54 | name = inflection.underscore(type(self).__name__)
55 | self.adjust = tuple(map(int, config.get('augmentation', name).split()))
56 |
57 | def __call__(self, hsv):
58 | h, s, v = cv2.split(hsv)
59 | adjust = random.randint(*self.adjust)
60 | h = h.astype(np.int) + adjust
61 | h = np.clip(h, 0, 179).astype(hsv.dtype)
62 | return cv2.merge((h, s, v))
63 |
64 |
65 | class RandomSaturation(object):
66 | def __init__(self, config):
67 | name = inflection.underscore(type(self).__name__)
68 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
69 |
70 | def __call__(self, hsv):
71 | h, s, v = cv2.split(hsv)
72 | adjust = random.uniform(*self.adjust)
73 | s = s * adjust
74 | s = np.clip(s, 0, 255).astype(hsv.dtype)
75 | return cv2.merge((h, s, v))
76 |
77 |
78 | class RandomBrightness(object):
79 | def __init__(self, config):
80 | name = inflection.underscore(type(self).__name__)
81 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
82 |
83 | def __call__(self, hsv):
84 | h, s, v = cv2.split(hsv)
85 | adjust = random.uniform(*self.adjust)
86 | v = v * adjust
87 | v = np.clip(v, 0, 255).astype(hsv.dtype)
88 | return cv2.merge((h, s, v))
89 |
90 |
91 | class RandomGamma(object):
92 | def __init__(self, config):
93 | name = inflection.underscore(type(self).__name__)
94 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
95 |
96 | def __call__(self, image):
97 | adjust = random.uniform(*self.adjust)
98 | return skimage.exposure.adjust_gamma(image, adjust)
99 |
100 |
101 | class Normalize(torchvision.transforms.Normalize):
102 | def __init__(self, config):
103 | name = inflection.underscore(type(self).__name__)
104 | mean, std = tuple(map(float, config.get('transform', name).split()))
105 | torchvision.transforms.Normalize.__init__(self, (mean, mean, mean), (std, std, std))
106 |
--------------------------------------------------------------------------------
/transform/resize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/transform/resize/__init__.py
--------------------------------------------------------------------------------
/transform/resize/image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inflection
19 | import numpy as np
20 | import cv2
21 |
22 |
23 | def rescale(image, height, width):
24 | return cv2.resize(image, (width, height))
25 |
26 |
27 | class Rescale(object):
28 | def __init__(self):
29 | name = inflection.underscore(type(self).__name__)
30 | self.fn = eval(name)
31 |
32 | def __call__(self, image, height, width):
33 | return self.fn(image, height, width)
34 |
35 |
36 | def fixed(image, height, width):
37 | _height, _width, _ = image.shape
38 | if _height / _width > height / width:
39 | scale = height / _height
40 | else:
41 | scale = width / _width
42 | m = np.eye(2, 3)
43 | m[0, 0] = scale
44 | m[1, 1] = scale
45 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC
46 | return cv2.warpAffine(image, m, (width, height), flags=flags)
47 |
48 |
49 | class Fixed(object):
50 | def __init__(self):
51 | name = inflection.underscore(type(self).__name__)
52 | self.fn = eval(name)
53 |
54 | def __call__(self, image, height, width):
55 | return self.fn(image, height, width)
56 |
57 |
58 | class Resize(object):
59 | def __init__(self, config):
60 | name = config.get('data', inflection.underscore(type(self).__name__))
61 | self.fn = eval(name)
62 |
63 | def __call__(self, image, height, width):
64 | return self.fn(image, height, width)
65 |
--------------------------------------------------------------------------------
/transform/resize/label.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 |
20 | import inflection
21 | import numpy as np
22 | import cv2
23 |
24 |
25 | def rescale(image, yx_min, yx_max, height, width):
26 | _height, _width = image.shape[:2]
27 | scale = np.array([height / _height, width / _width], np.float32)
28 | image = cv2.resize(image, (width, height))
29 | yx_min *= scale
30 | yx_max *= scale
31 | return image, yx_min, yx_max
32 |
33 |
34 | class Rescale(object):
35 | def __init__(self):
36 | self.fn = eval(inflection.underscore(type(self).__name__))
37 |
38 | def __call__(self, data, height, width):
39 | data['image'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['yx_min'], data['yx_max'], height, width)
40 | return data
41 |
42 |
43 | def resize(config, image, yx_min, yx_max, height, width):
44 | fn = eval(config.get('data', inspect.stack()[0][3]))
45 | return fn(image, yx_min, yx_max, height, width)
46 |
47 |
48 | class Resize(object):
49 | def __init__(self, config):
50 | self.config = config
51 | self.fn = eval(config.get('data', inflection.underscore(type(self).__name__)))
52 |
53 | def __call__(self, data, height, width):
54 | data['image'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['yx_min'], data['yx_max'], height, width)
55 | return data
56 |
57 |
58 | def random_crop(config, image, yx_min, yx_max, height, width):
59 | name = inspect.stack()[0][3]
60 | scale = config.getfloat('augmentation', name)
61 | assert 0 < scale <= 1
62 | _yx_min = np.min(yx_min, 0)
63 | _yx_max = np.max(yx_max, 0)
64 | dtype = yx_min.dtype
65 | size = np.array(image.shape[:2], dtype)
66 | margin = scale * np.random.rand(4).astype(dtype) * np.concatenate([_yx_min, size - _yx_max], 0)
67 | _yx_min = margin[:2]
68 | _yx_max = size - margin[2:]
69 | _ymin, _xmin = _yx_min
70 | _ymax, _xmax = _yx_max
71 | _ymin, _xmin, _ymax, _xmax = tuple(map(int, (_ymin, _xmin, _ymax, _xmax)))
72 | image = image[_ymin:_ymax, _xmin:_xmax, :]
73 | yx_min, yx_max = yx_min - _yx_min, yx_max - _yx_min
74 | return resize(config, image, yx_min, yx_max, height, width)
75 |
76 |
77 | class RandomCrop(object):
78 | def __init__(self, config):
79 | self.config = config
80 | self.fn = eval(inflection.underscore(type(self).__name__))
81 |
82 | def __call__(self, data, height, width):
83 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'], height, width)
84 | return data
85 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import re
20 | import configparser
21 | import importlib
22 | import inspect
23 |
24 | import numpy as np
25 | import pandas as pd
26 | import torch.autograd
27 | from PIL import Image
28 |
29 |
30 | class Compose(object):
31 | def __init__(self, transforms):
32 | self.transforms = transforms
33 |
34 | def __call__(self, img, yx_min, yx_max, cls):
35 | for t in self.transforms:
36 | img, yx_min, yx_max, cls = t(img, yx_min, yx_max, cls)
37 | return img, yx_min, yx_max, cls
38 |
39 |
40 | class RegexList(list):
41 | def __init__(self, l):
42 | for s in l:
43 | prog = re.compile(s)
44 | self.append(prog)
45 |
46 | def __call__(self, s):
47 | for prog in self:
48 | if prog.match(s):
49 | return True
50 | return False
51 |
52 |
53 | def get_cache_dir(config):
54 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
55 | name = config.get('cache', 'name')
56 | return os.path.join(root, name)
57 |
58 |
59 | def get_model_dir(config):
60 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
61 | name = config.get('model', 'name')
62 | model = config.get('model', 'dnn')
63 | return os.path.join(root, name, model)
64 |
65 |
66 | def get_eval_db(config):
67 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
68 | db = config.get('eval', 'db')
69 | return os.path.join(root, db)
70 |
71 |
72 | def get_category(config, cache_dir=None):
73 | path = os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))) if cache_dir is None else os.path.join(cache_dir, 'category')
74 | with open(path, 'r') as f:
75 | return [line.strip() for line in f]
76 |
77 |
78 | def get_anchors(config, dtype=np.float32):
79 | path = os.path.expanduser(os.path.expandvars(config.get('model', 'anchors')))
80 | df = pd.read_csv(path, sep='\t', dtype=dtype)
81 | return df[['height', 'width']].values
82 |
83 |
84 | def parse_attr(s):
85 | m, n = s.rsplit('.', 1)
86 | m = importlib.import_module(m)
87 | return getattr(m, n)
88 |
89 |
90 | def load_config(config, paths):
91 | for path in paths:
92 | path = os.path.expanduser(os.path.expandvars(path))
93 | assert os.path.exists(path)
94 | config.read(path)
95 |
96 |
97 | def modify_config(config, cmd):
98 | var, value = cmd.split('=', 1)
99 | section, option = var.split('/')
100 | if value:
101 | config.set(section, option, value)
102 | else:
103 | try:
104 | config.remove_option(section, option)
105 | except (configparser.NoSectionError, configparser.NoOptionError):
106 | pass
107 |
108 |
109 | def ensure_device(t, device_id=None, async=False):
110 | if torch.cuda.is_available():
111 | t = t.cuda(device_id, async)
112 | return t
113 |
114 |
115 | def dense(var):
116 | return [torch.mean(torch.abs(x)) if torch.is_tensor(x) else np.abs(x) for x in var]
117 |
118 |
119 | def abs_mean(data, dtype=np.float32):
120 | assert isinstance(data, np.ndarray), type(data)
121 | return np.sum(np.abs(data)) / dtype(data.size)
122 |
123 |
124 | def image_size(path):
125 | with Image.open(path) as image:
126 | return image.size
127 |
--------------------------------------------------------------------------------
/utils/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import numpy as np
19 |
20 |
21 | def verify_coords(yx_min, yx_max, size):
22 | assert np.all(yx_min <= yx_max), 'yx_min <= yx_max'
23 | assert np.all(0 <= yx_min), '0 <= yx_min'
24 | assert np.all(0 <= yx_max), '0 <= yx_max'
25 | assert np.all(yx_min < size), 'yx_min < size'
26 | assert np.all(yx_max < size), 'yx_max < size'
27 |
28 |
29 | def fix_coords(yx_min, yx_max, size):
30 | assert np.all(yx_min <= yx_max)
31 | assert yx_min.dtype == yx_max.dtype
32 | coord_min = np.zeros([2], dtype=yx_min.dtype)
33 | coord_max = np.array(size, dtype=yx_min.dtype) - 1
34 | yx_min = np.minimum(np.maximum(yx_min, coord_min), coord_max)
35 | yx_max = np.minimum(np.maximum(yx_max, coord_min), coord_max)
36 | return yx_min, yx_max
37 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import pickle
20 | import random
21 | import copy
22 |
23 | import numpy as np
24 | import torch.utils.data
25 | import sklearn.preprocessing
26 | import cv2
27 |
28 |
29 | def padding_labels(data, dim, labels='yx_min, yx_max, cls, difficult'.split(', ')):
30 | """
31 | Padding labels into the same dimension (to form a batch).
32 | :author 申瑞珉 (Ruimin Shen)
33 | :param data: A dict contains the labels to be padded.
34 | :param dim: The target dimension.
35 | :param labels: The list of label names.
36 | :return: The padded label dict.
37 | """
38 | pad = dim - len(data[labels[0]])
39 | for key in labels:
40 | label = data[key]
41 | data[key] = np.pad(label, [(0, pad)] + [(0, 0)] * (len(label.shape) - 1), 'constant')
42 | return data
43 |
44 |
45 | def load_pickles(paths):
46 | data = []
47 | for path in paths:
48 | with open(path, 'rb') as f:
49 | data += pickle.load(f)
50 | return data
51 |
52 |
53 | class Dataset(torch.utils.data.Dataset):
54 | def __init__(self, data, transform=lambda data: data, one_hot=None, shuffle=False, dir=None):
55 | """
56 | Load the cached data (.pkl) into memory.
57 | :author 申瑞珉 (Ruimin Shen)
58 | :param data: A list contains the data samples (dict).
59 | :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict.
60 | :param one_hot: If a int value (total number of classes) is given, the class label (key "cls") will be generated in a one-hot format.
61 | :param shuffle: Shuffle the loaded dataset.
62 | :param dir: The directory to store the exception data.
63 | """
64 | self.data = data
65 | if shuffle:
66 | random.shuffle(self.data)
67 | self.transform = transform
68 | self.one_hot = None if one_hot is None else sklearn.preprocessing.OneHotEncoder(one_hot, dtype=np.float32)
69 | self.dir = dir
70 |
71 | def __len__(self):
72 | return len(self.data)
73 |
74 | def __getitem__(self, index):
75 | data = copy.deepcopy(self.data[index])
76 | try:
77 | image = cv2.imread(data['path'])
78 | data['image'] = image
79 | data['size'] = np.array(image.shape[:2])
80 | data = self.transform(data)
81 | if self.one_hot is not None:
82 | data['cls'] = self.one_hot.fit_transform(np.expand_dims(data['cls'], -1)).todense()
83 | except:
84 | if self.dir is not None:
85 | os.makedirs(self.dir, exist_ok=True)
86 | name = self.__module__ + '.' + type(self).__name__
87 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
88 | pickle.dump(data, f)
89 | raise
90 | return data
91 |
92 |
93 | class Collate(object):
94 | def __init__(self, resize, sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None):
95 | """
96 | Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch.
97 | :author 申瑞珉 (Ruimin Shen)
98 | :param resize: A function to resize the image and labels.
99 | :param sizes: The image sizes to be randomly choosed.
100 | :param maintain: How many times a size to be maintained.
101 | :param transform_image: A function to transform the resized image.
102 | :param transform_tensor: A function to standardize a image into a tensor.
103 | :param dir: The directory to store the exception data.
104 | """
105 | self.resize = resize
106 | self.sizes = sizes
107 | assert maintain > 0
108 | self.maintain = maintain
109 | self._maintain = maintain
110 | self.transform_image = transform_image
111 | self.transform_tensor = transform_tensor
112 | self.dir = dir
113 |
114 | def __call__(self, batch):
115 | height, width = self.next_size()
116 | dim = max(len(data['cls']) for data in batch)
117 | _batch = []
118 | for data in batch:
119 | try:
120 | data = self.resize(data, height, width)
121 | data['image'] = self.transform_image(data['image'])
122 | data = padding_labels(data, dim)
123 | if self.transform_tensor is not None:
124 | data['tensor'] = self.transform_tensor(data['image'])
125 | _batch.append(data)
126 | except:
127 | if self.dir is not None:
128 | os.makedirs(self.dir, exist_ok=True)
129 | name = self.__module__ + '.' + type(self).__name__
130 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
131 | pickle.dump(data, f)
132 | raise
133 | return torch.utils.data.dataloader.default_collate(_batch)
134 |
135 | def next_size(self):
136 | if self._maintain < self.maintain:
137 | self._maintain += 1
138 | else:
139 | self.size = random.choice(self.sizes)
140 | self._maintain = 0
141 | return self.size
142 |
--------------------------------------------------------------------------------
/utils/iou/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/utils/iou/__init__.py
--------------------------------------------------------------------------------
/utils/iou/numpy.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import unittest
19 |
20 | import numpy as np
21 |
22 |
23 | def iou(yx_min1, yx_max1, yx_min2, yx_max2, min=None):
24 | """
25 | Calculates the IoU of two bounding boxes.
26 | :author 申瑞珉 (Ruimin Shen)
27 | :param yx_min1: The top left coordinates (y, x) of the first bounding boxe.
28 | :param yx_max1: The bottom right coordinates (y, x) of the first bounding boxe.
29 | :param yx_min2: The top left coordinates (y, x) of the second bounding boxe.
30 | :param yx_max2: The bottom right coordinates (y, x) of the second bounding boxe.
31 | :return: The IoU.
32 | """
33 | assert np.all(yx_min1 <= yx_max1)
34 | assert np.all(yx_min2 <= yx_max2)
35 | if min is None:
36 | min = np.finfo(yx_min1.dtype).eps
37 | yx_min = np.maximum(yx_min1, yx_min2)
38 | yx_max = np.minimum(yx_max1, yx_max2)
39 | intersect_area = np.multiply.reduce(np.maximum(0.0, yx_max - yx_min))
40 | area1 = np.multiply.reduce(yx_max1 - yx_min1)
41 | area2 = np.multiply.reduce(yx_max2 - yx_min2)
42 | assert np.all(intersect_area >= 0)
43 | assert np.all(intersect_area <= area1)
44 | assert np.all(intersect_area <= area2)
45 | union_area = np.maximum(area1 + area2 - intersect_area, min)
46 | return intersect_area / union_area
47 |
48 |
49 | def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
50 | """
51 | Calculates the intersection area of two lists of bounding boxes.
52 | :author 申瑞珉 (Ruimin Shen)
53 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
54 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
55 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
56 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
57 | :return: The matrix (size [N1, N2]) of the intersection area.
58 | """
59 | ymin1, xmin1 = yx_min1.T
60 | ymax1, xmax1 = yx_max1.T
61 | ymin2, xmin2 = yx_min2.T
62 | ymax2, xmax2 = yx_max2.T
63 | ymin1, xmin1, ymax1, xmax1, ymin2, xmin2, ymax2, xmax2 = (np.expand_dims(a, -1) for a in (ymin1, xmin1, ymax1, xmax1, ymin2, xmin2, ymax2, xmax2))
64 | max_ymin = np.maximum(ymin1, np.transpose(ymin2))
65 | min_ymax = np.minimum(ymax1, np.transpose(ymax2))
66 | height = np.maximum(0.0, min_ymax - max_ymin)
67 | max_xmin = np.maximum(xmin1, np.transpose(xmin2))
68 | min_xmax = np.minimum(xmax1, np.transpose(xmax2))
69 | width = np.maximum(0.0, min_xmax - max_xmin)
70 | return height * width
71 |
72 |
73 | def iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=None):
74 | """
75 | Calculates the IoU of two lists of bounding boxes.
76 | :author 申瑞珉 (Ruimin Shen)
77 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
78 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
79 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
80 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
81 | :return: The matrix (size [N1, N2]) of the IoU.
82 | """
83 | if min is None:
84 | min = np.finfo(yx_min1.dtype).eps
85 | assert np.all(yx_min1 <= yx_max1)
86 | assert np.all(yx_min2 <= yx_max2)
87 | intersect_area = intersection_area(yx_min1, yx_max1, yx_min2, yx_max2)
88 | area1 = np.expand_dims(np.multiply.reduce(yx_max1 - yx_min1, -1), 1)
89 | area2 = np.expand_dims(np.multiply.reduce(yx_max2 - yx_min2, -1), 0)
90 | assert np.all(intersect_area >= 0)
91 | assert np.all(intersect_area <= area1)
92 | assert np.all(intersect_area <= area2)
93 | union_area = np.maximum(area1 + area2 - intersect_area, min)
94 | return intersect_area / union_area
95 |
96 |
97 | class TestIouMatrix(unittest.TestCase):
98 | def _test(self, bbox1, bbox2, ans, dtype=np.float32):
99 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans))
100 | yx_min1, yx_max1 = np.split(bbox1, 2, -1)
101 | yx_min2, yx_max2 = np.split(bbox2, 2, -1)
102 | assert np.all(yx_min1 <= yx_max1)
103 | assert np.all(yx_min2 <= yx_max2)
104 | assert np.all(ans >= 0)
105 | matrix = iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2)
106 | np.testing.assert_almost_equal(matrix, ans)
107 |
108 | def test0(self):
109 | bbox1 = [
110 | (1, 1, 2, 2),
111 | ]
112 | bbox2 = [
113 | (0, 0, 1, 1),
114 | (0, 1, 1, 2),
115 | (0, 2, 1, 3),
116 | (1, 0, 2, 1),
117 | (2, 0, 3, 1),
118 | (1, 2, 2, 3),
119 | (2, 1, 3, 2),
120 | (2, 2, 3, 3),
121 | ]
122 | ans = [
123 | [0] * len(bbox2),
124 | ]
125 | self._test(bbox1, bbox2, ans)
126 |
127 | def test1(self):
128 | bbox1 = [
129 | (1, 1, 3, 3),
130 | (0, 0, 4, 4),
131 | ]
132 | bbox2 = [
133 | (0, 0, 2, 2),
134 | (2, 0, 4, 2),
135 | (0, 2, 2, 4),
136 | (2, 2, 4, 4),
137 | ]
138 | ans = [
139 | [1 / (4 + 4 - 1)] * len(bbox2),
140 | [4 / 16] * len(bbox2),
141 | ]
142 | self._test(bbox1, bbox2, ans)
143 |
--------------------------------------------------------------------------------
/utils/iou/torch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import unittest
19 |
20 | import numpy as np
21 | import torch
22 |
23 |
24 | def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
25 | """
26 | Calculates the intersection area of two lists of bounding boxes.
27 | :author 申瑞珉 (Ruimin Shen)
28 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
29 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
30 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
31 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
32 | :return: The matrix (size [N1, N2]) of the intersection area.
33 | """
34 | ymin1, xmin1 = torch.split(yx_min1, 1, -1)
35 | ymax1, xmax1 = torch.split(yx_max1, 1, -1)
36 | ymin2, xmin2 = torch.split(yx_min2, 1, -1)
37 | ymax2, xmax2 = torch.split(yx_max2, 1, -1)
38 | max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug
39 | min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug
40 | height = torch.clamp(min_ymax - max_ymin, min=0)
41 | max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug
42 | min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug
43 | width = torch.clamp(min_xmax - max_xmin, min=0)
44 | return height * width
45 |
46 |
47 | def iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)):
48 | """
49 | Calculates the IoU of two lists of bounding boxes.
50 | :author 申瑞珉 (Ruimin Shen)
51 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
52 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
53 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
54 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
55 | :return: The matrix (size [N1, N2]) of the IoU.
56 | """
57 | intersect_area = intersection_area(yx_min1, yx_max1, yx_min2, yx_max2)
58 | area1 = torch.prod(yx_max1 - yx_min1, -1).unsqueeze(-1)
59 | area2 = torch.prod(yx_max2 - yx_min2, -1).unsqueeze(-2)
60 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min)
61 | return intersect_area / union_area
62 |
63 |
64 | class TestIouMatrix(unittest.TestCase):
65 | def _test(self, bbox1, bbox2, ans, dtype=np.float32):
66 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans))
67 | yx_min1, yx_max1 = np.split(bbox1, 2, -1)
68 | yx_min2, yx_max2 = np.split(bbox2, 2, -1)
69 | assert np.all(yx_min1 <= yx_max1)
70 | assert np.all(yx_min2 <= yx_max2)
71 | assert np.all(ans >= 0)
72 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1))
73 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2))
74 | if torch.cuda.is_available():
75 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2))
76 | matrix = iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy()
77 | np.testing.assert_almost_equal(matrix, ans)
78 |
79 | def test0(self):
80 | bbox1 = [
81 | (1, 1, 2, 2),
82 | ]
83 | bbox2 = [
84 | (0, 0, 1, 1),
85 | (0, 1, 1, 2),
86 | (0, 2, 1, 3),
87 | (1, 0, 2, 1),
88 | (2, 0, 3, 1),
89 | (1, 2, 2, 3),
90 | (2, 1, 3, 2),
91 | (2, 2, 3, 3),
92 | ]
93 | ans = [
94 | [0] * len(bbox2),
95 | ]
96 | self._test(bbox1, bbox2, ans)
97 |
98 | def test1(self):
99 | bbox1 = [
100 | (1, 1, 3, 3),
101 | (0, 0, 4, 4),
102 | ]
103 | bbox2 = [
104 | (0, 0, 2, 2),
105 | (2, 0, 4, 2),
106 | (0, 2, 2, 4),
107 | (2, 2, 4, 4),
108 | ]
109 | ans = [
110 | [1 / (4 + 4 - 1)] * len(bbox2),
111 | [4 / 16] * len(bbox2),
112 | ]
113 | self._test(bbox1, bbox2, ans)
114 |
115 |
116 | def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
117 | """
118 | Calculates the intersection area of two lists of bounding boxes for N independent batches.
119 | :author 申瑞珉 (Ruimin Shen)
120 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
121 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
122 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
123 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
124 | :return: The matrics (size [N, N1, N2]) of the intersection area.
125 | """
126 | ymin1, xmin1 = torch.split(yx_min1, 1, -1)
127 | ymax1, xmax1 = torch.split(yx_max1, 1, -1)
128 | ymin2, xmin2 = torch.split(yx_min2, 1, -1)
129 | ymax2, xmax2 = torch.split(yx_max2, 1, -1)
130 | max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug
131 | min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug
132 | height = torch.clamp(min_ymax - max_ymin, min=0)
133 | max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug
134 | min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug
135 | width = torch.clamp(min_xmax - max_xmin, min=0)
136 | return height * width
137 |
138 |
139 | def batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)):
140 | """
141 | Calculates the IoU of two lists of bounding boxes for N independent batches.
142 | :author 申瑞珉 (Ruimin Shen)
143 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
144 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
145 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
146 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
147 | :return: The matrics (size [N, N1, N2]) of the IoU.
148 | """
149 | intersect_area = batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2)
150 | area1 = torch.prod(yx_max1 - yx_min1, -1).unsqueeze(-1)
151 | area2 = torch.prod(yx_max2 - yx_min2, -1).unsqueeze(-2)
152 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min)
153 | return intersect_area / union_area
154 |
155 |
156 | class TestBatchIouMatrix(unittest.TestCase):
157 | def _test(self, bbox1, bbox2, ans, batch_size=2, dtype=np.float32):
158 | bbox1, bbox2, ans = (np.expand_dims(np.array(a, dtype), 0) for a in (bbox1, bbox2, ans))
159 | if batch_size > 1:
160 | bbox1, bbox2, ans = (np.tile(a, (batch_size, 1, 1)) for a in (bbox1, bbox2, ans))
161 | for b in range(batch_size):
162 | indices1 = np.random.permutation(bbox1.shape[1])
163 | indices2 = np.random.permutation(bbox2.shape[1])
164 | bbox1[b] = bbox1[b][indices1]
165 | bbox2[b] = bbox2[b][indices2]
166 | ans[b] = ans[b][indices1][:, indices2]
167 | yx_min1, yx_max1 = np.split(bbox1, 2, -1)
168 | yx_min2, yx_max2 = np.split(bbox2, 2, -1)
169 | assert np.all(yx_min1 <= yx_max1)
170 | assert np.all(yx_min2 <= yx_max2)
171 | assert np.all(ans >= 0)
172 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1))
173 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2))
174 | if torch.cuda.is_available():
175 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2))
176 | matrix = batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy()
177 | np.testing.assert_almost_equal(matrix, ans)
178 |
179 | def test0(self):
180 | bbox1 = [
181 | (1, 1, 2, 2),
182 | ]
183 | bbox2 = [
184 | (0, 0, 1, 1),
185 | (0, 1, 1, 2),
186 | (0, 2, 1, 3),
187 | (1, 0, 2, 1),
188 | (2, 0, 3, 1),
189 | (1, 2, 2, 3),
190 | (2, 1, 3, 2),
191 | (2, 2, 3, 3),
192 | ]
193 | ans = [
194 | [0] * len(bbox2),
195 | ]
196 | self._test(bbox1, bbox2, ans)
197 |
198 | def test1(self):
199 | bbox1 = [
200 | (1, 1, 3, 3),
201 | (0, 0, 4, 4),
202 | ]
203 | bbox2 = [
204 | (0, 0, 2, 2),
205 | (2, 0, 4, 2),
206 | (0, 2, 2, 4),
207 | (2, 2, 4, 4),
208 | ]
209 | ans = [
210 | [1 / (4 + 4 - 1)] * len(bbox2),
211 | [4 / 16] * len(bbox2),
212 | ]
213 | self._test(bbox1, bbox2, ans)
214 |
215 |
216 | def batch_iou_pair(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)):
217 | """
218 | Pairwisely calculates the IoU of two lists (at the same size M) of bounding boxes for N independent batches.
219 | :author 申瑞珉 (Ruimin Shen)
220 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, M, 2]) of bounding boxes.
221 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, M, 2]) of bounding boxes.
222 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, M, 2]) of bounding boxes.
223 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, M, 2]) of bounding boxes.
224 | :return: The lists (size [N, M]) of the IoU.
225 | """
226 | yx_min = torch.max(yx_min1, yx_min2)
227 | yx_max = torch.min(yx_max1, yx_max2)
228 | size = torch.clamp(yx_max - yx_min, min=0)
229 | intersect_area = torch.prod(size, -1)
230 | area1 = torch.prod(yx_max1 - yx_min1, -1)
231 | area2 = torch.prod(yx_max2 - yx_min2, -1)
232 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min)
233 | return intersect_area / union_area
234 |
235 |
236 | class TestBatchIouPair(unittest.TestCase):
237 | def _test(self, bbox1, bbox2, ans, dtype=np.float32):
238 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans))
239 | batch_size = bbox1.shape[0]
240 | cells = bbox2.shape[0]
241 | bbox1 = np.tile(np.reshape(bbox1, [-1, 1, 4]), [1, cells, 1])
242 | bbox2 = np.tile(np.reshape(bbox2, [1, -1, 4]), [batch_size, 1, 1])
243 | yx_min1, yx_max1 = np.split(bbox1, 2, -1)
244 | yx_min2, yx_max2 = np.split(bbox2, 2, -1)
245 | assert np.all(yx_min1 <= yx_max1)
246 | assert np.all(yx_min2 <= yx_max2)
247 | assert np.all(ans >= 0)
248 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1))
249 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2))
250 | if torch.cuda.is_available():
251 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2))
252 | iou = batch_iou_pair(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy()
253 | np.testing.assert_almost_equal(iou, ans)
254 |
255 | def test0(self):
256 | bbox1 = [
257 | (1, 1, 2, 2),
258 | ]
259 | bbox2 = [
260 | (0, 0, 1, 1),
261 | (0, 1, 1, 2),
262 | (0, 2, 1, 3),
263 | (1, 0, 2, 1),
264 | (2, 0, 3, 1),
265 | (1, 2, 2, 3),
266 | (2, 1, 3, 2),
267 | (2, 2, 3, 3),
268 | ]
269 | ans = [
270 | [0] * len(bbox2),
271 | ]
272 | self._test(bbox1, bbox2, ans)
273 |
274 | def test1(self):
275 | bbox1 = [
276 | (1, 1, 3, 3),
277 | (0, 0, 4, 4),
278 | ]
279 | bbox2 = [
280 | (0, 0, 2, 2),
281 | (2, 0, 4, 2),
282 | (0, 2, 2, 4),
283 | (2, 2, 4, 4),
284 | ]
285 | ans = [
286 | [1 / (4 + 4 - 1)] * len(bbox2),
287 | [4 / 16] * len(bbox2),
288 | ]
289 | self._test(bbox1, bbox2, ans)
290 |
291 |
292 | if __name__ == '__main__':
293 | unittest.main()
294 |
--------------------------------------------------------------------------------
/utils/postprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import torch
19 |
20 | import utils.iou.torch
21 |
22 |
23 | def nms(score, yx_min, yx_max, overlap=0.5, limit=200):
24 | """
25 | Filtering the overlapping (IoU > overlap threshold) bounding boxes according to the score (in descending order).
26 | :author 申瑞珉 (Ruimin Shen)
27 | :param score: The scores of the list (size [N]) of bounding boxes.
28 | :param yx_min: The top left coordinates (y, x) of the list (size [N, 2]) of bounding boxes.
29 | :param yx_max: The bottom right coordinates (y, x) of the list (size [N, 2]) of bounding boxes.
30 | :param overlap: The IoU threshold.
31 | :param limit: Limits the number of results.
32 | :return: The indices of the selected bounding boxes.
33 | """
34 | keep = []
35 | if score.numel() == 0:
36 | return keep
37 | _, index = score.sort(descending=True)
38 | index = index[:limit]
39 | while index.numel() > 0:
40 | i = index[0]
41 | keep.append(i)
42 | if index.size(0) == 1:
43 | break
44 | index = index[1:]
45 | yx_min1, yx_max1 = (torch.unsqueeze(t[i], 0) for t in (yx_min, yx_max))
46 | yx_min2, yx_max2 = (torch.index_select(t, 0, index) for t in (yx_min, yx_max))
47 | iou = utils.iou.torch.iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2)[0]
48 | index = index[iou <= overlap]
49 | return keep
50 |
--------------------------------------------------------------------------------
/utils/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import time
20 | import operator
21 | import logging
22 |
23 | import torch
24 |
25 |
26 | class Timer(object):
27 | def __init__(self, max, first=True):
28 | """
29 | A simple function object to determine time event.
30 | :author 申瑞珉 (Ruimin Shen)
31 | :param max: Number of seconds to trigger a time event.
32 | :param first: Should a time event to be triggered at the first time.
33 | """
34 | self.start = 0 if first else time.time()
35 | self.max = max
36 |
37 | def __call__(self):
38 | """
39 | Return a boolean value to indicate if the time event is occurred.
40 | :author 申瑞珉 (Ruimin Shen)
41 | """
42 | t = time.time()
43 | elapsed = t - self.start
44 | if elapsed > self.max:
45 | self.start = t
46 | return True
47 | else:
48 | return False
49 |
50 |
51 | def load_model(model_dir, step=None, ext='.pth', ext_epoch='.epoch', logger=logging.info):
52 | """
53 | Load the latest checkpoint in a model directory.
54 | :author 申瑞珉 (Ruimin Shen)
55 | :param model_dir: The directory to store the model checkpoint files.
56 | :param step: If a integer value is given, the corresponding checkpoint will be loaded. Otherwise, the latest checkpoint (with the largest step value) will be loaded.
57 | :param ext: The extension of the model file.
58 | :param ext_epoch: The extension of the epoch file.
59 | :return:
60 | """
61 | if step is None:
62 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(model_dir)) if n.isdigit() and e == ext]
63 | step, name = max(steps, key=operator.itemgetter(0))
64 | else:
65 | name = str(step)
66 | prefix = os.path.join(model_dir, name)
67 | if logger is not None:
68 | logger('load %s.*' % prefix)
69 | try:
70 | with open(prefix + ext_epoch, 'r') as f:
71 | epoch = int(f.read())
72 | except (FileNotFoundError, ValueError):
73 | epoch = None
74 | path = prefix + ext
75 | assert os.path.exists(path), path
76 | return path, step, epoch
77 |
78 |
79 | class Saver(object):
80 | def __init__(self, model_dir, keep, ext='.pth', ext_epoch='.epoch', logger=logging.info):
81 | """
82 | Manage several latest checkpoints (with the largest step values) in a model directory.
83 | :author 申瑞珉 (Ruimin Shen)
84 | :param model_dir: The directory to store the model checkpoint files.
85 | :param keep: How many latest checkpoints to be maintained.
86 | :param ext: The extension of the model file.
87 | :param ext_epoch: The extension of the epoch file.
88 | """
89 | self.model_dir = model_dir
90 | self.keep = keep
91 | self.ext = ext
92 | self.ext_epoch = ext_epoch
93 | self.logger = (lambda s: s) if logger is None else logger
94 |
95 | def __call__(self, obj, step, epoch=None):
96 | """
97 | Save the PyTorch module.
98 | :author 申瑞珉 (Ruimin Shen)
99 | :param obj: The PyTorch module to be saved.
100 | :param step: Current step.
101 | :param epoch: Current epoch.
102 | """
103 | os.makedirs(self.model_dir, exist_ok=True)
104 | prefix = os.path.join(self.model_dir, str(step))
105 | torch.save(obj, prefix + self.ext)
106 | if epoch is not None:
107 | with open(prefix + self.ext_epoch, 'w') as f:
108 | f.write(str(epoch))
109 | self.logger('model saved into %s.*' % prefix)
110 | self.tidy()
111 | return prefix
112 |
113 | def tidy(self):
114 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(self.model_dir)) if n.isdigit() and e == self.ext]
115 | if len(steps) > self.keep:
116 | steps = sorted(steps, key=operator.itemgetter(0))
117 | remove = steps[:len(steps) - self.keep]
118 | for _, n in remove:
119 | path = os.path.join(self.model_dir, n)
120 | os.remove(path + self.ext)
121 | path_epoch = path + self.ext_epoch
122 | try:
123 | os.remove(path_epoch)
124 | except FileNotFoundError:
125 | self.logger(path_epoch + ' not found')
126 | logging.debug('tidy ' + path)
127 |
128 |
129 | def load_sizes(config):
130 | sizes = [s.split(',') for s in config.get('data', 'sizes').split()]
131 | return [(int(height), int(width)) for height, width in sizes]
132 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import itertools
20 | import inspect
21 |
22 | import numpy as np
23 | import torch
24 | import matplotlib
25 | import matplotlib.cm
26 | import matplotlib.colors
27 | import matplotlib.pyplot as plt
28 | import humanize
29 | import graphviz
30 | import cv2
31 |
32 | import utils
33 |
34 |
35 | class DrawBBox(object):
36 | def __init__(self, category, colors=[], thickness=3, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1):
37 | self.category = category
38 | if colors:
39 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
40 | else:
41 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']]
42 | self.thickness = thickness
43 | self.line_type = line_type
44 | self.shift = shift
45 | self.font_face = font_face
46 | self.font_scale = font_scale
47 |
48 | def __call__(self, image, yx_min, yx_max, cls=None, colors=None, debug=False):
49 | colors = self.colors if colors is None else [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
50 | if cls is None:
51 | cls = [None] * len(yx_min)
52 | for color, (ymin, xmin), (ymax, xmax), cls in zip(itertools.cycle(colors), yx_min, yx_max, cls):
53 | try:
54 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness=self.thickness, lineType=self.line_type, shift=self.shift)
55 | if cls is not None:
56 | cv2.putText(image, self.category[cls], (xmin, ymin), self.font_face, self.font_scale, color=color, thickness=self.thickness)
57 | except OverflowError as e:
58 | logging.warning(e, (xmin, ymin), (xmax, ymax))
59 | if debug:
60 | cv2.imshow('', image)
61 | cv2.waitKey(0)
62 | return image
63 |
64 |
65 | class DrawFeature(object):
66 | def __init__(self, alpha=0.5, cmap=None):
67 | self.alpha = alpha
68 | self.cm = matplotlib.cm.get_cmap(cmap)
69 |
70 | def __call__(self, image, feature, debug=False):
71 | _feature = (feature * self.cm.N).astype(np.int)
72 | heatmap = self.cm(_feature)[:, :, :3] * 255
73 | heatmap = cv2.resize(heatmap, image.shape[1::-1], interpolation=cv2.INTER_NEAREST)
74 | canvas = (image * (1 - self.alpha) + heatmap * self.alpha).astype(np.uint8)
75 | if debug:
76 | cv2.imshow('max=%f, sum=%f' % (np.max(feature), np.sum(feature)), canvas)
77 | cv2.waitKey(0)
78 | return canvas
79 |
80 |
81 | class Graph(object):
82 | def __init__(self, config, state_dict, cmap=None):
83 | self.dot = graphviz.Digraph(node_attr=dict(config.items('digraph_node_attr')), graph_attr=dict(config.items('digraph_graph_attr')))
84 | self.dot.format = config.get('graph', 'format')
85 | self.state_dict = state_dict
86 | self.var_name = {t._cdata: k for k, t in state_dict.items()}
87 | self.seen = set()
88 | self.index = 0
89 | self.drawn = set()
90 | self.cm = matplotlib.cm.get_cmap(cmap)
91 | self.metric = eval(config.get('graph', 'metric'))
92 | metrics = [self.metric(t) for t in state_dict.values()]
93 | self.minmax = [min(metrics), max(metrics)]
94 |
95 | def __call__(self, node):
96 | if node not in self.seen:
97 | self.traverse_next(node)
98 | self.traverse_tensor(node)
99 | self.seen.add(node)
100 | self.index += 1
101 |
102 | def traverse_next(self, node):
103 | if hasattr(node, 'next_functions'):
104 | for n, _ in node.next_functions:
105 | if n is not None:
106 | self.__call__(n)
107 | self._draw_node_edge(node, n)
108 | self._draw_node(node)
109 |
110 | def traverse_tensor(self, node):
111 | tensors = [t for name, t in inspect.getmembers(node) if torch.is_tensor(t)]
112 | if hasattr(node, 'saved_tensors'):
113 | tensors += node.saved_tensors
114 | for tensor in tensors:
115 | name = self.var_name[tensor._cdata]
116 | self.drawn.add(name)
117 | self._draw_tensor(node, tensor)
118 |
119 | def _draw_node(self, node):
120 | if hasattr(node, 'variable'):
121 | tensor = node.variable.data
122 | name = self.var_name[tensor._cdata]
123 | label = '\n'.join(map(str, [
124 | '%d: %s' % (self.index, name),
125 | list(tensor.size()),
126 | humanize.naturalsize(tensor.numpy().nbytes),
127 | ]))
128 | fillcolor, fontcolor = self._tensor_color(tensor)
129 | self.dot.node(str(id(node)), label, shape='note', fillcolor=fillcolor, fontcolor=fontcolor)
130 | self.drawn.add(name)
131 | else:
132 | self.dot.node(str(id(node)), '%d: %s' % (self.index, type(node).__name__), fillcolor='white')
133 |
134 | def _draw_node_edge(self, node, n):
135 | if hasattr(n, 'variable'):
136 | self.dot.edge(str(id(n)), str(id(node)), arrowhead='none', arrowtail='none')
137 | else:
138 | self.dot.edge(str(id(n)), str(id(node)))
139 |
140 | def _draw_tensor(self, node, tensor):
141 | name = self.var_name[tensor._cdata]
142 | label = '\n'.join(map(str, [
143 | name,
144 | list(tensor.size()),
145 | humanize.naturalsize(tensor.numpy().nbytes),
146 | ]))
147 | fillcolor, fontcolor = self._tensor_color(tensor)
148 | self.dot.node(name, label, style='filled, rounded', fillcolor=fillcolor, fontcolor=fontcolor)
149 | self.dot.edge(name, str(id(node)), style='dashed', arrowhead='none', arrowtail='none')
150 |
151 | def _tensor_color(self, tensor):
152 | level = self._norm(self.metric(tensor))
153 | fillcolor = self.cm(np.int(level * self.cm.N))
154 | fontcolor = self.cm(self.cm.N if level < 0.5 else 0)
155 | return matplotlib.colors.to_hex(fillcolor), matplotlib.colors.to_hex(fontcolor)
156 |
157 | def _norm(self, metric):
158 | min, max = self.minmax
159 | assert min <= metric <= max, (metric, self.minmax)
160 | if min < max:
161 | return (metric - min) / (max - min)
162 | else:
163 | return metric
164 |
--------------------------------------------------------------------------------
/variable_stat.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import importlib
24 | import inspect
25 | import inflection
26 | import yaml
27 |
28 | import numpy as np
29 | import torch
30 | import humanize
31 | import xlsxwriter
32 |
33 | import utils
34 | import utils.train
35 | import utils.channel
36 |
37 |
38 | class Name(object):
39 | def __call__(self, name, variable):
40 | return name
41 |
42 |
43 | class Size(object):
44 | def __call__(self, name, variable):
45 | return 'x'.join(map(str, variable.size()))
46 |
47 |
48 | class Bytes(object):
49 | def __call__(self, name, variable):
50 | return variable.numpy().nbytes
51 |
52 | def format(self, workbook, worksheet, num, col):
53 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
54 |
55 |
56 | class BytesNatural(object):
57 | def __call__(self, name, variable):
58 | return humanize.naturalsize(variable.numpy().nbytes)
59 |
60 |
61 | class MeanDense(object):
62 | def __call__(self, name, variable):
63 | return np.mean(utils.channel.dense(variable))
64 |
65 | def format(self, workbook, worksheet, num, col):
66 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
67 |
68 |
69 | class Rank(object):
70 | def __call__(self, name, variable):
71 | return len(variable.size())
72 |
73 | def format(self, workbook, worksheet, num, col):
74 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'})
75 |
76 |
77 | def main():
78 | args = make_args()
79 | config = configparser.ConfigParser()
80 | utils.load_config(config, args.config)
81 | for cmd in args.modify:
82 | utils.modify_config(config, cmd)
83 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
84 | logging.config.dictConfig(yaml.load(f))
85 | model_dir = utils.get_model_dir(config)
86 | path, step, epoch = utils.train.load_model(model_dir)
87 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
88 | mapper = [(inflection.underscore(name), member()) for name, member in inspect.getmembers(importlib.machinery.SourceFileLoader('', __file__).load_module()) if inspect.isclass(member)]
89 | path = os.path.join(model_dir, os.path.basename(os.path.splitext(__file__)[0])) + '.xlsx'
90 | with xlsxwriter.Workbook(path, {'strings_to_urls': False, 'nan_inf_to_errors': True}) as workbook:
91 | worksheet = workbook.add_worksheet(args.worksheet)
92 | for j, (key, m) in enumerate(mapper):
93 | worksheet.write(0, j, key)
94 | for i, (name, variable) in enumerate(state_dict.items()):
95 | value = m(name, variable)
96 | worksheet.write(1 + i, j, value)
97 | if hasattr(m, 'format'):
98 | m.format(workbook, worksheet, i, j)
99 | worksheet.autofilter(0, 0, i, len(mapper) - 1)
100 | worksheet.freeze_panes(1, 0)
101 | logging.info(path)
102 |
103 |
104 | def make_args():
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
107 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
108 | parser.add_argument('--logging', default='logging.yml', help='logging config')
109 | parser.add_argument('--worksheet', default='sheet')
110 | parser.add_argument('--nohead', action='store_true')
111 | return parser.parse_args()
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/video2image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import sys
20 | import argparse
21 |
22 | import pymediainfo
23 | import tqdm
24 | from contextlib import closing
25 | import videosequence
26 |
27 |
28 | def get_step(frames, video_track, **kwargs):
29 | if 'frames' in kwargs:
30 | step = len(frames) // kwargs['frames']
31 | elif 'frames_per_sec' in kwargs > 0:
32 | frame_rate = float(video_track.frame_rate)
33 | step = int(frame_rate / kwargs['frames_per_sec'])
34 | assert step > 0
35 | return step
36 |
37 |
38 | def convert(video_file, image_prefix, **kwargs):
39 | media_info = pymediainfo.MediaInfo.parse(video_file)
40 | video_tracks = [track for track in media_info.tracks if track.track_type == 'Video']
41 | if len(video_tracks) < 1:
42 | raise videosequence.VideoError()
43 | video_track = video_tracks[0]
44 | _rotation = float(video_track.rotation)
45 | rotation = int(_rotation)
46 | assert rotation - _rotation == 0
47 | with closing(videosequence.VideoSequence(video_file)) as frames:
48 | step = get_step(frames, video_track, **kwargs)
49 | _frames = frames[::step]
50 | for idx, frame in enumerate(tqdm.tqdm(_frames)):
51 | frame = frame.rotate(-rotation, expand=True)
52 | frame.save('%s_%04d.jpg' % (image_prefix, idx))
53 |
54 |
55 | def main():
56 | args = make_args()
57 | src = os.path.expanduser(os.path.expandvars(args.src))
58 | dst = os.path.expanduser(os.path.expandvars(args.dst))
59 | os.makedirs(dst, exist_ok=True)
60 | kwargs = {}
61 | if args.frames > 0:
62 | kwargs['frames'] = args.frames
63 | elif args.frames_per_sec > 0:
64 | kwargs['frames_per_sec'] = args.frames_per_sec
65 | exts = set()
66 | for dirpath, _, filenames in os.walk(src):
67 | for filename in filenames:
68 | ext = os.path.splitext(filename)[-1].lower()
69 | if ext in args.ext:
70 | path = os.path.join(dirpath, filename)
71 | print(path)
72 | name = os.path.relpath(path, src).replace(os.path.sep, args.replace)
73 | _path = os.path.join(dst, name)
74 | try:
75 | convert(path, _path, **kwargs)
76 | except videosequence.VideoError as e:
77 | sys.stderr.write(str(e) + '\n')
78 | else:
79 | exts.add(ext)
80 | print(exts)
81 |
82 |
83 | def make_args():
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('src')
86 | parser.add_argument('dst')
87 | parser.add_argument('-e', '--ext', nargs='+', default=['.mp4', '.mov', '.m4v'])
88 | parser.add_argument('-r', '--replace', default='_', help='replace the path separator into the given character')
89 | parser.add_argument('-f', '--frames', default=0, type=int, help='total output frames in a video')
90 | parser.add_argument('--frames_per_sec', default=0, type=int, help='output frames in a second')
91 | return parser.parse_args()
92 |
93 |
94 | if __name__ == '__main__':
95 | main()
96 |
--------------------------------------------------------------------------------