├── .gitignore ├── LICENSE.md ├── README.md ├── benchmark_caffe2.py ├── cache.py ├── cache ├── __init__.py ├── coco.py ├── coco.tsv ├── voc.py └── voc.txt ├── checksum_caffe2.py ├── checksum_torch.py ├── config.ini ├── config ├── anchors │ ├── coco.tsv │ ├── tiny-yolo-voc.tsv │ ├── voc.tsv │ ├── yolo-voc.tsv │ └── yolo.tsv ├── category │ ├── 20 │ ├── 80 │ └── person ├── darknet.ini ├── darknet │ ├── tiny-yolo-voc.ini │ ├── yolo-voc.ini │ └── yolo.ini ├── debug.ini ├── eval.py └── summary │ └── histogram.txt ├── convert_darknet_torch.py ├── convert_onnx_caffe2.py ├── convert_torch_onnx.py ├── demo.gif ├── demo_data.py ├── demo_graph.py ├── demo_lr.py ├── detect.py ├── dimension_cluster.py ├── disable_bad_images.py ├── donate_alipay.jpg ├── donate_mm.jpg ├── download_url.py ├── eval.py ├── image.jpg ├── logging.yml ├── model ├── __init__.py ├── densenet.py ├── inception3.py ├── inception4.py ├── mobilenet.py ├── resnet.py ├── vgg.py └── yolo2.py ├── pruner.py ├── quick_start.sh ├── receptive_field_analyzer.py ├── requirements.txt ├── split_data.py ├── train.py ├── transform ├── __init__.py ├── augmentation.py ├── image.py └── resize │ ├── __init__.py │ ├── image.py │ └── label.py ├── utils ├── __init__.py ├── cache.py ├── channel.py ├── data.py ├── iou │ ├── __init__.py │ ├── numpy.py │ └── torch.py ├── postprocess.py ├── train.py └── visualize.py ├── variable_stat.py └── video2image.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .project 4 | .pydevproject 5 | .settings/ 6 | .idea/ 7 | .cache/ 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of the [YOLO (You Only Look Once) v2](https://arxiv.org/pdf/1612.08242.pdf) 2 | 3 | The YOLOv2 is one of the most popular [one-stage](https://arxiv.org/abs/1708.02002) object detector. 4 | This project adopts [PyTorch](http://pytorch.org/) as the developing framework to increase productivity, and utilize [ONNX](https://github.com/onnx/onnx) to convert models into [Caffe 2](https://caffe2.ai/) to benefit engineering deployment. 5 | If you are benefited from this project, a donation will be appreciated (via [PayPal](https://www.paypal.me/minimumshen), [微信支付](donate_mm.jpg) or [支付宝](donate_alipay.jpg)). 6 | 7 | ![](demo.gif) 8 | 9 | ## Designs 10 | 11 | - Flexible configuration design. 12 | Program settings are configurable and can be modified (via **configure file overlaping** (-c/--config option) or **command editing** (-m/--modify option)) using command line argument. 13 | 14 | - Monitoring via [TensorBoard](https://github.com/tensorflow/tensorboard). 15 | Such as the loss values and the debugging images (such as IoU heatmap, ground truth and predict bounding boxes). 16 | 17 | - Parallel model training design. 18 | Different models are saved into different directories so that can be trained simultaneously. 19 | 20 | - Using a NoSQL database to store evaluation results with multiple dimension of information. 21 | This design is useful when analyzing a large amount of experiment results. 22 | 23 | - Time-based output design. 24 | Running information (such as the model, the summaries (produced by TensorBoard), and the evaluation results) are saved periodically via a predefined time. 25 | 26 | - Checkpoint management. 27 | Several latest checkpoint files (.pth) are preserved in the model directory and the older ones are deleted. 28 | 29 | - NaN debug. 30 | When a NaN loss is detected, the running environment (data batch) and the model will be exported to analyze the reason. 31 | 32 | - Unified data cache design. 33 | Various dataset are converted into a unified data cache via corresponding cache plugins. 34 | Some plugins are already implemented. Such as [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [MS COCO](http://cocodataset.org/). 35 | 36 | - Arbitrarily replaceable model plugin design. 37 | The main deep neural network (DNN) can be easily replaced via configuration settings. 38 | Multiple models are already provided. Such as Darknet, [ResNet](https://arxiv.org/abs/1512.03385), Inception [v3](https://arxiv.org/abs/1512.00567) and [v4](https://arxiv.org/abs/1602.07261), [MobileNet](https://arxiv.org/abs/1704.04861) and [DenseNet](https://arxiv.org/abs/1608.06993). 39 | 40 | - Extendable data preprocess plugin design. 41 | The original images (in different sizes) and labels are processed via a sequence of operations to form a training batch (images with the same size, and bounding boxes list are padded). 42 | Multiple preprocess plugins are already implemented. Such as 43 | augmentation operators to process images and labels (such as random rotate and random flip) simultaneously, 44 | operators to resize both images and labels into a fixed size in a batch (such as random crop), 45 | and operators to augment images without labels (such as random blur, random saturation and random brightness). 46 | 47 | ## Feautures 48 | 49 | - [x] Reproduce the original paper's training results. 50 | - [x] Multi-scale training. 51 | - [x] Dimension cluster. 52 | - [x] [Darknet](http://pjreddie.com) model file (`.weights`) parser. 53 | - [x] Detection from image and camera. 54 | - [x] Processing Video file. 55 | - [x] Multi-GPU supporting. 56 | - [ ] Distributed training. 57 | - [ ] [Focal loss](https://arxiv.org/abs/1708.02002). 58 | - [x] Channel-wise model parameter analyzer. 59 | - [x] Automatically change the number of channels. 60 | - [x] Receptive field analyzer. 61 | 62 | ## Quick Start 63 | 64 | This project uses [Python 3](https://www.python.org/). To install the dependent libraries, type the following command in a terminal. 65 | 66 | ``` 67 | sudo pip3 install -r requirements.txt 68 | ``` 69 | 70 | `quick_start.sh` contains the examples to perform detection and evaluation. Run this script. 71 | Multiple datasets and models (the original Darknet's format, will be converted into PyTorch's format) will be downloaded ([aria2](https://aria2.github.io/) is required). 72 | These datasets are cached into different data profiles, and the models are evaluated over the cached data. 73 | The models are used to detect objects in an example image, and the detection results will be shown. 74 | 75 | ## License 76 | 77 | This project is released as the open source software with the GNU Lesser General Public License version 3 ([LGPL v3](http://www.gnu.org/licenses/lgpl-3.0.html)). 78 | -------------------------------------------------------------------------------- /benchmark_caffe2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import onnx_caffe2.helper 26 | 27 | import utils 28 | 29 | 30 | def main(): 31 | args = make_args() 32 | config = configparser.ConfigParser() 33 | utils.load_config(config, args.config) 34 | for cmd in args.modify: 35 | utils.modify_config(config, cmd) 36 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 37 | logging.config.dictConfig(yaml.load(f)) 38 | model_dir = utils.get_model_dir(config) 39 | init_net = onnx_caffe2.helper.load_caffe2_net(os.path.join(model_dir, 'init_net.pb')) 40 | predict_net = onnx_caffe2.helper.load_caffe2_net(os.path.join(model_dir, 'predict_net.pb')) 41 | benchmark = onnx_caffe2.helper.benchmark_caffe2_model(init_net, predict_net) 42 | logging.info('benchmark=%f(milliseconds)' % benchmark) 43 | 44 | 45 | def make_args(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 48 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 49 | parser.add_argument('-b', '--benchmark', action='store_true') 50 | parser.add_argument('--logging', default='logging.yml', help='logging config') 51 | return parser.parse_args() 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import importlib 24 | import pickle 25 | import random 26 | import shutil 27 | import yaml 28 | 29 | import utils 30 | 31 | 32 | def main(): 33 | args = make_args() 34 | config = configparser.ConfigParser() 35 | utils.load_config(config, args.config) 36 | for cmd in args.modify: 37 | utils.modify_config(config, cmd) 38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 39 | logging.config.dictConfig(yaml.load(f)) 40 | cache_dir = utils.get_cache_dir(config) 41 | os.makedirs(cache_dir, exist_ok=True) 42 | shutil.copyfile(os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))), os.path.join(cache_dir, 'category')) 43 | category = utils.get_category(config) 44 | category_index = dict([(name, i) for i, name in enumerate(category)]) 45 | datasets = config.get('cache', 'datasets').split() 46 | for phase in args.phase: 47 | path = os.path.join(cache_dir, phase) + '.pkl' 48 | logging.info('save cache file: ' + path) 49 | data = [] 50 | for dataset in datasets: 51 | logging.info('load %s dataset' % dataset) 52 | module, func = dataset.rsplit('.', 1) 53 | module = importlib.import_module(module) 54 | func = getattr(module, func) 55 | data += func(config, path, category_index) 56 | if config.getboolean('cache', 'shuffle'): 57 | random.shuffle(data) 58 | with open(path, 'wb') as f: 59 | pickle.dump(data, f) 60 | logging.info('%s data are saved into %s' % (str(args.phase), cache_dir)) 61 | 62 | 63 | def make_args(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 66 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 67 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 68 | parser.add_argument('--logging', default='logging.yml', help='logging config') 69 | return parser.parse_args() 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/cache/__init__.py -------------------------------------------------------------------------------- /cache/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import logging 20 | import configparser 21 | 22 | import numpy as np 23 | import pandas as pd 24 | import tqdm 25 | import pycocotools.coco 26 | import cv2 27 | 28 | import utils.cache 29 | 30 | 31 | def cache(config, path, category_index): 32 | phase = os.path.splitext(os.path.basename(path))[0] 33 | data = [] 34 | for i, row in pd.read_csv(os.path.splitext(__file__)[0] + '.tsv', sep='\t').iterrows(): 35 | logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()]))) 36 | root = os.path.expanduser(os.path.expandvars(row['root'])) 37 | year = str(row['year']) 38 | suffix = phase + year 39 | path = os.path.join(root, 'annotations', 'instances_%s.json' % suffix) 40 | if not os.path.exists(path): 41 | logging.warning(path + ' not exists') 42 | continue 43 | coco = pycocotools.coco.COCO(path) 44 | catIds = coco.getCatIds(catNms=list(category_index.keys())) 45 | cats = coco.loadCats(catIds) 46 | id_index = dict((cat['id'], category_index[cat['name']]) for cat in cats) 47 | imgIds = coco.getImgIds() 48 | path = os.path.join(root, suffix) 49 | imgs = coco.loadImgs(imgIds) 50 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(path, img['file_name'])), imgs)) 51 | if len(imgs) > len(_imgs): 52 | logging.warning('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs))) 53 | for img in tqdm.tqdm(_imgs): 54 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) 55 | anns = coco.loadAnns(annIds) 56 | if len(anns) <= 0: 57 | continue 58 | path = os.path.join(path, img['file_name']) 59 | width, height = img['width'], img['height'] 60 | bbox = np.array([ann['bbox'] for ann in anns], dtype=np.float32) 61 | yx_min = bbox[:, 1::-1] 62 | hw = bbox[:, -1:1:-1] 63 | yx_max = yx_min + hw 64 | cls = np.array([id_index[ann['category_id']] for ann in anns], dtype=np.int) 65 | difficult = np.zeros(cls.shape, dtype=np.uint8) 66 | try: 67 | if config.getboolean('cache', 'verify'): 68 | size = (height, width) 69 | image = cv2.imread(path) 70 | assert image is not None 71 | assert image.shape[:2] == size[:2] 72 | utils.cache.verify_coords(yx_min, yx_max, size[:2]) 73 | except configparser.NoOptionError: 74 | pass 75 | assert len(yx_min) == len(cls) 76 | assert yx_min.shape == yx_max.shape 77 | assert len(yx_min.shape) == 2 and yx_min.shape[-1] == 2 78 | data.append(dict(path=path, yx_min=yx_min, yx_max=yx_max, cls=cls, difficult=difficult)) 79 | logging.warning('%d of %d images are saved' % (len(data), len(_imgs))) 80 | return data 81 | -------------------------------------------------------------------------------- /cache/coco.tsv: -------------------------------------------------------------------------------- 1 | root year 2 | ~/data/coco 2014 3 | ~/data/coco 2017 4 | -------------------------------------------------------------------------------- /cache/voc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import logging 20 | import configparser 21 | 22 | import numpy as np 23 | import tqdm 24 | import xml.etree.ElementTree 25 | import cv2 26 | 27 | import utils.cache 28 | 29 | 30 | def load_annotation(path, category_index): 31 | tree = xml.etree.ElementTree.parse(path) 32 | yx_min = [] 33 | yx_max = [] 34 | cls = [] 35 | difficult = [] 36 | for obj in tree.findall('object'): 37 | try: 38 | cls.append(category_index[obj.find('name').text]) 39 | except KeyError: 40 | continue 41 | bbox = obj.find('bndbox') 42 | ymin = float(bbox.find('ymin').text) - 1 43 | xmin = float(bbox.find('xmin').text) - 1 44 | ymax = float(bbox.find('ymax').text) - 1 45 | xmax = float(bbox.find('xmax').text) - 1 46 | assert ymin < ymax 47 | assert xmin < xmax 48 | yx_min.append((ymin, xmin)) 49 | yx_max.append((ymax, xmax)) 50 | difficult.append(int(obj.find('difficult').text)) 51 | size = tree.find('size') 52 | return tree.find('filename').text, (int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)), yx_min, yx_max, cls, difficult 53 | 54 | 55 | def load_root(): 56 | with open(os.path.splitext(__file__)[0] + '.txt', 'r') as f: 57 | return [line.rstrip() for line in f] 58 | 59 | 60 | def cache(config, path, category_index, root=load_root()): 61 | phase = os.path.splitext(os.path.basename(path))[0] 62 | data = [] 63 | for root in root: 64 | logging.info('loading ' + root) 65 | root = os.path.expanduser(os.path.expandvars(root)) 66 | path = os.path.join(root, 'ImageSets', 'Main', phase) + '.txt' 67 | if not os.path.exists(path): 68 | logging.warning(path + ' not exists') 69 | continue 70 | with open(path, 'r') as f: 71 | filenames = [line.strip() for line in f] 72 | for filename in tqdm.tqdm(filenames): 73 | filename, size, yx_min, yx_max, cls, difficult = load_annotation(os.path.join(root, 'Annotations', filename + '.xml'), category_index) 74 | if len(cls) <= 0: 75 | continue 76 | path = os.path.join(root, 'JPEGImages', filename) 77 | yx_min = np.array(yx_min, dtype=np.float32) 78 | yx_max = np.array(yx_max, dtype=np.float32) 79 | cls = np.array(cls, dtype=np.int) 80 | difficult = np.array(difficult, dtype=np.uint8) 81 | assert len(yx_min) == len(cls) 82 | assert yx_min.shape == yx_max.shape 83 | assert len(yx_min.shape) == 2 and yx_min.shape[-1] == 2 84 | try: 85 | if config.getboolean('cache', 'verify'): 86 | try: 87 | image = cv2.imread(path) 88 | assert image is not None 89 | assert image.shape[:2] == size[:2] 90 | utils.cache.verify_coords(yx_min, yx_max, size[:2]) 91 | except AssertionError as e: 92 | logging.error(path + ': ' + str(e)) 93 | continue 94 | except configparser.NoOptionError: 95 | pass 96 | data.append(dict(path=path, yx_min=yx_min, yx_max=yx_max, cls=cls, difficult=difficult)) 97 | logging.info('%d of %d images are saved' % (len(data), len(filenames))) 98 | return data 99 | -------------------------------------------------------------------------------- /cache/voc.txt: -------------------------------------------------------------------------------- 1 | ~/data/VOCdevkit/VOC2007 2 | ~/data/VOCdevkit/VOC2012 3 | -------------------------------------------------------------------------------- /checksum_caffe2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import hashlib 24 | import yaml 25 | 26 | import torch 27 | from caffe2.proto import caffe2_pb2 28 | from caffe2.python import workspace 29 | 30 | import utils 31 | 32 | 33 | def main(): 34 | args = make_args() 35 | config = configparser.ConfigParser() 36 | utils.load_config(config, args.config) 37 | for cmd in args.modify: 38 | utils.modify_config(config, cmd) 39 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 40 | logging.config.dictConfig(yaml.load(f)) 41 | torch.manual_seed(args.seed) 42 | model_dir = utils.get_model_dir(config) 43 | init_net = caffe2_pb2.NetDef() 44 | with open(os.path.join(model_dir, 'init_net.pb'), 'rb') as f: 45 | init_net.ParseFromString(f.read()) 46 | predict_net = caffe2_pb2.NetDef() 47 | with open(os.path.join(model_dir, 'predict_net.pb'), 'rb') as f: 48 | predict_net.ParseFromString(f.read()) 49 | p = workspace.Predictor(init_net, predict_net) 50 | height, width = tuple(map(int, config.get('image', 'size').split())) 51 | tensor = torch.randn(1, 3, height, width) 52 | # Checksum 53 | output = p.run([tensor.numpy()]) 54 | for key, a in [ 55 | ('tensor', tensor.cpu().numpy()), 56 | ('output', output[0]), 57 | ]: 58 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()]))) 59 | 60 | 61 | def make_args(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 64 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 65 | parser.add_argument('--logging', default='logging.yml', help='logging config') 66 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor') 67 | return parser.parse_args() 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /checksum_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import hashlib 24 | import yaml 25 | 26 | import torch 27 | import torch.autograd 28 | import cv2 29 | 30 | import utils 31 | import utils.train 32 | import model 33 | import transform 34 | 35 | 36 | def main(): 37 | args = make_args() 38 | config = configparser.ConfigParser() 39 | utils.load_config(config, args.config) 40 | for cmd in args.modify: 41 | utils.modify_config(config, cmd) 42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 43 | logging.config.dictConfig(yaml.load(f)) 44 | torch.manual_seed(args.seed) 45 | cache_dir = utils.get_cache_dir(config) 46 | model_dir = utils.get_model_dir(config) 47 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None) 48 | anchors = utils.get_anchors(config) 49 | anchors = torch.from_numpy(anchors).contiguous() 50 | path, step, epoch = utils.train.load_model(model_dir) 51 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category)) 53 | dnn.load_state_dict(state_dict) 54 | height, width = tuple(map(int, config.get('image', 'size').split())) 55 | tensor = torch.randn(1, 3, height, width) 56 | # Checksum 57 | for key, var in dnn.state_dict().items(): 58 | a = var.cpu().numpy() 59 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()]))) 60 | output = dnn(torch.autograd.Variable(tensor, volatile=True)).data 61 | for key, a in [ 62 | ('tensor', tensor.cpu().numpy()), 63 | ('output', output.cpu().numpy()), 64 | ]: 65 | print('\t'.join(map(str, [key, a.shape, utils.abs_mean(a), hashlib.md5(a.tostring()).hexdigest()]))) 66 | 67 | 68 | def make_args(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 71 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 72 | parser.add_argument('--logging', default='logging.yml', help='logging config') 73 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor') 74 | return parser.parse_args() 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | root = ~/model/yolo2-pytorch 3 | 4 | [image] 5 | size = 416 416 6 | 7 | [cache] 8 | name = cache 9 | category = config/category/20 10 | # voc coco 11 | datasets = cache.voc.cache cache.coco.cache 12 | shuffle = 1 13 | 14 | [model] 15 | name = model 16 | anchors = config/anchors/voc.tsv 17 | ; model.yolo2.Darknet 18 | ; model.yolo2.Tiny 19 | ; model.resnet.resnet18 20 | ; model.inception3.Inception3 21 | ; model.inception4.Inception4 22 | ; model.mobilenet.MobileNet 23 | ; model.densenet.densenet121 24 | ; model.vgg.vgg19 25 | dnn = model.yolo2.Tiny 26 | pretrained = 0 27 | threshold = 0.6 28 | 29 | [batch_norm] 30 | enable = 1 31 | gamma = 1 32 | beta = 1 33 | 34 | [inception4] 35 | pretrained = imagenet 36 | 37 | [data] 38 | workers = 3 39 | sizes = 320,320 352,352 384,384 416,416 448,448 480,480 512,512 544,544 576,576 608,608 40 | maintain = 10 41 | shuffle = 0 42 | # rescale padding 43 | resize = rescale 44 | 45 | [transform] 46 | ; transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally 47 | augmentation = transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally 48 | resize_train = transform.resize.label.RandomCrop 49 | resize_eval = transform.resize.label.Resize 50 | resize_test = transform.resize.image.Resize 51 | ; transform.image.RandomBlur transform.image.BGR2HSV transform.image.RandomHue transform.image.RandomSaturation transform.image.RandomBrightness transform.image.HSV2RGB transform.image.RandomGamma 52 | image_train = transform.image.BGR2RGB 53 | image_test = transform.image.BGR2RGB 54 | ; torchvision.transforms.ToTensor transform.image.Normalize 55 | tensor = torchvision.transforms.ToTensor transform.image.Normalize 56 | normalize = 0.5 1 57 | 58 | [augmentation] 59 | random_rotate = -5 5 60 | random_flip_horizontally = 0.5 61 | random_crop = 1 62 | random_blur = 5 5 63 | random_hue = 0 25 64 | random_saturation = 0.5 1.5 65 | random_brightness = 0.5 1.5 66 | random_gamma = 0.9 1.5 67 | 68 | [train] 69 | ; lambda params, lr: torch.optim.SGD(params, lr, momentum=2) 70 | ; lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8) 71 | ; lambda params, lr: torch.optim.RMSprop(params, lr, alpha=0.99, eps=1e-8) 72 | optimizer = lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8) 73 | ; lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 74 | ; lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) 75 | scheduler = lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) 76 | phase = train val 77 | cross_entropy = 1 78 | clip_ = 5 79 | 80 | [save] 81 | secs = 600 82 | keep = 5 83 | 84 | [summary] 85 | scalar = 10 86 | image = 60 87 | histogram_ = 60 88 | 89 | [summary_scalar] 90 | loss_hparam = 1 91 | 92 | [summary_image] 93 | limit = 2 94 | bbox = 1 95 | iou = 1 96 | 97 | [summary_histogram] 98 | parameters = config/summary/histogram.txt 99 | 100 | [hparam] 101 | foreground = 5 102 | background = 1 103 | center = 1 104 | size = 1 105 | cls = 1 106 | 107 | [detect] 108 | threshold = 0.3 109 | threshold_cls = 0.005 110 | fix = 0 111 | overlap = 0.45 112 | 113 | [eval] 114 | phase = test 115 | secs = 12 * 60 * 60 116 | first = 0 117 | iou = 0.5 118 | db = eval.json 119 | mapper = config/eval.py 120 | debug = 0 121 | sort = timestamp 122 | metric07 = 1 123 | 124 | [graph] 125 | metric = lambda t: np.mean(utils.dense(t)) 126 | format = svg 127 | 128 | [digraph_graph_attr] 129 | size = 12, 12 130 | 131 | [digraph_node_attr] 132 | style = filled 133 | shape = box 134 | align = left 135 | fontsize = 12 136 | ranksep = 0.1 137 | height = 0.2 138 | -------------------------------------------------------------------------------- /config/anchors/coco.tsv: -------------------------------------------------------------------------------- 1 | width height 2 | 0.57273 0.677385 3 | 1.87446 2.06253 4 | 3.33843 5.47434 5 | 7.88282 3.52778 6 | 9.77052 9.16828 7 | -------------------------------------------------------------------------------- /config/anchors/tiny-yolo-voc.tsv: -------------------------------------------------------------------------------- 1 | width height 2 | 1.08 1.19 3 | 3.42 4.41 4 | 6.63 11.38 5 | 9.42 5.11 6 | 16.62 10.52 7 | -------------------------------------------------------------------------------- /config/anchors/voc.tsv: -------------------------------------------------------------------------------- 1 | width height 2 | 1.08 1.19 3 | 3.42 4.41 4 | 6.63 11.38 5 | 9.42 5.11 6 | 16.62 10.52 7 | -------------------------------------------------------------------------------- /config/anchors/yolo-voc.tsv: -------------------------------------------------------------------------------- 1 | width height 2 | 1.3221 1.73145 3 | 3.19275 4.00944 4 | 5.05587 8.09892 5 | 9.47112 4.84053 6 | 11.2364 10.0071 7 | -------------------------------------------------------------------------------- /config/anchors/yolo.tsv: -------------------------------------------------------------------------------- 1 | width height 2 | 0.57273 0.677385 3 | 1.87446 2.06253 4 | 3.33843 5.47434 5 | 7.88282 3.52778 6 | 9.77052 9.16828 7 | -------------------------------------------------------------------------------- /config/category/20: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /config/category/80: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /config/category/person: -------------------------------------------------------------------------------- 1 | person -------------------------------------------------------------------------------- /config/darknet.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | sizes = 416,416 448,448 480,480 512,512 3 | maintain = 10 4 | 5 | [transform] 6 | tensor = torchvision.transforms.ToTensor 7 | -------------------------------------------------------------------------------- /config/darknet/tiny-yolo-voc.ini: -------------------------------------------------------------------------------- 1 | [image] 2 | size = 416 416 3 | 4 | [cache] 5 | name = cache_voc 6 | category = config/category/20 7 | datasets = cache.voc.cache 8 | 9 | [transform] 10 | tensor = torchvision.transforms.ToTensor 11 | 12 | [model] 13 | name = model_voc 14 | anchors = config/anchors/tiny-yolo-voc.tsv 15 | dnn = model.yolo2.Tiny 16 | 17 | [detect] 18 | fix = 1 19 | -------------------------------------------------------------------------------- /config/darknet/yolo-voc.ini: -------------------------------------------------------------------------------- 1 | [image] 2 | size = 416 416 3 | 4 | [cache] 5 | name = cache_voc 6 | category = config/category/20 7 | datasets = cache.voc.cache cache.coco.cache 8 | 9 | [transform] 10 | tensor = torchvision.transforms.ToTensor 11 | 12 | [model] 13 | name = model_voc 14 | anchors = config/anchors/yolo-voc.tsv 15 | dnn = model.yolo2.Darknet 16 | 17 | [detect] 18 | fix = 1 19 | -------------------------------------------------------------------------------- /config/darknet/yolo.ini: -------------------------------------------------------------------------------- 1 | [image] 2 | size = 416 416 3 | 4 | [cache] 5 | name = cache_coco 6 | category = config/category/80 7 | datasets = cache.coco.cache 8 | 9 | [transform] 10 | tensor = torchvision.transforms.ToTensor 11 | 12 | [model] 13 | name = model_coco 14 | anchors = config/anchors/yolo.tsv 15 | dnn = model.yolo2.Darknet 16 | 17 | [detect] 18 | fix = 1 19 | -------------------------------------------------------------------------------- /config/debug.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | sizes = 416,416 3 | 4 | [transform] 5 | augmentation = 6 | resize_train = transform.resize.label.Resize 7 | image_train = transform.image.BGR2RGB 8 | tensor = torchvision.transforms.ToTensor 9 | -------------------------------------------------------------------------------- /config/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import configparser 20 | 21 | import numpy as np 22 | import humanize 23 | import pybenchmark 24 | 25 | 26 | class Timestamp(object): 27 | def __call__(self, env, **kwargs): 28 | return float(env.now.timestamp()) 29 | 30 | 31 | class Time(object): 32 | def __call__(self, env, **kwargs): 33 | return env.now.strftime('%Y-%m-%d %H:%M:%S') 34 | 35 | def get_format(self, workbook, worksheet): 36 | return workbook.add_format({'num_format': 'yyyy-mm-dd hh:mm:ss'}) 37 | 38 | 39 | class Step(object): 40 | def __call__(self, env, **kwargs): 41 | return env.step 42 | 43 | 44 | class Epoch(object): 45 | def __call__(self, env, **kwargs): 46 | return env.epoch 47 | 48 | 49 | class Model(object): 50 | def __call__(self, env, **kwargs): 51 | return env.config.get('model', 'dnn') 52 | 53 | 54 | class SizeDnn(object): 55 | def __call__(self, env, **kwargs): 56 | return sum(var.cpu().numpy().nbytes for var in env.inference.state_dict().values()) 57 | 58 | def format(self, workbook, worksheet, num, col): 59 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 60 | 61 | 62 | class SizeDnnNature(object): 63 | def __call__(self, env, **kwargs): 64 | return humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in env.inference.state_dict().values())) 65 | 66 | 67 | class TimeInference(object): 68 | def __call__(self, env, **kwargs): 69 | return pybenchmark.stats['inference']['time'] 70 | 71 | def format(self, workbook, worksheet, num, col): 72 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 73 | 74 | 75 | class Root(object): 76 | def __call__(self, env, **kwargs): 77 | return os.path.basename(env.config.get('config', 'root')) 78 | 79 | 80 | class CacheName(object): 81 | def __call__(self, env, **kwargs): 82 | return env.config.get('cache', 'name') 83 | 84 | 85 | class ModelName(object): 86 | def __call__(self, env, **kwargs): 87 | return env.config.get('model', 'name') 88 | 89 | 90 | class Category(object): 91 | def __call__(self, env, **kwargs): 92 | return env.config.get('cache', 'category') 93 | 94 | 95 | class DatasetSize(object): 96 | def __call__(self, env, **kwargs): 97 | return len(env.loader.dataset) 98 | 99 | def format(self, workbook, worksheet, num, col): 100 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 101 | 102 | 103 | class DetectThreshold(object): 104 | def __call__(self, env, **kwargs): 105 | return env.config.getfloat('detect', 'threshold') 106 | 107 | def format(self, workbook, worksheet, num, col): 108 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 109 | 110 | 111 | class DetectThresholdCls(object): 112 | def __call__(self, env, **kwargs): 113 | return env.config.getfloat('detect', 'threshold_cls') 114 | 115 | def format(self, workbook, worksheet, num, col): 116 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 117 | 118 | 119 | class DetectFix(object): 120 | def __call__(self, env, **kwargs): 121 | return env.config.getboolean('detect', 'fix') 122 | 123 | def format(self, workbook, worksheet, num, col): 124 | format_green = workbook.add_format({'bg_color': '#C6EFCE', 'font_color': '#006100'}) 125 | format_red = workbook.add_format({'bg_color': '#FFC7CE', 'font_color': '#9C0006'}) 126 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'cell', 'criteria': '==', 'value': '1', 'format': format_green}) 127 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'cell', 'criteria': '<>', 'value': '1', 'format': format_red}) 128 | 129 | 130 | class DetectOverlap(object): 131 | def __call__(self, env, **kwargs): 132 | return env.config.getfloat('detect', 'overlap') 133 | 134 | def format(self, workbook, worksheet, num, col): 135 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 136 | 137 | 138 | class EvalIou(object): 139 | def __call__(self, env, **kwargs): 140 | return env.config.getfloat('eval', 'iou') 141 | 142 | def format(self, workbook, worksheet, num, col): 143 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 144 | 145 | 146 | class EvalMeanAp(object): 147 | def __call__(self, env, **kwargs): 148 | return np.mean(list(kwargs['cls_ap'].values())) 149 | 150 | def format(self, workbook, worksheet, num, col): 151 | worksheet.conditional_format(1, col, num + 1, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 152 | 153 | 154 | class EvalAp(object): 155 | def __call__(self, env, **kwargs): 156 | cls_ap = kwargs['cls_ap'] 157 | return ', '.join(['%s=%f' % (env.category[c], cls_ap[c]) for c in sorted(cls_ap.keys())]) 158 | 159 | 160 | class Hparam(object): 161 | def __call__(self, env, **kwargs): 162 | try: 163 | return ', '.join([option + '=' + value for option, value in env._config.items('hparam')]) 164 | except AttributeError: 165 | return None 166 | 167 | 168 | class Optimizer(object): 169 | def __call__(self, env, **kwargs): 170 | try: 171 | return env._config.get('train', 'optimizer') 172 | except (AttributeError, configparser.NoOptionError): 173 | return None 174 | 175 | 176 | class Scheduler(object): 177 | def __call__(self, env, **kwargs): 178 | try: 179 | return env._config.get('train', 'scheduler') 180 | except (AttributeError, configparser.NoOptionError): 181 | return None 182 | -------------------------------------------------------------------------------- /config/summary/histogram.txt: -------------------------------------------------------------------------------- 1 | .+\.bn\.weight$ 2 | -------------------------------------------------------------------------------- /convert_darknet_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import struct 24 | import collections 25 | import shutil 26 | import hashlib 27 | import yaml 28 | 29 | import numpy as np 30 | import torch 31 | import humanize 32 | 33 | import model 34 | import utils.train 35 | 36 | 37 | def transpose_weight(weight, num_anchors): 38 | _, channels_in, ksize1, ksize2 = weight.size() 39 | weight = weight.view(num_anchors, -1, channels_in, ksize1, ksize2) 40 | x = weight[:, 0:1, :, :, :] 41 | y = weight[:, 1:2, :, :, :] 42 | w = weight[:, 2:3, :, :, :] 43 | h = weight[:, 3:4, :, :, :] 44 | iou = weight[:, 4:5, :, :, :] 45 | cls = weight[:, 5:, :, :, :] 46 | return torch.cat([iou, y, x, h, w, cls], 1).view(-1, channels_in, ksize1, ksize2) 47 | 48 | 49 | def transpose_bias(bias, num_anchors): 50 | bias = bias.view([num_anchors, -1]) 51 | x = bias[:, 0:1] 52 | y = bias[:, 1:2] 53 | w = bias[:, 2:3] 54 | h = bias[:, 3:4] 55 | iou = bias[:, 4:5] 56 | cls = bias[:, 5:] 57 | return torch.cat([iou, y, x, h, w, cls], 1).view(-1) 58 | 59 | 60 | def group_state(state_dict): 61 | grouped_dict = collections.OrderedDict() 62 | for key, var in state_dict.items(): 63 | layer, suffix1, suffix2 = key.rsplit('.', 2) 64 | suffix = suffix1 + '.' + suffix2 65 | if layer in grouped_dict: 66 | grouped_dict[layer][suffix] = var 67 | else: 68 | grouped_dict[layer] = {suffix: var} 69 | return grouped_dict 70 | 71 | 72 | def main(): 73 | args = make_args() 74 | config = configparser.ConfigParser() 75 | utils.load_config(config, args.config) 76 | for cmd in args.modify: 77 | utils.modify_config(config, cmd) 78 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 79 | logging.config.dictConfig(yaml.load(f)) 80 | cache_dir = utils.get_cache_dir(config) 81 | model_dir = utils.get_model_dir(config) 82 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None) 83 | anchors = utils.get_anchors(config) 84 | anchors = torch.from_numpy(anchors).contiguous() 85 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), anchors, len(category)) 86 | dnn.eval() 87 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values()))) 88 | state_dict = dnn.state_dict() 89 | grouped_dict = group_state(state_dict) 90 | try: 91 | layers = [] 92 | with open(os.path.expanduser(os.path.expandvars(args.file)), 'rb') as f: 93 | major, minor, revision, seen = struct.unpack('4i', f.read(16)) 94 | logging.info('major=%d, minor=%d, revision=%d, seen=%d' % (major, minor, revision, seen)) 95 | total = 0 96 | filesize = os.fstat(f.fileno()).st_size 97 | for layer in grouped_dict: 98 | group = grouped_dict[layer] 99 | for suffix in ['conv.bias', 'bn.bias', 'bn.weight', 'bn.running_mean', 'bn.running_var', 'conv.weight']: 100 | if suffix in group: 101 | var = group[suffix] 102 | size = var.size() 103 | cnt = np.multiply.reduce(size) 104 | total += cnt 105 | key = layer + '.' + suffix 106 | val = np.array(struct.unpack('%df' % cnt, f.read(cnt * 4)), np.float32) 107 | val = np.reshape(val, size) 108 | remaining = filesize - f.tell() 109 | logging.info('%s.%s: %s=%f (%s), remaining=%d' % (layer, suffix, 'x'.join(list(map(str, size))), utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), remaining)) 110 | layers.append([key, torch.from_numpy(val)]) 111 | logging.info('%d parameters assigned' % total) 112 | layers[-1][1] = transpose_weight(layers[-1][1], len(anchors)) 113 | layers[-2][1] = transpose_bias(layers[-2][1], len(anchors)) 114 | finally: 115 | if remaining > 0: 116 | logging.warning('%d bytes remaining' % remaining) 117 | state_dict = collections.OrderedDict(layers) 118 | if args.delete: 119 | logging.warning('delete model directory: ' + model_dir) 120 | shutil.rmtree(model_dir, ignore_errors=True) 121 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep'), logger=None) 122 | path = saver(state_dict, 0, 0) + saver.ext 123 | if args.copy is not None: 124 | _path = os.path.expandvars(os.path.expanduser(args.copy)) 125 | logging.info('copy %s to %s' % (path, _path)) 126 | shutil.copy(path, _path) 127 | 128 | 129 | def make_args(): 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('file', help='Darknet .weights file') 132 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 133 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 134 | parser.add_argument('-d', '--delete', action='store_true', help='delete logdir') 135 | parser.add_argument('--copy', help='copy model') 136 | parser.add_argument('--logging', default='logging.yml', help='logging config') 137 | return parser.parse_args() 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /convert_onnx_caffe2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import onnx 26 | import onnx_caffe2.backend 27 | import onnx_caffe2.helper 28 | 29 | import utils 30 | 31 | 32 | def main(): 33 | args = make_args() 34 | config = configparser.ConfigParser() 35 | utils.load_config(config, args.config) 36 | for cmd in args.modify: 37 | utils.modify_config(config, cmd) 38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 39 | logging.config.dictConfig(yaml.load(f)) 40 | model_dir = utils.get_model_dir(config) 41 | model = onnx.load(model_dir + '.onnx') 42 | onnx.checker.check_model(model) 43 | init_net, predict_net = onnx_caffe2.backend.Caffe2Backend.onnx_graph_to_caffe2_net(model.graph, device='CPU') 44 | onnx_caffe2.helper.save_caffe2_net(init_net, os.path.join(model_dir, 'init_net.pb')) 45 | onnx_caffe2.helper.save_caffe2_net(predict_net, os.path.join(model_dir, 'predict_net.pb'), output_txt=True) 46 | logging.info(model_dir) 47 | 48 | 49 | def make_args(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 52 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 53 | parser.add_argument('--logging', default='logging.yml', help='logging config') 54 | return parser.parse_args() 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /convert_torch_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import torch.autograd 26 | import torch.cuda 27 | import torch.optim 28 | import torch.utils.data 29 | import torch.onnx 30 | import humanize 31 | 32 | import utils.train 33 | import model 34 | 35 | 36 | def main(): 37 | args = make_args() 38 | config = configparser.ConfigParser() 39 | utils.load_config(config, args.config) 40 | for cmd in args.modify: 41 | utils.modify_config(config, cmd) 42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 43 | logging.config.dictConfig(yaml.load(f)) 44 | height, width = tuple(map(int, config.get('image', 'size').split())) 45 | cache_dir = utils.get_cache_dir(config) 46 | model_dir = utils.get_model_dir(config) 47 | category = utils.get_category(config, cache_dir if os.path.exists(cache_dir) else None) 48 | anchors = utils.get_anchors(config) 49 | anchors = torch.from_numpy(anchors).contiguous() 50 | path, step, epoch = utils.train.load_model(model_dir) 51 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category)) 53 | inference = model.Inference(config, dnn, anchors) 54 | inference.eval() 55 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in inference.state_dict().values()))) 56 | dnn.load_state_dict(state_dict) 57 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width), volatile=True) 58 | path = model_dir + '.onnx' 59 | logging.info('save ' + path) 60 | torch.onnx.export(dnn, image, path, export_params=True, verbose=args.verbose) # PyTorch's bug 61 | 62 | 63 | def make_args(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 66 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 67 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 68 | parser.add_argument('-v', '--verbose', action='store_true') 69 | parser.add_argument('--logging', default='logging.yml', help='logging config') 70 | return parser.parse_args() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/demo.gif -------------------------------------------------------------------------------- /demo_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import multiprocessing 24 | import yaml 25 | 26 | import numpy as np 27 | import torch.utils.data 28 | import matplotlib.pyplot as plt 29 | 30 | import utils.data 31 | import utils.train 32 | import utils.visualize 33 | import transform.augmentation 34 | 35 | 36 | def main(): 37 | args = make_args() 38 | config = configparser.ConfigParser() 39 | utils.load_config(config, args.config) 40 | for cmd in args.modify: 41 | utils.modify_config(config, cmd) 42 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 43 | logging.config.dictConfig(yaml.load(f)) 44 | cache_dir = utils.get_cache_dir(config) 45 | category = utils.get_category(config, cache_dir) 46 | draw_bbox = utils.visualize.DrawBBox(category) 47 | batch_size = args.rows * args.cols 48 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase] 49 | dataset = utils.data.Dataset( 50 | utils.data.load_pickles(paths), 51 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()), 52 | shuffle=config.getboolean('data', 'shuffle'), 53 | ) 54 | logging.info('num_examples=%d' % len(dataset)) 55 | try: 56 | workers = config.getint('data', 'workers') 57 | except configparser.NoOptionError: 58 | workers = multiprocessing.cpu_count() 59 | collate_fn = utils.data.Collate( 60 | transform.parse_transform(config, config.get('transform', 'resize_train')), 61 | utils.train.load_sizes(config), 62 | maintain=config.getint('data', 'maintain'), 63 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()), 64 | ) 65 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn) 66 | for data in loader: 67 | path, size, image, yx_min, yx_max, cls = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, yx_min, yx_max, cls'.split(', '))) 68 | fig, axes = plt.subplots(args.rows, args.cols) 69 | axes = axes.flat if batch_size > 1 else [axes] 70 | for ax, path, size, image, yx_min, yx_max, cls in zip(*[axes, path, size, image, yx_min, yx_max, cls]): 71 | logging.info(path + ': ' + 'x'.join(map(str, size))) 72 | size = yx_max - yx_min 73 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)]) 74 | yx_min, yx_max, cls = (a[target] for a in (yx_min, yx_max, cls)) 75 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int), cls) 76 | ax.imshow(image) 77 | ax.set_title('%d objects' % np.sum(target)) 78 | ax.set_xticks([]) 79 | ax.set_yticks([]) 80 | fig.tight_layout() 81 | mng = plt.get_current_fig_manager() 82 | mng.resize(*mng.window.maxsize()) 83 | plt.show() 84 | 85 | 86 | def make_args(): 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 89 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 90 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 91 | parser.add_argument('--rows', default=3, type=int) 92 | parser.add_argument('--cols', default=3, type=int) 93 | parser.add_argument('--logging', default='logging.yml', help='logging config') 94 | return parser.parse_args() 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /demo_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import torch.autograd 26 | import torch.cuda 27 | import torch.optim 28 | import torch.utils.data 29 | import humanize 30 | 31 | import model 32 | import utils 33 | import utils.train 34 | import utils.visualize 35 | 36 | 37 | def main(): 38 | args = make_args() 39 | config = configparser.ConfigParser() 40 | utils.load_config(config, args.config) 41 | for cmd in args.modify: 42 | utils.modify_config(config, cmd) 43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 44 | logging.config.dictConfig(yaml.load(f)) 45 | model_dir = utils.get_model_dir(config) 46 | category = utils.get_category(config) 47 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() 48 | try: 49 | path, step, epoch = utils.train.load_model(model_dir) 50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 51 | except (FileNotFoundError, ValueError): 52 | logging.warning('model cannot be loaded') 53 | state_dict = None 54 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), anchors, len(category)) 55 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values()))) 56 | if state_dict is not None: 57 | dnn.load_state_dict(state_dict) 58 | height, width = tuple(map(int, config.get('image', 'size').split())) 59 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width)) 60 | output = dnn(image) 61 | state_dict = dnn.state_dict() 62 | graph = utils.visualize.Graph(config, state_dict) 63 | graph(output.grad_fn) 64 | diff = [key for key in state_dict if key not in graph.drawn] 65 | if diff: 66 | logging.warning('variables not shown: ' + str(diff)) 67 | path = graph.dot.view(os.path.basename(model_dir) + '.gv', os.path.dirname(model_dir)) 68 | logging.info(path) 69 | 70 | 71 | def make_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 74 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 75 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 76 | parser.add_argument('--logging', default='logging.yml', help='logging config') 77 | return parser.parse_args() 78 | 79 | 80 | if __name__ == '__main__': 81 | main() -------------------------------------------------------------------------------- /demo_lr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import argparse 19 | import configparser 20 | import logging 21 | import logging.config 22 | import os 23 | import yaml 24 | 25 | import torch.autograd 26 | import torch.cuda 27 | import torch.optim 28 | import torch.utils.data 29 | 30 | import model 31 | import utils.data 32 | import utils.postprocess 33 | import utils.train 34 | import utils.visualize 35 | 36 | 37 | def main(): 38 | args = make_args() 39 | config = configparser.ConfigParser() 40 | utils.load_config(config, args.config) 41 | for cmd in args.modify: 42 | utils.modify_config(config, cmd) 43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 44 | logging.config.dictConfig(yaml.load(f)) 45 | category = utils.get_category(config) 46 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() 47 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), anchors, len(category)) 48 | inference = model.Inference(config, dnn, anchors) 49 | inference.train() 50 | optimizer = eval(config.get('train', 'optimizer'))(filter(lambda p: p.requires_grad, inference.parameters()), args.learning_rate) 51 | scheduler = eval(config.get('train', 'scheduler'))(optimizer) 52 | for epoch in range(args.epoch): 53 | scheduler.step(epoch) 54 | lr = scheduler.get_lr() 55 | print('\t'.join(map(str, [epoch] + lr))) 56 | 57 | 58 | 59 | def make_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('epoch', type=int) 62 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 63 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 64 | parser.add_argument('-o', '--optimizer', default='adam') 65 | parser.add_argument('-lr', '--learning_rate', default=1e-3, type=float, help='learning rate') 66 | parser.add_argument('--logging', default='logging.yml', help='logging config') 67 | return parser.parse_args() 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import argparse 19 | import configparser 20 | import logging 21 | import logging.config 22 | import os 23 | import time 24 | import yaml 25 | 26 | import numpy as np 27 | import torch.autograd 28 | import torch.cuda 29 | import torch.optim 30 | import torch.utils.data 31 | import torch.nn.functional as F 32 | import humanize 33 | import pybenchmark 34 | import cv2 35 | 36 | import transform 37 | import model 38 | import utils.postprocess 39 | import utils.train 40 | import utils.visualize 41 | 42 | 43 | def get_logits(pred): 44 | if 'logits' in pred: 45 | return pred['logits'].contiguous() 46 | else: 47 | size = pred['iou'].size() 48 | return torch.autograd.Variable(utils.ensure_device(torch.ones(*size, 1))) 49 | 50 | 51 | def filter_visible(config, iou, yx_min, yx_max, prob): 52 | prob_cls, cls = torch.max(prob, -1) 53 | if config.getboolean('detect', 'fix'): 54 | mask = (iou * prob_cls) > config.getfloat('detect', 'threshold_cls') 55 | else: 56 | mask = iou > config.getfloat('detect', 'threshold') 57 | iou, prob_cls, cls = (t[mask].view(-1) for t in (iou, prob_cls, cls)) 58 | _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug 59 | yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) 60 | num = prob.size(-1) 61 | _mask = torch.unsqueeze(mask, -1).repeat(1, num) # PyTorch's bug 62 | prob = prob[_mask].view(-1, num) 63 | return iou, yx_min, yx_max, prob, prob_cls, cls 64 | 65 | 66 | def postprocess(config, iou, yx_min, yx_max, prob): 67 | iou, yx_min, yx_max, prob, prob_cls, cls = filter_visible(config, iou, yx_min, yx_max, prob) 68 | keep = pybenchmark.profile('nms')(utils.postprocess.nms)(iou, yx_min, yx_max, config.getfloat('detect', 'overlap')) 69 | if keep: 70 | keep = utils.ensure_device(torch.LongTensor(keep)) 71 | iou, yx_min, yx_max, prob, prob_cls, cls = (t[keep] for t in (iou, yx_min, yx_max, prob, prob_cls, cls)) 72 | if config.getboolean('detect', 'fix'): 73 | score = torch.unsqueeze(iou, -1) * prob 74 | mask = score > config.getfloat('detect', 'threshold_cls') 75 | indices, cls = torch.unbind(mask.nonzero(), -1) 76 | yx_min, yx_max = (t[indices] for t in (yx_min, yx_max)) 77 | score = score[mask] 78 | else: 79 | score = iou 80 | return iou, yx_min, yx_max, cls, score 81 | 82 | 83 | class Detect(object): 84 | def __init__(self, args, config): 85 | self.args = args 86 | self.config = config 87 | self.cache_dir = utils.get_cache_dir(config) 88 | self.model_dir = utils.get_model_dir(config) 89 | self.category = utils.get_category(config, self.cache_dir if os.path.exists(self.cache_dir) else None) 90 | self.draw_bbox = utils.visualize.DrawBBox(self.category, colors=args.colors, thickness=args.thickness) 91 | self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() 92 | self.height, self.width = tuple(map(int, config.get('image', 'size').split())) 93 | self.path, self.step, self.epoch = utils.train.load_model(self.model_dir) 94 | state_dict = torch.load(self.path, map_location=lambda storage, loc: storage) 95 | self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), self.anchors, len(self.category)) 96 | self.dnn.load_state_dict(state_dict) 97 | self.inference = model.Inference(config, self.dnn, self.anchors) 98 | self.inference.eval() 99 | if torch.cuda.is_available(): 100 | self.inference.cuda() 101 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values()))) 102 | self.cap = self.create_cap() 103 | self.keys = set(args.keys) 104 | self.resize = transform.parse_transform(config, config.get('transform', 'resize_test')) 105 | self.transform_image = transform.get_transform(config, config.get('transform', 'image_test').split()) 106 | self.transform_tensor = transform.get_transform(config, config.get('transform', 'tensor').split()) 107 | 108 | def __del__(self): 109 | cv2.destroyAllWindows() 110 | try: 111 | self.writer.release() 112 | except AttributeError: 113 | pass 114 | self.cap.release() 115 | 116 | def create_cap(self): 117 | try: 118 | cap = int(self.args.input) 119 | except ValueError: 120 | cap = os.path.expanduser(os.path.expandvars(self.args.input)) 121 | assert os.path.exists(cap) 122 | return cv2.VideoCapture(cap) 123 | 124 | def create_writer(self, height, width): 125 | fps = self.cap.get(cv2.CAP_PROP_FPS) 126 | logging.info('cap fps=%f' % fps) 127 | path = os.path.expanduser(os.path.expandvars(self.args.output)) 128 | if self.args.fourcc: 129 | fourcc = cv2.VideoWriter_fourcc(*self.args.fourcc.upper()) 130 | else: 131 | fourcc = int(self.cap.get(cv2.CAP_PROP_FOURCC)) 132 | os.makedirs(os.path.dirname(path), exist_ok=True) 133 | return cv2.VideoWriter(path, fourcc, fps, (width, height)) 134 | 135 | def get_image(self): 136 | ret, image_bgr = self.cap.read() 137 | if self.args.crop: 138 | image_bgr = image_bgr[self.crop_ymin:self.crop_ymax, self.crop_xmin:self.crop_xmax] 139 | return image_bgr 140 | 141 | def __call__(self): 142 | image_bgr = self.get_image() 143 | image_resized = self.resize(image_bgr, self.height, self.width) 144 | image = self.transform_image(image_resized) 145 | tensor = self.transform_tensor(image) 146 | tensor = utils.ensure_device(tensor.unsqueeze(0)) 147 | pred = pybenchmark.profile('inference')(model._inference)(self.inference, torch.autograd.Variable(tensor, volatile=True)) 148 | rows, cols = pred['feature'].size()[-2:] 149 | iou = pred['iou'].data.contiguous().view(-1) 150 | yx_min, yx_max = (pred[key].data.view(-1, 2) for key in 'yx_min, yx_max'.split(', ')) 151 | logits = get_logits(pred) 152 | prob = F.softmax(logits, -1).data.view(-1, logits.size(-1)) 153 | ret = postprocess(self.config, iou, yx_min, yx_max, prob) 154 | image_result = image_bgr.copy() 155 | if ret is not None: 156 | iou, yx_min, yx_max, cls, score = ret 157 | try: 158 | scale = self.scale 159 | except AttributeError: 160 | scale = utils.ensure_device(torch.from_numpy(np.array(image_result.shape[:2], np.float32) / np.array([rows, cols], np.float32))) 161 | self.scale = scale 162 | yx_min, yx_max = ((t * scale).cpu().numpy().astype(np.int) for t in (yx_min, yx_max)) 163 | image_result = self.draw_bbox(image_result, yx_min, yx_max, cls) 164 | if self.args.output: 165 | if not hasattr(self, 'writer'): 166 | self.writer = self.create_writer(*image_result.shape[:2]) 167 | self.writer.write(image_result) 168 | else: 169 | cv2.imshow('detection', image_result) 170 | if cv2.waitKey(0 if self.args.pause else 1) in self.keys: 171 | root = os.path.join(self.model_dir, 'snapshot') 172 | os.makedirs(root, exist_ok=True) 173 | path = os.path.join(root, time.strftime(self.args.format)) 174 | cv2.imwrite(path, image_bgr) 175 | logging.warning('image dumped into ' + path) 176 | 177 | 178 | def main(): 179 | args = make_args() 180 | config = configparser.ConfigParser() 181 | utils.load_config(config, args.config) 182 | for cmd in args.modify: 183 | utils.modify_config(config, cmd) 184 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 185 | logging.config.dictConfig(yaml.load(f)) 186 | detect = Detect(args, config) 187 | try: 188 | while detect.cap.isOpened(): 189 | detect() 190 | except KeyboardInterrupt: 191 | logging.warning('interrupted') 192 | finally: 193 | logging.info(pybenchmark.stats) 194 | 195 | 196 | def make_args(): 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 199 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 200 | parser.add_argument('-i', '--input', default=-1) 201 | parser.add_argument('-k', '--keys', nargs='+', type=int, default=[ord(' ')], help='keys to dump images') 202 | parser.add_argument('-o', '--output', help='output video file') 203 | parser.add_argument('-f', '--format', default='%Y-%m-%d_%H-%M-%S.jpg', help='dump file name format') 204 | parser.add_argument('--crop', nargs='+', type=float, default=[], help='ymin ymax xmin xmax') 205 | parser.add_argument('--pause', action='store_true') 206 | parser.add_argument('--fourcc', default='XVID', help='4-character code of codec used to compress the frames, such as XVID, MJPG') 207 | parser.add_argument('--thickness', default=3, type=int) 208 | parser.add_argument('--colors', nargs='+', default=[]) 209 | parser.add_argument('--logging', default='logging.yml', help='logging config') 210 | return parser.parse_args() 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /dimension_cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import numpy as np 26 | import nltk.cluster.kmeans 27 | 28 | import utils.data 29 | import utils.iou.numpy 30 | 31 | 32 | def distance(a, b): 33 | return 1 - utils.iou.numpy.iou(-a, a, -b, b) 34 | 35 | 36 | def get_data(paths): 37 | dataset = utils.data.Dataset(utils.data.load_pickles(paths)) 38 | return np.concatenate([(data['yx_max'] - data['yx_min']) / utils.image_size(data['path']) for data in dataset.dataset]) 39 | 40 | 41 | def main(): 42 | args = make_args() 43 | config = configparser.ConfigParser() 44 | utils.load_config(config, args.config) 45 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 46 | logging.config.dictConfig(yaml.load(f)) 47 | cache_dir = utils.get_cache_dir(config) 48 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase] 49 | data = get_data(paths) 50 | logging.info('num_examples=%d' % len(data)) 51 | clusterer = nltk.cluster.kmeans.KMeansClusterer(args.num, distance, args.repeats) 52 | try: 53 | clusterer.cluster(data) 54 | except KeyboardInterrupt: 55 | logging.warning('interrupted') 56 | for m in clusterer.means(): 57 | print('\t'.join(map(str, m))) 58 | 59 | 60 | def make_args(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('num', type=int) 63 | parser.add_argument('-r', '--repeats', type=int, default=np.iinfo(np.int).max) 64 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 65 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 66 | parser.add_argument('--logging', default='logging.yml', help='logging config') 67 | return parser.parse_args() 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /disable_bad_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import sys 20 | import argparse 21 | import shutil 22 | import tqdm 23 | 24 | import cv2 25 | 26 | 27 | def main(): 28 | args = make_args() 29 | root = os.path.expanduser(os.path.expandvars(args.root)) 30 | for dirpath, _, filenames in os.walk(root): 31 | for filename in tqdm.tqdm(filenames, desc=dirpath): 32 | if os.path.splitext(filename)[-1].lower() in args.exts and filename[0] != '.': 33 | path = os.path.join(dirpath, filename) 34 | image = cv2.imread(path) 35 | if image is None: 36 | sys.stderr.write('disable bad image %s\n' % path) 37 | _path = os.path.join(os.path.dirname(path), '.' + os.path.basename(path)) 38 | if os.path.exists(_path): 39 | os.remove(_path) 40 | shutil.move(path, _path) 41 | 42 | 43 | def make_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('root') 46 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 47 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpe', '.jpg', '.jpeg', '.png']) 48 | parser.add_argument('--level', default='info', help='logging level') 49 | return parser.parse_args() 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /donate_alipay.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/donate_alipay.jpg -------------------------------------------------------------------------------- /donate_mm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/donate_mm.jpg -------------------------------------------------------------------------------- /download_url.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import sys 20 | import argparse 21 | import threading 22 | 23 | import numpy as np 24 | import tqdm 25 | import wget 26 | 27 | 28 | def _task(url, root, ext): 29 | path = wget.download(url, bar=None) 30 | with open(path + ext, 'w') as f: 31 | f.write(url) 32 | 33 | 34 | def task(urls, root, ext, pbar, lock, f): 35 | for url in urls: 36 | url = url.rstrip() 37 | try: 38 | _task(url, root, ext) 39 | except: 40 | with lock: 41 | f.write(url + '\n') 42 | pbar.update() 43 | 44 | 45 | def main(): 46 | args = make_args() 47 | root = os.path.expandvars(os.path.expanduser(args.root)) 48 | os.makedirs(root, exist_ok=True) 49 | os.chdir(root) 50 | workers = [] 51 | urls = list(set(sys.stdin.readlines())) 52 | lock = threading.Lock() 53 | with tqdm.tqdm(total=len(urls)) as pbar, open(root + args.ext, 'w') as f: 54 | for urls in np.array_split(urls, args.workers): 55 | w = threading.Thread(target=task, args=(urls, root, args.ext, pbar, lock, f)) 56 | w.start() 57 | workers.append(w) 58 | for w in workers: 59 | w.join() 60 | 61 | 62 | def make_args(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('root') 65 | parser.add_argument('-w', '--workers', type=int, default=6) 66 | parser.add_argument('-e', '--ext', default='.url') 67 | return parser.parse_args() 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/image.jpg -------------------------------------------------------------------------------- /logging.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 5 | handlers: 6 | console: 7 | class: logging.StreamHandler 8 | level: INFO 9 | formatter: simple 10 | stream: ext://sys.stderr 11 | root: 12 | level: INFO 13 | handlers: [console] -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | 20 | import numpy as np 21 | import torch 22 | import torch.autograd 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | import utils.iou.torch 27 | 28 | 29 | class ConfigChannels(object): 30 | def __init__(self, config, state_dict=None, channels=3): 31 | self.config = config 32 | self.state_dict = state_dict 33 | self.channels = channels 34 | 35 | def __call__(self, default, name, fn=lambda var: var.size(0)): 36 | if self.state_dict is None: 37 | self.channels = default 38 | else: 39 | var = self.state_dict[name] 40 | self.channels = fn(var) 41 | if self.channels != default: 42 | logging.warning('%s: change number of output channels from %d to %d' % (name, default, self.channels)) 43 | return self.channels 44 | 45 | 46 | def output_channels(num_anchors, num_cls): 47 | if num_cls > 1: 48 | return num_anchors * (5 + num_cls) 49 | else: 50 | return num_anchors * 5 51 | 52 | 53 | def meshgrid(rows, cols, swap=False): 54 | i = torch.arange(0, rows).repeat(cols).view(-1, 1) 55 | j = torch.arange(0, cols).view(-1, 1).repeat(1, rows).view(-1, 1) 56 | return torch.cat([i, j], 1) if swap else torch.cat([j, i], 1) 57 | 58 | 59 | def iou_match(yx_min, yx_max, data): 60 | batch_size, cells, num_anchors, _ = yx_min.size() 61 | iou_matrix = utils.iou.torch.batch_iou_matrix(yx_min.view(batch_size, -1, 2), yx_max.view(batch_size, -1, 2), data['yx_min'], data['yx_max']) 62 | iou_matrix = iou_matrix.view(batch_size, cells, num_anchors, -1) 63 | iou, index = iou_matrix.max(-1) 64 | _index = torch.unbind(index.view(batch_size, -1)) 65 | _data = {} 66 | for key in 'yx_min, yx_max, cls'.split(', '): 67 | t = data[key] 68 | if len(t.size()) == 2: 69 | t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors) 70 | elif len(t.size()) == 3: 71 | t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors, -1) 72 | _data[key] = t 73 | return iou_matrix, iou, index, _data 74 | 75 | 76 | def fit_positive(rows, cols, yx_min, yx_max, anchors): 77 | device_id = anchors.get_device() if torch.cuda.is_available() else None 78 | batch_size, num, _ = yx_min.size() 79 | num_anchors, _ = anchors.size() 80 | valid = torch.prod(yx_min < yx_max, -1) 81 | center = (yx_min + yx_max) / 2 82 | ij = torch.floor(center) 83 | i, j = torch.unbind(ij.long(), -1) 84 | index = i * cols + j 85 | anchors2 = anchors / 2 86 | iou_matrix = utils.iou.torch.iou_matrix((yx_min - center).view(-1, 2), (yx_max - center).view(-1, 2), -anchors2, anchors2).view(batch_size, -1, num_anchors) 87 | iou, index_anchor = iou_matrix.max(-1) 88 | _positive = [] 89 | cells = rows * cols 90 | for valid, index, index_anchor in zip(torch.unbind(valid), torch.unbind(index), torch.unbind(index_anchor)): 91 | index, index_anchor = (t[valid] for t in (index, index_anchor)) 92 | t = utils.ensure_device(torch.ByteTensor(cells, num_anchors).zero_(), device_id) 93 | t[index, index_anchor] = 1 94 | _positive.append(t) 95 | return torch.stack(_positive) 96 | 97 | 98 | def fill_norm(yx_min, yx_max, anchors): 99 | center = (yx_min + yx_max) / 2 100 | ij = torch.floor(center) 101 | center_offset = center - ij 102 | size = yx_max - yx_min 103 | return center_offset, torch.log(size / anchors.view(1, -1, 2)) 104 | 105 | 106 | def square(t): 107 | return t * t 108 | 109 | 110 | class Inference(nn.Module): 111 | def __init__(self, config, dnn, anchors): 112 | nn.Module.__init__(self) 113 | self.config = config 114 | self.dnn = dnn 115 | self.anchors = anchors 116 | 117 | def forward(self, x): 118 | device_id = x.get_device() if torch.cuda.is_available() else None 119 | feature = self.dnn(x) 120 | rows, cols = feature.size()[-2:] 121 | cells = rows * cols 122 | _feature = feature.permute(0, 2, 3, 1).contiguous().view(feature.size(0), cells, self.anchors.size(0), -1) 123 | sigmoid = F.sigmoid(_feature[:, :, :, :3]) 124 | iou = sigmoid[:, :, :, 0] 125 | ij = torch.autograd.Variable(utils.ensure_device(meshgrid(rows, cols).view(1, -1, 1, 2), device_id)) 126 | center_offset = sigmoid[:, :, :, 1:3] 127 | center = ij + center_offset 128 | size_norm = _feature[:, :, :, 3:5] 129 | anchors = torch.autograd.Variable(utils.ensure_device(self.anchors.view(1, 1, -1, 2), device_id)) 130 | size = torch.exp(size_norm) * anchors 131 | size2 = size / 2 132 | yx_min = center - size2 133 | yx_max = center + size2 134 | logits = _feature[:, :, :, 5:] if _feature.size(-1) > 5 else None 135 | return feature, iou, center_offset, size_norm, yx_min, yx_max, logits 136 | 137 | 138 | def loss(anchors, data, pred, threshold): 139 | iou = pred['iou'] 140 | device_id = iou.get_device() if torch.cuda.is_available() else None 141 | rows, cols = pred['feature'].size()[-2:] 142 | iou_matrix, _iou, _, _data = iou_match(pred['yx_min'].data, pred['yx_max'].data, data) 143 | anchors = utils.ensure_device(anchors, device_id) 144 | positive = fit_positive(rows, cols, *(data[key] for key in 'yx_min, yx_max'.split(', ')), anchors) 145 | negative = ~positive & (_iou < threshold) 146 | _center_offset, _size_norm = fill_norm(*(_data[key] for key in 'yx_min, yx_max'.split(', ')), anchors) 147 | positive, negative, _iou, _center_offset, _size_norm, _cls = (torch.autograd.Variable(t) for t in (positive, negative, _iou, _center_offset, _size_norm, _data['cls'])) 148 | _positive = torch.unsqueeze(positive, -1) 149 | loss = {} 150 | # iou 151 | loss['foreground'] = F.mse_loss(iou[positive], _iou[positive], size_average=False) 152 | loss['background'] = torch.sum(square(iou[negative])) 153 | # bbox 154 | loss['center'] = F.mse_loss(pred['center_offset'][_positive], _center_offset[_positive], size_average=False) 155 | loss['size'] = F.mse_loss(pred['size_norm'][_positive], _size_norm[_positive], size_average=False) 156 | # cls 157 | if 'logits' in pred: 158 | logits = pred['logits'] 159 | if len(_cls.size()) > 3: 160 | loss['cls'] = F.mse_loss(F.softmax(logits, -1)[_positive], _cls[_positive], size_average=False) 161 | else: 162 | loss['cls'] = F.cross_entropy(logits[_positive].view(-1, logits.size(-1)), _cls[positive].view(-1)) 163 | # normalize 164 | cnt = float(np.multiply.reduce(positive.size())) 165 | for key in loss: 166 | loss[key] /= cnt 167 | return loss, dict(iou=_iou, data=_data, positive=positive, negative=negative) 168 | 169 | 170 | def _inference(inference, tensor): 171 | feature, iou, center_offset, size_norm, yx_min, yx_max, logits = inference(tensor) 172 | pred = dict( 173 | feature=feature, iou=iou, 174 | center_offset=center_offset, size_norm=size_norm, 175 | yx_min=yx_min, yx_max=yx_max, 176 | ) 177 | if logits is not None: 178 | pred['logits'] = logits.contiguous() 179 | return pred 180 | -------------------------------------------------------------------------------- /model/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | from collections import OrderedDict 20 | 21 | import torch.nn as nn 22 | import torch.utils.model_zoo as model_zoo 23 | import torchvision.models.densenet as _model 24 | from torchvision.models.densenet import _DenseBlock, _Transition, model_urls 25 | 26 | import model 27 | 28 | 29 | class DenseNet(_model.DenseNet): 30 | def __init__(self, config_channels, anchors, num_cls, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0): 31 | nn.Module.__init__(self) 32 | 33 | # First convolution 34 | self.features = nn.Sequential(OrderedDict([ 35 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 36 | ('norm0', nn.BatchNorm2d(num_init_features)), 37 | ('relu0', nn.ReLU(inplace=True)), 38 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 39 | ])) 40 | 41 | # Each denseblock 42 | num_features = num_init_features 43 | for i, num_layers in enumerate(block_config): 44 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 45 | self.features.add_module('denseblock%d' % (i + 1), block) 46 | num_features = num_features + num_layers * growth_rate 47 | if i != len(block_config) - 1: 48 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 49 | self.features.add_module('transition%d' % (i + 1), trans) 50 | num_features = num_features // 2 51 | 52 | # Final batch norm 53 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 54 | self.features.add_module('conv', nn.Conv2d(num_features, model.output_channels(len(anchors), num_cls), 1)) 55 | 56 | # init 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | m.weight = nn.init.kaiming_normal(m.weight) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | 64 | def forward(self, x): 65 | return self.features(x) 66 | 67 | 68 | def densenet121(config_channels, anchors, num_cls, **kwargs): 69 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) 70 | if config_channels.config.getboolean('model', 'pretrained'): 71 | url = model_urls['densenet121'] 72 | logging.info('use pretrained model: ' + url) 73 | state_dict = model.state_dict() 74 | for key, value in model_zoo.load_url(url).items(): 75 | if key in state_dict: 76 | state_dict[key] = value 77 | model.load_state_dict(state_dict) 78 | return model 79 | 80 | 81 | def densenet169(config_channels, anchors, num_cls, **kwargs): 82 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) 83 | if config_channels.config.getboolean('model', 'pretrained'): 84 | url = model_urls['densenet169'] 85 | logging.info('use pretrained model: ' + url) 86 | state_dict = model.state_dict() 87 | for key, value in model_zoo.load_url(url).items(): 88 | if key in state_dict: 89 | state_dict[key] = value 90 | model.load_state_dict(state_dict) 91 | return model 92 | 93 | 94 | def densenet201(config_channels, anchors, num_cls, **kwargs): 95 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) 96 | if config_channels.config.getboolean('model', 'pretrained'): 97 | url = model_urls['densenet201'] 98 | logging.info('use pretrained model: ' + url) 99 | state_dict = model.state_dict() 100 | for key, value in model_zoo.load_url(url).items(): 101 | if key in state_dict: 102 | state_dict[key] = value 103 | model.load_state_dict(state_dict) 104 | return model 105 | 106 | 107 | def densenet161(config_channels, anchors, num_cls, **kwargs): 108 | model = DenseNet(config_channels, anchors, num_cls, num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) 109 | if config_channels.config.getboolean('model', 'pretrained'): 110 | url = model_urls['densenet161'] 111 | logging.info('use pretrained model: ' + url) 112 | state_dict = model.state_dict() 113 | for key, value in model_zoo.load_url(url).items(): 114 | if key in state_dict: 115 | state_dict[key] = value 116 | model.load_state_dict(state_dict) 117 | return model 118 | -------------------------------------------------------------------------------- /model/inception3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | 20 | import scipy.stats as stats 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.utils.model_zoo 25 | import torchvision.models.inception as _model 26 | from torchvision.models.inception import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE 27 | 28 | import model 29 | 30 | 31 | class Inception3(_model.Inception3): 32 | def __init__(self, config_channels, anchors, num_cls, transform_input=False): 33 | nn.Module.__init__(self) 34 | self.transform_input = transform_input 35 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 36 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 37 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 38 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 39 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 40 | self.Mixed_5b = InceptionA(192, pool_features=32) 41 | self.Mixed_5c = InceptionA(256, pool_features=64) 42 | self.Mixed_5d = InceptionA(288, pool_features=64) 43 | self.Mixed_6a = InceptionB(288) 44 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 45 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 46 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 47 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 48 | # aux_logits 49 | self.Mixed_7a = InceptionD(768) 50 | self.Mixed_7b = InceptionE(1280) 51 | self.Mixed_7c = InceptionE(2048) 52 | self.conv = nn.Conv2d(2048, model.output_channels(len(anchors), num_cls), 1) 53 | 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 56 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 57 | X = stats.truncnorm(-2, 2, scale=stddev) 58 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 59 | m.weight.data.copy_(values) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | 64 | if config_channels.config.getboolean('model', 'pretrained'): 65 | url = _model.model_urls['inception_v3_google'] 66 | logging.info('use pretrained model: ' + url) 67 | state_dict = self.state_dict() 68 | for key, value in torch.utils.model_zoo.load_url(url).items(): 69 | if key in state_dict: 70 | state_dict[key] = value 71 | self.load_state_dict(state_dict) 72 | 73 | def forward(self, x): 74 | if self.transform_input: 75 | x = x.clone() 76 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 77 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 78 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 79 | # 299 x 299 x 3 80 | x = self.Conv2d_1a_3x3(x) 81 | # 149 x 149 x 32 82 | x = self.Conv2d_2a_3x3(x) 83 | # 147 x 147 x 32 84 | x = self.Conv2d_2b_3x3(x) 85 | # 147 x 147 x 64 86 | x = F.max_pool2d(x, kernel_size=3, stride=2) 87 | # 73 x 73 x 64 88 | x = self.Conv2d_3b_1x1(x) 89 | # 73 x 73 x 80 90 | x = self.Conv2d_4a_3x3(x) 91 | # 71 x 71 x 192 92 | x = F.max_pool2d(x, kernel_size=3, stride=2) 93 | # 35 x 35 x 192 94 | x = self.Mixed_5b(x) 95 | # 35 x 35 x 256 96 | x = self.Mixed_5c(x) 97 | # 35 x 35 x 288 98 | x = self.Mixed_5d(x) 99 | # 35 x 35 x 288 100 | x = self.Mixed_6a(x) 101 | # 17 x 17 x 768 102 | x = self.Mixed_6b(x) 103 | # 17 x 17 x 768 104 | x = self.Mixed_6c(x) 105 | # 17 x 17 x 768 106 | x = self.Mixed_6d(x) 107 | # 17 x 17 x 768 108 | x = self.Mixed_6e(x) 109 | # 17 x 17 x 768 110 | # aux_logits 111 | # 17 x 17 x 768 112 | x = self.Mixed_7a(x) 113 | # 8 x 8 x 1280 114 | x = self.Mixed_7b(x) 115 | # 8 x 8 x 2048 116 | x = self.Mixed_7c(x) 117 | # 8 x 8 x 2048 118 | return self.conv(x) 119 | -------------------------------------------------------------------------------- /model/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import collections 19 | 20 | import torch.nn as nn 21 | 22 | import model 23 | 24 | 25 | def conv_bn(in_channels, out_channels, stride): 26 | return nn.Sequential(collections.OrderedDict([ 27 | ('conv', nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)), 28 | ('bn', nn.BatchNorm2d(out_channels)), 29 | ('act', nn.ReLU(inplace=True)), 30 | ])) 31 | 32 | 33 | def conv_dw(in_channels, stride): 34 | return nn.Sequential(collections.OrderedDict([ 35 | ('conv', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)), 36 | ('bn', nn.BatchNorm2d(in_channels)), 37 | ('act', nn.ReLU(inplace=True)), 38 | ])) 39 | 40 | 41 | def conv_pw(in_channels, out_channels): 42 | return nn.Sequential(collections.OrderedDict([ 43 | ('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)), 44 | ('bn', nn.BatchNorm2d(out_channels)), 45 | ('act', nn.ReLU(inplace=True)), 46 | ])) 47 | 48 | 49 | def conv_unit(in_channels, out_channels, stride): 50 | return nn.Sequential(collections.OrderedDict([ 51 | ('dw', conv_dw(in_channels, stride)), 52 | ('pw', conv_pw(in_channels, out_channels)), 53 | ])) 54 | 55 | 56 | class MobileNet(nn.Module): 57 | def __init__(self, config_channels, anchors, num_cls): 58 | nn.Module.__init__(self) 59 | layers = [] 60 | layers.append(conv_bn(config_channels.channels, config_channels(32, 'layers.%d.conv.weight' % len(layers)), 2)) 61 | layers.append(conv_unit(config_channels.channels, config_channels(64, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 62 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 63 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 64 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 65 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 66 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 67 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 68 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 69 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 70 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 71 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 72 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 73 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 74 | layers.append(nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1)) 75 | self.layers = nn.Sequential(*layers) 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | m.weight = nn.init.kaiming_normal(m.weight) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import re 20 | 21 | import torch.nn as nn 22 | import torch.utils.model_zoo as model_zoo 23 | import torchvision.models.resnet as _model 24 | from torchvision.models.resnet import conv3x3 25 | 26 | import model 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | def __init__(self, config_channels, prefix, channels, stride=1): 31 | nn.Module.__init__(self) 32 | channels_in = config_channels.channels 33 | self.conv1 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), stride) 34 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix)) 37 | self.bn2 = nn.BatchNorm2d(config_channels.channels) 38 | if stride > 1 or channels_in != config_channels.channels: 39 | downsample = [] 40 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False)) 41 | downsample.append(nn.BatchNorm2d(config_channels.channels)) 42 | self.downsample = nn.Sequential(*downsample) 43 | else: 44 | self.downsample = None 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | def __init__(self, config_channels, prefix, channels, stride=1): 67 | nn.Module.__init__(self) 68 | channels_in = config_channels.channels 69 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 71 | self.conv2 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix), kernel_size=3, stride=stride, padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(config_channels.channels) 73 | self.conv3 = nn.Conv2d(config_channels.channels, config_channels(channels * 4, '%s.conv3.weight' % prefix), kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(config_channels.channels) 75 | self.relu = nn.ReLU(inplace=True) 76 | if stride > 1 or channels_in != config_channels.channels: 77 | downsample = [] 78 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False)) 79 | downsample.append(nn.BatchNorm2d(config_channels.channels)) 80 | self.downsample = nn.Sequential(*downsample) 81 | else: 82 | self.downsample = None 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(_model.ResNet): 108 | def __init__(self, config_channels, anchors, num_cls, block, layers): 109 | nn.Module.__init__(self) 110 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(64, 'conv1.weight'), kernel_size=7, stride=2, padding=3, bias=False) 111 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(config_channels, 'layer1', block, 64, layers[0]) 115 | self.layer2 = self._make_layer(config_channels, 'layer2', block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(config_channels, 'layer3', block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(config_channels, 'layer4', block, 512, layers[3], stride=2) 118 | self.conv = nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | m.weight = nn.init.kaiming_normal(m.weight) 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | 127 | def _make_layer(self, config_channels, prefix, block, channels, blocks, stride=1): 128 | layers = [] 129 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels, stride)) 130 | for i in range(1, blocks): 131 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels)) 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | x = self.bn1(x) 137 | x = self.relu(x) 138 | x = self.maxpool(x) 139 | 140 | x = self.layer1(x) 141 | x = self.layer2(x) 142 | x = self.layer3(x) 143 | x = self.layer4(x) 144 | 145 | return self.conv(x) 146 | 147 | def scope(self, name): 148 | comp = name.split('.')[:-1] 149 | try: 150 | comp[-1] = re.search('[(conv)|(bn)](\d+)', comp[-1]).group(1) 151 | except AttributeError: 152 | if len(comp) > 1: 153 | if comp[-2] == 'downsample': 154 | comp = comp[:-1] 155 | else: 156 | assert False, name 157 | else: 158 | assert comp[-1] == 'conv', name 159 | return '.'.join(comp) 160 | 161 | 162 | def resnet18(config_channels, anchors, num_cls, **kwargs): 163 | model = ResNet(config_channels, anchors, num_cls, BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if config_channels.config.getboolean('model', 'pretrained'): 165 | url = _model.model_urls['resnet18'] 166 | logging.info('use pretrained model: ' + url) 167 | state_dict = model.state_dict() 168 | for key, value in model_zoo.load_url(url).items(): 169 | if key in state_dict: 170 | state_dict[key] = value 171 | model.load_state_dict(state_dict) 172 | return model 173 | 174 | 175 | def resnet34(config_channels, anchors, num_cls, **kwargs): 176 | model = ResNet(config_channels, anchors, num_cls, BasicBlock, [3, 4, 6, 3], **kwargs) 177 | if config_channels.config.getboolean('model', 'pretrained'): 178 | url = _model.model_urls['resnet34'] 179 | logging.info('use pretrained model: ' + url) 180 | state_dict = model.state_dict() 181 | for key, value in model_zoo.load_url(url).items(): 182 | if key in state_dict: 183 | state_dict[key] = value 184 | model.load_state_dict(state_dict) 185 | return model 186 | 187 | 188 | def resnet50(config_channels, anchors, num_cls, **kwargs): 189 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 6, 3], **kwargs) 190 | if config_channels.config.getboolean('model', 'pretrained'): 191 | url = _model.model_urls['resnet50'] 192 | logging.info('use pretrained model: ' + url) 193 | state_dict = model.state_dict() 194 | for key, value in model_zoo.load_url(url).items(): 195 | if key in state_dict: 196 | state_dict[key] = value 197 | model.load_state_dict(state_dict) 198 | return model 199 | 200 | 201 | def resnet101(config_channels, anchors, num_cls, **kwargs): 202 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 23, 3], **kwargs) 203 | if config_channels.config.getboolean('model', 'pretrained'): 204 | url = _model.model_urls['resnet101'] 205 | logging.info('use pretrained model: ' + url) 206 | state_dict = model.state_dict() 207 | for key, value in model_zoo.load_url(url).items(): 208 | if key in state_dict: 209 | state_dict[key] = value 210 | model.load_state_dict(state_dict) 211 | return model 212 | 213 | 214 | def resnet152(config_channels, anchors, num_cls, **kwargs): 215 | model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 8, 36, 3], **kwargs) 216 | if config_channels.config.getboolean('model', 'pretrained'): 217 | url = _model.model_urls['resnet152'] 218 | logging.info('use pretrained model: ' + url) 219 | state_dict = model.state_dict() 220 | for key, value in model_zoo.load_url(url).items(): 221 | if key in state_dict: 222 | state_dict[key] = value 223 | model.load_state_dict(state_dict) 224 | return model 225 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | 20 | import torch.nn as nn 21 | import torch.utils.model_zoo as model_zoo 22 | import torchvision.models.vgg as _model 23 | from torchvision.models.vgg import model_urls, cfg 24 | 25 | import model 26 | 27 | 28 | class VGG(_model.VGG): 29 | def __init__(self, config_channels, anchors, num_cls, features): 30 | nn.Module.__init__(self) 31 | self.features = features 32 | self.conv = nn.Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1) 33 | self._initialize_weights() 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | return self.conv(x) 38 | 39 | 40 | def make_layers(config_channels, cfg, batch_norm=False): 41 | features = [] 42 | for v in cfg: 43 | if v == 'M': 44 | features += [nn.MaxPool2d(kernel_size=2, stride=2)] 45 | else: 46 | conv2d = nn.Conv2d(config_channels.channels, config_channels(v, 'features.%d.weight' % len(features)), kernel_size=3, padding=1) 47 | if batch_norm: 48 | features += [conv2d, nn.BatchNorm2d(config_channels.channels), nn.ReLU(inplace=True)] 49 | else: 50 | features += [conv2d, nn.ReLU(inplace=True)] 51 | return nn.Sequential(*features) 52 | 53 | 54 | def vgg11(config_channels, anchors, num_cls): 55 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['A'])) 56 | if config_channels.config.getboolean('model', 'pretrained'): 57 | url = model_urls['vgg11'] 58 | logging.info('use pretrained model: ' + url) 59 | state_dict = model.state_dict() 60 | for key, value in model_zoo.load_url(url).items(): 61 | if key in state_dict: 62 | state_dict[key] = value 63 | model.load_state_dict(state_dict) 64 | return model 65 | 66 | 67 | def vgg11_bn(config_channels, anchors, num_cls): 68 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['A'], batch_norm=True)) 69 | if config_channels.config.getboolean('model', 'pretrained'): 70 | url = model_urls['vgg11_bn'] 71 | logging.info('use pretrained model: ' + url) 72 | state_dict = model.state_dict() 73 | for key, value in model_zoo.load_url(url).items(): 74 | if key in state_dict: 75 | state_dict[key] = value 76 | model.load_state_dict(state_dict) 77 | return model 78 | 79 | 80 | def vgg13(config_channels, anchors, num_cls): 81 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['B'])) 82 | if config_channels.config.getboolean('model', 'pretrained'): 83 | url = model_urls['vgg13'] 84 | logging.info('use pretrained model: ' + url) 85 | state_dict = model.state_dict() 86 | for key, value in model_zoo.load_url(url).items(): 87 | if key in state_dict: 88 | state_dict[key] = value 89 | model.load_state_dict(state_dict) 90 | return model 91 | 92 | 93 | def vgg13_bn(config_channels, anchors, num_cls): 94 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['B'], batch_norm=True)) 95 | if config_channels.config.getboolean('model', 'pretrained'): 96 | url = model_urls['vgg13_bn'] 97 | logging.info('use pretrained model: ' + url) 98 | state_dict = model.state_dict() 99 | for key, value in model_zoo.load_url(url).items(): 100 | if key in state_dict: 101 | state_dict[key] = value 102 | model.load_state_dict(state_dict) 103 | return model 104 | 105 | 106 | def vgg16(config_channels, anchors, num_cls): 107 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['D'])) 108 | if config_channels.config.getboolean('model', 'pretrained'): 109 | url = model_urls['vgg16'] 110 | logging.info('use pretrained model: ' + url) 111 | state_dict = model.state_dict() 112 | for key, value in model_zoo.load_url(url).items(): 113 | if key in state_dict: 114 | state_dict[key] = value 115 | model.load_state_dict(state_dict) 116 | return model 117 | 118 | 119 | def vgg16_bn(config_channels, anchors, num_cls): 120 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['D'], batch_norm=True)) 121 | if config_channels.config.getboolean('model', 'pretrained'): 122 | url = model_urls['vgg16_bn'] 123 | logging.info('use pretrained model: ' + url) 124 | state_dict = model.state_dict() 125 | for key, value in model_zoo.load_url(url).items(): 126 | if key in state_dict: 127 | state_dict[key] = value 128 | model.load_state_dict(state_dict) 129 | return model 130 | 131 | 132 | def vgg19(config_channels, anchors, num_cls): 133 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['E'])) 134 | if config_channels.config.getboolean('model', 'pretrained'): 135 | url = model_urls['vgg19'] 136 | logging.info('use pretrained model: ' + url) 137 | state_dict = model.state_dict() 138 | for key, value in model_zoo.load_url(url).items(): 139 | if key in state_dict: 140 | state_dict[key] = value 141 | model.load_state_dict(state_dict) 142 | return model 143 | 144 | 145 | def vgg19_bn(config_channels, anchors, num_cls): 146 | model = VGG(config_channels, anchors, num_cls, make_layers(config_channels, cfg['E'], batch_norm=True)) 147 | if config_channels.config.getboolean('model', 'pretrained'): 148 | url = model_urls['vgg19_bn'] 149 | logging.info('use pretrained model: ' + url) 150 | state_dict = model.state_dict() 151 | for key, value in model_zoo.load_url(url).items(): 152 | if key in state_dict: 153 | state_dict[key] = value 154 | model.load_state_dict(state_dict) 155 | return model 156 | -------------------------------------------------------------------------------- /model/yolo2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import collections.abc 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.autograd 24 | 25 | import model 26 | 27 | 28 | settings = { 29 | 'size': (416, 416), 30 | } 31 | 32 | 33 | def reorg(x, stride_h=2, stride_w=2): 34 | batch_size, channels, height, width = x.size() 35 | _height, _width = height // stride_h, width // stride_w 36 | if 1: 37 | x = x.view(batch_size, channels, _height, stride_h, _width, stride_w).transpose(3, 4).contiguous() 38 | x = x.view(batch_size, channels, _height * _width, stride_h * stride_w).transpose(2, 3).contiguous() 39 | x = x.view(batch_size, channels, stride_h * stride_w, _height, _width).transpose(1, 2).contiguous() 40 | x = x.view(batch_size, -1, _height, _width) 41 | else: 42 | x = x.view(batch_size, channels, _height, stride_h, _width, stride_w) 43 | x = x.permute(0, 1, 3, 5, 2, 4) # batch_size, channels, stride, stride, _height, _width 44 | x = x.contiguous() 45 | x = x.view(batch_size, -1, _height, _width) 46 | return x 47 | 48 | 49 | class Conv2d(nn.Module): 50 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=True, act=True): 51 | nn.Module.__init__(self) 52 | if isinstance(padding, bool): 53 | if isinstance(kernel_size, collections.abc.Iterable): 54 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0 55 | else: 56 | padding = (kernel_size - 1) // 2 if padding else 0 57 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn) 58 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x 59 | self.act = nn.LeakyReLU(0.1, inplace=True) if act else lambda x: x 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn(x) 64 | x = self.act(x) 65 | return x 66 | 67 | 68 | class Darknet(nn.Module): 69 | def __init__(self, config_channels, anchors, num_cls, stride=2, ratio=1): 70 | nn.Module.__init__(self) 71 | self.stride = stride 72 | channels = int(32 * ratio) 73 | layers = [] 74 | 75 | bn = config_channels.config.getboolean('batch_norm', 'enable') 76 | # layers1 77 | for _ in range(2): 78 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 79 | layers.append(nn.MaxPool2d(kernel_size=2)) 80 | channels *= 2 81 | # down 4 82 | for _ in range(2): 83 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 84 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers1.%d.conv.weight' % len(layers)), 1, bn=bn)) 85 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 86 | layers.append(nn.MaxPool2d(kernel_size=2)) 87 | channels *= 2 88 | # down 16 89 | for _ in range(2): 90 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 91 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers1.%d.conv.weight' % len(layers)), 1, bn=bn)) 92 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers1.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 93 | self.layers1 = nn.Sequential(*layers) 94 | 95 | # layers2 96 | layers = [] 97 | layers.append(nn.MaxPool2d(kernel_size=2)) 98 | channels *= 2 99 | # down 32 100 | for _ in range(2): 101 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers2.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 102 | layers.append(Conv2d(config_channels.channels, config_channels(channels // 2, 'layers2.%d.conv.weight' % len(layers)), 1, bn=bn)) 103 | for _ in range(3): 104 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers2.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 105 | self.layers2 = nn.Sequential(*layers) 106 | 107 | self.passthrough = Conv2d(self.layers1[-1].conv.weight.size(0), config_channels(int(64 * ratio), 'passthrough.conv.weight'), 1, bn=bn) 108 | 109 | # layers3 110 | layers = [] 111 | layers.append(Conv2d(self.passthrough.conv.weight.size(0) * self.stride * self.stride + self.layers2[-1].conv.weight.size(0), config_channels(int(1024 * ratio), 'layers3.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 112 | layers.append(Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1, bn=False, act=False)) 113 | self.layers3 = nn.Sequential(*layers) 114 | 115 | self.init() 116 | 117 | def init(self): 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | m.weight = nn.init.kaiming_normal(m.weight) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | 125 | def forward(self, x): 126 | x = self.layers1(x) 127 | _x = reorg(self.passthrough(x), self.stride) 128 | x = self.layers2(x) 129 | x = torch.cat([_x, x], 1) 130 | return self.layers3(x) 131 | 132 | def scope(self, name): 133 | return '.'.join(name.split('.')[:-2]) 134 | 135 | def get_mapper(self, index): 136 | if index == 94: 137 | return lambda indices, channels: torch.cat([indices + i * channels for i in range(self.stride * self.stride)]) 138 | 139 | 140 | class Tiny(nn.Module): 141 | def __init__(self, config_channels, anchors, num_cls, channels=16): 142 | nn.Module.__init__(self) 143 | layers = [] 144 | 145 | bn = config_channels.config.getboolean('batch_norm', 'enable') 146 | for _ in range(5): 147 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 148 | layers.append(nn.MaxPool2d(kernel_size=2)) 149 | channels *= 2 150 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 151 | layers.append(nn.ConstantPad2d((0, 1, 0, 1), float(np.finfo(np.float32).min))) 152 | layers.append(nn.MaxPool2d(kernel_size=2, stride=1)) 153 | channels *= 2 154 | for _ in range(2): 155 | layers.append(Conv2d(config_channels.channels, config_channels(channels, 'layers.%d.conv.weight' % len(layers)), 3, bn=bn, padding=True)) 156 | layers.append(Conv2d(config_channels.channels, model.output_channels(len(anchors), num_cls), 1, bn=False, act=False)) 157 | self.layers = nn.Sequential(*layers) 158 | 159 | self.init() 160 | 161 | def init(self): 162 | for m in self.modules(): 163 | if isinstance(m, nn.Conv2d): 164 | m.weight = nn.init.xavier_normal(m.weight) 165 | elif isinstance(m, nn.BatchNorm2d): 166 | m.weight.data.fill_(1) 167 | m.bias.data.zero_() 168 | 169 | def forward(self, x): 170 | return self.layers(x) 171 | 172 | def scope(self, name): 173 | return '.'.join(name.split('.')[:-2]) 174 | -------------------------------------------------------------------------------- /pruner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import numpy as np 26 | import torch.autograd 27 | import torch.cuda 28 | import torch.optim 29 | import torch.utils.data 30 | import humanize 31 | 32 | import model 33 | import utils 34 | import utils.train 35 | import utils.channel 36 | 37 | 38 | def main(): 39 | args = make_args() 40 | config = configparser.ConfigParser() 41 | utils.load_config(config, args.config) 42 | for cmd in args.modify: 43 | utils.modify_config(config, cmd) 44 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 45 | logging.config.dictConfig(yaml.load(f)) 46 | model_dir = utils.get_model_dir(config) 47 | category = utils.get_category(config) 48 | anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() 49 | path, step, epoch = utils.train.load_model(model_dir) 50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 51 | _model = utils.parse_attr(config.get('model', 'dnn')) 52 | dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category)) 53 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values()))) 54 | dnn.load_state_dict(state_dict) 55 | height, width = tuple(map(int, config.get('image', 'size').split())) 56 | image = torch.autograd.Variable(torch.randn(args.batch_size, 3, height, width)) 57 | output = dnn(image) 58 | state_dict = dnn.state_dict() 59 | d = utils.dense(state_dict[args.name]) 60 | keep = torch.LongTensor(np.argsort(d)[:int(len(d) * args.keep)]) 61 | modifier = utils.channel.Modifier( 62 | args.name, state_dict, dnn, 63 | lambda name, var: var[keep], 64 | lambda name, var, mapper: var[mapper(keep, len(d))], 65 | debug=args.debug, 66 | ) 67 | modifier(output.grad_fn) 68 | if args.debug: 69 | path = modifier.dot.view('%s.%s.gv' % (os.path.basename(model_dir), os.path.basename(os.path.splitext(__file__)[0])), os.path.dirname(model_dir)) 70 | logging.info(path) 71 | assert len(keep) == len(state_dict[args.name]) 72 | dnn = _model(model.ConfigChannels(config, state_dict), anchors, len(category)) 73 | dnn.load_state_dict(state_dict) 74 | dnn(image) 75 | if not args.debug: 76 | torch.save(state_dict, path) 77 | 78 | 79 | def make_args(): 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('name') 82 | parser.add_argument('keep', type=float) 83 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 84 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 85 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 86 | parser.add_argument('-d', '--debug', action='store_true') 87 | parser.add_argument('--logging', default='logging.yml', help='logging config') 88 | return parser.parse_args() 89 | 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /quick_start.sh: -------------------------------------------------------------------------------- 1 | echo download VOC dataset 2 | LINKS=" 3 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 4 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 5 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 6 | " 7 | ROOT=~/data 8 | for LINK in $LINKS 9 | do 10 | aria2c --auto-file-renaming=false -d $ROOT $LINK 11 | tar -kxvf $ROOT/$(basename $LINK) -C $ROOT 12 | done 13 | 14 | echo download COCO dataset 15 | LINKS=" 16 | http://images.cocodataset.org/zips/train2014.zip 17 | http://images.cocodataset.org/zips/val2014.zip 18 | http://images.cocodataset.org/annotations/annotations_trainval2014.zip 19 | http://images.cocodataset.org/zips/train2017.zip 20 | http://images.cocodataset.org/zips/val2017.zip 21 | http://images.cocodataset.org/annotations/annotations_trainval2017.zip 22 | " 23 | ROOT=~/data/coco 24 | for LINK in $LINKS 25 | do 26 | aria2c --auto-file-renaming=false -d $ROOT $LINK 27 | unzip -n $ROOT/$(basename $LINK) -d $ROOT 28 | done 29 | rm $ROOT/val2014/COCO_val2014_000000320612.jpg 30 | 31 | echo cache data 32 | python3 cache.py -m cache/datasets=cache.voc.cache cache/name=cache_voc cache/category=config/category/20 33 | python3 cache.py -m cache/datasets=cache.coco.cache cache/name=cache_coco cache/category=config/category/80 34 | python3 cache.py -m cache/datasets='cache.voc.cache cache.coco.cache' cache/name=cache_20 cache/category=config/category/20 35 | 36 | ROOT=~/model/darknet 37 | 38 | echo test VOC models 39 | MODELS=" 40 | yolo-voc 41 | tiny-yolo-voc 42 | " 43 | 44 | for MODEL in $MODELS 45 | do 46 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/$MODEL.weights 47 | python3 convert_darknet_torch.py ~/model/darknet/$MODEL.weights -c config.ini config/darknet/$MODEL.ini -d 48 | python3 eval.py -c config.ini config/darknet/$MODEL.ini 49 | python3 detect.py -c config.ini config/darknet/$MODEL.ini -i image.jpg --pause 50 | done 51 | 52 | echo test COCO models 53 | MODELS=" 54 | yolo 55 | " 56 | 57 | for MODEL in $MODELS 58 | do 59 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/$MODEL.weights 60 | python3 convert_darknet_torch.py ~/model/darknet/$MODEL.weights -c config.ini config/darknet/$MODEL.ini -d 61 | python3 eval.py -c config.ini config/darknet/$MODEL.ini 62 | python3 detect.py -c config.ini config/darknet/$MODEL.ini -i image.jpg --pause 63 | done 64 | 65 | echo convert pretrained Darknet model 66 | aria2c --auto-file-renaming=false -d $ROOT http://pjreddie.com/media/files/darknet19_448.conv.23 67 | python3 convert_darknet_torch.py ~/model/darknet/darknet19_448.conv.23 -m model/name=model_voc model/dnn=model.yolo2.Darknet -d --copy ~/model/darknet/darknet19_448.conv.23.pth 68 | 69 | echo reproduce the training results 70 | export CACHE_NAME=cache_voc MODEL_NAME=model_voc MODEL=model.yolo2.Darknet 71 | python3 train.py -b 64 -lr 1e-3 -e 160 -m cache/name=$CACHE_NAME model/name=$MODEL_NAME model/dnn=$MODEL train/optimizer='lambda params, lr: torch.optim.SGD(params, lr, momentum=0.9)' train/scheduler='lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)' -f ~/model/darknet/darknet19_448.conv.23.pth -d 72 | python3 eval.py -m cache/name=$CACHE_NAME model/name=$MODEL_NAME model/dnn=$MODEL 73 | -------------------------------------------------------------------------------- /receptive_field_analyzer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import multiprocessing 24 | import yaml 25 | 26 | import numpy as np 27 | import scipy.misc 28 | import torch.autograd 29 | import torch.cuda 30 | import torch.optim 31 | import torch.utils.data 32 | import tqdm 33 | import humanize 34 | 35 | import model 36 | import utils.data 37 | import utils.iou.torch 38 | import utils.postprocess 39 | import utils.train 40 | import utils.visualize 41 | 42 | 43 | class Dataset(torch.utils.data.Dataset): 44 | def __init__(self, height, width): 45 | self.points = np.array([(i, j) for i in range(height) for j in range(width)]) 46 | 47 | def __len__(self): 48 | return len(self.points) 49 | 50 | def __getitem__(self, index): 51 | return self.points[index] 52 | 53 | 54 | class Analyzer(object): 55 | def __init__(self, args, config): 56 | self.args = args 57 | self.config = config 58 | self.model_dir = utils.get_model_dir(config) 59 | self.category = utils.get_category(config) 60 | self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() 61 | self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), self.anchors, len(self.category)) 62 | self.dnn.eval() 63 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values()))) 64 | if torch.cuda.is_available(): 65 | self.dnn.cuda() 66 | self.height, self.width = tuple(map(int, config.get('image', 'size').split())) 67 | output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True)) 68 | _, _, self.rows, self.cols = output.size() 69 | self.i, self.j = self.rows // 2, self.cols // 2 70 | self.output = output[:, :, self.i, self.j] 71 | dataset = Dataset(self.height, self.width) 72 | try: 73 | workers = self.config.getint('data', 'workers') 74 | except configparser.NoOptionError: 75 | workers = multiprocessing.cpu_count() 76 | self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers) 77 | 78 | def __call__(self): 79 | changed = np.zeros([self.height, self.width], np.bool) 80 | for yx in tqdm.tqdm(self.loader): 81 | batch_size = yx.size(0) 82 | tensor = torch.zeros(batch_size, 3, self.height, self.width) 83 | for i, _yx in enumerate(torch.unbind(yx)): 84 | y, x = torch.unbind(_yx) 85 | tensor[i, :, y, x] = 1 86 | tensor = utils.ensure_device(tensor) 87 | output = self.dnn(torch.autograd.Variable(tensor, volatile=True)) 88 | output = output[:, :, self.i, self.j] 89 | cmp = output == self.output 90 | cmp = torch.prod(cmp, -1).data 91 | for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)): 92 | y, x = torch.unbind(_yx) 93 | changed[y, x] = c 94 | return changed 95 | 96 | 97 | def main(): 98 | args = make_args() 99 | config = configparser.ConfigParser() 100 | utils.load_config(config, args.config) 101 | for cmd in args.modify: 102 | utils.modify_config(config, cmd) 103 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 104 | logging.config.dictConfig(yaml.load(f)) 105 | analyzer = Analyzer(args, config) 106 | changed = analyzer() 107 | os.makedirs(analyzer.model_dir, exist_ok=True) 108 | path = os.path.join(analyzer.model_dir, args.filename) 109 | scipy.misc.imsave(path, (~changed).astype(np.uint8) * 255) 110 | logging.info(path) 111 | 112 | 113 | def make_args(): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 116 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 117 | parser.add_argument('-b', '--batch_size', default=16, type=int, help='batch size') 118 | parser.add_argument('-n', '--filename', default='receptive_field.jpg') 119 | parser.add_argument('--logging', default='logging.yml', help='logging config') 120 | return parser.parse_args() 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | humanize 2 | tqdm 3 | onnx_caffe2 4 | onnx 5 | torch<=0.3.1 6 | torchvision 7 | nltk 8 | pandas 9 | pycocotools 10 | XlsxWriter 11 | filelock 12 | matplotlib 13 | scikit_image 14 | pybenchmark 15 | tinydb 16 | graphviz 17 | pretrainedmodels 18 | inflection 19 | videosequence 20 | pymediainfo 21 | Pillow 22 | scipy 23 | skimage 24 | scikit_learn 25 | tensorboardX 26 | wget 27 | PyYAML 28 | -------------------------------------------------------------------------------- /split_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import random 21 | 22 | 23 | def main(): 24 | args = make_args() 25 | root = os.path.expanduser(os.path.expandvars(args.root)) 26 | realpaths = [] 27 | for dirpath, _, filenames in os.walk(root): 28 | for filename in filenames: 29 | if os.path.splitext(filename)[-1].lower() in args.exts and filename[0] != '.': 30 | path = os.path.join(dirpath, filename) 31 | realpath = os.path.relpath(path, root) 32 | realpaths.append(realpath) 33 | random.shuffle(realpaths) 34 | total = args.train + args.val + args.test 35 | nval = int(len(realpaths) * args.val / total) 36 | ntest = nval + int(len(realpaths) * args.test / total) 37 | val = realpaths[:nval] 38 | test = realpaths[nval:ntest] 39 | train = realpaths[ntest:] 40 | print('train=%d, val=%d, test=%d' % (len(train), len(val), len(test))) 41 | with open(os.path.join(root, 'train' + args.ext), 'w') as f: 42 | for path in train: 43 | f.write(path + '\n') 44 | with open(os.path.join(root, 'val' + args.ext), 'w') as f: 45 | for path in val: 46 | f.write(path + '\n') 47 | with open(os.path.join(root, 'test' + args.ext), 'w') as f: 48 | for path in test: 49 | f.write(path + '\n') 50 | 51 | 52 | def make_args(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('root') 55 | parser.add_argument('-e', '--exts', nargs='+', default=['.jpe', '.jpg', '.jpeg', '.png']) 56 | parser.add_argument('--train', type=float, default=7) 57 | parser.add_argument('--val', type=float, default=2) 58 | parser.add_argument('--test', type=float, default=1) 59 | parser.add_argument('--ext', default='.txt') 60 | return parser.parse_args() 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /transform/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | 20 | import torchvision 21 | 22 | import utils 23 | 24 | 25 | def parse_transform(config, method): 26 | if isinstance(method, str): 27 | attr = utils.parse_attr(method) 28 | sig = inspect.signature(attr) 29 | if len(sig.parameters) == 1: 30 | return attr(config) 31 | else: 32 | return attr() 33 | else: 34 | return method 35 | 36 | 37 | def get_transform(config, sequence, compose=torchvision.transforms.Compose): 38 | return compose([parse_transform(config, method) for method in sequence]) 39 | -------------------------------------------------------------------------------- /transform/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | import random 20 | 21 | import inflection 22 | import numpy as np 23 | import cv2 24 | 25 | import transform 26 | 27 | 28 | class Rotator(object): 29 | def __init__(self, y, x, height, width, angle): 30 | """ 31 | A efficient tool to rotate multiple images in the same size. 32 | :author 申瑞珉 (Ruimin Shen) 33 | :param y: The y coordinate of rotation point. 34 | :param x: The x coordinate of rotation point. 35 | :param height: Image height. 36 | :param width: Image width. 37 | :param angle: Rotate angle. 38 | """ 39 | self._mat = cv2.getRotationMatrix2D((x, y), angle, 1.0) 40 | r = np.abs(self._mat[0, :2]) 41 | _height, _width = np.inner(r, [height, width]), np.inner(r, [width, height]) 42 | fix_y, fix_x = _height / 2 - y, _width / 2 - x 43 | self._mat[:, 2] += [fix_x, fix_y] 44 | self._size = int(_width), int(_height) 45 | 46 | def __call__(self, image, flags=cv2.INTER_LINEAR, fill=None): 47 | if fill is None: 48 | fill = np.random.rand(3) * 256 49 | return cv2.warpAffine(image, self._mat, self._size, flags=flags, borderMode=cv2.BORDER_CONSTANT, borderValue=fill) 50 | 51 | def _rotate_points(self, points): 52 | _points = np.pad(points, [(0, 0), (0, 1)], 'constant') 53 | _points[:, 2] = 1 54 | _points = np.dot(self._mat, _points.T) 55 | return _points.T.astype(points.dtype) 56 | 57 | def rotate_points(self, points): 58 | return self._rotate_points(points[:, ::-1])[:, ::-1] 59 | 60 | 61 | def random_rotate(config, image, yx_min, yx_max): 62 | name = inspect.stack()[0][3] 63 | angle = random.uniform(*tuple(map(float, config.get('augmentation', name).split()))) 64 | height, width = image.shape[:2] 65 | p1, p2 = np.copy(yx_min), np.copy(yx_max) 66 | p1[:, 0] = yx_max[:, 0] 67 | p2[:, 0] = yx_min[:, 0] 68 | points = np.concatenate([yx_min, yx_max, p1, p2], 0) 69 | rotator = Rotator(height / 2, width / 2, height, width, angle) 70 | image = rotator(image, fill=0) 71 | points = rotator.rotate_points(points) 72 | bbox_points = np.reshape(points, [4, -1, 2]) 73 | yx_min = np.apply_along_axis(lambda points: np.min(points, 0), 0, bbox_points) 74 | yx_max = np.apply_along_axis(lambda points: np.max(points, 0), 0, bbox_points) 75 | return image, yx_min, yx_max 76 | 77 | 78 | class RandomRotate(object): 79 | def __init__(self, config): 80 | self.config = config 81 | self.fn = eval(inflection.underscore(type(self).__name__)) 82 | 83 | def __call__(self, data): 84 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max']) 85 | return data 86 | 87 | 88 | def flip_horizontally(image, yx_min, yx_max): 89 | assert len(image.shape) == 3 90 | image = cv2.flip(image, 1) 91 | width = image.shape[1] 92 | temp = width - yx_min[:, 1] 93 | yx_min[:, 1] = width - yx_max[:, 1] 94 | yx_max[:, 1] = temp 95 | return image, yx_min, yx_max 96 | 97 | 98 | def random_flip_horizontally(config, image, yx_min, yx_max): 99 | name = inspect.stack()[0][3] 100 | if random.random() > config.getfloat('augmentation', name): 101 | return flip_horizontally(image, yx_min, yx_max) 102 | else: 103 | return image, yx_min, yx_max 104 | 105 | 106 | class RandomFlipHorizontally(object): 107 | def __init__(self, config): 108 | self.config = config 109 | self.fn = eval(inflection.underscore(type(self).__name__)) 110 | 111 | def __call__(self, data): 112 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max']) 113 | return data 114 | 115 | 116 | def get_transform(config, sequence): 117 | return transform.get_transform(config, sequence) 118 | -------------------------------------------------------------------------------- /transform/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import random 19 | 20 | import numpy as np 21 | import torchvision 22 | import inflection 23 | import skimage.exposure 24 | import cv2 25 | 26 | 27 | class BGR2RGB(object): 28 | def __call__(self, image): 29 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | 31 | 32 | class BGR2HSV(object): 33 | def __call__(self, image): 34 | return cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 35 | 36 | 37 | class HSV2RGB(object): 38 | def __call__(self, image): 39 | return cv2.cvtColor(image, cv2.COLOR_HSV2RGB) 40 | 41 | 42 | class RandomBlur(object): 43 | def __init__(self, config): 44 | name = inflection.underscore(type(self).__name__) 45 | self.adjust = tuple(map(int, config.get('augmentation', name).split())) 46 | 47 | def __call__(self, image): 48 | adjust = tuple(random.randint(1, adjust) for adjust in self.adjust) 49 | return cv2.blur(image, adjust) 50 | 51 | 52 | class RandomHue(object): 53 | def __init__(self, config): 54 | name = inflection.underscore(type(self).__name__) 55 | self.adjust = tuple(map(int, config.get('augmentation', name).split())) 56 | 57 | def __call__(self, hsv): 58 | h, s, v = cv2.split(hsv) 59 | adjust = random.randint(*self.adjust) 60 | h = h.astype(np.int) + adjust 61 | h = np.clip(h, 0, 179).astype(hsv.dtype) 62 | return cv2.merge((h, s, v)) 63 | 64 | 65 | class RandomSaturation(object): 66 | def __init__(self, config): 67 | name = inflection.underscore(type(self).__name__) 68 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 69 | 70 | def __call__(self, hsv): 71 | h, s, v = cv2.split(hsv) 72 | adjust = random.uniform(*self.adjust) 73 | s = s * adjust 74 | s = np.clip(s, 0, 255).astype(hsv.dtype) 75 | return cv2.merge((h, s, v)) 76 | 77 | 78 | class RandomBrightness(object): 79 | def __init__(self, config): 80 | name = inflection.underscore(type(self).__name__) 81 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 82 | 83 | def __call__(self, hsv): 84 | h, s, v = cv2.split(hsv) 85 | adjust = random.uniform(*self.adjust) 86 | v = v * adjust 87 | v = np.clip(v, 0, 255).astype(hsv.dtype) 88 | return cv2.merge((h, s, v)) 89 | 90 | 91 | class RandomGamma(object): 92 | def __init__(self, config): 93 | name = inflection.underscore(type(self).__name__) 94 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 95 | 96 | def __call__(self, image): 97 | adjust = random.uniform(*self.adjust) 98 | return skimage.exposure.adjust_gamma(image, adjust) 99 | 100 | 101 | class Normalize(torchvision.transforms.Normalize): 102 | def __init__(self, config): 103 | name = inflection.underscore(type(self).__name__) 104 | mean, std = tuple(map(float, config.get('transform', name).split())) 105 | torchvision.transforms.Normalize.__init__(self, (mean, mean, mean), (std, std, std)) 106 | -------------------------------------------------------------------------------- /transform/resize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/transform/resize/__init__.py -------------------------------------------------------------------------------- /transform/resize/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inflection 19 | import numpy as np 20 | import cv2 21 | 22 | 23 | def rescale(image, height, width): 24 | return cv2.resize(image, (width, height)) 25 | 26 | 27 | class Rescale(object): 28 | def __init__(self): 29 | name = inflection.underscore(type(self).__name__) 30 | self.fn = eval(name) 31 | 32 | def __call__(self, image, height, width): 33 | return self.fn(image, height, width) 34 | 35 | 36 | def fixed(image, height, width): 37 | _height, _width, _ = image.shape 38 | if _height / _width > height / width: 39 | scale = height / _height 40 | else: 41 | scale = width / _width 42 | m = np.eye(2, 3) 43 | m[0, 0] = scale 44 | m[1, 1] = scale 45 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC 46 | return cv2.warpAffine(image, m, (width, height), flags=flags) 47 | 48 | 49 | class Fixed(object): 50 | def __init__(self): 51 | name = inflection.underscore(type(self).__name__) 52 | self.fn = eval(name) 53 | 54 | def __call__(self, image, height, width): 55 | return self.fn(image, height, width) 56 | 57 | 58 | class Resize(object): 59 | def __init__(self, config): 60 | name = config.get('data', inflection.underscore(type(self).__name__)) 61 | self.fn = eval(name) 62 | 63 | def __call__(self, image, height, width): 64 | return self.fn(image, height, width) 65 | -------------------------------------------------------------------------------- /transform/resize/label.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | 20 | import inflection 21 | import numpy as np 22 | import cv2 23 | 24 | 25 | def rescale(image, yx_min, yx_max, height, width): 26 | _height, _width = image.shape[:2] 27 | scale = np.array([height / _height, width / _width], np.float32) 28 | image = cv2.resize(image, (width, height)) 29 | yx_min *= scale 30 | yx_max *= scale 31 | return image, yx_min, yx_max 32 | 33 | 34 | class Rescale(object): 35 | def __init__(self): 36 | self.fn = eval(inflection.underscore(type(self).__name__)) 37 | 38 | def __call__(self, data, height, width): 39 | data['image'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['yx_min'], data['yx_max'], height, width) 40 | return data 41 | 42 | 43 | def resize(config, image, yx_min, yx_max, height, width): 44 | fn = eval(config.get('data', inspect.stack()[0][3])) 45 | return fn(image, yx_min, yx_max, height, width) 46 | 47 | 48 | class Resize(object): 49 | def __init__(self, config): 50 | self.config = config 51 | self.fn = eval(config.get('data', inflection.underscore(type(self).__name__))) 52 | 53 | def __call__(self, data, height, width): 54 | data['image'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['yx_min'], data['yx_max'], height, width) 55 | return data 56 | 57 | 58 | def random_crop(config, image, yx_min, yx_max, height, width): 59 | name = inspect.stack()[0][3] 60 | scale = config.getfloat('augmentation', name) 61 | assert 0 < scale <= 1 62 | _yx_min = np.min(yx_min, 0) 63 | _yx_max = np.max(yx_max, 0) 64 | dtype = yx_min.dtype 65 | size = np.array(image.shape[:2], dtype) 66 | margin = scale * np.random.rand(4).astype(dtype) * np.concatenate([_yx_min, size - _yx_max], 0) 67 | _yx_min = margin[:2] 68 | _yx_max = size - margin[2:] 69 | _ymin, _xmin = _yx_min 70 | _ymax, _xmax = _yx_max 71 | _ymin, _xmin, _ymax, _xmax = tuple(map(int, (_ymin, _xmin, _ymax, _xmax))) 72 | image = image[_ymin:_ymax, _xmin:_xmax, :] 73 | yx_min, yx_max = yx_min - _yx_min, yx_max - _yx_min 74 | return resize(config, image, yx_min, yx_max, height, width) 75 | 76 | 77 | class RandomCrop(object): 78 | def __init__(self, config): 79 | self.config = config 80 | self.fn = eval(inflection.underscore(type(self).__name__)) 81 | 82 | def __call__(self, data, height, width): 83 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'], height, width) 84 | return data 85 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import re 20 | import configparser 21 | import importlib 22 | import inspect 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import torch.autograd 27 | from PIL import Image 28 | 29 | 30 | class Compose(object): 31 | def __init__(self, transforms): 32 | self.transforms = transforms 33 | 34 | def __call__(self, img, yx_min, yx_max, cls): 35 | for t in self.transforms: 36 | img, yx_min, yx_max, cls = t(img, yx_min, yx_max, cls) 37 | return img, yx_min, yx_max, cls 38 | 39 | 40 | class RegexList(list): 41 | def __init__(self, l): 42 | for s in l: 43 | prog = re.compile(s) 44 | self.append(prog) 45 | 46 | def __call__(self, s): 47 | for prog in self: 48 | if prog.match(s): 49 | return True 50 | return False 51 | 52 | 53 | def get_cache_dir(config): 54 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 55 | name = config.get('cache', 'name') 56 | return os.path.join(root, name) 57 | 58 | 59 | def get_model_dir(config): 60 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 61 | name = config.get('model', 'name') 62 | model = config.get('model', 'dnn') 63 | return os.path.join(root, name, model) 64 | 65 | 66 | def get_eval_db(config): 67 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 68 | db = config.get('eval', 'db') 69 | return os.path.join(root, db) 70 | 71 | 72 | def get_category(config, cache_dir=None): 73 | path = os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))) if cache_dir is None else os.path.join(cache_dir, 'category') 74 | with open(path, 'r') as f: 75 | return [line.strip() for line in f] 76 | 77 | 78 | def get_anchors(config, dtype=np.float32): 79 | path = os.path.expanduser(os.path.expandvars(config.get('model', 'anchors'))) 80 | df = pd.read_csv(path, sep='\t', dtype=dtype) 81 | return df[['height', 'width']].values 82 | 83 | 84 | def parse_attr(s): 85 | m, n = s.rsplit('.', 1) 86 | m = importlib.import_module(m) 87 | return getattr(m, n) 88 | 89 | 90 | def load_config(config, paths): 91 | for path in paths: 92 | path = os.path.expanduser(os.path.expandvars(path)) 93 | assert os.path.exists(path) 94 | config.read(path) 95 | 96 | 97 | def modify_config(config, cmd): 98 | var, value = cmd.split('=', 1) 99 | section, option = var.split('/') 100 | if value: 101 | config.set(section, option, value) 102 | else: 103 | try: 104 | config.remove_option(section, option) 105 | except (configparser.NoSectionError, configparser.NoOptionError): 106 | pass 107 | 108 | 109 | def ensure_device(t, device_id=None, async=False): 110 | if torch.cuda.is_available(): 111 | t = t.cuda(device_id, async) 112 | return t 113 | 114 | 115 | def dense(var): 116 | return [torch.mean(torch.abs(x)) if torch.is_tensor(x) else np.abs(x) for x in var] 117 | 118 | 119 | def abs_mean(data, dtype=np.float32): 120 | assert isinstance(data, np.ndarray), type(data) 121 | return np.sum(np.abs(data)) / dtype(data.size) 122 | 123 | 124 | def image_size(path): 125 | with Image.open(path) as image: 126 | return image.size 127 | -------------------------------------------------------------------------------- /utils/cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import numpy as np 19 | 20 | 21 | def verify_coords(yx_min, yx_max, size): 22 | assert np.all(yx_min <= yx_max), 'yx_min <= yx_max' 23 | assert np.all(0 <= yx_min), '0 <= yx_min' 24 | assert np.all(0 <= yx_max), '0 <= yx_max' 25 | assert np.all(yx_min < size), 'yx_min < size' 26 | assert np.all(yx_max < size), 'yx_max < size' 27 | 28 | 29 | def fix_coords(yx_min, yx_max, size): 30 | assert np.all(yx_min <= yx_max) 31 | assert yx_min.dtype == yx_max.dtype 32 | coord_min = np.zeros([2], dtype=yx_min.dtype) 33 | coord_max = np.array(size, dtype=yx_min.dtype) - 1 34 | yx_min = np.minimum(np.maximum(yx_min, coord_min), coord_max) 35 | yx_max = np.minimum(np.maximum(yx_max, coord_min), coord_max) 36 | return yx_min, yx_max 37 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import pickle 20 | import random 21 | import copy 22 | 23 | import numpy as np 24 | import torch.utils.data 25 | import sklearn.preprocessing 26 | import cv2 27 | 28 | 29 | def padding_labels(data, dim, labels='yx_min, yx_max, cls, difficult'.split(', ')): 30 | """ 31 | Padding labels into the same dimension (to form a batch). 32 | :author 申瑞珉 (Ruimin Shen) 33 | :param data: A dict contains the labels to be padded. 34 | :param dim: The target dimension. 35 | :param labels: The list of label names. 36 | :return: The padded label dict. 37 | """ 38 | pad = dim - len(data[labels[0]]) 39 | for key in labels: 40 | label = data[key] 41 | data[key] = np.pad(label, [(0, pad)] + [(0, 0)] * (len(label.shape) - 1), 'constant') 42 | return data 43 | 44 | 45 | def load_pickles(paths): 46 | data = [] 47 | for path in paths: 48 | with open(path, 'rb') as f: 49 | data += pickle.load(f) 50 | return data 51 | 52 | 53 | class Dataset(torch.utils.data.Dataset): 54 | def __init__(self, data, transform=lambda data: data, one_hot=None, shuffle=False, dir=None): 55 | """ 56 | Load the cached data (.pkl) into memory. 57 | :author 申瑞珉 (Ruimin Shen) 58 | :param data: A list contains the data samples (dict). 59 | :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict. 60 | :param one_hot: If a int value (total number of classes) is given, the class label (key "cls") will be generated in a one-hot format. 61 | :param shuffle: Shuffle the loaded dataset. 62 | :param dir: The directory to store the exception data. 63 | """ 64 | self.data = data 65 | if shuffle: 66 | random.shuffle(self.data) 67 | self.transform = transform 68 | self.one_hot = None if one_hot is None else sklearn.preprocessing.OneHotEncoder(one_hot, dtype=np.float32) 69 | self.dir = dir 70 | 71 | def __len__(self): 72 | return len(self.data) 73 | 74 | def __getitem__(self, index): 75 | data = copy.deepcopy(self.data[index]) 76 | try: 77 | image = cv2.imread(data['path']) 78 | data['image'] = image 79 | data['size'] = np.array(image.shape[:2]) 80 | data = self.transform(data) 81 | if self.one_hot is not None: 82 | data['cls'] = self.one_hot.fit_transform(np.expand_dims(data['cls'], -1)).todense() 83 | except: 84 | if self.dir is not None: 85 | os.makedirs(self.dir, exist_ok=True) 86 | name = self.__module__ + '.' + type(self).__name__ 87 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f: 88 | pickle.dump(data, f) 89 | raise 90 | return data 91 | 92 | 93 | class Collate(object): 94 | def __init__(self, resize, sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None): 95 | """ 96 | Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch. 97 | :author 申瑞珉 (Ruimin Shen) 98 | :param resize: A function to resize the image and labels. 99 | :param sizes: The image sizes to be randomly choosed. 100 | :param maintain: How many times a size to be maintained. 101 | :param transform_image: A function to transform the resized image. 102 | :param transform_tensor: A function to standardize a image into a tensor. 103 | :param dir: The directory to store the exception data. 104 | """ 105 | self.resize = resize 106 | self.sizes = sizes 107 | assert maintain > 0 108 | self.maintain = maintain 109 | self._maintain = maintain 110 | self.transform_image = transform_image 111 | self.transform_tensor = transform_tensor 112 | self.dir = dir 113 | 114 | def __call__(self, batch): 115 | height, width = self.next_size() 116 | dim = max(len(data['cls']) for data in batch) 117 | _batch = [] 118 | for data in batch: 119 | try: 120 | data = self.resize(data, height, width) 121 | data['image'] = self.transform_image(data['image']) 122 | data = padding_labels(data, dim) 123 | if self.transform_tensor is not None: 124 | data['tensor'] = self.transform_tensor(data['image']) 125 | _batch.append(data) 126 | except: 127 | if self.dir is not None: 128 | os.makedirs(self.dir, exist_ok=True) 129 | name = self.__module__ + '.' + type(self).__name__ 130 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f: 131 | pickle.dump(data, f) 132 | raise 133 | return torch.utils.data.dataloader.default_collate(_batch) 134 | 135 | def next_size(self): 136 | if self._maintain < self.maintain: 137 | self._maintain += 1 138 | else: 139 | self.size = random.choice(self.sizes) 140 | self._maintain = 0 141 | return self.size 142 | -------------------------------------------------------------------------------- /utils/iou/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/yolo2-pytorch/146ebdf581677964caa31c69cccd0c86230fb216/utils/iou/__init__.py -------------------------------------------------------------------------------- /utils/iou/numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import unittest 19 | 20 | import numpy as np 21 | 22 | 23 | def iou(yx_min1, yx_max1, yx_min2, yx_max2, min=None): 24 | """ 25 | Calculates the IoU of two bounding boxes. 26 | :author 申瑞珉 (Ruimin Shen) 27 | :param yx_min1: The top left coordinates (y, x) of the first bounding boxe. 28 | :param yx_max1: The bottom right coordinates (y, x) of the first bounding boxe. 29 | :param yx_min2: The top left coordinates (y, x) of the second bounding boxe. 30 | :param yx_max2: The bottom right coordinates (y, x) of the second bounding boxe. 31 | :return: The IoU. 32 | """ 33 | assert np.all(yx_min1 <= yx_max1) 34 | assert np.all(yx_min2 <= yx_max2) 35 | if min is None: 36 | min = np.finfo(yx_min1.dtype).eps 37 | yx_min = np.maximum(yx_min1, yx_min2) 38 | yx_max = np.minimum(yx_max1, yx_max2) 39 | intersect_area = np.multiply.reduce(np.maximum(0.0, yx_max - yx_min)) 40 | area1 = np.multiply.reduce(yx_max1 - yx_min1) 41 | area2 = np.multiply.reduce(yx_max2 - yx_min2) 42 | assert np.all(intersect_area >= 0) 43 | assert np.all(intersect_area <= area1) 44 | assert np.all(intersect_area <= area2) 45 | union_area = np.maximum(area1 + area2 - intersect_area, min) 46 | return intersect_area / union_area 47 | 48 | 49 | def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): 50 | """ 51 | Calculates the intersection area of two lists of bounding boxes. 52 | :author 申瑞珉 (Ruimin Shen) 53 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 54 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 55 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 56 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 57 | :return: The matrix (size [N1, N2]) of the intersection area. 58 | """ 59 | ymin1, xmin1 = yx_min1.T 60 | ymax1, xmax1 = yx_max1.T 61 | ymin2, xmin2 = yx_min2.T 62 | ymax2, xmax2 = yx_max2.T 63 | ymin1, xmin1, ymax1, xmax1, ymin2, xmin2, ymax2, xmax2 = (np.expand_dims(a, -1) for a in (ymin1, xmin1, ymax1, xmax1, ymin2, xmin2, ymax2, xmax2)) 64 | max_ymin = np.maximum(ymin1, np.transpose(ymin2)) 65 | min_ymax = np.minimum(ymax1, np.transpose(ymax2)) 66 | height = np.maximum(0.0, min_ymax - max_ymin) 67 | max_xmin = np.maximum(xmin1, np.transpose(xmin2)) 68 | min_xmax = np.minimum(xmax1, np.transpose(xmax2)) 69 | width = np.maximum(0.0, min_xmax - max_xmin) 70 | return height * width 71 | 72 | 73 | def iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=None): 74 | """ 75 | Calculates the IoU of two lists of bounding boxes. 76 | :author 申瑞珉 (Ruimin Shen) 77 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 78 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 79 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 80 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 81 | :return: The matrix (size [N1, N2]) of the IoU. 82 | """ 83 | if min is None: 84 | min = np.finfo(yx_min1.dtype).eps 85 | assert np.all(yx_min1 <= yx_max1) 86 | assert np.all(yx_min2 <= yx_max2) 87 | intersect_area = intersection_area(yx_min1, yx_max1, yx_min2, yx_max2) 88 | area1 = np.expand_dims(np.multiply.reduce(yx_max1 - yx_min1, -1), 1) 89 | area2 = np.expand_dims(np.multiply.reduce(yx_max2 - yx_min2, -1), 0) 90 | assert np.all(intersect_area >= 0) 91 | assert np.all(intersect_area <= area1) 92 | assert np.all(intersect_area <= area2) 93 | union_area = np.maximum(area1 + area2 - intersect_area, min) 94 | return intersect_area / union_area 95 | 96 | 97 | class TestIouMatrix(unittest.TestCase): 98 | def _test(self, bbox1, bbox2, ans, dtype=np.float32): 99 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans)) 100 | yx_min1, yx_max1 = np.split(bbox1, 2, -1) 101 | yx_min2, yx_max2 = np.split(bbox2, 2, -1) 102 | assert np.all(yx_min1 <= yx_max1) 103 | assert np.all(yx_min2 <= yx_max2) 104 | assert np.all(ans >= 0) 105 | matrix = iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2) 106 | np.testing.assert_almost_equal(matrix, ans) 107 | 108 | def test0(self): 109 | bbox1 = [ 110 | (1, 1, 2, 2), 111 | ] 112 | bbox2 = [ 113 | (0, 0, 1, 1), 114 | (0, 1, 1, 2), 115 | (0, 2, 1, 3), 116 | (1, 0, 2, 1), 117 | (2, 0, 3, 1), 118 | (1, 2, 2, 3), 119 | (2, 1, 3, 2), 120 | (2, 2, 3, 3), 121 | ] 122 | ans = [ 123 | [0] * len(bbox2), 124 | ] 125 | self._test(bbox1, bbox2, ans) 126 | 127 | def test1(self): 128 | bbox1 = [ 129 | (1, 1, 3, 3), 130 | (0, 0, 4, 4), 131 | ] 132 | bbox2 = [ 133 | (0, 0, 2, 2), 134 | (2, 0, 4, 2), 135 | (0, 2, 2, 4), 136 | (2, 2, 4, 4), 137 | ] 138 | ans = [ 139 | [1 / (4 + 4 - 1)] * len(bbox2), 140 | [4 / 16] * len(bbox2), 141 | ] 142 | self._test(bbox1, bbox2, ans) 143 | -------------------------------------------------------------------------------- /utils/iou/torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import unittest 19 | 20 | import numpy as np 21 | import torch 22 | 23 | 24 | def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): 25 | """ 26 | Calculates the intersection area of two lists of bounding boxes. 27 | :author 申瑞珉 (Ruimin Shen) 28 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 29 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 30 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 31 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 32 | :return: The matrix (size [N1, N2]) of the intersection area. 33 | """ 34 | ymin1, xmin1 = torch.split(yx_min1, 1, -1) 35 | ymax1, xmax1 = torch.split(yx_max1, 1, -1) 36 | ymin2, xmin2 = torch.split(yx_min2, 1, -1) 37 | ymax2, xmax2 = torch.split(yx_max2, 1, -1) 38 | max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug 39 | min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug 40 | height = torch.clamp(min_ymax - max_ymin, min=0) 41 | max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug 42 | min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug 43 | width = torch.clamp(min_xmax - max_xmin, min=0) 44 | return height * width 45 | 46 | 47 | def iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)): 48 | """ 49 | Calculates the IoU of two lists of bounding boxes. 50 | :author 申瑞珉 (Ruimin Shen) 51 | :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 52 | :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. 53 | :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 54 | :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. 55 | :return: The matrix (size [N1, N2]) of the IoU. 56 | """ 57 | intersect_area = intersection_area(yx_min1, yx_max1, yx_min2, yx_max2) 58 | area1 = torch.prod(yx_max1 - yx_min1, -1).unsqueeze(-1) 59 | area2 = torch.prod(yx_max2 - yx_min2, -1).unsqueeze(-2) 60 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min) 61 | return intersect_area / union_area 62 | 63 | 64 | class TestIouMatrix(unittest.TestCase): 65 | def _test(self, bbox1, bbox2, ans, dtype=np.float32): 66 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans)) 67 | yx_min1, yx_max1 = np.split(bbox1, 2, -1) 68 | yx_min2, yx_max2 = np.split(bbox2, 2, -1) 69 | assert np.all(yx_min1 <= yx_max1) 70 | assert np.all(yx_min2 <= yx_max2) 71 | assert np.all(ans >= 0) 72 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1)) 73 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2)) 74 | if torch.cuda.is_available(): 75 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2)) 76 | matrix = iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy() 77 | np.testing.assert_almost_equal(matrix, ans) 78 | 79 | def test0(self): 80 | bbox1 = [ 81 | (1, 1, 2, 2), 82 | ] 83 | bbox2 = [ 84 | (0, 0, 1, 1), 85 | (0, 1, 1, 2), 86 | (0, 2, 1, 3), 87 | (1, 0, 2, 1), 88 | (2, 0, 3, 1), 89 | (1, 2, 2, 3), 90 | (2, 1, 3, 2), 91 | (2, 2, 3, 3), 92 | ] 93 | ans = [ 94 | [0] * len(bbox2), 95 | ] 96 | self._test(bbox1, bbox2, ans) 97 | 98 | def test1(self): 99 | bbox1 = [ 100 | (1, 1, 3, 3), 101 | (0, 0, 4, 4), 102 | ] 103 | bbox2 = [ 104 | (0, 0, 2, 2), 105 | (2, 0, 4, 2), 106 | (0, 2, 2, 4), 107 | (2, 2, 4, 4), 108 | ] 109 | ans = [ 110 | [1 / (4 + 4 - 1)] * len(bbox2), 111 | [4 / 16] * len(bbox2), 112 | ] 113 | self._test(bbox1, bbox2, ans) 114 | 115 | 116 | def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): 117 | """ 118 | Calculates the intersection area of two lists of bounding boxes for N independent batches. 119 | :author 申瑞珉 (Ruimin Shen) 120 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. 121 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. 122 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. 123 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. 124 | :return: The matrics (size [N, N1, N2]) of the intersection area. 125 | """ 126 | ymin1, xmin1 = torch.split(yx_min1, 1, -1) 127 | ymax1, xmax1 = torch.split(yx_max1, 1, -1) 128 | ymin2, xmin2 = torch.split(yx_min2, 1, -1) 129 | ymax2, xmax2 = torch.split(yx_max2, 1, -1) 130 | max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug 131 | min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug 132 | height = torch.clamp(min_ymax - max_ymin, min=0) 133 | max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug 134 | min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug 135 | width = torch.clamp(min_xmax - max_xmin, min=0) 136 | return height * width 137 | 138 | 139 | def batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)): 140 | """ 141 | Calculates the IoU of two lists of bounding boxes for N independent batches. 142 | :author 申瑞珉 (Ruimin Shen) 143 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. 144 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. 145 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. 146 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. 147 | :return: The matrics (size [N, N1, N2]) of the IoU. 148 | """ 149 | intersect_area = batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2) 150 | area1 = torch.prod(yx_max1 - yx_min1, -1).unsqueeze(-1) 151 | area2 = torch.prod(yx_max2 - yx_min2, -1).unsqueeze(-2) 152 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min) 153 | return intersect_area / union_area 154 | 155 | 156 | class TestBatchIouMatrix(unittest.TestCase): 157 | def _test(self, bbox1, bbox2, ans, batch_size=2, dtype=np.float32): 158 | bbox1, bbox2, ans = (np.expand_dims(np.array(a, dtype), 0) for a in (bbox1, bbox2, ans)) 159 | if batch_size > 1: 160 | bbox1, bbox2, ans = (np.tile(a, (batch_size, 1, 1)) for a in (bbox1, bbox2, ans)) 161 | for b in range(batch_size): 162 | indices1 = np.random.permutation(bbox1.shape[1]) 163 | indices2 = np.random.permutation(bbox2.shape[1]) 164 | bbox1[b] = bbox1[b][indices1] 165 | bbox2[b] = bbox2[b][indices2] 166 | ans[b] = ans[b][indices1][:, indices2] 167 | yx_min1, yx_max1 = np.split(bbox1, 2, -1) 168 | yx_min2, yx_max2 = np.split(bbox2, 2, -1) 169 | assert np.all(yx_min1 <= yx_max1) 170 | assert np.all(yx_min2 <= yx_max2) 171 | assert np.all(ans >= 0) 172 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1)) 173 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2)) 174 | if torch.cuda.is_available(): 175 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2)) 176 | matrix = batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy() 177 | np.testing.assert_almost_equal(matrix, ans) 178 | 179 | def test0(self): 180 | bbox1 = [ 181 | (1, 1, 2, 2), 182 | ] 183 | bbox2 = [ 184 | (0, 0, 1, 1), 185 | (0, 1, 1, 2), 186 | (0, 2, 1, 3), 187 | (1, 0, 2, 1), 188 | (2, 0, 3, 1), 189 | (1, 2, 2, 3), 190 | (2, 1, 3, 2), 191 | (2, 2, 3, 3), 192 | ] 193 | ans = [ 194 | [0] * len(bbox2), 195 | ] 196 | self._test(bbox1, bbox2, ans) 197 | 198 | def test1(self): 199 | bbox1 = [ 200 | (1, 1, 3, 3), 201 | (0, 0, 4, 4), 202 | ] 203 | bbox2 = [ 204 | (0, 0, 2, 2), 205 | (2, 0, 4, 2), 206 | (0, 2, 2, 4), 207 | (2, 2, 4, 4), 208 | ] 209 | ans = [ 210 | [1 / (4 + 4 - 1)] * len(bbox2), 211 | [4 / 16] * len(bbox2), 212 | ] 213 | self._test(bbox1, bbox2, ans) 214 | 215 | 216 | def batch_iou_pair(yx_min1, yx_max1, yx_min2, yx_max2, min=float(np.finfo(np.float32).eps)): 217 | """ 218 | Pairwisely calculates the IoU of two lists (at the same size M) of bounding boxes for N independent batches. 219 | :author 申瑞珉 (Ruimin Shen) 220 | :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, M, 2]) of bounding boxes. 221 | :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, M, 2]) of bounding boxes. 222 | :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, M, 2]) of bounding boxes. 223 | :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, M, 2]) of bounding boxes. 224 | :return: The lists (size [N, M]) of the IoU. 225 | """ 226 | yx_min = torch.max(yx_min1, yx_min2) 227 | yx_max = torch.min(yx_max1, yx_max2) 228 | size = torch.clamp(yx_max - yx_min, min=0) 229 | intersect_area = torch.prod(size, -1) 230 | area1 = torch.prod(yx_max1 - yx_min1, -1) 231 | area2 = torch.prod(yx_max2 - yx_min2, -1) 232 | union_area = torch.clamp(area1 + area2 - intersect_area, min=min) 233 | return intersect_area / union_area 234 | 235 | 236 | class TestBatchIouPair(unittest.TestCase): 237 | def _test(self, bbox1, bbox2, ans, dtype=np.float32): 238 | bbox1, bbox2, ans = (np.array(a, dtype) for a in (bbox1, bbox2, ans)) 239 | batch_size = bbox1.shape[0] 240 | cells = bbox2.shape[0] 241 | bbox1 = np.tile(np.reshape(bbox1, [-1, 1, 4]), [1, cells, 1]) 242 | bbox2 = np.tile(np.reshape(bbox2, [1, -1, 4]), [batch_size, 1, 1]) 243 | yx_min1, yx_max1 = np.split(bbox1, 2, -1) 244 | yx_min2, yx_max2 = np.split(bbox2, 2, -1) 245 | assert np.all(yx_min1 <= yx_max1) 246 | assert np.all(yx_min2 <= yx_max2) 247 | assert np.all(ans >= 0) 248 | yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1)) 249 | yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2)) 250 | if torch.cuda.is_available(): 251 | yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2)) 252 | iou = batch_iou_pair(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy() 253 | np.testing.assert_almost_equal(iou, ans) 254 | 255 | def test0(self): 256 | bbox1 = [ 257 | (1, 1, 2, 2), 258 | ] 259 | bbox2 = [ 260 | (0, 0, 1, 1), 261 | (0, 1, 1, 2), 262 | (0, 2, 1, 3), 263 | (1, 0, 2, 1), 264 | (2, 0, 3, 1), 265 | (1, 2, 2, 3), 266 | (2, 1, 3, 2), 267 | (2, 2, 3, 3), 268 | ] 269 | ans = [ 270 | [0] * len(bbox2), 271 | ] 272 | self._test(bbox1, bbox2, ans) 273 | 274 | def test1(self): 275 | bbox1 = [ 276 | (1, 1, 3, 3), 277 | (0, 0, 4, 4), 278 | ] 279 | bbox2 = [ 280 | (0, 0, 2, 2), 281 | (2, 0, 4, 2), 282 | (0, 2, 2, 4), 283 | (2, 2, 4, 4), 284 | ] 285 | ans = [ 286 | [1 / (4 + 4 - 1)] * len(bbox2), 287 | [4 / 16] * len(bbox2), 288 | ] 289 | self._test(bbox1, bbox2, ans) 290 | 291 | 292 | if __name__ == '__main__': 293 | unittest.main() 294 | -------------------------------------------------------------------------------- /utils/postprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import torch 19 | 20 | import utils.iou.torch 21 | 22 | 23 | def nms(score, yx_min, yx_max, overlap=0.5, limit=200): 24 | """ 25 | Filtering the overlapping (IoU > overlap threshold) bounding boxes according to the score (in descending order). 26 | :author 申瑞珉 (Ruimin Shen) 27 | :param score: The scores of the list (size [N]) of bounding boxes. 28 | :param yx_min: The top left coordinates (y, x) of the list (size [N, 2]) of bounding boxes. 29 | :param yx_max: The bottom right coordinates (y, x) of the list (size [N, 2]) of bounding boxes. 30 | :param overlap: The IoU threshold. 31 | :param limit: Limits the number of results. 32 | :return: The indices of the selected bounding boxes. 33 | """ 34 | keep = [] 35 | if score.numel() == 0: 36 | return keep 37 | _, index = score.sort(descending=True) 38 | index = index[:limit] 39 | while index.numel() > 0: 40 | i = index[0] 41 | keep.append(i) 42 | if index.size(0) == 1: 43 | break 44 | index = index[1:] 45 | yx_min1, yx_max1 = (torch.unsqueeze(t[i], 0) for t in (yx_min, yx_max)) 46 | yx_min2, yx_max2 = (torch.index_select(t, 0, index) for t in (yx_min, yx_max)) 47 | iou = utils.iou.torch.iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2)[0] 48 | index = index[iou <= overlap] 49 | return keep 50 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import time 20 | import operator 21 | import logging 22 | 23 | import torch 24 | 25 | 26 | class Timer(object): 27 | def __init__(self, max, first=True): 28 | """ 29 | A simple function object to determine time event. 30 | :author 申瑞珉 (Ruimin Shen) 31 | :param max: Number of seconds to trigger a time event. 32 | :param first: Should a time event to be triggered at the first time. 33 | """ 34 | self.start = 0 if first else time.time() 35 | self.max = max 36 | 37 | def __call__(self): 38 | """ 39 | Return a boolean value to indicate if the time event is occurred. 40 | :author 申瑞珉 (Ruimin Shen) 41 | """ 42 | t = time.time() 43 | elapsed = t - self.start 44 | if elapsed > self.max: 45 | self.start = t 46 | return True 47 | else: 48 | return False 49 | 50 | 51 | def load_model(model_dir, step=None, ext='.pth', ext_epoch='.epoch', logger=logging.info): 52 | """ 53 | Load the latest checkpoint in a model directory. 54 | :author 申瑞珉 (Ruimin Shen) 55 | :param model_dir: The directory to store the model checkpoint files. 56 | :param step: If a integer value is given, the corresponding checkpoint will be loaded. Otherwise, the latest checkpoint (with the largest step value) will be loaded. 57 | :param ext: The extension of the model file. 58 | :param ext_epoch: The extension of the epoch file. 59 | :return: 60 | """ 61 | if step is None: 62 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(model_dir)) if n.isdigit() and e == ext] 63 | step, name = max(steps, key=operator.itemgetter(0)) 64 | else: 65 | name = str(step) 66 | prefix = os.path.join(model_dir, name) 67 | if logger is not None: 68 | logger('load %s.*' % prefix) 69 | try: 70 | with open(prefix + ext_epoch, 'r') as f: 71 | epoch = int(f.read()) 72 | except (FileNotFoundError, ValueError): 73 | epoch = None 74 | path = prefix + ext 75 | assert os.path.exists(path), path 76 | return path, step, epoch 77 | 78 | 79 | class Saver(object): 80 | def __init__(self, model_dir, keep, ext='.pth', ext_epoch='.epoch', logger=logging.info): 81 | """ 82 | Manage several latest checkpoints (with the largest step values) in a model directory. 83 | :author 申瑞珉 (Ruimin Shen) 84 | :param model_dir: The directory to store the model checkpoint files. 85 | :param keep: How many latest checkpoints to be maintained. 86 | :param ext: The extension of the model file. 87 | :param ext_epoch: The extension of the epoch file. 88 | """ 89 | self.model_dir = model_dir 90 | self.keep = keep 91 | self.ext = ext 92 | self.ext_epoch = ext_epoch 93 | self.logger = (lambda s: s) if logger is None else logger 94 | 95 | def __call__(self, obj, step, epoch=None): 96 | """ 97 | Save the PyTorch module. 98 | :author 申瑞珉 (Ruimin Shen) 99 | :param obj: The PyTorch module to be saved. 100 | :param step: Current step. 101 | :param epoch: Current epoch. 102 | """ 103 | os.makedirs(self.model_dir, exist_ok=True) 104 | prefix = os.path.join(self.model_dir, str(step)) 105 | torch.save(obj, prefix + self.ext) 106 | if epoch is not None: 107 | with open(prefix + self.ext_epoch, 'w') as f: 108 | f.write(str(epoch)) 109 | self.logger('model saved into %s.*' % prefix) 110 | self.tidy() 111 | return prefix 112 | 113 | def tidy(self): 114 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(self.model_dir)) if n.isdigit() and e == self.ext] 115 | if len(steps) > self.keep: 116 | steps = sorted(steps, key=operator.itemgetter(0)) 117 | remove = steps[:len(steps) - self.keep] 118 | for _, n in remove: 119 | path = os.path.join(self.model_dir, n) 120 | os.remove(path + self.ext) 121 | path_epoch = path + self.ext_epoch 122 | try: 123 | os.remove(path_epoch) 124 | except FileNotFoundError: 125 | self.logger(path_epoch + ' not found') 126 | logging.debug('tidy ' + path) 127 | 128 | 129 | def load_sizes(config): 130 | sizes = [s.split(',') for s in config.get('data', 'sizes').split()] 131 | return [(int(height), int(width)) for height, width in sizes] 132 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import itertools 20 | import inspect 21 | 22 | import numpy as np 23 | import torch 24 | import matplotlib 25 | import matplotlib.cm 26 | import matplotlib.colors 27 | import matplotlib.pyplot as plt 28 | import humanize 29 | import graphviz 30 | import cv2 31 | 32 | import utils 33 | 34 | 35 | class DrawBBox(object): 36 | def __init__(self, category, colors=[], thickness=3, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1): 37 | self.category = category 38 | if colors: 39 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 40 | else: 41 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']] 42 | self.thickness = thickness 43 | self.line_type = line_type 44 | self.shift = shift 45 | self.font_face = font_face 46 | self.font_scale = font_scale 47 | 48 | def __call__(self, image, yx_min, yx_max, cls=None, colors=None, debug=False): 49 | colors = self.colors if colors is None else [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 50 | if cls is None: 51 | cls = [None] * len(yx_min) 52 | for color, (ymin, xmin), (ymax, xmax), cls in zip(itertools.cycle(colors), yx_min, yx_max, cls): 53 | try: 54 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness=self.thickness, lineType=self.line_type, shift=self.shift) 55 | if cls is not None: 56 | cv2.putText(image, self.category[cls], (xmin, ymin), self.font_face, self.font_scale, color=color, thickness=self.thickness) 57 | except OverflowError as e: 58 | logging.warning(e, (xmin, ymin), (xmax, ymax)) 59 | if debug: 60 | cv2.imshow('', image) 61 | cv2.waitKey(0) 62 | return image 63 | 64 | 65 | class DrawFeature(object): 66 | def __init__(self, alpha=0.5, cmap=None): 67 | self.alpha = alpha 68 | self.cm = matplotlib.cm.get_cmap(cmap) 69 | 70 | def __call__(self, image, feature, debug=False): 71 | _feature = (feature * self.cm.N).astype(np.int) 72 | heatmap = self.cm(_feature)[:, :, :3] * 255 73 | heatmap = cv2.resize(heatmap, image.shape[1::-1], interpolation=cv2.INTER_NEAREST) 74 | canvas = (image * (1 - self.alpha) + heatmap * self.alpha).astype(np.uint8) 75 | if debug: 76 | cv2.imshow('max=%f, sum=%f' % (np.max(feature), np.sum(feature)), canvas) 77 | cv2.waitKey(0) 78 | return canvas 79 | 80 | 81 | class Graph(object): 82 | def __init__(self, config, state_dict, cmap=None): 83 | self.dot = graphviz.Digraph(node_attr=dict(config.items('digraph_node_attr')), graph_attr=dict(config.items('digraph_graph_attr'))) 84 | self.dot.format = config.get('graph', 'format') 85 | self.state_dict = state_dict 86 | self.var_name = {t._cdata: k for k, t in state_dict.items()} 87 | self.seen = set() 88 | self.index = 0 89 | self.drawn = set() 90 | self.cm = matplotlib.cm.get_cmap(cmap) 91 | self.metric = eval(config.get('graph', 'metric')) 92 | metrics = [self.metric(t) for t in state_dict.values()] 93 | self.minmax = [min(metrics), max(metrics)] 94 | 95 | def __call__(self, node): 96 | if node not in self.seen: 97 | self.traverse_next(node) 98 | self.traverse_tensor(node) 99 | self.seen.add(node) 100 | self.index += 1 101 | 102 | def traverse_next(self, node): 103 | if hasattr(node, 'next_functions'): 104 | for n, _ in node.next_functions: 105 | if n is not None: 106 | self.__call__(n) 107 | self._draw_node_edge(node, n) 108 | self._draw_node(node) 109 | 110 | def traverse_tensor(self, node): 111 | tensors = [t for name, t in inspect.getmembers(node) if torch.is_tensor(t)] 112 | if hasattr(node, 'saved_tensors'): 113 | tensors += node.saved_tensors 114 | for tensor in tensors: 115 | name = self.var_name[tensor._cdata] 116 | self.drawn.add(name) 117 | self._draw_tensor(node, tensor) 118 | 119 | def _draw_node(self, node): 120 | if hasattr(node, 'variable'): 121 | tensor = node.variable.data 122 | name = self.var_name[tensor._cdata] 123 | label = '\n'.join(map(str, [ 124 | '%d: %s' % (self.index, name), 125 | list(tensor.size()), 126 | humanize.naturalsize(tensor.numpy().nbytes), 127 | ])) 128 | fillcolor, fontcolor = self._tensor_color(tensor) 129 | self.dot.node(str(id(node)), label, shape='note', fillcolor=fillcolor, fontcolor=fontcolor) 130 | self.drawn.add(name) 131 | else: 132 | self.dot.node(str(id(node)), '%d: %s' % (self.index, type(node).__name__), fillcolor='white') 133 | 134 | def _draw_node_edge(self, node, n): 135 | if hasattr(n, 'variable'): 136 | self.dot.edge(str(id(n)), str(id(node)), arrowhead='none', arrowtail='none') 137 | else: 138 | self.dot.edge(str(id(n)), str(id(node))) 139 | 140 | def _draw_tensor(self, node, tensor): 141 | name = self.var_name[tensor._cdata] 142 | label = '\n'.join(map(str, [ 143 | name, 144 | list(tensor.size()), 145 | humanize.naturalsize(tensor.numpy().nbytes), 146 | ])) 147 | fillcolor, fontcolor = self._tensor_color(tensor) 148 | self.dot.node(name, label, style='filled, rounded', fillcolor=fillcolor, fontcolor=fontcolor) 149 | self.dot.edge(name, str(id(node)), style='dashed', arrowhead='none', arrowtail='none') 150 | 151 | def _tensor_color(self, tensor): 152 | level = self._norm(self.metric(tensor)) 153 | fillcolor = self.cm(np.int(level * self.cm.N)) 154 | fontcolor = self.cm(self.cm.N if level < 0.5 else 0) 155 | return matplotlib.colors.to_hex(fillcolor), matplotlib.colors.to_hex(fontcolor) 156 | 157 | def _norm(self, metric): 158 | min, max = self.minmax 159 | assert min <= metric <= max, (metric, self.minmax) 160 | if min < max: 161 | return (metric - min) / (max - min) 162 | else: 163 | return metric 164 | -------------------------------------------------------------------------------- /variable_stat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import importlib 24 | import inspect 25 | import inflection 26 | import yaml 27 | 28 | import numpy as np 29 | import torch 30 | import humanize 31 | import xlsxwriter 32 | 33 | import utils 34 | import utils.train 35 | import utils.channel 36 | 37 | 38 | class Name(object): 39 | def __call__(self, name, variable): 40 | return name 41 | 42 | 43 | class Size(object): 44 | def __call__(self, name, variable): 45 | return 'x'.join(map(str, variable.size())) 46 | 47 | 48 | class Bytes(object): 49 | def __call__(self, name, variable): 50 | return variable.numpy().nbytes 51 | 52 | def format(self, workbook, worksheet, num, col): 53 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 54 | 55 | 56 | class BytesNatural(object): 57 | def __call__(self, name, variable): 58 | return humanize.naturalsize(variable.numpy().nbytes) 59 | 60 | 61 | class MeanDense(object): 62 | def __call__(self, name, variable): 63 | return np.mean(utils.channel.dense(variable)) 64 | 65 | def format(self, workbook, worksheet, num, col): 66 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 67 | 68 | 69 | class Rank(object): 70 | def __call__(self, name, variable): 71 | return len(variable.size()) 72 | 73 | def format(self, workbook, worksheet, num, col): 74 | worksheet.conditional_format(1, col, num, col, {'type': 'data_bar', 'bar_color': '#FFC7CE'}) 75 | 76 | 77 | def main(): 78 | args = make_args() 79 | config = configparser.ConfigParser() 80 | utils.load_config(config, args.config) 81 | for cmd in args.modify: 82 | utils.modify_config(config, cmd) 83 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 84 | logging.config.dictConfig(yaml.load(f)) 85 | model_dir = utils.get_model_dir(config) 86 | path, step, epoch = utils.train.load_model(model_dir) 87 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 88 | mapper = [(inflection.underscore(name), member()) for name, member in inspect.getmembers(importlib.machinery.SourceFileLoader('', __file__).load_module()) if inspect.isclass(member)] 89 | path = os.path.join(model_dir, os.path.basename(os.path.splitext(__file__)[0])) + '.xlsx' 90 | with xlsxwriter.Workbook(path, {'strings_to_urls': False, 'nan_inf_to_errors': True}) as workbook: 91 | worksheet = workbook.add_worksheet(args.worksheet) 92 | for j, (key, m) in enumerate(mapper): 93 | worksheet.write(0, j, key) 94 | for i, (name, variable) in enumerate(state_dict.items()): 95 | value = m(name, variable) 96 | worksheet.write(1 + i, j, value) 97 | if hasattr(m, 'format'): 98 | m.format(workbook, worksheet, i, j) 99 | worksheet.autofilter(0, 0, i, len(mapper) - 1) 100 | worksheet.freeze_panes(1, 0) 101 | logging.info(path) 102 | 103 | 104 | def make_args(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 107 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 108 | parser.add_argument('--logging', default='logging.yml', help='logging config') 109 | parser.add_argument('--worksheet', default='sheet') 110 | parser.add_argument('--nohead', action='store_true') 111 | return parser.parse_args() 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /video2image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import sys 20 | import argparse 21 | 22 | import pymediainfo 23 | import tqdm 24 | from contextlib import closing 25 | import videosequence 26 | 27 | 28 | def get_step(frames, video_track, **kwargs): 29 | if 'frames' in kwargs: 30 | step = len(frames) // kwargs['frames'] 31 | elif 'frames_per_sec' in kwargs > 0: 32 | frame_rate = float(video_track.frame_rate) 33 | step = int(frame_rate / kwargs['frames_per_sec']) 34 | assert step > 0 35 | return step 36 | 37 | 38 | def convert(video_file, image_prefix, **kwargs): 39 | media_info = pymediainfo.MediaInfo.parse(video_file) 40 | video_tracks = [track for track in media_info.tracks if track.track_type == 'Video'] 41 | if len(video_tracks) < 1: 42 | raise videosequence.VideoError() 43 | video_track = video_tracks[0] 44 | _rotation = float(video_track.rotation) 45 | rotation = int(_rotation) 46 | assert rotation - _rotation == 0 47 | with closing(videosequence.VideoSequence(video_file)) as frames: 48 | step = get_step(frames, video_track, **kwargs) 49 | _frames = frames[::step] 50 | for idx, frame in enumerate(tqdm.tqdm(_frames)): 51 | frame = frame.rotate(-rotation, expand=True) 52 | frame.save('%s_%04d.jpg' % (image_prefix, idx)) 53 | 54 | 55 | def main(): 56 | args = make_args() 57 | src = os.path.expanduser(os.path.expandvars(args.src)) 58 | dst = os.path.expanduser(os.path.expandvars(args.dst)) 59 | os.makedirs(dst, exist_ok=True) 60 | kwargs = {} 61 | if args.frames > 0: 62 | kwargs['frames'] = args.frames 63 | elif args.frames_per_sec > 0: 64 | kwargs['frames_per_sec'] = args.frames_per_sec 65 | exts = set() 66 | for dirpath, _, filenames in os.walk(src): 67 | for filename in filenames: 68 | ext = os.path.splitext(filename)[-1].lower() 69 | if ext in args.ext: 70 | path = os.path.join(dirpath, filename) 71 | print(path) 72 | name = os.path.relpath(path, src).replace(os.path.sep, args.replace) 73 | _path = os.path.join(dst, name) 74 | try: 75 | convert(path, _path, **kwargs) 76 | except videosequence.VideoError as e: 77 | sys.stderr.write(str(e) + '\n') 78 | else: 79 | exts.add(ext) 80 | print(exts) 81 | 82 | 83 | def make_args(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('src') 86 | parser.add_argument('dst') 87 | parser.add_argument('-e', '--ext', nargs='+', default=['.mp4', '.mov', '.m4v']) 88 | parser.add_argument('-r', '--replace', default='_', help='replace the path separator into the given character') 89 | parser.add_argument('-f', '--frames', default=0, type=int, help='total output frames in a video') 90 | parser.add_argument('--frames_per_sec', default=0, type=int, help='output frames in a second') 91 | return parser.parse_args() 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | --------------------------------------------------------------------------------