├── .gitignore ├── LICENSE.md ├── README.md ├── cache.py ├── cache ├── __init__.py ├── coco.py └── coco.tsv ├── config.ini ├── config ├── convert_caffe_torch │ └── original_person18_19.tsv ├── convert_tf_torch │ └── model.dnn.inception4.Inception4_down3_4 │ │ ├── Unet1.tsv │ │ └── Unet2.tsv ├── dataset │ ├── coco.tsv │ ├── coco │ │ └── cache.coco.cache │ ├── hand20.tsv │ ├── hand20 │ │ └── cache.hand_nyu.cache │ ├── hand21.tsv │ ├── hand21 │ │ └── cache.hand_nyu.cache │ ├── hand_nyu.tsv │ ├── hand_nyu │ │ └── cache.hand_nyu.cache │ ├── mpii.tsv │ ├── mpii.txt │ ├── mpii │ │ └── cache.mpii.cache │ ├── person13_12.tsv │ ├── person13_12.txt │ ├── person13_12 │ │ ├── cache.coco.cache │ │ └── cache.mpii.cache │ ├── person14_13.tsv │ ├── person14_13.txt │ ├── person14_13 │ │ └── cache.coco.cache │ ├── person18.tsv │ ├── person18.txt │ ├── person18 │ │ └── cache.coco.cache │ ├── person18_19.tsv │ ├── person18_19.txt │ └── person18_19 │ │ └── cache.coco.cache ├── inception_unet.ini ├── original_person18_19.ini └── summary │ └── histogram.txt ├── convert_caffe_torch.py ├── convert_onnx_caffe2.py ├── convert_tf_torch.py ├── convert_torch_onnx.py ├── demo_data.py ├── demo_keypoints.py ├── demo_label.py ├── donate_alipay.jpg ├── donate_mm.jpg ├── estimate.py ├── logging.yml ├── model ├── __init__.py ├── dnn │ ├── __init__.py │ ├── inception4.py │ ├── mobilenet.py │ ├── mobilenet2.py │ ├── resnet.py │ └── vgg.py └── stages │ ├── __init__.py │ ├── openpose.py │ └── unet.py ├── quick_start.sh ├── receptive_field_analyzer.py ├── requirements.txt ├── train.py ├── transform ├── __init__.py ├── augmentation.py ├── image.py └── resize │ ├── __init__.py │ ├── image.py │ └── label.py └── utils ├── __init__.py ├── cache.py ├── data.py ├── train.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .project 2 | .pydevproject 3 | .settings/ 4 | .idea/ 5 | .cache/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of the [OpenPose](https://arxiv.org/abs/1611.08050) 2 | 3 | The OpenPose is one of the most popular keypoint estimator, which uses two branches of feature map (is trained and enhanced via multiple stages) to estimate (via a [postprocess procedure](https://github.com/ruiminshen/pyopenpose)) the position of keypoints (via Gaussian heatmap) and the relationship between keypoints (called part affinity fields), respectively. 4 | This project adopts [PyTorch](http://pytorch.org/) as the developing framework to increase productivity, and utilize [ONNX](https://github.com/onnx/onnx) to convert models into [Caffe 2](https://caffe2.ai/) to benefit engineering deployment. 5 | If you are benefited from this project, a donation will be appreciated (via [PayPal](https://www.paypal.me/minimumshen), [微信支付](donate_mm.jpg) or [支付宝](donate_alipay.jpg)). 6 | 7 | ## Designs 8 | 9 | - Flexible configuration design. 10 | Program settings are configurable and can be modified (via **configure file overlaping** (-c/--config option) or **command editing** (-m/--modify option)) using command line argument. 11 | 12 | - Monitoring via [TensorBoard](https://github.com/tensorflow/tensorboard). 13 | Such as the loss values and the debugging images (such as IoU heatmap, ground truth and predict bounding boxes). 14 | 15 | - Parallel model training design. 16 | Different models are saved into different directories so that can be trained simultaneously. 17 | 18 | - Time-based output design. 19 | Running information (such as the model, the summaries (produced by TensorBoard), and the evaluation results) are saved periodically via a predefined time. 20 | 21 | - Checkpoint management. 22 | Several latest checkpoint files (.pth) are preserved in the model directory and the older ones are deleted. 23 | 24 | - NaN debug. 25 | When a NaN loss is detected, the running environment (data batch) and the model will be exported to analyze the reason. 26 | 27 | - Unified data cache design. 28 | Various dataset are converted into a unified data cache via a programmable (a series of Python lambda expressions, which means some points can be flexibly generated) configuration. 29 | Some plugins are already implemented. Such as [MS COCO](http://cocodataset.org/). 30 | 31 | - Arbitrarily replaceable model plugin design. 32 | The deep neural network (both the feature extraction network and the stage networks) can be easily replaced via configuration settings. 33 | Multiple models are already provided. Such as the oringal VGG like network, [Inception v4](https://arxiv.org/abs/1602.07261), [MobileNet v2](https://arxiv.org/abs/1801.04381) and [U-Net](https://arxiv.org/abs/1505.04597). 34 | 35 | - Extendable data preprocess plugin design. 36 | The original images (in different sizes) and labels are processed via a sequence of operations to form a training batch (images with the same size, and bounding boxes list are padded). 37 | Multiple preprocess plugins are already implemented. Such as 38 | augmentation operators to process images and labels (such as random rotate and random flip) simultaneously, 39 | operators to resize both images and labels into a fixed size in a batch (such as random crop), 40 | and operators to augment images without labels (such as random blur, random saturation and random brightness). 41 | 42 | ## Quick Start 43 | 44 | This project uses [Python 3](https://www.python.org/). To install the dependent libraries, make sure the [pyopenpose](https://github.com/ruiminshen/pyopenpose) is installed, and type the following command in a terminal. 45 | 46 | ``` 47 | sudo pip3 install -r requirements.txt 48 | ``` 49 | 50 | `quick_start.sh` contains the examples to perform detection and evaluation. Run this script. 51 | The COCO dataset is downloaded ([aria2](https://aria2.github.io/) is required) and cached, and the original pose model (18 parts and 19 limbs) is converted into PyTorch's format. 52 | If a webcam is present, the keypoint estimation demo will be shown. 53 | Finally, the training program is started. 54 | -------------------------------------------------------------------------------- /cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import importlib 24 | import pickle 25 | import random 26 | import shutil 27 | import yaml 28 | 29 | import utils 30 | 31 | 32 | def main(): 33 | args = make_args() 34 | config = configparser.ConfigParser() 35 | utils.load_config(config, args.config) 36 | for cmd in args.modify: 37 | utils.modify_config(config, cmd) 38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 39 | logging.config.dictConfig(yaml.load(f)) 40 | cache_dir = utils.get_cache_dir(config) 41 | os.makedirs(cache_dir, exist_ok=True) 42 | mappers, _ = utils.get_dataset_mappers(config) 43 | for phase in args.phase: 44 | path = os.path.join(cache_dir, phase) + '.pkl' 45 | logging.info('save cache file: ' + path) 46 | data = [] 47 | for dataset in mappers: 48 | logging.info('load %s dataset' % dataset) 49 | module, func = dataset.rsplit('.', 1) 50 | module = importlib.import_module(module) 51 | func = getattr(module, func) 52 | data += func(config, path, mappers[dataset]) 53 | if config.getboolean('cache', 'shuffle'): 54 | random.shuffle(data) 55 | with open(path, 'wb') as f: 56 | pickle.dump(data, f) 57 | logging.info('%s data are saved into %s' % (str(args.phase), cache_dir)) 58 | 59 | 60 | def make_args(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 63 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 64 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 65 | parser.add_argument('--logging', default='logging.yml', help='logging config') 66 | return parser.parse_args() 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/cache/__init__.py -------------------------------------------------------------------------------- /cache/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import logging 20 | import configparser 21 | 22 | import numpy as np 23 | import pandas as pd 24 | import tqdm 25 | import pycocotools.coco 26 | import pycocotools.mask 27 | from PIL import Image, ImageDraw 28 | 29 | import utils 30 | import utils.cache 31 | 32 | 33 | def draw_mask(segmentation, canvas, draw): 34 | pixels = canvas.load() 35 | if isinstance(segmentation, list): 36 | for polygon in segmentation: 37 | draw.polygon(polygon, fill=0) 38 | else: 39 | if isinstance(segmentation['counts'], list): 40 | rle = pycocotools.mask.frPyObjects([segmentation], canvas.size[1], canvas.size[0]) 41 | else: 42 | rle = [segmentation] 43 | m = np.squeeze(pycocotools.mask.decode(rle)) 44 | assert m.shape[:2] == canvas.size[::-1] 45 | for y, row in enumerate(m): 46 | for x, v in enumerate(row): 47 | if v: 48 | pixels[x, y] = 0 49 | 50 | 51 | def cache(config, path, mapper): 52 | name = __name__.split('.')[-1] 53 | cachedir = os.path.dirname(path) 54 | phase = os.path.splitext(os.path.basename(path))[0] 55 | phasedir = os.path.join(cachedir, phase) 56 | os.makedirs(phasedir, exist_ok=True) 57 | mask_ext = config.get('cache', 'mask_ext') 58 | data = [] 59 | for i, row in pd.read_csv(os.path.splitext(__file__)[0] + '.tsv', sep='\t').iterrows(): 60 | logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()]))) 61 | root = os.path.expanduser(os.path.expandvars(row['root'])) 62 | year = str(row['year']) 63 | suffix = phase + year 64 | path = os.path.join(root, 'annotations', 'person_keypoints_%s.json' % suffix) 65 | if not os.path.exists(path): 66 | logging.warning(path + ' not exists') 67 | continue 68 | coco_kp = pycocotools.coco.COCO(path) 69 | skeleton = np.array(coco_kp.loadCats(1)[0]['skeleton']) - 1 70 | np.savetxt(os.path.join(os.path.dirname(cachedir), name + '.tsv'), skeleton, fmt='%d', delimiter='\t') 71 | imgIds = coco_kp.getImgIds() 72 | folder = os.path.join(root, suffix) 73 | imgs = coco_kp.loadImgs(imgIds) 74 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(folder, img['file_name'])), imgs)) 75 | if len(imgs) > len(_imgs): 76 | logging.warning('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs))) 77 | for img in tqdm.tqdm(_imgs): 78 | # image 79 | path = os.path.join(folder, img['file_name']) 80 | width, height = img['width'], img['height'] 81 | try: 82 | if config.getboolean('cache', 'verify'): 83 | if not np.all(np.equal(utils.image_size(path), [width, height])): 84 | logging.error('failed to verify shape of image ' + path) 85 | continue 86 | except configparser.NoOptionError: 87 | pass 88 | # keypoints 89 | annIds = coco_kp.getAnnIds(imgIds=img['id'], iscrowd=None) 90 | anns = coco_kp.loadAnns(annIds) 91 | keypoints = [] 92 | bbox = [] 93 | keypath = os.path.join(phasedir, __name__.split('.')[-1] + year, os.path.relpath(os.path.splitext(path)[0], root)) 94 | os.makedirs(os.path.dirname(keypath), exist_ok=True) 95 | maskpath = keypath + '.mask' + mask_ext 96 | with Image.new('L', (width, height), 255) as canvas: 97 | draw = ImageDraw.Draw(canvas) 98 | for ann in anns: 99 | points = mapper(np.array(ann['keypoints']).reshape([-1, 3])) 100 | if np.any(points[:, 2] > 0): 101 | keypoints.append(points) 102 | bbox.append(ann['bbox']) 103 | else: 104 | draw_mask(ann['segmentation'], canvas, draw) 105 | if len(keypoints) <= 0: 106 | continue 107 | canvas.save(os.path.join(cachedir, maskpath)) 108 | keypoints = np.array(keypoints, dtype=np.float32) 109 | keypoints = keypoints[:, :, [1, 0, 2]] 110 | bbox = np.array(bbox, dtype=np.float32) 111 | yx_min = bbox[:, 1::-1] 112 | size = bbox[:, -1:1:-1] 113 | yx_max = yx_min + size 114 | try: 115 | if config.getboolean('cache', 'dump'): 116 | np.save(keypath + '.keypoints.npy', keypoints) 117 | np.save(keypath + '.yx_min.npy', yx_min) 118 | np.save(keypath + '.yx_max.npy', yx_max) 119 | except configparser.NoOptionError: 120 | pass 121 | data.append(dict( 122 | path=path, keypath=keypath, 123 | keypoints=keypoints, 124 | yx_min=yx_min, yx_max=yx_max, 125 | )) 126 | logging.warning('%d of %d images are saved' % (len(data), len(_imgs))) 127 | return data 128 | -------------------------------------------------------------------------------- /cache/coco.tsv: -------------------------------------------------------------------------------- 1 | root year 2 | ~/data/coco 2014 3 | ~/data/coco 2017 4 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [config] 2 | root = ~/model/openpose-pytorch 3 | 4 | [image] 5 | # 368 6 | # 344 7 | # 320 8 | size = 320 320 9 | 10 | [cache] 11 | name = cache 12 | ; config/dataset/person18_19 13 | ; config/dataset/person14_13 14 | dataset = config/dataset/person14_13 15 | shuffle = 1 16 | mask_ext = .jpg 17 | 18 | [model] 19 | name = model 20 | ; model.dnn.vgg.person18_19 21 | ; model.dnn.resnet.resnet18 22 | ; model.dnn.inception3.Inception3 23 | ; model.dnn.inception4.Inception4 24 | ; model.dnn.inception4.Inception4_down3_4 25 | ; model.dnn.mobilenet.MobileNet 26 | ; model.dnn.mobilenet2.MobileNet2 27 | ; model.dnn.densenet.densenet121 28 | dnn = model.dnn.mobilenet2.MobileNet2 29 | # model.stages.openpose.Stage0 model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage 30 | # model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a 31 | # model.stages.unet.Unet2Sqz3 model.stages.unet.Unet2Sqz3 32 | stages = model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a 33 | pretrained = 0 34 | 35 | [batch_norm] 36 | enable = 0 37 | gamma = 1 38 | beta = 1 39 | 40 | [inception4] 41 | pretrained = imagenet 42 | 43 | [data] 44 | workers = 3 45 | sizes = 320,320 46 | maintain = 10 47 | shuffle = 0 48 | # fixed rescale 49 | resize = fixed 50 | 51 | [transform] 52 | ; transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally 53 | augmentation = transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally 54 | resize_train = transform.resize.label.RandomCrop 55 | resize_eval = transform.resize.label.Resize 56 | resize_test = transform.resize.image.Resize 57 | ; transform.image.RandomBlur transform.image.BGR2HSV transform.image.RandomHue transform.image.RandomSaturation transform.image.RandomBrightness transform.image.HSV2RGB transform.image.RandomGamma 58 | image_train = transform.image.BGR2RGB 59 | image_test = transform.image.BGR2RGB 60 | ; torchvision.transforms.ToTensor transform.image.Normalize 61 | tensor = torchvision.transforms.ToTensor transform.image.Normalize 62 | normalize = 0.5 1 63 | 64 | [augmentation] 65 | random_rotate = -40 40 66 | random_flip_horizontally = 0.5 67 | random_crop = 1.5 2 68 | random_blur = 5 5 69 | random_hue = 0 25 70 | random_saturation = 0.5 1.5 71 | random_brightness = 0.5 1.5 72 | random_gamma = 0.9 1.5 73 | 74 | [label] 75 | sigma_parts = 7 76 | sigma_limbs = 1 77 | 78 | [train] 79 | ; lambda params, lr: torch.optim.SGD(params, lr, momentum=2) 80 | ; lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8) 81 | ; lambda params, lr: torch.optim.RMSprop(params, lr, alpha=0.99, eps=1e-8) 82 | optimizer = lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8) 83 | ; lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 84 | ; lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) 85 | scheduler = lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) 86 | phase = train val 87 | cross_entropy = 1 88 | clip_ = 5 89 | 90 | [save] 91 | secs = 600 92 | keep = 5 93 | 94 | [draw_points] 95 | colors = r w 96 | 97 | [summary] 98 | scalar = 10 99 | image = 60 100 | histogram_ = 60 101 | 102 | [summary_scalar] 103 | loss_hparam = 0 104 | 105 | [summary_image] 106 | limit = 2 107 | data_keypoints = 1 108 | data_parts = 1 109 | data_limbs = 1 110 | estimate = 1 111 | output = parts limbs 112 | stage = -1 113 | 114 | [summary_histogram] 115 | parameters = config/summary/histogram.txt 116 | 117 | [hparam] 118 | parts = 1 119 | limbs = 1 120 | 121 | [estimate] 122 | interpolation = cubic 123 | 124 | [nms] 125 | threshold = 0.05 126 | 127 | [integration] 128 | step = 5 129 | step_limits = 5 25 130 | min_score = 0.05 131 | min_count = 9 132 | 133 | [cluster] 134 | min_score = 0.4 135 | min_count = 3 136 | 137 | [eval] 138 | phase = test 139 | secs = 12 * 60 * 60 140 | first = 0 141 | iou = 0.5 142 | db = eval.json 143 | mapper = config/eval.py 144 | debug = 0 145 | sort = timestamp 146 | metric07 = 1 147 | 148 | [graph] 149 | metric = lambda t: np.mean(utils.dense(t)) 150 | format = svg 151 | 152 | [digraph_graph_attr] 153 | size = 12, 12 154 | 155 | [digraph_node_attr] 156 | style = filled 157 | shape = box 158 | align = left 159 | fontsize = 12 160 | ranksep = 0.1 161 | height = 0.2 162 | -------------------------------------------------------------------------------- /config/convert_caffe_torch/original_person18_19.tsv: -------------------------------------------------------------------------------- 1 | dnn.features.0.weight conv1_1 lambda blobs: blobs[0] 2 | dnn.features.0.bias conv1_1 lambda blobs: blobs[1] 3 | 4 | dnn.features.2.weight conv1_2 lambda blobs: blobs[0] 5 | dnn.features.2.bias conv1_2 lambda blobs: blobs[1] 6 | 7 | dnn.features.5.weight conv2_1 lambda blobs: blobs[0] 8 | dnn.features.5.bias conv2_1 lambda blobs: blobs[1] 9 | 10 | dnn.features.7.weight conv2_2 lambda blobs: blobs[0] 11 | dnn.features.7.bias conv2_2 lambda blobs: blobs[1] 12 | 13 | dnn.features.10.weight conv3_1 lambda blobs: blobs[0] 14 | dnn.features.10.bias conv3_1 lambda blobs: blobs[1] 15 | 16 | dnn.features.12.weight conv3_2 lambda blobs: blobs[0] 17 | dnn.features.12.bias conv3_2 lambda blobs: blobs[1] 18 | 19 | dnn.features.14.weight conv3_3 lambda blobs: blobs[0] 20 | dnn.features.14.bias conv3_3 lambda blobs: blobs[1] 21 | 22 | dnn.features.16.weight conv3_4 lambda blobs: blobs[0] 23 | dnn.features.16.bias conv3_4 lambda blobs: blobs[1] 24 | 25 | dnn.features.19.weight conv4_1 lambda blobs: blobs[0] 26 | dnn.features.19.bias conv4_1 lambda blobs: blobs[1] 27 | 28 | dnn.features.21.weight conv4_2 lambda blobs: blobs[0] 29 | dnn.features.21.bias conv4_2 lambda blobs: blobs[1] 30 | 31 | dnn.features.23.weight conv4_3_CPM lambda blobs: blobs[0] 32 | dnn.features.23.bias conv4_3_CPM lambda blobs: blobs[1] 33 | 34 | dnn.features.25.weight conv4_4_CPM lambda blobs: blobs[0] 35 | dnn.features.25.bias conv4_4_CPM lambda blobs: blobs[1] 36 | 37 | stages.0.limbs.0.conv.weight conv5_1_CPM_L1 lambda blobs: blobs[0] 38 | stages.0.limbs.0.conv.bias conv5_1_CPM_L1 lambda blobs: blobs[1] 39 | 40 | stages.0.limbs.1.conv.weight conv5_2_CPM_L1 lambda blobs: blobs[0] 41 | stages.0.limbs.1.conv.bias conv5_2_CPM_L1 lambda blobs: blobs[1] 42 | 43 | stages.0.limbs.2.conv.weight conv5_3_CPM_L1 lambda blobs: blobs[0] 44 | stages.0.limbs.2.conv.bias conv5_3_CPM_L1 lambda blobs: blobs[1] 45 | 46 | stages.0.limbs.3.conv.weight conv5_4_CPM_L1 lambda blobs: blobs[0] 47 | stages.0.limbs.3.conv.bias conv5_4_CPM_L1 lambda blobs: blobs[1] 48 | 49 | stages.0.limbs.4.conv.weight conv5_5_CPM_L1 lambda blobs: blobs[0] 50 | stages.0.limbs.4.conv.bias conv5_5_CPM_L1 lambda blobs: blobs[1] 51 | 52 | stages.0.parts.0.conv.weight conv5_1_CPM_L2 lambda blobs: blobs[0] 53 | stages.0.parts.0.conv.bias conv5_1_CPM_L2 lambda blobs: blobs[1] 54 | 55 | stages.0.parts.1.conv.weight conv5_2_CPM_L2 lambda blobs: blobs[0] 56 | stages.0.parts.1.conv.bias conv5_2_CPM_L2 lambda blobs: blobs[1] 57 | 58 | stages.0.parts.2.conv.weight conv5_3_CPM_L2 lambda blobs: blobs[0] 59 | stages.0.parts.2.conv.bias conv5_3_CPM_L2 lambda blobs: blobs[1] 60 | 61 | stages.0.parts.3.conv.weight conv5_4_CPM_L2 lambda blobs: blobs[0] 62 | stages.0.parts.3.conv.bias conv5_4_CPM_L2 lambda blobs: blobs[1] 63 | 64 | stages.0.parts.4.conv.weight conv5_5_CPM_L2 lambda blobs: blobs[0] 65 | stages.0.parts.4.conv.bias conv5_5_CPM_L2 lambda blobs: blobs[1] 66 | 67 | stages.1.limbs.0.conv.weight Mconv1_stage2_L1 lambda blobs: blobs[0] 68 | stages.1.limbs.0.conv.bias Mconv1_stage2_L1 lambda blobs: blobs[1] 69 | 70 | stages.1.limbs.1.conv.weight Mconv2_stage2_L1 lambda blobs: blobs[0] 71 | stages.1.limbs.1.conv.bias Mconv2_stage2_L1 lambda blobs: blobs[1] 72 | 73 | stages.1.limbs.2.conv.weight Mconv3_stage2_L1 lambda blobs: blobs[0] 74 | stages.1.limbs.2.conv.bias Mconv3_stage2_L1 lambda blobs: blobs[1] 75 | 76 | stages.1.limbs.3.conv.weight Mconv4_stage2_L1 lambda blobs: blobs[0] 77 | stages.1.limbs.3.conv.bias Mconv4_stage2_L1 lambda blobs: blobs[1] 78 | 79 | stages.1.limbs.4.conv.weight Mconv5_stage2_L1 lambda blobs: blobs[0] 80 | stages.1.limbs.4.conv.bias Mconv5_stage2_L1 lambda blobs: blobs[1] 81 | 82 | stages.1.limbs.5.conv.weight Mconv6_stage2_L1 lambda blobs: blobs[0] 83 | stages.1.limbs.5.conv.bias Mconv6_stage2_L1 lambda blobs: blobs[1] 84 | 85 | stages.1.limbs.6.conv.weight Mconv7_stage2_L1 lambda blobs: blobs[0] 86 | stages.1.limbs.6.conv.bias Mconv7_stage2_L1 lambda blobs: blobs[1] 87 | 88 | stages.1.parts.0.conv.weight Mconv1_stage2_L2 lambda blobs: blobs[0] 89 | stages.1.parts.0.conv.bias Mconv1_stage2_L2 lambda blobs: blobs[1] 90 | 91 | stages.1.parts.1.conv.weight Mconv2_stage2_L2 lambda blobs: blobs[0] 92 | stages.1.parts.1.conv.bias Mconv2_stage2_L2 lambda blobs: blobs[1] 93 | 94 | stages.1.parts.2.conv.weight Mconv3_stage2_L2 lambda blobs: blobs[0] 95 | stages.1.parts.2.conv.bias Mconv3_stage2_L2 lambda blobs: blobs[1] 96 | 97 | stages.1.parts.3.conv.weight Mconv4_stage2_L2 lambda blobs: blobs[0] 98 | stages.1.parts.3.conv.bias Mconv4_stage2_L2 lambda blobs: blobs[1] 99 | 100 | stages.1.parts.4.conv.weight Mconv5_stage2_L2 lambda blobs: blobs[0] 101 | stages.1.parts.4.conv.bias Mconv5_stage2_L2 lambda blobs: blobs[1] 102 | 103 | stages.1.parts.5.conv.weight Mconv6_stage2_L2 lambda blobs: blobs[0] 104 | stages.1.parts.5.conv.bias Mconv6_stage2_L2 lambda blobs: blobs[1] 105 | 106 | stages.1.parts.6.conv.weight Mconv7_stage2_L2 lambda blobs: blobs[0] 107 | stages.1.parts.6.conv.bias Mconv7_stage2_L2 lambda blobs: blobs[1] 108 | 109 | stages.2.limbs.0.conv.weight Mconv1_stage3_L1 lambda blobs: blobs[0] 110 | stages.2.limbs.0.conv.bias Mconv1_stage3_L1 lambda blobs: blobs[1] 111 | 112 | stages.2.limbs.1.conv.weight Mconv2_stage3_L1 lambda blobs: blobs[0] 113 | stages.2.limbs.1.conv.bias Mconv2_stage3_L1 lambda blobs: blobs[1] 114 | 115 | stages.2.limbs.2.conv.weight Mconv3_stage3_L1 lambda blobs: blobs[0] 116 | stages.2.limbs.2.conv.bias Mconv3_stage3_L1 lambda blobs: blobs[1] 117 | 118 | stages.2.limbs.3.conv.weight Mconv4_stage3_L1 lambda blobs: blobs[0] 119 | stages.2.limbs.3.conv.bias Mconv4_stage3_L1 lambda blobs: blobs[1] 120 | 121 | stages.2.limbs.4.conv.weight Mconv5_stage3_L1 lambda blobs: blobs[0] 122 | stages.2.limbs.4.conv.bias Mconv5_stage3_L1 lambda blobs: blobs[1] 123 | 124 | stages.2.limbs.5.conv.weight Mconv6_stage3_L1 lambda blobs: blobs[0] 125 | stages.2.limbs.5.conv.bias Mconv6_stage3_L1 lambda blobs: blobs[1] 126 | 127 | stages.2.limbs.6.conv.weight Mconv7_stage3_L1 lambda blobs: blobs[0] 128 | stages.2.limbs.6.conv.bias Mconv7_stage3_L1 lambda blobs: blobs[1] 129 | 130 | stages.2.parts.0.conv.weight Mconv1_stage3_L2 lambda blobs: blobs[0] 131 | stages.2.parts.0.conv.bias Mconv1_stage3_L2 lambda blobs: blobs[1] 132 | 133 | stages.2.parts.1.conv.weight Mconv2_stage3_L2 lambda blobs: blobs[0] 134 | stages.2.parts.1.conv.bias Mconv2_stage3_L2 lambda blobs: blobs[1] 135 | 136 | stages.2.parts.2.conv.weight Mconv3_stage3_L2 lambda blobs: blobs[0] 137 | stages.2.parts.2.conv.bias Mconv3_stage3_L2 lambda blobs: blobs[1] 138 | 139 | stages.2.parts.3.conv.weight Mconv4_stage3_L2 lambda blobs: blobs[0] 140 | stages.2.parts.3.conv.bias Mconv4_stage3_L2 lambda blobs: blobs[1] 141 | 142 | stages.2.parts.4.conv.weight Mconv5_stage3_L2 lambda blobs: blobs[0] 143 | stages.2.parts.4.conv.bias Mconv5_stage3_L2 lambda blobs: blobs[1] 144 | 145 | stages.2.parts.5.conv.weight Mconv6_stage3_L2 lambda blobs: blobs[0] 146 | stages.2.parts.5.conv.bias Mconv6_stage3_L2 lambda blobs: blobs[1] 147 | 148 | stages.2.parts.6.conv.weight Mconv7_stage3_L2 lambda blobs: blobs[0] 149 | stages.2.parts.6.conv.bias Mconv7_stage3_L2 lambda blobs: blobs[1] 150 | 151 | stages.3.limbs.0.conv.weight Mconv1_stage4_L1 lambda blobs: blobs[0] 152 | stages.3.limbs.0.conv.bias Mconv1_stage4_L1 lambda blobs: blobs[1] 153 | 154 | stages.3.limbs.1.conv.weight Mconv2_stage4_L1 lambda blobs: blobs[0] 155 | stages.3.limbs.1.conv.bias Mconv2_stage4_L1 lambda blobs: blobs[1] 156 | 157 | stages.3.limbs.2.conv.weight Mconv3_stage4_L1 lambda blobs: blobs[0] 158 | stages.3.limbs.2.conv.bias Mconv3_stage4_L1 lambda blobs: blobs[1] 159 | 160 | stages.3.limbs.3.conv.weight Mconv4_stage4_L1 lambda blobs: blobs[0] 161 | stages.3.limbs.3.conv.bias Mconv4_stage4_L1 lambda blobs: blobs[1] 162 | 163 | stages.3.limbs.4.conv.weight Mconv5_stage4_L1 lambda blobs: blobs[0] 164 | stages.3.limbs.4.conv.bias Mconv5_stage4_L1 lambda blobs: blobs[1] 165 | 166 | stages.3.limbs.5.conv.weight Mconv6_stage4_L1 lambda blobs: blobs[0] 167 | stages.3.limbs.5.conv.bias Mconv6_stage4_L1 lambda blobs: blobs[1] 168 | 169 | stages.3.limbs.6.conv.weight Mconv7_stage4_L1 lambda blobs: blobs[0] 170 | stages.3.limbs.6.conv.bias Mconv7_stage4_L1 lambda blobs: blobs[1] 171 | 172 | stages.3.parts.0.conv.weight Mconv1_stage4_L2 lambda blobs: blobs[0] 173 | stages.3.parts.0.conv.bias Mconv1_stage4_L2 lambda blobs: blobs[1] 174 | 175 | stages.3.parts.1.conv.weight Mconv2_stage4_L2 lambda blobs: blobs[0] 176 | stages.3.parts.1.conv.bias Mconv2_stage4_L2 lambda blobs: blobs[1] 177 | 178 | stages.3.parts.2.conv.weight Mconv3_stage4_L2 lambda blobs: blobs[0] 179 | stages.3.parts.2.conv.bias Mconv3_stage4_L2 lambda blobs: blobs[1] 180 | 181 | stages.3.parts.3.conv.weight Mconv4_stage4_L2 lambda blobs: blobs[0] 182 | stages.3.parts.3.conv.bias Mconv4_stage4_L2 lambda blobs: blobs[1] 183 | 184 | stages.3.parts.4.conv.weight Mconv5_stage4_L2 lambda blobs: blobs[0] 185 | stages.3.parts.4.conv.bias Mconv5_stage4_L2 lambda blobs: blobs[1] 186 | 187 | stages.3.parts.5.conv.weight Mconv6_stage4_L2 lambda blobs: blobs[0] 188 | stages.3.parts.5.conv.bias Mconv6_stage4_L2 lambda blobs: blobs[1] 189 | 190 | stages.3.parts.6.conv.weight Mconv7_stage4_L2 lambda blobs: blobs[0] 191 | stages.3.parts.6.conv.bias Mconv7_stage4_L2 lambda blobs: blobs[1] 192 | 193 | stages.4.limbs.0.conv.weight Mconv1_stage5_L1 lambda blobs: blobs[0] 194 | stages.4.limbs.0.conv.bias Mconv1_stage5_L1 lambda blobs: blobs[1] 195 | 196 | stages.4.limbs.1.conv.weight Mconv2_stage5_L1 lambda blobs: blobs[0] 197 | stages.4.limbs.1.conv.bias Mconv2_stage5_L1 lambda blobs: blobs[1] 198 | 199 | stages.4.limbs.2.conv.weight Mconv3_stage5_L1 lambda blobs: blobs[0] 200 | stages.4.limbs.2.conv.bias Mconv3_stage5_L1 lambda blobs: blobs[1] 201 | 202 | stages.4.limbs.3.conv.weight Mconv4_stage5_L1 lambda blobs: blobs[0] 203 | stages.4.limbs.3.conv.bias Mconv4_stage5_L1 lambda blobs: blobs[1] 204 | 205 | stages.4.limbs.4.conv.weight Mconv5_stage5_L1 lambda blobs: blobs[0] 206 | stages.4.limbs.4.conv.bias Mconv5_stage5_L1 lambda blobs: blobs[1] 207 | 208 | stages.4.limbs.5.conv.weight Mconv6_stage5_L1 lambda blobs: blobs[0] 209 | stages.4.limbs.5.conv.bias Mconv6_stage5_L1 lambda blobs: blobs[1] 210 | 211 | stages.4.limbs.6.conv.weight Mconv7_stage5_L1 lambda blobs: blobs[0] 212 | stages.4.limbs.6.conv.bias Mconv7_stage5_L1 lambda blobs: blobs[1] 213 | 214 | stages.4.parts.0.conv.weight Mconv1_stage5_L2 lambda blobs: blobs[0] 215 | stages.4.parts.0.conv.bias Mconv1_stage5_L2 lambda blobs: blobs[1] 216 | 217 | stages.4.parts.1.conv.weight Mconv2_stage5_L2 lambda blobs: blobs[0] 218 | stages.4.parts.1.conv.bias Mconv2_stage5_L2 lambda blobs: blobs[1] 219 | 220 | stages.4.parts.2.conv.weight Mconv3_stage5_L2 lambda blobs: blobs[0] 221 | stages.4.parts.2.conv.bias Mconv3_stage5_L2 lambda blobs: blobs[1] 222 | 223 | stages.4.parts.3.conv.weight Mconv4_stage5_L2 lambda blobs: blobs[0] 224 | stages.4.parts.3.conv.bias Mconv4_stage5_L2 lambda blobs: blobs[1] 225 | 226 | stages.4.parts.4.conv.weight Mconv5_stage5_L2 lambda blobs: blobs[0] 227 | stages.4.parts.4.conv.bias Mconv5_stage5_L2 lambda blobs: blobs[1] 228 | 229 | stages.4.parts.5.conv.weight Mconv6_stage5_L2 lambda blobs: blobs[0] 230 | stages.4.parts.5.conv.bias Mconv6_stage5_L2 lambda blobs: blobs[1] 231 | 232 | stages.4.parts.6.conv.weight Mconv7_stage5_L2 lambda blobs: blobs[0] 233 | stages.4.parts.6.conv.bias Mconv7_stage5_L2 lambda blobs: blobs[1] 234 | 235 | stages.5.limbs.0.conv.weight Mconv1_stage6_L1 lambda blobs: blobs[0] 236 | stages.5.limbs.0.conv.bias Mconv1_stage6_L1 lambda blobs: blobs[1] 237 | 238 | stages.5.limbs.1.conv.weight Mconv2_stage6_L1 lambda blobs: blobs[0] 239 | stages.5.limbs.1.conv.bias Mconv2_stage6_L1 lambda blobs: blobs[1] 240 | 241 | stages.5.limbs.2.conv.weight Mconv3_stage6_L1 lambda blobs: blobs[0] 242 | stages.5.limbs.2.conv.bias Mconv3_stage6_L1 lambda blobs: blobs[1] 243 | 244 | stages.5.limbs.3.conv.weight Mconv4_stage6_L1 lambda blobs: blobs[0] 245 | stages.5.limbs.3.conv.bias Mconv4_stage6_L1 lambda blobs: blobs[1] 246 | 247 | stages.5.limbs.4.conv.weight Mconv5_stage6_L1 lambda blobs: blobs[0] 248 | stages.5.limbs.4.conv.bias Mconv5_stage6_L1 lambda blobs: blobs[1] 249 | 250 | stages.5.limbs.5.conv.weight Mconv6_stage6_L1 lambda blobs: blobs[0] 251 | stages.5.limbs.5.conv.bias Mconv6_stage6_L1 lambda blobs: blobs[1] 252 | 253 | stages.5.limbs.6.conv.weight Mconv7_stage6_L1 lambda blobs: blobs[0] 254 | stages.5.limbs.6.conv.bias Mconv7_stage6_L1 lambda blobs: blobs[1] 255 | 256 | stages.5.parts.0.conv.weight Mconv1_stage6_L2 lambda blobs: blobs[0] 257 | stages.5.parts.0.conv.bias Mconv1_stage6_L2 lambda blobs: blobs[1] 258 | 259 | stages.5.parts.1.conv.weight Mconv2_stage6_L2 lambda blobs: blobs[0] 260 | stages.5.parts.1.conv.bias Mconv2_stage6_L2 lambda blobs: blobs[1] 261 | 262 | stages.5.parts.2.conv.weight Mconv3_stage6_L2 lambda blobs: blobs[0] 263 | stages.5.parts.2.conv.bias Mconv3_stage6_L2 lambda blobs: blobs[1] 264 | 265 | stages.5.parts.3.conv.weight Mconv4_stage6_L2 lambda blobs: blobs[0] 266 | stages.5.parts.3.conv.bias Mconv4_stage6_L2 lambda blobs: blobs[1] 267 | 268 | stages.5.parts.4.conv.weight Mconv5_stage6_L2 lambda blobs: blobs[0] 269 | stages.5.parts.4.conv.bias Mconv5_stage6_L2 lambda blobs: blobs[1] 270 | 271 | stages.5.parts.5.conv.weight Mconv6_stage6_L2 lambda blobs: blobs[0] 272 | stages.5.parts.5.conv.bias Mconv6_stage6_L2 lambda blobs: blobs[1] 273 | 274 | stages.5.parts.6.conv.weight Mconv7_stage6_L2 lambda blobs: blobs[0] 275 | stages.5.parts.6.conv.bias Mconv7_stage6_L2 lambda blobs: blobs[1] 276 | -------------------------------------------------------------------------------- /config/dataset/coco.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 0 2 3 | 1 3 4 | 2 4 5 | 0 5 6 | 0 6 7 | 5 7 8 | 6 8 9 | 7 9 10 | 8 10 11 | 0 11 12 | 0 12 13 | 11 13 14 | 12 14 15 | 13 15 16 | 14 16 17 | -------------------------------------------------------------------------------- /config/dataset/coco/cache.coco.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[0] 2 | lambda parts: parts[1] 3 | lambda parts: parts[2] 4 | lambda parts: parts[3] 5 | lambda parts: parts[4] 6 | lambda parts: parts[5] 7 | lambda parts: parts[6] 8 | lambda parts: parts[7] 9 | lambda parts: parts[8] 10 | lambda parts: parts[9] 11 | lambda parts: parts[10] 12 | lambda parts: parts[11] 13 | lambda parts: parts[12] 14 | lambda parts: parts[13] 15 | lambda parts: parts[14] 16 | lambda parts: parts[15] 17 | lambda parts: parts[16] 18 | -------------------------------------------------------------------------------- /config/dataset/hand20.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 1 2 3 | 2 3 4 | 0 4 5 | 4 5 6 | 5 6 7 | 6 7 8 | 0 8 9 | 8 9 10 | 9 10 11 | 10 11 12 | 0 12 13 | 12 13 14 | 13 14 15 | 14 15 16 | 0 16 17 | 16 17 18 | 17 18 19 | 18 19 -------------------------------------------------------------------------------- /config/dataset/hand20/cache.hand_nyu.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[29] 2 | lambda parts: parts[26] 3 | lambda parts: parts[25] 4 | lambda parts: parts[24] 5 | lambda parts: parts[22] 6 | lambda parts: parts[21] 7 | lambda parts: np.append((parts[20][:2] + parts[19][:2]) / 2, 1) if parts[20][2] > 0 and parts[19][2] > 0 else [0, 0, 0] 8 | lambda parts: parts[18] 9 | lambda parts: parts[16] 10 | lambda parts: parts[15] 11 | lambda parts: np.append((parts[14][:2] + parts[13][:2]) / 2, 1) if parts[14][2] > 0 and parts[13][2] > 0 else [0, 0, 0] 12 | lambda parts: parts[12] 13 | lambda parts: parts[10] 14 | lambda parts: parts[9] 15 | lambda parts: np.append((parts[8][:2] + parts[7][:2]) / 2, 1) if parts[8][2] > 0 and parts[7][2] > 0 else [0, 0, 0] 16 | lambda parts: parts[6] 17 | lambda parts: parts[4] 18 | lambda parts: parts[2] 19 | lambda parts: parts[1] 20 | lambda parts: parts[0] 21 | -------------------------------------------------------------------------------- /config/dataset/hand21.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 1 2 3 | 2 3 4 | 3 4 5 | 0 5 6 | 5 6 7 | 6 7 8 | 7 8 9 | 0 9 10 | 9 10 11 | 10 11 12 | 11 12 13 | 0 13 14 | 13 14 15 | 14 15 16 | 15 16 17 | 0 17 18 | 17 18 19 | 18 19 20 | 19 20 -------------------------------------------------------------------------------- /config/dataset/hand21/cache.hand_nyu.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[29] 2 | lambda parts: parts[28] 3 | lambda parts: parts[26] 4 | lambda parts: parts[25] 5 | lambda parts: parts[24] 6 | lambda parts: parts[22] 7 | lambda parts: parts[21] 8 | lambda parts: np.append((parts[20][:2] + parts[19][:2]) / 2, 1) if parts[20][2] > 0 and parts[19][2] > 0 else [0, 0, 0] 9 | lambda parts: parts[18] 10 | lambda parts: parts[16] 11 | lambda parts: parts[15] 12 | lambda parts: np.append((parts[14][:2] + parts[13][:2]) / 2, 1) if parts[14][2] > 0 and parts[13][2] > 0 else [0, 0, 0] 13 | lambda parts: parts[12] 14 | lambda parts: parts[10] 15 | lambda parts: parts[9] 16 | lambda parts: np.append((parts[8][:2] + parts[7][:2]) / 2, 1) if parts[8][2] > 0 and parts[7][2] > 0 else [0, 0, 0] 17 | lambda parts: parts[6] 18 | lambda parts: parts[4] 19 | lambda parts: parts[2] 20 | lambda parts: parts[1] 21 | lambda parts: parts[0] 22 | -------------------------------------------------------------------------------- /config/dataset/hand_nyu.tsv: -------------------------------------------------------------------------------- 1 | 29 34 2 | 34 33 3 | 33 5 4 | 5 4 5 | 4 3 6 | 3 2 7 | 2 1 8 | 1 0 9 | 34 32 10 | 32 11 11 | 11 10 12 | 10 9 13 | 9 8 14 | 8 7 15 | 7 6 16 | 32 17 17 | 17 16 18 | 16 15 19 | 15 14 20 | 14 13 21 | 13 12 22 | 34 23 23 | 23 22 24 | 22 21 25 | 21 20 26 | 20 19 27 | 19 18 28 | 29 28 29 | 28 27 30 | 27 26 31 | 26 25 32 | 25 24 33 | 29 30 34 | 29 31 35 | 29 35 36 | -------------------------------------------------------------------------------- /config/dataset/hand_nyu/cache.hand_nyu.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[0] 2 | lambda parts: parts[1] 3 | lambda parts: parts[2] 4 | lambda parts: parts[3] 5 | lambda parts: parts[4] 6 | lambda parts: parts[5] 7 | lambda parts: parts[6] 8 | lambda parts: parts[7] 9 | lambda parts: parts[8] 10 | lambda parts: parts[9] 11 | lambda parts: parts[10] 12 | lambda parts: parts[11] 13 | lambda parts: parts[12] 14 | lambda parts: parts[13] 15 | lambda parts: parts[14] 16 | lambda parts: parts[15] 17 | lambda parts: parts[16] 18 | lambda parts: parts[17] 19 | lambda parts: parts[18] 20 | lambda parts: parts[19] 21 | lambda parts: parts[20] 22 | lambda parts: parts[21] 23 | lambda parts: parts[22] 24 | lambda parts: parts[23] 25 | lambda parts: parts[24] 26 | lambda parts: parts[25] 27 | lambda parts: parts[26] 28 | lambda parts: parts[27] 29 | lambda parts: parts[28] 30 | lambda parts: parts[29] 31 | lambda parts: parts[30] 32 | lambda parts: parts[31] 33 | lambda parts: parts[32] 34 | lambda parts: parts[33] 35 | lambda parts: parts[34] 36 | lambda parts: parts[35] 37 | -------------------------------------------------------------------------------- /config/dataset/mpii.tsv: -------------------------------------------------------------------------------- 1 | 9 8 2 | 8 7 3 | 7 12 4 | 7 13 5 | 12 11 6 | 13 14 7 | 11 10 8 | 14 15 9 | 7 2 10 | 7 3 11 | 2 1 12 | 3 4 13 | 1 0 14 | 4 5 15 | 2 6 16 | 3 6 17 | -------------------------------------------------------------------------------- /config/dataset/mpii.txt: -------------------------------------------------------------------------------- 1 | 5 2 | 4 3 | 3 4 | 2 5 | 1 6 | 0 7 | 8 | 9 | 10 | 11 | 15 12 | 14 13 | 13 14 | 12 15 | 11 16 | 10 17 | -------------------------------------------------------------------------------- /config/dataset/mpii/cache.mpii.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[0] 2 | lambda parts: parts[1] 3 | lambda parts: parts[2] 4 | lambda parts: parts[3] 5 | lambda parts: parts[4] 6 | lambda parts: parts[5] 7 | lambda parts: parts[6] 8 | lambda parts: parts[7] 9 | lambda parts: parts[8] 10 | lambda parts: parts[9] 11 | lambda parts: parts[10] 12 | lambda parts: parts[11] 13 | lambda parts: parts[12] 14 | lambda parts: parts[13] 15 | lambda parts: parts[14] 16 | lambda parts: parts[15] 17 | -------------------------------------------------------------------------------- /config/dataset/person13_12.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 0 2 3 | 1 3 4 | 2 4 5 | 3 5 6 | 4 6 7 | 0 7 8 | 0 8 9 | 7 9 10 | 8 10 11 | 9 11 12 | 10 12 13 | -------------------------------------------------------------------------------- /config/dataset/person13_12.txt: -------------------------------------------------------------------------------- 1 | 2 | 2 3 | 1 4 | 4 5 | 3 6 | 6 7 | 5 8 | 8 9 | 7 10 | 10 11 | 9 12 | 12 13 | 11 14 | -------------------------------------------------------------------------------- /config/dataset/person13_12/cache.coco.cache: -------------------------------------------------------------------------------- 1 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0] 2 | lambda parts: parts[6] 3 | lambda parts: parts[5] 4 | lambda parts: parts[8] 5 | lambda parts: parts[7] 6 | lambda parts: parts[10] 7 | lambda parts: parts[9] 8 | lambda parts: parts[12] 9 | lambda parts: parts[11] 10 | lambda parts: parts[14] 11 | lambda parts: parts[13] 12 | lambda parts: parts[16] 13 | lambda parts: parts[15] 14 | -------------------------------------------------------------------------------- /config/dataset/person13_12/cache.mpii.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[7] 2 | lambda parts: parts[12] 3 | lambda parts: parts[13] 4 | lambda parts: parts[11] 5 | lambda parts: parts[14] 6 | lambda parts: parts[10] 7 | lambda parts: parts[15] 8 | lambda parts: parts[2] 9 | lambda parts: parts[3] 10 | lambda parts: parts[1] 11 | lambda parts: parts[4] 12 | lambda parts: parts[0] 13 | lambda parts: parts[5] 14 | -------------------------------------------------------------------------------- /config/dataset/person14_13.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 0 2 3 | 1 3 4 | 2 4 5 | 3 5 6 | 4 6 7 | 0 7 8 | 0 8 9 | 7 9 10 | 8 10 11 | 9 11 12 | 10 12 13 | 0 13 14 | -------------------------------------------------------------------------------- /config/dataset/person14_13.txt: -------------------------------------------------------------------------------- 1 | 2 | 2 3 | 1 4 | 4 5 | 3 6 | 6 7 | 5 8 | 8 9 | 7 10 | 10 11 | 9 12 | 12 13 | 11 14 | 15 | -------------------------------------------------------------------------------- /config/dataset/person14_13/cache.coco.cache: -------------------------------------------------------------------------------- 1 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0] 2 | lambda parts: parts[6] 3 | lambda parts: parts[5] 4 | lambda parts: parts[8] 5 | lambda parts: parts[7] 6 | lambda parts: parts[10] 7 | lambda parts: parts[9] 8 | lambda parts: parts[12] 9 | lambda parts: parts[11] 10 | lambda parts: parts[14] 11 | lambda parts: parts[13] 12 | lambda parts: parts[16] 13 | lambda parts: parts[15] 14 | lambda parts: parts[0] 15 | -------------------------------------------------------------------------------- /config/dataset/person18.tsv: -------------------------------------------------------------------------------- 1 | 0 1 2 | 0 2 3 | 1 2 4 | 1 3 5 | 2 4 6 | 0 7 7 | 7 5 8 | 7 6 9 | 5 8 10 | 6 9 11 | 8 10 12 | 9 11 13 | 7 12 14 | 7 13 15 | 12 14 16 | 13 15 17 | 14 16 18 | 15 17 19 | -------------------------------------------------------------------------------- /config/dataset/person18.txt: -------------------------------------------------------------------------------- 1 | 2 | 2 3 | 1 4 | 4 5 | 3 6 | 6 7 | 5 8 | 9 | 9 10 | 8 11 | 11 12 | 10 13 | 13 14 | 12 15 | 15 16 | 14 17 | 17 18 | 16 19 | -------------------------------------------------------------------------------- /config/dataset/person18/cache.coco.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[0] 2 | lambda parts: parts[1] 3 | lambda parts: parts[2] 4 | lambda parts: parts[3] 5 | lambda parts: parts[4] 6 | lambda parts: parts[5] 7 | lambda parts: parts[6] 8 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0] 9 | lambda parts: parts[7] 10 | lambda parts: parts[8] 11 | lambda parts: parts[9] 12 | lambda parts: parts[10] 13 | lambda parts: parts[11] 14 | lambda parts: parts[12] 15 | lambda parts: parts[13] 16 | lambda parts: parts[14] 17 | lambda parts: parts[15] 18 | lambda parts: parts[16] 19 | -------------------------------------------------------------------------------- /config/dataset/person18_19.tsv: -------------------------------------------------------------------------------- 1 | 1 8 2 | 8 9 3 | 9 10 4 | 1 11 5 | 11 12 6 | 12 13 7 | 1 2 8 | 2 3 9 | 3 4 10 | 2 16 11 | 1 5 12 | 5 6 13 | 6 7 14 | 5 17 15 | 1 0 16 | 0 14 17 | 0 15 18 | 14 16 19 | 15 17 20 | -------------------------------------------------------------------------------- /config/dataset/person18_19.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 4 | 4 5 | 3 6 | 2 7 | 3 8 | 3 9 | 11 10 | 12 11 | 13 12 | 8 13 | 9 14 | 10 15 | 15 16 | 14 17 | 17 18 | 16 19 | -------------------------------------------------------------------------------- /config/dataset/person18_19/cache.coco.cache: -------------------------------------------------------------------------------- 1 | lambda parts: parts[0] 2 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0] 3 | lambda parts: parts[6] 4 | lambda parts: parts[8] 5 | lambda parts: parts[10] 6 | lambda parts: parts[5] 7 | lambda parts: parts[7] 8 | lambda parts: parts[9] 9 | lambda parts: parts[12] 10 | lambda parts: parts[14] 11 | lambda parts: parts[16] 12 | lambda parts: parts[11] 13 | lambda parts: parts[13] 14 | lambda parts: parts[15] 15 | lambda parts: parts[2] 16 | lambda parts: parts[1] 17 | lambda parts: parts[4] 18 | lambda parts: parts[3] 19 | -------------------------------------------------------------------------------- /config/inception_unet.ini: -------------------------------------------------------------------------------- 1 | [image] 2 | size = 344 344 3 | 4 | [cache] 5 | dataset = config/dataset/person14_13 6 | 7 | [model] 8 | dnn = model.dnn.inception4.Inception4_down3_4 9 | stages = model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a 10 | 11 | [data] 12 | sizes = 344,344 13 | 14 | [nms] 15 | threshold = 0.05 16 | 17 | [integration] 18 | step = 5 19 | step_limits = 5 25 20 | min_score = 0.05 21 | min_count = 9 22 | 23 | [cluster] 24 | min_score = 0.4 25 | min_count = 3 26 | -------------------------------------------------------------------------------- /config/original_person18_19.ini: -------------------------------------------------------------------------------- 1 | [image] 2 | size = 368 368 3 | 4 | [cache] 5 | dataset = config/dataset/person18_19 6 | 7 | [model] 8 | dnn = model.dnn.vgg.person18_19 9 | stages = model.stages.openpose.Stage0 model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage 10 | 11 | [data] 12 | sizes = 368,368 13 | 14 | [nms] 15 | threshold = 0.05 16 | 17 | [integration] 18 | step = 5 19 | step_limits = 5 25 20 | min_score = 0.05 21 | min_count = 9 22 | 23 | [cluster] 24 | min_score = 0.4 25 | min_count = 3 26 | -------------------------------------------------------------------------------- /config/summary/histogram.txt: -------------------------------------------------------------------------------- 1 | .+\.bn\.weight$ 2 | -------------------------------------------------------------------------------- /convert_caffe_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import csv 24 | import hashlib 25 | import shutil 26 | import yaml 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn as nn 31 | import torch.autograd 32 | import caffe 33 | 34 | import utils 35 | import utils.train 36 | import model 37 | 38 | 39 | def load_mapper(path): 40 | with open(path, 'r') as f: 41 | lines = list(csv.reader(f, delimiter='\t')) 42 | mapper = {} 43 | for line in lines: 44 | if len(line) == 3: 45 | dst, src, transform = line 46 | transform = eval(transform) 47 | mapper[dst] = (src, transform) 48 | return mapper 49 | 50 | 51 | def main(): 52 | args = make_args() 53 | config = configparser.ConfigParser() 54 | utils.load_config(config, args.config) 55 | for cmd in args.modify: 56 | utils.modify_config(config, cmd) 57 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 58 | logging.config.dictConfig(yaml.load(f)) 59 | torch.manual_seed(args.seed) 60 | mapper = load_mapper(os.path.expandvars(os.path.expanduser(args.mapper))) 61 | model_dir = utils.get_model_dir(config) 62 | _, num_parts = utils.get_dataset_mappers(config) 63 | limbs_index = utils.get_limbs_index(config) 64 | height, width = tuple(map(int, config.get('image', 'size').split())) 65 | tensor = torch.randn(args.batch_size, 3, height, width) 66 | # PyTorch 67 | try: 68 | path, step, epoch = utils.train.load_model(model_dir) 69 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 70 | except (FileNotFoundError, ValueError): 71 | state_dict = {name: None for name in ('dnn', 'stages')} 72 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn']) 73 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn) 74 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels) 75 | channel_dict = model.channel_dict(num_parts, len(limbs_index)) 76 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())]) 77 | inference = model.Inference(config, dnn, stages) 78 | inference.eval() 79 | state_dict = inference.state_dict() 80 | # Caffe 81 | net = caffe.Net(os.path.expanduser(os.path.expandvars(args.prototxt)), os.path.expanduser(os.path.expandvars(args.caffemodel)), caffe.TEST) 82 | if args.debug: 83 | logging.info('Caffe variables') 84 | for name, blobs in net.params.items(): 85 | for i, blob in enumerate(blobs): 86 | val = blob.data 87 | print('\t'.join(map(str, [ 88 | '%s/%d' % (name, i), 89 | 'x'.join(map(str, val.shape)), 90 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 91 | ]))) 92 | logging.info('Caffe features') 93 | input = net.blobs[args.input] 94 | input.reshape(*tensor.size()) 95 | input.data[...] = tensor.numpy() 96 | net.forward() 97 | for name, blob in net.blobs.items(): 98 | val = blob.data 99 | print('\t'.join(map(str, [ 100 | name, 101 | 'x'.join(map(str, val.shape)), 102 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 103 | ]))) 104 | # convert 105 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep')) 106 | try: 107 | for dst in state_dict: 108 | src, transform = mapper[dst] 109 | blobs = [b.data for b in net.params[src]] 110 | blob = transform(blobs) 111 | if isinstance(blob, np.ndarray): 112 | state_dict[dst] = torch.from_numpy(blob) 113 | else: 114 | state_dict[dst].fill_(blob) 115 | val = state_dict[dst].numpy() 116 | logging.info('\t'.join(list(map(str, (dst, src, val.shape, utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest()))))) 117 | inference.load_state_dict(state_dict) 118 | if args.delete: 119 | logging.warning('delete model directory: ' + model_dir) 120 | shutil.rmtree(model_dir, ignore_errors=True) 121 | saver(dict( 122 | dnn=inference.dnn.state_dict(), 123 | stages=inference.stages.state_dict(), 124 | ), 0) 125 | finally: 126 | for stage, output in enumerate(inference(tensor)): 127 | for name, feature in output.items(): 128 | val = feature.detach().numpy() 129 | print('\t'.join(map(str, [ 130 | 'stage%d/%s' % (stage, name), 131 | 'x'.join(map(str, val.shape)), 132 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 133 | ]))) 134 | 135 | 136 | def make_args(): 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('mapper') 139 | parser.add_argument('prototxt') 140 | parser.add_argument('caffemodel') 141 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 142 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 143 | parser.add_argument('--logging', default='logging.yml', help='logging config') 144 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 145 | parser.add_argument('-i', '--input', default='image', help='input tensor name of Caffe') 146 | parser.add_argument('-d', '--delete', action='store_true', help='delete model') 147 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor') 148 | parser.add_argument('--debug', action='store_true') 149 | return parser.parse_args() 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /convert_onnx_caffe2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import onnx 26 | import onnx_caffe2.backend 27 | import onnx_caffe2.helper 28 | 29 | import utils 30 | 31 | 32 | def main(): 33 | args = make_args() 34 | config = configparser.ConfigParser() 35 | utils.load_config(config, args.config) 36 | for cmd in args.modify: 37 | utils.modify_config(config, cmd) 38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 39 | logging.config.dictConfig(yaml.load(f)) 40 | model_dir = utils.get_model_dir(config) 41 | model = onnx.load(model_dir + '.onnx') 42 | onnx.checker.check_model(model) 43 | init_net, predict_net = onnx_caffe2.backend.Caffe2Backend.onnx_graph_to_caffe2_net(model.graph, device='CPU') 44 | onnx_caffe2.helper.save_caffe2_net(init_net, os.path.join(model_dir, 'init_net.pb')) 45 | onnx_caffe2.helper.save_caffe2_net(predict_net, os.path.join(model_dir, 'predict_net.pb'), output_txt=True) 46 | logging.info(model_dir) 47 | 48 | 49 | def make_args(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 52 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 53 | parser.add_argument('--logging', default='logging.yml', help='logging config') 54 | return parser.parse_args() 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /convert_tf_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import csv 24 | import hashlib 25 | import shutil 26 | import yaml 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn as nn 31 | import torch.autograd 32 | import tensorflow as tf 33 | from tensorflow.python.framework import ops 34 | from tensorboardX import SummaryWriter 35 | 36 | import utils 37 | import utils.train 38 | import model 39 | 40 | 41 | def load_mapper(path): 42 | with open(os.path.splitext(path)[0] + '.tsv', 'r') as f: 43 | lines = list(csv.reader(f, delimiter='\t')) 44 | mapper = {} 45 | for line in lines: 46 | if line: 47 | if len(line) < 3: 48 | line += [''] * (3 - len(line)) 49 | dst, src, _converter = line 50 | converter = eval(_converter) if _converter else lambda val: val 51 | mapper[dst] = (src, converter) 52 | return mapper 53 | 54 | 55 | def main(): 56 | args = make_args() 57 | config = configparser.ConfigParser() 58 | utils.load_config(config, args.config) 59 | for cmd in args.modify: 60 | utils.modify_config(config, cmd) 61 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 62 | logging.config.dictConfig(yaml.load(f)) 63 | torch.manual_seed(args.seed) 64 | mapper = load_mapper(os.path.expandvars(os.path.expanduser(args.mapper))) 65 | model_dir = utils.get_model_dir(config) 66 | _, num_parts = utils.get_dataset_mappers(config) 67 | limbs_index = utils.get_limbs_index(config) 68 | height, width = tuple(map(int, config.get('image', 'size').split())) 69 | tensor = torch.randn(args.batch_size, 3, height, width) 70 | # PyTorch 71 | try: 72 | path, step, epoch = utils.train.load_model(model_dir) 73 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 74 | except (FileNotFoundError, ValueError): 75 | state_dict = {name: None for name in ('dnn', 'stages')} 76 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn']) 77 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn) 78 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels) 79 | channel_dict = model.channel_dict(num_parts, len(limbs_index)) 80 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())]) 81 | inference = model.Inference(config, dnn, stages) 82 | inference.eval() 83 | state_dict = inference.state_dict() 84 | # TensorFlow 85 | with open(os.path.expanduser(os.path.expandvars(args.path)), 'rb') as f: 86 | graph_def = tf.GraphDef() 87 | graph_def.ParseFromString(f.read()) 88 | image = ops.convert_to_tensor(np.transpose(tensor.cpu().numpy(), [0, 2, 3, 1]), name='image') 89 | tf.import_graph_def(graph_def, input_map={'image:0': image}) 90 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep')) 91 | with tf.Session(config=tf.ConfigProto( 92 | device_count={'CPU': 1, 'GPU': 0}, 93 | allow_soft_placement=True, 94 | log_device_placement=False 95 | )) as sess: 96 | try: 97 | for dst in state_dict: 98 | src, converter = mapper[dst] 99 | if src.isdigit(): 100 | state_dict[dst].fill_(float(src)) 101 | else: 102 | op = sess.graph.get_operation_by_name(src) 103 | t = op.values()[0] 104 | v = sess.run(t) 105 | state_dict[dst] = torch.from_numpy(converter(v)) 106 | val = state_dict[dst].numpy() 107 | print('\t'.join(list(map(str, (dst, src, val.shape, utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest()))))) 108 | inference.load_state_dict(state_dict) 109 | if args.delete: 110 | logging.warning('delete model directory: ' + model_dir) 111 | shutil.rmtree(model_dir, ignore_errors=True) 112 | saver(dict( 113 | dnn=inference.dnn.state_dict(), 114 | stages=inference.stages.state_dict(), 115 | ), 0) 116 | finally: 117 | if args.debug: 118 | for op in sess.graph.get_operations(): 119 | if op.values(): 120 | logging.info(op.values()[0]) 121 | for name in args.debug: 122 | t = sess.graph.get_tensor_by_name(name + ':0') 123 | val = sess.run(t) 124 | val = np.transpose(val, [0, 3, 1, 2]) 125 | print('\t'.join(map(str, [ 126 | name, 127 | 'x'.join(map(str, val.shape)), 128 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 129 | ]))) 130 | val = dnn(tensor).detach().numpy() 131 | print('\t'.join(map(str, [ 132 | 'x'.join(map(str, val.shape)), 133 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 134 | ]))) 135 | for stage, output in enumerate(inference(tensor)): 136 | for name, feature in output.items(): 137 | val = feature.detach().numpy() 138 | print('\t'.join(map(str, [ 139 | 'stage%d/%s' % (stage, name), 140 | 'x'.join(map(str, val.shape)), 141 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(), 142 | ]))) 143 | forward = inference.forward 144 | inference.forward = lambda self, *x: list(forward(self, *x)[-1].values()) 145 | with SummaryWriter(model_dir) as writer: 146 | writer.add_graph(inference, (tensor,)) 147 | 148 | 149 | def make_args(): 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('path') 152 | parser.add_argument('mapper') 153 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 154 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 155 | parser.add_argument('--logging', default='logging.yml', help='logging config') 156 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 157 | parser.add_argument('-d', '--delete', action='store_true', help='delete model') 158 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor') 159 | parser.add_argument('--debug', nargs='+') 160 | return parser.parse_args() 161 | 162 | 163 | if __name__ == '__main__': 164 | main() -------------------------------------------------------------------------------- /convert_torch_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import torch.nn as nn 26 | import torch.autograd 27 | import torch.cuda 28 | import torch.optim 29 | import torch.utils.data 30 | import torch.onnx 31 | import humanize 32 | 33 | import utils.train 34 | import model 35 | 36 | 37 | def main(): 38 | args = make_args() 39 | config = configparser.ConfigParser() 40 | utils.load_config(config, args.config) 41 | for cmd in args.modify: 42 | utils.modify_config(config, cmd) 43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 44 | logging.config.dictConfig(yaml.load(f)) 45 | height, width = tuple(map(int, config.get('image', 'size').split())) 46 | model_dir = utils.get_model_dir(config) 47 | _, num_parts = utils.get_dataset_mappers(config) 48 | limbs_index = utils.get_limbs_index(config) 49 | path, step, epoch = utils.train.load_model(model_dir) 50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 51 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn']) 52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn) 53 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels) 54 | channel_dict = model.channel_dict(num_parts, len(limbs_index)) 55 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())]) 56 | dnn.load_state_dict(config_channels_dnn.state_dict) 57 | stages.load_state_dict(config_channels_stages.state_dict) 58 | inference = model.Inference(config, dnn, stages) 59 | inference.eval() 60 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in inference.state_dict().values()))) 61 | image = torch.randn(args.batch_size, 3, height, width) 62 | path = model_dir + '.onnx' 63 | logging.info('save ' + path) 64 | forward = inference.forward 65 | inference.forward = lambda self, *x: [[output[name] for name in 'parts, limbs'.split(', ')] for output in forward(self, *x)] 66 | torch.onnx.export(inference, image, path, export_params=True, verbose=args.verbose) 67 | 68 | 69 | def make_args(): 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 72 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 73 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size') 74 | parser.add_argument('-v', '--verbose', action='store_true') 75 | parser.add_argument('--logging', default='logging.yml', help='logging config') 76 | return parser.parse_args() 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /demo_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import multiprocessing 24 | import yaml 25 | 26 | import numpy as np 27 | import torch.utils.data 28 | import matplotlib.pyplot as plt 29 | 30 | import utils.data 31 | import utils.train 32 | import utils.visualize 33 | import transform.augmentation 34 | import model 35 | 36 | 37 | def main(): 38 | args = make_args() 39 | config = configparser.ConfigParser() 40 | utils.load_config(config, args.config) 41 | for cmd in args.modify: 42 | utils.modify_config(config, cmd) 43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 44 | logging.config.dictConfig(yaml.load(f)) 45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | cache_dir = utils.get_cache_dir(config) 47 | _, num_parts = utils.get_dataset_mappers(config) 48 | limbs_index = utils.get_limbs_index(config) 49 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device) 50 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split()) 51 | _draw_points = utils.visualize.DrawPoints(limbs_index, thickness=1) 52 | draw_bbox = utils.visualize.DrawBBox() 53 | batch_size = args.rows * args.cols 54 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase] 55 | dataset = utils.data.Dataset( 56 | config, 57 | utils.data.load_pickles(paths), 58 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()), 59 | shuffle=config.getboolean('data', 'shuffle'), 60 | ) 61 | logging.info('num_examples=%d' % len(dataset)) 62 | try: 63 | workers = config.getint('data', 'workers') 64 | except configparser.NoOptionError: 65 | workers = multiprocessing.cpu_count() 66 | sizes = utils.train.load_sizes(config) 67 | feature_sizes = [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:] for size in sizes] 68 | collate_fn = utils.data.Collate( 69 | config, 70 | transform.parse_transform(config, config.get('transform', 'resize_train')), 71 | sizes, feature_sizes, 72 | maintain=config.getint('data', 'maintain'), 73 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()), 74 | ) 75 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn) 76 | for data in loader: 77 | path, size, image, mask, keypoints, yx_min, yx_max, index = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, mask, keypoints, yx_min, yx_max, index'.split(', '))) 78 | fig, axes = plt.subplots(args.rows, args.cols) 79 | axes = axes.flat if batch_size > 1 else [axes] 80 | for ax, path, size, image, mask, keypoints, yx_min, yx_max, index in zip(*[axes, path, size, image, mask, keypoints, yx_min, yx_max, index]): 81 | logging.info(path + ': ' + 'x'.join(map(str, size))) 82 | image = utils.visualize.draw_mask(image, mask, 1) 83 | size = yx_max - yx_min 84 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)]) 85 | keypoints, yx_min, yx_max = (a[target] for a in (keypoints, yx_min, yx_max)) 86 | for i, points in enumerate(keypoints): 87 | if i == index: 88 | image = draw_points(image, points) 89 | else: 90 | image = _draw_points(image, points) 91 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int)) 92 | ax.imshow(image) 93 | ax.set_xticks([]) 94 | ax.set_yticks([]) 95 | fig.tight_layout() 96 | mng = plt.get_current_fig_manager() 97 | mng.resize(*mng.window.maxsize()) 98 | plt.show() 99 | 100 | 101 | def make_args(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 104 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 105 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 106 | parser.add_argument('--rows', default=3, type=int) 107 | parser.add_argument('--cols', default=3, type=int) 108 | parser.add_argument('--logging', default='logging.yml', help='logging config') 109 | return parser.parse_args() 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /demo_keypoints.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import yaml 24 | 25 | import numpy as np 26 | import scipy.misc 27 | import matplotlib.pyplot as plt 28 | 29 | import utils.data 30 | import utils.visualize 31 | 32 | 33 | def main(): 34 | args = make_args() 35 | config = configparser.ConfigParser() 36 | utils.load_config(config, args.config) 37 | for cmd in args.modify: 38 | utils.modify_config(config, cmd) 39 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 40 | logging.config.dictConfig(yaml.load(f)) 41 | cache_dir = utils.get_cache_dir(config) 42 | _, num_parts = utils.get_dataset_mappers(config) 43 | limbs_index = utils.get_limbs_index(config) 44 | mask_ext = config.get('cache', 'mask_ext') 45 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase] 46 | dataset = utils.data.Dataset(config, utils.data.load_pickles(paths)) 47 | logging.info('num_examples=%d' % len(dataset)) 48 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split()) 49 | draw_bbox = utils.visualize.DrawBBox(config) 50 | for data in dataset: 51 | path, keypath, keypoints, yx_min, yx_max = (data[key] for key in 'path, keypath, keypoints, yx_min, yx_max'.split(', ')) 52 | image = scipy.misc.imread(path, mode='RGB') 53 | fig = plt.figure() 54 | ax = fig.gca() 55 | maskpath = keypath + '.mask' + mask_ext 56 | mask = scipy.misc.imread(maskpath) 57 | image = utils.visualize.draw_mask(image, mask) 58 | for points in keypoints: 59 | image = draw_points(image, points) 60 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int)) 61 | ax.imshow(image) 62 | ax.set_xlim([0, image.shape[1] - 1]) 63 | ax.set_ylim([image.shape[0] - 1, 0]) 64 | ax.set_xticks([]) 65 | ax.set_yticks([]) 66 | mng = plt.get_current_fig_manager() 67 | mng.resize(*mng.window.maxsize()) 68 | plt.show() 69 | 70 | 71 | def make_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 74 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 75 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 76 | parser.add_argument('--logging', default='logging.yml', help='logging config') 77 | return parser.parse_args() 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /demo_label.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import sys 19 | import os 20 | import argparse 21 | import configparser 22 | import logging 23 | import logging.config 24 | import multiprocessing 25 | import yaml 26 | 27 | import numpy as np 28 | import torch.utils.data 29 | from PyQt5 import QtCore, QtWidgets 30 | import matplotlib.pyplot as plt 31 | import matplotlib.backends.backend_qt5agg as qtagg 32 | import humanize 33 | import cv2 34 | 35 | import model 36 | import utils.data 37 | import utils.train 38 | import utils.visualize 39 | import transform.augmentation 40 | 41 | 42 | class Visualizer(QtWidgets.QDialog): 43 | def __init__(self, name, image, feature, alpha=0.5): 44 | super(Visualizer, self).__init__() 45 | self.name = name 46 | self.image = image 47 | self.feature = feature 48 | self.draw_feature = utils.visualize.DrawFeature(alpha) 49 | 50 | layout = QtWidgets.QVBoxLayout(self) 51 | fig = plt.Figure() 52 | self.ax = fig.gca() 53 | self.canvas = qtagg.FigureCanvasQTAgg(fig) 54 | layout.addWidget(self.canvas) 55 | toolbar = qtagg.NavigationToolbar2QT(self.canvas, self) 56 | layout.addWidget(toolbar) 57 | self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self) 58 | self.slider.setRange(0, feature.shape[0] - 1) 59 | layout.addWidget(self.slider) 60 | self.slider.valueChanged[int].connect(self.on_progress) 61 | 62 | self.ax.imshow(self.image) 63 | self.ax.set_xticks([]) 64 | self.ax.set_yticks([]) 65 | self.on_progress(0) 66 | 67 | def on_progress(self, index): 68 | try: 69 | self.last.remove() 70 | except AttributeError: 71 | pass 72 | image = np.copy(self.image) 73 | feature = self.feature[index, :, :] 74 | image = self.draw_feature(image, feature) 75 | self.last = self.ax.imshow(image) 76 | self.canvas.draw() 77 | plt.draw() 78 | self.setWindowTitle('%s %d/%d' % (self.name, index + 1, self.feature.shape[0])) 79 | 80 | 81 | def main(): 82 | args = make_args() 83 | config = configparser.ConfigParser() 84 | utils.load_config(config, args.config) 85 | for cmd in args.modify: 86 | utils.modify_config(config, cmd) 87 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 88 | logging.config.dictConfig(yaml.load(f)) 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | cache_dir = utils.get_cache_dir(config) 91 | _, num_parts = utils.get_dataset_mappers(config) 92 | limbs_index = utils.get_limbs_index(config) 93 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device) 94 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values()))) 95 | size = tuple(map(int, config.get('image', 'size').split())) 96 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split()) 97 | _draw_points = utils.visualize.DrawPoints(limbs_index, thickness=1) 98 | draw_bbox = utils.visualize.DrawBBox() 99 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase] 100 | dataset = utils.data.Dataset( 101 | config, 102 | utils.data.load_pickles(paths), 103 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()), 104 | shuffle=config.getboolean('data', 'shuffle'), 105 | ) 106 | logging.info('num_examples=%d' % len(dataset)) 107 | try: 108 | workers = config.getint('data', 'workers') 109 | except configparser.NoOptionError: 110 | workers = multiprocessing.cpu_count() 111 | collate_fn = utils.data.Collate( 112 | config, 113 | transform.parse_transform(config, config.get('transform', 'resize_train')), 114 | [size], [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:]], 115 | maintain=config.getint('data', 'maintain'), 116 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()), 117 | ) 118 | loader = torch.utils.data.DataLoader(dataset, shuffle=True, num_workers=workers, collate_fn=collate_fn) 119 | for data in loader: 120 | path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index'.split(', '))) 121 | for path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index in zip(*[path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index]): 122 | logging.info(path + ': ' + 'x'.join(map(str, size))) 123 | image = utils.visualize.draw_mask(image, mask, 1) 124 | size = yx_max - yx_min 125 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)]) 126 | keypoints, yx_min, yx_max = (a[target] for a in (keypoints, yx_min, yx_max)) 127 | for i, points in enumerate(keypoints): 128 | if i == index: 129 | image = draw_points(image, points) 130 | else: 131 | image = _draw_points(image, points) 132 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int)) 133 | dialog = Visualizer('parts', image, parts) 134 | dialog.exec() 135 | dialog = Visualizer('limbs', image, limbs) 136 | dialog.exec() 137 | 138 | 139 | def make_args(): 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 142 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 143 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test']) 144 | parser.add_argument('--logging', default='logging.yml', help='logging config') 145 | return parser.parse_args() 146 | 147 | if __name__ == '__main__': 148 | app = QtWidgets.QApplication(sys.argv) 149 | main() 150 | sys.exit(app.exec_()) 151 | -------------------------------------------------------------------------------- /donate_alipay.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/donate_alipay.jpg -------------------------------------------------------------------------------- /donate_mm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/donate_mm.jpg -------------------------------------------------------------------------------- /estimate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import argparse 19 | import configparser 20 | import logging 21 | import logging.config 22 | import os 23 | import time 24 | import re 25 | import yaml 26 | 27 | import numpy as np 28 | import torch.autograd 29 | import torch.cuda 30 | import torch.optim 31 | import torch.utils.data 32 | import torch.nn as nn 33 | try: 34 | from caffe2.proto import caffe2_pb2 35 | from caffe2.python import workspace 36 | except ImportError: 37 | pass 38 | import humanize 39 | import pybenchmark 40 | import cv2 41 | 42 | import transform 43 | import model 44 | import utils.train 45 | import utils.visualize 46 | import pyopenpose 47 | 48 | 49 | class Estimate(object): 50 | def __init__(self, args, config): 51 | self.args = args 52 | self.config = config 53 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | self.cache_dir = utils.get_cache_dir(config) 55 | self.model_dir = utils.get_model_dir(config) 56 | _, self.num_parts = utils.get_dataset_mappers(config) 57 | self.limbs_index = utils.get_limbs_index(config) 58 | if args.debug is None: 59 | self.draw_cluster = utils.visualize.DrawCluster(colors=args.colors, thickness=args.thickness) 60 | else: 61 | self.draw_feature = utils.visualize.DrawFeature() 62 | s = re.search('(-?[0-9]+)([a-z]+)(-?[0-9]+)', args.debug) 63 | stage = int(s.group(1)) 64 | name = s.group(2) 65 | channel = int(s.group(3)) 66 | self.get_feature = lambda outputs: outputs[stage][name][0][channel] 67 | self.height, self.width = tuple(map(int, config.get('image', 'size').split())) 68 | if args.caffe: 69 | init_net = caffe2_pb2.NetDef() 70 | with open(os.path.join(self.model_dir, 'init_net.pb'), 'rb') as f: 71 | init_net.ParseFromString(f.read()) 72 | predict_net = caffe2_pb2.NetDef() 73 | with open(os.path.join(self.model_dir, 'predict_net.pb'), 'rb') as f: 74 | predict_net.ParseFromString(f.read()) 75 | p = workspace.Predictor(init_net, predict_net) 76 | self.inference = lambda tensor: [{'parts': torch.from_numpy(parts), 'limbs': torch.from_numpy(limbs)} for parts, limbs in zip(*[iter(p.run([tensor.detach().cpu().numpy()]))] * 2)] 77 | else: 78 | self.step, self.epoch, self.dnn, self.stages = self.load() 79 | self.inference = model.Inference(config, self.dnn, self.stages) 80 | self.inference.eval() 81 | if torch.cuda.is_available(): 82 | self.inference.cuda() 83 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values()))) 84 | self.cap = self.create_cap() 85 | self.keys = set(args.keys) 86 | self.resize = transform.parse_transform(config, config.get('transform', 'resize_test')) 87 | self.transform_image = transform.get_transform(config, config.get('transform', 'image_test').split()) 88 | self.transform_tensor = transform.get_transform(config, config.get('transform', 'tensor').split()) 89 | 90 | def __del__(self): 91 | cv2.destroyAllWindows() 92 | try: 93 | self.writer.release() 94 | except AttributeError: 95 | pass 96 | self.cap.release() 97 | 98 | def load(self): 99 | path, step, epoch = utils.train.load_model(self.model_dir) 100 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 101 | config_channels_dnn = model.ConfigChannels(self.config, state_dict['dnn']) 102 | dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels_dnn) 103 | config_channels_stages = model.ConfigChannels(self.config, state_dict['stages'], config_channels_dnn.channels) 104 | channel_dict = model.channel_dict(self.num_parts, len(self.limbs_index)) 105 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(self.config.get('model', 'stages').split())]) 106 | dnn.load_state_dict(config_channels_dnn.state_dict) 107 | stages.load_state_dict(config_channels_stages.state_dict) 108 | return step, epoch, dnn, stages 109 | 110 | def create_cap(self): 111 | try: 112 | cap = int(self.args.input) 113 | except ValueError: 114 | cap = os.path.expanduser(os.path.expandvars(self.args.input)) 115 | assert os.path.exists(cap) 116 | return cv2.VideoCapture(cap) 117 | 118 | def create_writer(self, height, width): 119 | fps = self.cap.get(cv2.CAP_PROP_FPS) 120 | logging.info('cap fps=%f' % fps) 121 | path = os.path.expanduser(os.path.expandvars(self.args.output)) 122 | if self.args.fourcc: 123 | fourcc = cv2.VideoWriter_fourcc(*self.args.fourcc.upper()) 124 | else: 125 | fourcc = int(self.cap.get(cv2.CAP_PROP_FOURCC)) 126 | os.makedirs(os.path.dirname(path), exist_ok=True) 127 | return cv2.VideoWriter(path, fourcc, fps, (width, height)) 128 | 129 | def get_image(self): 130 | ret, image_bgr = self.cap.read() 131 | if self.args.crop: 132 | image_bgr = image_bgr[self.crop_ymin:self.crop_ymax, self.crop_xmin:self.crop_xmax] 133 | return image_bgr 134 | 135 | def __call__(self): 136 | image_bgr = self.get_image() 137 | image_resized = self.resize(image_bgr, self.height, self.width) 138 | image = self.transform_image(image_resized) 139 | tensor = self.transform_tensor(image) 140 | tensor = tensor.unsqueeze(0).to(self.device) 141 | outputs = pybenchmark.profile('inference')(self.inference)(tensor) 142 | if hasattr(self, 'draw_cluster'): 143 | output = outputs[-1] 144 | parts, limbs = (output[name][0] for name in 'parts, limbs'.split(', ')) 145 | parts = parts[:-1] 146 | parts, limbs = (t.detach().cpu().numpy() for t in (parts, limbs)) 147 | try: 148 | interpolation = getattr(cv2, 'INTER_' + self.config.get('estimate', 'interpolation').upper()) 149 | parts, limbs = (np.stack([cv2.resize(feature, (self.width, self.height), interpolation=interpolation) for feature in a]) for a in (parts, limbs)) 150 | except configparser.NoOptionError: 151 | pass 152 | clusters = pyopenpose.estimate( 153 | parts, limbs, 154 | self.limbs_index, 155 | self.config.getfloat('nms', 'threshold'), 156 | self.config.getfloat('integration', 'step'), tuple(map(int, self.config.get('integration', 'step_limits').split())), self.config.getfloat('integration', 'min_score'), self.config.getint('integration', 'min_count'), 157 | self.config.getfloat('cluster', 'min_score'), self.config.getint('cluster', 'min_count'), 158 | ) 159 | scale_y, scale_x = self.resize.scale(parts.shape[-2:], image_bgr.shape[:2]) 160 | image_result = image_bgr.copy() 161 | for cluster in clusters: 162 | cluster = [((i1, int(y1 * scale_y), int(x1 * scale_x)), (i2, int(y2 * scale_y), int(x2 * scale_x))) for (i1, y1, x1), (i2, y2, x2) in cluster] 163 | image_result = self.draw_cluster(image_result, cluster) 164 | else: 165 | image_result = image_resized.copy() 166 | feature = self.get_feature(outputs).detach().cpu().numpy() 167 | image_result = self.draw_feature(image_result, feature) 168 | if self.args.output: 169 | if not hasattr(self, 'writer'): 170 | self.writer = self.create_writer(*image_result.shape[:2]) 171 | self.writer.write(image_result) 172 | else: 173 | cv2.imshow('estimate', image_result) 174 | if cv2.waitKey(0 if self.args.pause else 1) in self.keys: 175 | root = os.path.join(self.model_dir, 'snapshot') 176 | os.makedirs(root, exist_ok=True) 177 | path = os.path.join(root, time.strftime(self.args.format)) 178 | cv2.imwrite(path, image_bgr) 179 | logging.warning('image dumped into ' + path) 180 | 181 | 182 | def main(): 183 | args = make_args() 184 | config = configparser.ConfigParser() 185 | utils.load_config(config, args.config) 186 | for cmd in args.modify: 187 | utils.modify_config(config, cmd) 188 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 189 | logging.config.dictConfig(yaml.load(f)) 190 | estimate = Estimate(args, config) 191 | try: 192 | with torch.no_grad(): 193 | while estimate.cap.isOpened(): 194 | estimate() 195 | except KeyboardInterrupt: 196 | logging.warning('interrupted') 197 | finally: 198 | logging.info(pybenchmark.stats) 199 | 200 | 201 | def make_args(): 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 204 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 205 | parser.add_argument('-i', '--input', default=-1) 206 | parser.add_argument('-k', '--keys', nargs='+', type=int, default=[ord(' ')], help='keys to dump images') 207 | parser.add_argument('-o', '--output', help='output video file') 208 | parser.add_argument('-f', '--format', default='%Y-%m-%d_%H-%M-%S.jpg', help='dump file name format') 209 | parser.add_argument('--crop', nargs='+', type=float, default=[], help='ymin ymax xmin xmax') 210 | parser.add_argument('--pause', action='store_true') 211 | parser.add_argument('--fourcc', default='XVID', help='4-character code of codec used to compress the frames, such as XVID, MJPG') 212 | parser.add_argument('--thickness', default=3, type=int) 213 | parser.add_argument('--colors', nargs='+', default=[]) 214 | parser.add_argument('-d', '--debug') 215 | parser.add_argument('--caffe', action='store_true') 216 | parser.add_argument('--logging', default='logging.yml', help='logging config') 217 | return parser.parse_args() 218 | 219 | 220 | if __name__ == '__main__': 221 | main() 222 | -------------------------------------------------------------------------------- /logging.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 5 | handlers: 6 | console: 7 | class: logging.StreamHandler 8 | level: INFO 9 | formatter: simple 10 | stream: ext://sys.stderr 11 | root: 12 | level: INFO 13 | handlers: [console] -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import collections 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.autograd 24 | 25 | 26 | class ConfigChannels(object): 27 | def __init__(self, config, state_dict=None, channels=3): 28 | self.config = config 29 | self.state_dict = state_dict 30 | self.channels = channels 31 | 32 | def __call__(self, default, name, fn=lambda var: var.size(0)): 33 | if self.state_dict is None: 34 | self.channels = default 35 | else: 36 | var = self.state_dict[name] 37 | self.channels = fn(var) 38 | if self.channels != default: 39 | logging.warning('%s: change number of output channels from %d to %d' % (name, default, self.channels)) 40 | return self.channels 41 | 42 | 43 | def channel_dict(num_parts, num_limbs): 44 | return collections.OrderedDict([ 45 | ('parts', num_parts + 1), 46 | ('limbs', num_limbs * 2), 47 | ]) 48 | 49 | 50 | class Inference(nn.Module): 51 | def __init__(self, config, dnn, stages): 52 | nn.Module.__init__(self) 53 | self.config = config 54 | self.dnn = dnn 55 | self.stages = stages 56 | 57 | def forward(self, x): 58 | x = self.dnn(x) 59 | outputs = [] 60 | output = {} 61 | for stage in self.stages: 62 | output = stage(x, **output) 63 | outputs.append(output) 64 | return outputs 65 | 66 | 67 | class Loss(object): 68 | def __init__(self, config, data, limbs_index, height, width): 69 | self.config = config 70 | self.data = data 71 | self.limbs_index = limbs_index 72 | self.height = height 73 | self.width = width 74 | 75 | def __call__(self, **kwargs): 76 | mask = self.data['mask'].float() 77 | batch_size, rows, cols = mask.size() 78 | mask = mask.view(batch_size, 1, rows, cols) 79 | data = {name: self.data[name] for name in kwargs} 80 | return {name: self.loss(mask, data[name], feature) for name, feature in kwargs.items()} 81 | 82 | def loss(self, mask, label, feature): 83 | return torch.mean(mask * (feature - label) ** 2) 84 | -------------------------------------------------------------------------------- /model/dnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/model/dnn/__init__.py -------------------------------------------------------------------------------- /model/dnn/inception4.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import configparser 20 | import collections.abc 21 | 22 | import torch 23 | import torch.nn as nn 24 | from pretrainedmodels.models.inceptionv4 import pretrained_settings 25 | 26 | 27 | class Conv2d(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=True, act=True): 29 | nn.Module.__init__(self) 30 | if isinstance(padding, bool): 31 | if isinstance(kernel_size, collections.abc.Iterable): 32 | padding = [(kernel_size - 1) // 2 for kernel_size in kernel_size] if padding else 0 33 | else: 34 | padding = (kernel_size - 1) // 2 if padding else 0 35 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn) 36 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) if bn else lambda x: x 37 | self.act = nn.ReLU(inplace=True) if act else lambda x: x 38 | 39 | def forward(self, x): 40 | x = self.conv(x) 41 | x = self.bn(x) 42 | x = self.act(x) 43 | return x 44 | 45 | 46 | class Mixed_3a(nn.Module): 47 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 48 | nn.Module.__init__(self) 49 | channels = config_channels.channels 50 | self.maxpool = nn.MaxPool2d(3, stride=2) 51 | self.conv = Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.conv.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn) 52 | config_channels.channels = channels + self.conv.conv.weight.size(0) 53 | 54 | def forward(self, x): 55 | x0 = self.maxpool(x) 56 | x1 = self.conv(x) 57 | out = torch.cat((x0, x1), 1) 58 | return out 59 | 60 | 61 | class Mixed_4a(nn.Module): 62 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 63 | nn.Module.__init__(self) 64 | # branch0 65 | channels = config_channels.channels 66 | branch = [] 67 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 68 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, bn=bn)) 69 | self.branch0 = nn.Sequential(*branch) 70 | # branch1 71 | config_channels.channels = channels 72 | branch = [] 73 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 74 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn)) 75 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn)) 76 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(3, 3), stride=1, bn=bn)) 77 | self.branch1 = nn.Sequential(*branch) 78 | # output 79 | config_channels.channels = self.branch0[-1].conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) 80 | 81 | def forward(self, x): 82 | x0 = self.branch0(x) 83 | x1 = self.branch1(x) 84 | out = torch.cat((x0, x1), 1) 85 | return out 86 | 87 | 88 | class Mixed_5a(nn.Module): 89 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 90 | nn.Module.__init__(self) 91 | channels = config_channels.channels 92 | self.conv = Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.conv.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn) 93 | self.maxpool = nn.MaxPool2d(3, stride=2) 94 | config_channels.channels = self.conv.conv.weight.size(0) + channels 95 | 96 | def forward(self, x): 97 | x0 = self.conv(x) 98 | x1 = self.maxpool(x) 99 | out = torch.cat((x0, x1), 1) 100 | return out 101 | 102 | 103 | class Inception_A(nn.Module): 104 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 105 | nn.Module.__init__(self) 106 | channels = config_channels.channels 107 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn) 108 | # branch1 109 | config_channels.channels = channels 110 | branch = [] 111 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 112 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn)) 113 | self.branch1 = nn.Sequential(*branch) 114 | # branch2 115 | config_channels.channels = channels 116 | branch = [] 117 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 118 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn)) 119 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn)) 120 | self.branch2 = nn.Sequential(*branch) 121 | #branch3 122 | config_channels.channels = channels 123 | branch = [] 124 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)) 125 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch3.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 126 | self.branch3 = nn.Sequential(*branch) 127 | # output 128 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + self.branch2[-1].conv.weight.size(0) + self.branch3[-1].conv.weight.size(0) 129 | 130 | def forward(self, x): 131 | x0 = self.branch0(x) 132 | x1 = self.branch1(x) 133 | x2 = self.branch2(x) 134 | x3 = self.branch3(x) 135 | out = torch.cat((x0, x1, x2, x3), 1) 136 | return out 137 | 138 | 139 | class Reduction_A(nn.Module): 140 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 141 | nn.Module.__init__(self) 142 | channels = config_channels.channels 143 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn) 144 | # branch1 145 | config_channels.channels = channels 146 | branch = [] 147 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 148 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn)) 149 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn)) 150 | self.branch1 = nn.Sequential(*branch) 151 | 152 | self.branch2 = nn.MaxPool2d(3, stride=2) 153 | # output 154 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + channels 155 | 156 | def forward(self, x): 157 | x0 = self.branch0(x) 158 | x1 = self.branch1(x) 159 | x2 = self.branch2(x) 160 | out = torch.cat((x0, x1, x2), 1) 161 | return out 162 | 163 | 164 | class Inception_B(nn.Module): 165 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 166 | nn.Module.__init__(self) 167 | channels = config_channels.channels 168 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn) 169 | # branch1 170 | config_channels.channels = channels 171 | branch = [] 172 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 173 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn)) 174 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn)) 175 | self.branch1 = nn.Sequential(*branch) 176 | # branch2 177 | config_channels.channels = channels 178 | branch = [] 179 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 180 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn)) 181 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn)) 182 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn)) 183 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn)) 184 | self.branch2 = nn.Sequential(*branch) 185 | # branch3 186 | config_channels.channels = channels 187 | branch = [] 188 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)) 189 | branch.append(Conv2d(config_channels.channels, config_channels(int(128 * ratio), '%s.branch3.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 190 | self.branch3 = nn.Sequential(*branch) 191 | # output 192 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + self.branch2[-1].conv.weight.size(0) + self.branch3[-1].conv.weight.size(0) 193 | 194 | def forward(self, x): 195 | x0 = self.branch0(x) 196 | x1 = self.branch1(x) 197 | x2 = self.branch2(x) 198 | x3 = self.branch3(x) 199 | out = torch.cat((x0, x1, x2, x3), 1) 200 | return out 201 | 202 | 203 | class Reduction_B(nn.Module): 204 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 205 | nn.Module.__init__(self) 206 | # branch0 207 | channels = config_channels.channels 208 | branch = [] 209 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 210 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn)) 211 | self.branch0 = nn.Sequential(*branch) 212 | # branch1 213 | config_channels.channels = channels 214 | branch = [] 215 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn)) 216 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn)) 217 | branch.append(Conv2d(config_channels.channels, config_channels(int(320 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn)) 218 | branch.append(Conv2d(config_channels.channels, config_channels(int(320 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn)) 219 | self.branch1 = nn.Sequential(*branch) 220 | self.branch2 = nn.MaxPool2d(3, stride=2) 221 | # output 222 | config_channels.channels = self.branch0[-1].conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + channels 223 | 224 | def forward(self, x): 225 | x0 = self.branch0(x) 226 | x1 = self.branch1(x) 227 | x2 = self.branch2(x) 228 | out = torch.cat((x0, x1, x2), 1) 229 | return out 230 | 231 | 232 | class Inception_C(nn.Module): 233 | def __init__(self, config_channels, prefix, bn=True, ratio=1): 234 | nn.Module.__init__(self) 235 | channels = config_channels.channels 236 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn) 237 | # branch1 238 | config_channels.channels = channels 239 | self.branch1_0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch1_0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn) 240 | _channels = config_channels.channels 241 | self.branch1_1a = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch1_1a.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn) 242 | self.branch1_1b = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch1_1b.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn) 243 | # branch2 244 | config_channels.channels = channels 245 | self.branch2_0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch2_0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn) 246 | self.branch2_1 = Conv2d(config_channels.channels, config_channels(int(448 * ratio), '%s.branch2_1.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn) 247 | self.branch2_2 = Conv2d(config_channels.channels, config_channels(int(512 * ratio), '%s.branch2_2.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn) 248 | _channels = config_channels.channels 249 | self.branch2_3a = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch2_3a.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn) 250 | self.branch2_3b = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch2_3b.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn) 251 | # branch3 252 | config_channels.channels = channels 253 | branch = [] 254 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)) 255 | branch.append(Conv2d(config_channels.channels, int(256 * ratio), kernel_size=1, stride=1, bn=bn)) 256 | self.branch3 = nn.Sequential(*branch) 257 | # output 258 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1_1a.conv.weight.size(0) + self.branch1_1b.conv.weight.size(0) + self.branch2_3a.conv.weight.size(0) + self.branch2_3b.conv.weight.size(0) + self.branch3[-1].conv.weight.size(0) 259 | 260 | def forward(self, x): 261 | x0 = self.branch0(x) 262 | 263 | x1_0 = self.branch1_0(x) 264 | x1_1a = self.branch1_1a(x1_0) 265 | x1_1b = self.branch1_1b(x1_0) 266 | x1 = torch.cat((x1_1a, x1_1b), 1) 267 | 268 | x2_0 = self.branch2_0(x) 269 | x2_1 = self.branch2_1(x2_0) 270 | x2_2 = self.branch2_2(x2_1) 271 | x2_3a = self.branch2_3a(x2_2) 272 | x2_3b = self.branch2_3b(x2_2) 273 | x2 = torch.cat((x2_3a, x2_3b), 1) 274 | 275 | x3 = self.branch3(x) 276 | 277 | out = torch.cat((x0, x1, x2, x3), 1) 278 | return out 279 | 280 | 281 | class Inception4(nn.Module): 282 | def __init__(self, config_channels, ratio=1): 283 | nn.Module.__init__(self) 284 | features = [] 285 | bn = config_channels.config.getboolean('batch_norm', 'enable') 286 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=2, bn=bn)) 287 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, bn=bn)) 288 | features.append(Conv2d(config_channels.channels, config_channels(64, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, padding=1, bn=bn)) 289 | features.append(Mixed_3a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 290 | features.append(Mixed_4a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 291 | features.append(Mixed_5a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 292 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 293 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 294 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 295 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 296 | features.append(Reduction_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) # Mixed_6a 297 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 298 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 299 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 300 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 301 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 302 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 303 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 304 | features.append(Reduction_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) # Mixed_7a 305 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 306 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 307 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 308 | self.features = nn.Sequential(*features) 309 | self.init(config_channels) 310 | 311 | def init(self, config_channels): 312 | try: 313 | gamma = config_channels.config.getboolean('batch_norm', 'gamma') 314 | except (configparser.NoSectionError, configparser.NoOptionError): 315 | gamma = True 316 | try: 317 | beta = config_channels.config.getboolean('batch_norm', 'beta') 318 | except (configparser.NoSectionError, configparser.NoOptionError): 319 | beta = True 320 | for m in self.modules(): 321 | if isinstance(m, nn.Conv2d): 322 | m.weight = nn.init.kaiming_normal_(m.weight) 323 | elif isinstance(m, nn.BatchNorm2d): 324 | m.weight.fill_(1) 325 | m.bias.zero_() 326 | m.weight.requires_grad = gamma 327 | m.bias.requires_grad = beta 328 | try: 329 | if config_channels.config.getboolean('model', 'pretrained'): 330 | settings = pretrained_settings['inceptionv4'][config_channels.config.get('inception4', 'pretrained')] 331 | logging.info('use pretrained model: ' + str(settings)) 332 | state_dict = self.state_dict() 333 | for key, value in torch.utils.model_zoo.load_url(settings['url']).items(): 334 | if key in state_dict: 335 | state_dict[key] = value 336 | self.load_state_dict(state_dict) 337 | except (configparser.NoSectionError, configparser.NoOptionError): 338 | pass 339 | 340 | def forward(self, x): 341 | return self.features(x) 342 | 343 | def scope(self, name): 344 | return '.'.join(name.split('.')[:-2]) 345 | 346 | 347 | class Inception4_down3_4(Inception4): 348 | def __init__(self, config_channels, ratio=1 / 4): 349 | nn.Module.__init__(self) 350 | features = [] 351 | bn = config_channels.config.getboolean('batch_norm', 'enable') 352 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=2, bn=bn)) 353 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, bn=bn)) 354 | features.append(Conv2d(config_channels.channels, config_channels(64, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, padding=1, bn=bn)) 355 | features.append(Mixed_3a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 356 | features.append(Mixed_4a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 357 | features.append(Mixed_5a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 358 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 359 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 360 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 361 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) 362 | self.features = nn.Sequential(*features) 363 | self.init(config_channels) 364 | -------------------------------------------------------------------------------- /model/dnn/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import collections 19 | 20 | import torch.nn as nn 21 | 22 | import model 23 | 24 | 25 | def conv_bn(in_channels, out_channels, stride): 26 | return nn.Sequential(collections.OrderedDict([ 27 | ('conv', nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)), 28 | ('bn', nn.BatchNorm2d(out_channels)), 29 | ('act', nn.ReLU(inplace=True)), 30 | ])) 31 | 32 | 33 | def conv_dw(in_channels, stride): 34 | return nn.Sequential(collections.OrderedDict([ 35 | ('conv', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)), 36 | ('bn', nn.BatchNorm2d(in_channels)), 37 | ('act', nn.ReLU(inplace=True)), 38 | ])) 39 | 40 | 41 | def conv_pw(in_channels, out_channels): 42 | return nn.Sequential(collections.OrderedDict([ 43 | ('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)), 44 | ('bn', nn.BatchNorm2d(out_channels)), 45 | ('act', nn.ReLU(inplace=True)), 46 | ])) 47 | 48 | 49 | def conv_unit(in_channels, out_channels, stride): 50 | return nn.Sequential(collections.OrderedDict([ 51 | ('dw', conv_dw(in_channels, stride)), 52 | ('pw', conv_pw(in_channels, out_channels)), 53 | ])) 54 | 55 | 56 | class MobileNet(nn.Module): 57 | def __init__(self, config_channels): 58 | nn.Module.__init__(self) 59 | layers = [] 60 | layers.append(conv_bn(config_channels.channels, config_channels(32, 'layers.%d.conv.weight' % len(layers)), 2)) 61 | layers.append(conv_unit(config_channels.channels, config_channels(64, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 62 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 63 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 64 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 65 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 66 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 67 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 68 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 69 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 70 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 71 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 72 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 2)) 73 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 1)) 74 | self.layers = nn.Sequential(*layers) 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | m.weight = nn.init.kaiming_normal_(m.weight) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.fill_(1) 81 | m.bias.zero_() 82 | 83 | def forward(self, x): 84 | return self.layers(x) 85 | -------------------------------------------------------------------------------- /model/dnn/mobilenet2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import torch.nn as nn 19 | import math 20 | 21 | 22 | def conv_bn(inp, oup, stride, dilation=1): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False, dilation=dilation), 25 | nn.BatchNorm2d(oup), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | 30 | def conv_1x1_bn(inp, oup): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 33 | nn.BatchNorm2d(oup), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | self.use_res_connect = self.stride == 1 and inp == oup 45 | 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(inp * expand_ratio), 50 | nn.ReLU(inplace=True), 51 | # dw 52 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 53 | nn.BatchNorm2d(inp * expand_ratio), 54 | nn.ReLU(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | if self.use_res_connect: 62 | return x + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNet2(nn.Module): 68 | def __init__(self, config_channels, input_size=224, last_channel=320, width_mult=1., dilation=1, ratio=1): 69 | nn.Module.__init__(self) 70 | # setting of inverted residual blocks 71 | self.interverted_residual_setting = [ 72 | # t, c, n, s 73 | [1, int(16 * ratio), 1, 1], 74 | [6, int(24 * ratio), 2, 2], 75 | [6, int(32 * ratio), 3, 2], 76 | [6, int(64 * ratio), 4, 1], # stride 2->1 77 | [6, int(96 * ratio), 3, 1], 78 | [6, int(160 * ratio), 3, 1], # stride 2->1 79 | [6, int(320 * ratio), 1, 1], 80 | ] 81 | 82 | # building first layer 83 | assert input_size % 32 == 0 84 | input_channel = int(32 * width_mult) 85 | if last_channel is None: 86 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 87 | else: 88 | self.last_channel = int(last_channel * ratio) 89 | self.features = [conv_bn(3, input_channel, 2)] 90 | # building inverted residual blocks 91 | for t, c, n, s in self.interverted_residual_setting: 92 | output_channel = int(c * width_mult) 93 | for i in range(n): 94 | if i == 0: 95 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 96 | else: 97 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | self.features.append(conv_bn(input_channel, self.last_channel, 1, dilation=dilation)) 101 | #self.features.append(nn.AvgPool2d(input_size/32)) 102 | config_channels.channels = self.last_channel # temp 103 | 104 | # make it nn.Sequential 105 | self.features = nn.Sequential(*self.features) 106 | 107 | self._initialize_weights() 108 | 109 | def forward(self, x): 110 | return self.features(x) 111 | 112 | def _initialize_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) # PyTorch's bug 121 | m.bias.data.zero_() # PyTorch's bug 122 | elif isinstance(m, nn.Linear): 123 | m.weight.normal_(0, 0.01) 124 | m.bias.zero_() 125 | 126 | 127 | class MobileNet2Dilate2(MobileNet2): 128 | def __init__(self, config_channels): 129 | MobileNet2.__init__(self, config_channels, dilation=2) 130 | 131 | 132 | class MobileNet2Dilate4(MobileNet2): 133 | def __init__(self, config_channels): 134 | MobileNet2.__init__(self, config_channels, dilation=4) 135 | 136 | 137 | class MobileNet2Half(MobileNet2): 138 | def __init__(self, config_channels): 139 | MobileNet2.__init__(self, config_channels, ratio=1 / 2) 140 | 141 | 142 | class MobileNet2Quarter(MobileNet2): 143 | def __init__(self, config_channels): 144 | MobileNet2.__init__(self, config_channels, ratio=1 / 4) 145 | -------------------------------------------------------------------------------- /model/dnn/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import re 20 | 21 | import torch.nn as nn 22 | import torch.utils.model_zoo as model_zoo 23 | import torchvision.models.resnet as _model 24 | from torchvision.models.resnet import conv3x3 25 | 26 | import model 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | def __init__(self, config_channels, prefix, channels, stride=1): 31 | nn.Module.__init__(self) 32 | channels_in = config_channels.channels 33 | self.conv1 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), stride) 34 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix)) 37 | self.bn2 = nn.BatchNorm2d(config_channels.channels) 38 | if stride > 1 or channels_in != config_channels.channels: 39 | downsample = [] 40 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False)) 41 | downsample.append(nn.BatchNorm2d(config_channels.channels)) 42 | self.downsample = nn.Sequential(*downsample) 43 | else: 44 | self.downsample = None 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | def __init__(self, config_channels, prefix, channels, stride=1): 67 | nn.Module.__init__(self) 68 | channels_in = config_channels.channels 69 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 71 | self.conv2 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix), kernel_size=3, stride=stride, padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(config_channels.channels) 73 | self.conv3 = nn.Conv2d(config_channels.channels, config_channels(channels * 4, '%s.conv3.weight' % prefix), kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(config_channels.channels) 75 | self.relu = nn.ReLU(inplace=True) 76 | if stride > 1 or channels_in != config_channels.channels: 77 | downsample = [] 78 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False)) 79 | downsample.append(nn.BatchNorm2d(config_channels.channels)) 80 | self.downsample = nn.Sequential(*downsample) 81 | else: 82 | self.downsample = None 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(_model.ResNet): 108 | def __init__(self, config_channels, anchors, num_cls, block, layers): 109 | nn.Module.__init__(self) 110 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(64, 'conv1.weight'), kernel_size=7, stride=2, padding=3, bias=False) 111 | self.bn1 = nn.BatchNorm2d(config_channels.channels) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(config_channels, 'layer1', block, 64, layers[0]) 115 | self.layer2 = self._make_layer(config_channels, 'layer2', block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(config_channels, 'layer3', block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(config_channels, 'layer4', block, 512, layers[3], stride=2) 118 | 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | m.weight = nn.init.kaiming_normal_(m.weight) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.fill_(1) 124 | m.bias.zero_() 125 | 126 | def _make_layer(self, config_channels, prefix, block, channels, blocks, stride=1): 127 | layers = [] 128 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels, stride)) 129 | for i in range(1, blocks): 130 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels)) 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | return x 145 | 146 | def scope(self, name): 147 | comp = name.split('.')[:-1] 148 | try: 149 | comp[-1] = re.search('[(conv)|(bn)](\d+)', comp[-1]).group(1) 150 | except AttributeError: 151 | if len(comp) > 1: 152 | if comp[-2] == 'downsample': 153 | comp = comp[:-1] 154 | else: 155 | assert False, name 156 | else: 157 | assert comp[-1] == 'conv', name 158 | return '.'.join(comp) 159 | 160 | 161 | def resnet18(config_channels, **kwargs): 162 | model = ResNet(config_channels, BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if config_channels.config.getboolean('model', 'pretrained'): 164 | url = _model.model_urls['resnet18'] 165 | logging.info('use pretrained model: ' + url) 166 | state_dict = model.state_dict() 167 | for key, value in model_zoo.load_url(url).items(): 168 | if key in state_dict: 169 | state_dict[key] = value 170 | model.load_state_dict(state_dict) 171 | return model 172 | 173 | 174 | def resnet34(config_channels, **kwargs): 175 | model = ResNet(config_channels, BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if config_channels.config.getboolean('model', 'pretrained'): 177 | url = _model.model_urls['resnet34'] 178 | logging.info('use pretrained model: ' + url) 179 | state_dict = model.state_dict() 180 | for key, value in model_zoo.load_url(url).items(): 181 | if key in state_dict: 182 | state_dict[key] = value 183 | model.load_state_dict(state_dict) 184 | return model 185 | 186 | 187 | def resnet50(config_channels, **kwargs): 188 | model = ResNet(config_channels, Bottleneck, [3, 4, 6, 3], **kwargs) 189 | if config_channels.config.getboolean('model', 'pretrained'): 190 | url = _model.model_urls['resnet50'] 191 | logging.info('use pretrained model: ' + url) 192 | state_dict = model.state_dict() 193 | for key, value in model_zoo.load_url(url).items(): 194 | if key in state_dict: 195 | state_dict[key] = value 196 | model.load_state_dict(state_dict) 197 | return model 198 | 199 | 200 | def resnet101(config_channels, **kwargs): 201 | model = ResNet(config_channels, Bottleneck, [3, 4, 23, 3], **kwargs) 202 | if config_channels.config.getboolean('model', 'pretrained'): 203 | url = _model.model_urls['resnet101'] 204 | logging.info('use pretrained model: ' + url) 205 | state_dict = model.state_dict() 206 | for key, value in model_zoo.load_url(url).items(): 207 | if key in state_dict: 208 | state_dict[key] = value 209 | model.load_state_dict(state_dict) 210 | return model 211 | 212 | 213 | def resnet152(config_channels, **kwargs): 214 | model = ResNet(config_channels, Bottleneck, [3, 8, 36, 3], **kwargs) 215 | if config_channels.config.getboolean('model', 'pretrained'): 216 | url = _model.model_urls['resnet152'] 217 | logging.info('use pretrained model: ' + url) 218 | state_dict = model.state_dict() 219 | for key, value in model_zoo.load_url(url).items(): 220 | if key in state_dict: 221 | state_dict[key] = value 222 | model.load_state_dict(state_dict) 223 | return model 224 | -------------------------------------------------------------------------------- /model/dnn/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | 20 | import torch.nn as nn 21 | import torch.utils.model_zoo as model_zoo 22 | import torchvision.models.vgg as _model 23 | from torchvision.models.vgg import model_urls, cfg 24 | 25 | import model 26 | 27 | 28 | class VGG(_model.VGG): 29 | def __init__(self, config_channels, features): 30 | nn.Module.__init__(self) 31 | self.features = features 32 | self._initialize_weights() 33 | 34 | def forward(self, x): 35 | return self.features(x) 36 | 37 | 38 | def make_layers(config_channels, cfg, batch_norm=False): 39 | features = [] 40 | for v in cfg: 41 | if v == 'M': 42 | features += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else: 44 | conv2d = nn.Conv2d(config_channels.channels, config_channels(v, 'features.%d.weight' % len(features)), kernel_size=3, padding=1) 45 | if batch_norm: 46 | features += [conv2d, nn.BatchNorm2d(config_channels.channels), nn.ReLU(inplace=True)] 47 | else: 48 | features += [conv2d, nn.ReLU(inplace=True)] 49 | return nn.Sequential(*features) 50 | 51 | 52 | def vgg11(config_channels): 53 | model = VGG(config_channels, make_layers(config_channels, cfg['A'])) 54 | if config_channels.config.getboolean('model', 'pretrained'): 55 | url = model_urls['vgg11'] 56 | logging.info('use pretrained model: ' + url) 57 | state_dict = model.state_dict() 58 | for key, value in model_zoo.load_url(url).items(): 59 | if key in state_dict: 60 | state_dict[key] = value 61 | model.load_state_dict(state_dict) 62 | return model 63 | 64 | 65 | def vgg11_bn(config_channels): 66 | model = VGG(config_channels, make_layers(config_channels, cfg['A'], batch_norm=True)) 67 | if config_channels.config.getboolean('model', 'pretrained'): 68 | url = model_urls['vgg11_bn'] 69 | logging.info('use pretrained model: ' + url) 70 | state_dict = model.state_dict() 71 | for key, value in model_zoo.load_url(url).items(): 72 | if key in state_dict: 73 | state_dict[key] = value 74 | model.load_state_dict(state_dict) 75 | return model 76 | 77 | 78 | def vgg13(config_channels): 79 | model = VGG(config_channels, make_layers(config_channels, cfg['B'])) 80 | if config_channels.config.getboolean('model', 'pretrained'): 81 | url = model_urls['vgg13'] 82 | logging.info('use pretrained model: ' + url) 83 | state_dict = model.state_dict() 84 | for key, value in model_zoo.load_url(url).items(): 85 | if key in state_dict: 86 | state_dict[key] = value 87 | model.load_state_dict(state_dict) 88 | return model 89 | 90 | 91 | def vgg13_bn(config_channels): 92 | model = VGG(config_channels, make_layers(config_channels, cfg['B'], batch_norm=True)) 93 | if config_channels.config.getboolean('model', 'pretrained'): 94 | url = model_urls['vgg13_bn'] 95 | logging.info('use pretrained model: ' + url) 96 | state_dict = model.state_dict() 97 | for key, value in model_zoo.load_url(url).items(): 98 | if key in state_dict: 99 | state_dict[key] = value 100 | model.load_state_dict(state_dict) 101 | return model 102 | 103 | 104 | def vgg16(config_channels): 105 | model = VGG(config_channels, make_layers(config_channels, cfg['D'])) 106 | if config_channels.config.getboolean('model', 'pretrained'): 107 | url = model_urls['vgg16'] 108 | logging.info('use pretrained model: ' + url) 109 | state_dict = model.state_dict() 110 | for key, value in model_zoo.load_url(url).items(): 111 | if key in state_dict: 112 | state_dict[key] = value 113 | model.load_state_dict(state_dict) 114 | return model 115 | 116 | 117 | def vgg16_bn(config_channels): 118 | model = VGG(config_channels, make_layers(config_channels, cfg['D'], batch_norm=True)) 119 | if config_channels.config.getboolean('model', 'pretrained'): 120 | url = model_urls['vgg16_bn'] 121 | logging.info('use pretrained model: ' + url) 122 | state_dict = model.state_dict() 123 | for key, value in model_zoo.load_url(url).items(): 124 | if key in state_dict: 125 | state_dict[key] = value 126 | model.load_state_dict(state_dict) 127 | return model 128 | 129 | 130 | def vgg19(config_channels): 131 | model = VGG(config_channels, make_layers(config_channels, cfg['E'])) 132 | if config_channels.config.getboolean('model', 'pretrained'): 133 | url = model_urls['vgg19'] 134 | logging.info('use pretrained model: ' + url) 135 | state_dict = model.state_dict() 136 | for key, value in model_zoo.load_url(url).items(): 137 | if key in state_dict: 138 | state_dict[key] = value 139 | model.load_state_dict(state_dict) 140 | return model 141 | 142 | 143 | def vgg19_bn(config_channels): 144 | model = VGG(config_channels, make_layers(config_channels, cfg['E'], batch_norm=True)) 145 | if config_channels.config.getboolean('model', 'pretrained'): 146 | url = model_urls['vgg19_bn'] 147 | logging.info('use pretrained model: ' + url) 148 | state_dict = model.state_dict() 149 | for key, value in model_zoo.load_url(url).items(): 150 | if key in state_dict: 151 | state_dict[key] = value 152 | model.load_state_dict(state_dict) 153 | return model 154 | 155 | 156 | def person18_19(config_channels): 157 | cfg = [ 158 | 64, 64, 'M', 159 | 128, 128, 'M', 160 | 256, 256, 256, 256, 'M', 161 | 512, 512, 162 | 256, 128, 163 | ] 164 | return VGG(config_channels, make_layers(config_channels, cfg)) 165 | 166 | 167 | def hand21(config_channels): 168 | cfg = [ 169 | 64, 64, 'M', 170 | 128, 128, 'M', 171 | 256, 256, 256, 256, 'M', 172 | 512, 512, 512, 512, 173 | 512, 512, 128, 174 | ] 175 | return VGG(config_channels, make_layers(config_channels, cfg)) 176 | -------------------------------------------------------------------------------- /model/stages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/model/stages/__init__.py -------------------------------------------------------------------------------- /model/stages/openpose.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import collections.abc 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | class Conv2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size, padding=True, stride=1, bn=False, act=True): 26 | nn.Module.__init__(self) 27 | if isinstance(padding, bool): 28 | if isinstance(kernel_size, collections.abc.Iterable): 29 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0 30 | else: 31 | padding = (kernel_size - 1) // 2 if padding else 0 32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn) 33 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x 34 | self.act = nn.ReLU(inplace=True) if act else lambda x: x 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | x = self.bn(x) 39 | x = self.act(x) 40 | return x 41 | 42 | 43 | class Stage0(nn.Module): 44 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix): 45 | nn.Module.__init__(self) 46 | channels_stage = config_channels.channels 47 | for name, channels in channel_dict.items(): 48 | config_channels.channels = channels_stage 49 | branch = [] 50 | for _ in range(3): 51 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 3)) 52 | branch.append(Conv2d(config_channels.channels, config_channels(512, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 1)) 53 | branch.append(Conv2d(config_channels.channels, channels, 1, act=False)) 54 | setattr(self, name, nn.Sequential(*branch)) 55 | config_channels.channels = channels_dnn + sum(branch[-1].conv.weight.size(0) for branch in self._modules.values()) 56 | self.init() 57 | 58 | def init(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | m.weight = nn.init.xavier_normal_(m.weight) 62 | elif isinstance(m, nn.BatchNorm2d): 63 | m.weight.fill_(1) 64 | m.bias.zero_() 65 | 66 | def forward(self, x, **kwargs): 67 | return {name: var(x) for name, var in self._modules.items()} 68 | 69 | 70 | class Stage(nn.Module): 71 | def __init__(self, config_channels, channels, channels_dnn, prefix): 72 | nn.Module.__init__(self) 73 | channels_stage = config_channels.channels 74 | for name, _channels in channels.items(): 75 | config_channels.channels = channels_stage 76 | branch = [] 77 | for _ in range(5): 78 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 7)) 79 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 1)) 80 | branch.append(Conv2d(config_channels.channels, _channels, 1, act=False)) 81 | setattr(self, name, nn.Sequential(*branch)) 82 | config_channels.channels = channels_dnn + sum(branch[-1].conv.weight.size(0) for branch in self._modules.values()) 83 | self.init() 84 | 85 | def init(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | m.weight = nn.init.xavier_normal_(m.weight) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.fill_(1) 91 | m.bias.zero_() 92 | 93 | def forward(self, x, **kwargs): 94 | x = torch.cat([kwargs[name] for name in ('limbs', 'parts')] + [x], 1) 95 | return {name: var(x) for name, var in self._modules.items()} 96 | -------------------------------------------------------------------------------- /model/stages/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import collections.abc 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | class Conv2d(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size, padding=True, stride=1, bn=False, act=True): 26 | nn.Module.__init__(self) 27 | if isinstance(padding, bool): 28 | if isinstance(kernel_size, collections.abc.Iterable): 29 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0 30 | else: 31 | padding = (kernel_size - 1) // 2 if padding else 0 32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn) 33 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x 34 | self.act = nn.ReLU(inplace=True) if act else lambda x: x 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | x = self.bn(x) 39 | x = self.act(x) 40 | return x 41 | 42 | 43 | class ConvTranspose2d(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=False, act=True): 45 | nn.Module.__init__(self) 46 | if isinstance(padding, bool): 47 | if isinstance(kernel_size, collections.abc.Iterable): 48 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0 49 | else: 50 | padding = (kernel_size - 1) // 2 if padding else 0 51 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn) 52 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x 53 | self.act = nn.ReLU(inplace=True) if act else lambda x: x 54 | 55 | def forward(self, x): 56 | x = self.conv(x) 57 | x = self.bn(x) 58 | x = self.act(x) 59 | return x 60 | 61 | 62 | class Downsample(nn.Module): 63 | def __init__(self, config_channels, channels, prefix, kernel_sizes, pooling): 64 | nn.Module.__init__(self) 65 | self.seq = nn.Sequential(*[Conv2d(config_channels.channels, config_channels(channels, '%s.seq.%d.conv.weight' % (prefix, index)), kernel_size) for index, kernel_size in enumerate(kernel_sizes)]) 66 | self.downsample = nn.MaxPool2d(kernel_size=pooling) 67 | 68 | def forward(self, x): 69 | feature = self.seq(x) 70 | return self.downsample(feature), feature 71 | 72 | 73 | class Upsample(nn.Module): 74 | def __init__(self, config_channels, channels, channels_min, prefix, sample, kernel_sizes, ratio=1): 75 | nn.Module.__init__(self) 76 | self.upsample = ConvTranspose2d(config_channels.channels, config_channels(channels, '%s.upsample.conv.weight' % prefix, fn=lambda var: var.size(1)), kernel_size=sample, stride=sample) 77 | config_channels.channels += channels # concat 78 | 79 | seq = [] 80 | if ratio < 1: 81 | seq.append(Conv2d(config_channels.channels, config_channels(max(int(config_channels.channels * ratio), channels_min), '%s.seq.%d.conv.weight' % (prefix, len(seq))), 1)) 82 | for kernel_size in kernel_sizes: 83 | seq.append(Conv2d(config_channels.channels, config_channels(channels, '%s.seq.%d.conv.weight' % (prefix, len(seq))), kernel_size)) 84 | self.seq = nn.Sequential(*seq) 85 | 86 | def forward(self, x, feature): 87 | x = self.upsample(x) 88 | x = torch.cat([x, feature], 1) 89 | return self.seq(x) 90 | 91 | 92 | class Branch(nn.Module): 93 | def __init__(self, config_channels, channels, prefix, multiply, ratio, kernel_sizes, sample): 94 | nn.Module.__init__(self) 95 | _channels = channels 96 | self.down = [] 97 | for index, m in enumerate(multiply): 98 | name = 'down%d' % index 99 | block = Downsample(config_channels, _channels, '%s.%s' % (prefix, name), kernel_sizes, pooling=sample) 100 | setattr(self, name, block) 101 | self.down.append(block) 102 | _channels = int(_channels * m) 103 | self.top = nn.Sequential(*[Conv2d(config_channels.channels, config_channels(_channels, '%s.top.%d.conv.weight' % (prefix, index)), kernel_size) for index, kernel_size in enumerate(kernel_sizes)]) 104 | 105 | self.up = [] 106 | for index, block in enumerate(self.down[::-1]): 107 | name = 'up%d' % index 108 | block = Upsample(config_channels, block.seq[-1].conv.weight.size(0), channels, '%s.%s' % (prefix, name), sample, kernel_sizes, ratio) 109 | setattr(self, name, block) 110 | self.up.append(block) 111 | self.out = Conv2d(config_channels.channels, channels, 1, act=False) 112 | 113 | def forward(self, x): 114 | features = [] 115 | for block in self.down: 116 | x, feature = block(x) 117 | features.append(feature) 118 | x = self.top(x) 119 | 120 | for block, feature in zip(self.up, features[::-1]): 121 | x = block(x, feature) 122 | return self.out(x) 123 | 124 | 125 | class Unet(nn.Module): 126 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2, 2], ratio=1, kernel_sizes=[3], sample=2): 127 | nn.Module.__init__(self) 128 | channels_stage = config_channels.channels 129 | for name, channels in channel_dict.items(): 130 | config_channels.channels = channels_stage 131 | branch = Branch(config_channels, channels, '%s.%s' % (prefix, name), multiply, ratio, kernel_sizes, sample) 132 | setattr(self, name, branch) 133 | config_channels.channels = channels_dnn + sum(branch.out.conv.weight.size(0) for branch in self._modules.values()) 134 | 135 | def forward(self, x, **kwargs): 136 | if kwargs: 137 | x = torch.cat([kwargs[name] for name in ('parts', 'limbs') if name in kwargs] + [x], 1) 138 | return {name: branch(x) for name, branch in self._modules.items()} 139 | 140 | 141 | class Unet1Sqz3(Unet): 142 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix): 143 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2], ratio=1 / 3) 144 | 145 | 146 | class Unet1Sqz3_a(Unet): 147 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix): 148 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[1.5], ratio=1 / 3) 149 | 150 | 151 | class Unet2Sqz3(Unet): 152 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix): 153 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2, 2], ratio=1 / 3) 154 | -------------------------------------------------------------------------------- /quick_start.sh: -------------------------------------------------------------------------------- 1 | echo download COCO dataset 2 | LINKS=" 3 | http://images.cocodataset.org/zips/train2014.zip 4 | http://images.cocodataset.org/zips/val2014.zip 5 | http://images.cocodataset.org/annotations/annotations_trainval2014.zip 6 | http://images.cocodataset.org/zips/train2017.zip 7 | http://images.cocodataset.org/zips/val2017.zip 8 | http://images.cocodataset.org/annotations/annotations_trainval2017.zip 9 | " 10 | ROOT=~/data/coco 11 | for LINK in $LINKS 12 | do 13 | aria2c --auto-file-renaming=false -d $ROOT $LINK 14 | unzip -n $ROOT/$(basename $LINK) -d $ROOT 15 | done 16 | rm $ROOT/val2014/COCO_val2014_000000320612.jpg 17 | 18 | echo cache data 19 | python3 cache.py -c config.ini config/original_person18_19.ini -m cache/name=cache_original 20 | 21 | echo download and cache the original model 22 | ROOT=~/model/openpose/pose/coco 23 | aria2c --auto-file-renaming=false -d $ROOT https://raw.githubusercontent.com/CMU-Perceptual-Computing-Lab/openpose/master/models/pose/coco/pose_deploy_linevec.prototxt 24 | aria2c --auto-file-renaming=false -d $ROOT http://posefs1.perception.cs.cmu.edu/OpenPose/models/pose/coco/pose_iter_440000.caffemodel 25 | python3 convert_caffe_torch.py config/convert_caffe_torch/original_person18_19.tsv $ROOT/pose_deploy_linevec.prototxt $ROOT/pose_iter_440000.caffemodel -c config.ini config/original_person18_19.ini -m model/name=model_original -d 26 | 27 | echo demo keypoint estimation via a webcam 28 | python3 estimate.py -c config.ini config/original_person18_19.ini -m model/name=model_original 29 | 30 | echo training 31 | python3 train.py -c config.ini config/original_person18_19.ini -m cache/name=cache_original model/name=model_original -------------------------------------------------------------------------------- /receptive_field_analyzer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import argparse 20 | import configparser 21 | import logging 22 | import logging.config 23 | import multiprocessing 24 | import yaml 25 | 26 | import numpy as np 27 | import scipy.misc 28 | import torch.autograd 29 | import torch.cuda 30 | import torch.optim 31 | import torch.utils.data 32 | import torch.nn as nn 33 | import tqdm 34 | import humanize 35 | 36 | import model 37 | import utils.data 38 | import utils.train 39 | import utils.visualize 40 | 41 | 42 | class Dataset(torch.utils.data.Dataset): 43 | def __init__(self, height, width): 44 | self.points = np.array([(i, j) for i in range(height) for j in range(width)]) 45 | 46 | def __len__(self): 47 | return len(self.points) 48 | 49 | def __getitem__(self, index): 50 | return self.points[index] 51 | 52 | 53 | class Analyzer(object): 54 | def __init__(self, args, config): 55 | self.args = args 56 | self.config = config 57 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | self.model_dir = utils.get_model_dir(config) 59 | _, self.num_parts = utils.get_dataset_mappers(config) 60 | self.limbs_index = utils.get_limbs_index(config) 61 | self.step, self.epoch, self.dnn, self.stages = self.load() 62 | self.inference = model.Inference(self.config, self.dnn, self.stages) 63 | self.inference.eval() 64 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values()))) 65 | if torch.cuda.is_available(): 66 | self.inference.cuda() 67 | self.height, self.width = tuple(map(int, config.get('image', 'size').split())) 68 | t = torch.zeros(1, 3, self.height, self.width).to(self.device) 69 | output = self.dnn(t) 70 | _, _, self.rows, self.cols = output.size() 71 | self.i, self.j = self.rows // 2, self.cols // 2 72 | self.output = output[:, :, self.i, self.j] 73 | dataset = Dataset(self.height, self.width) 74 | try: 75 | workers = self.config.getint('data', 'workers') 76 | except configparser.NoOptionError: 77 | workers = multiprocessing.cpu_count() 78 | self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers) 79 | 80 | def __call__(self): 81 | changed = np.zeros([self.height, self.width], np.bool) 82 | for yx in tqdm.tqdm(self.loader): 83 | batch_size = yx.size(0) 84 | tensor = torch.zeros(batch_size, 3, self.height, self.width) 85 | for i, _yx in enumerate(torch.unbind(yx)): 86 | y, x = torch.unbind(_yx) 87 | tensor[i, :, y, x] = 1 88 | tensor = tensor.to(self.device) 89 | output = self.dnn(tensor) 90 | output = output[:, :, self.i, self.j] 91 | cmp = output == self.output 92 | cmp = torch.prod(cmp, -1) 93 | for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)): 94 | y, x = torch.unbind(_yx) 95 | changed[y, x] = c 96 | return changed 97 | 98 | def load(self): 99 | try: 100 | path, step, epoch = utils.train.load_model(self.model_dir) 101 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 102 | except (FileNotFoundError, ValueError): 103 | step, epoch = 0, 0 104 | state_dict = {name: None for name in ('dnn', 'stages')} 105 | config_channels_dnn = model.ConfigChannels(self.config, state_dict['dnn']) 106 | dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels_dnn) 107 | config_channels_stages = model.ConfigChannels(self.config, state_dict['stages'], config_channels_dnn.channels) 108 | channel_dict = model.channel_dict(self.num_parts, len(self.limbs_index)) 109 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(self.config.get('model', 'stages').split())]) 110 | return step, epoch, dnn, stages 111 | 112 | 113 | def main(): 114 | args = make_args() 115 | config = configparser.ConfigParser() 116 | utils.load_config(config, args.config) 117 | for cmd in args.modify: 118 | utils.modify_config(config, cmd) 119 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f: 120 | logging.config.dictConfig(yaml.load(f)) 121 | analyzer = Analyzer(args, config) 122 | changed = analyzer() 123 | os.makedirs(analyzer.model_dir, exist_ok=True) 124 | path = os.path.join(analyzer.model_dir, args.filename) 125 | scipy.misc.imsave(path, (~changed).astype(np.uint8) * 255) 126 | logging.info(path) 127 | 128 | 129 | def make_args(): 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file') 132 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config') 133 | parser.add_argument('-b', '--batch_size', default=16, type=int, help='batch size') 134 | parser.add_argument('-n', '--filename', default='receptive_field.jpg') 135 | parser.add_argument('--logging', default='logging.yml', help='logging config') 136 | return parser.parse_args() 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pybenchmark 3 | graphviz 4 | torch>=0.4.0 5 | pandas 6 | onnx 7 | onnx_caffe2 8 | pretrainedmodels 9 | torchvision 10 | matplotlib 11 | filelock 12 | scikit_image 13 | inflection 14 | numpy 15 | humanize 16 | Pillow 17 | PyQt5 18 | scipy 19 | skimage 20 | tensorboardX>=1.2 21 | tensorflow 22 | PyYAML 23 | pycocotools 24 | -------------------------------------------------------------------------------- /transform/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | 20 | import torchvision 21 | 22 | import utils 23 | 24 | 25 | def parse_transform(config, method): 26 | if isinstance(method, str): 27 | attr = utils.parse_attr(method) 28 | sig = inspect.signature(attr) 29 | if len(sig.parameters) == 1: 30 | return attr(config) 31 | else: 32 | return attr() 33 | else: 34 | return method 35 | 36 | 37 | def get_transform(config, sequence, compose=torchvision.transforms.Compose): 38 | return compose([parse_transform(config, method) for method in sequence]) 39 | -------------------------------------------------------------------------------- /transform/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import inspect 20 | import random 21 | 22 | import inflection 23 | import numpy as np 24 | import cv2 25 | 26 | import transform 27 | 28 | 29 | class Rotator(object): 30 | def __init__(self, y, x, height, width, angle): 31 | """ 32 | A efficient tool to rotate multiple images in the same size. 33 | :author 申瑞珉 (Ruimin Shen) 34 | :param y: The y coordinate of rotation point. 35 | :param x: The x coordinate of rotation point. 36 | :param height: Image height. 37 | :param width: Image width. 38 | :param angle: Rotate angle. 39 | """ 40 | self._mat = cv2.getRotationMatrix2D((x, y), angle, 1.0) 41 | r = np.abs(self._mat[0, :2]) 42 | _height, _width = np.inner(r, [height, width]), np.inner(r, [width, height]) 43 | fix_y, fix_x = _height / 2 - y, _width / 2 - x 44 | self._mat[:, 2] += [fix_x, fix_y] 45 | self._size = int(_width), int(_height) 46 | 47 | def __call__(self, image, flags=cv2.INTER_LINEAR, fill=None): 48 | if fill is None: 49 | fill = np.random.rand(3) * 256 50 | return cv2.warpAffine(image, self._mat, self._size, flags=flags, borderMode=cv2.BORDER_CONSTANT, borderValue=fill) 51 | 52 | def _rotate_points(self, points): 53 | _points = np.pad(points, [(0, 0), (0, 1)], 'constant') 54 | _points[:, 2] = 1 55 | _points = np.dot(self._mat, _points.T) 56 | return _points.T.astype(points.dtype) 57 | 58 | def rotate_points(self, points): 59 | return self._rotate_points(points[:, ::-1])[:, ::-1] 60 | 61 | 62 | def random_rotate(config, image, mask, keypoints, yx_min, yx_max, index): 63 | name = inspect.stack()[0][3] 64 | angle = random.uniform(*tuple(map(float, config.get('augmentation', name).split()))) 65 | height, width = image.shape[:2] 66 | p1, p2 = np.copy(yx_min), np.copy(yx_max) 67 | p1[:, 0] = yx_max[:, 0] 68 | p2[:, 0] = yx_min[:, 0] 69 | points = np.concatenate([yx_min, yx_max, p1, p2], 0) 70 | rotator = Rotator(*((yx_min[index] + yx_max[index]) / 2), height, width, angle) 71 | image = rotator(image, fill=0) 72 | mask = rotator(mask, fill=0) 73 | keypoints[:, :, :2] = np.reshape(rotator.rotate_points(np.reshape(keypoints, [-1, 3])[:, :2]), [len(keypoints), -1, 2]) 74 | points = rotator.rotate_points(points) 75 | bbox_points = np.reshape(points, [4, -1, 2]) 76 | yx_min = np.apply_along_axis(lambda points: np.min(points, 0), 0, bbox_points) 77 | yx_max = np.apply_along_axis(lambda points: np.max(points, 0), 0, bbox_points) 78 | return image, mask, keypoints, yx_min, yx_max 79 | 80 | 81 | class RandomRotate(object): 82 | def __init__(self, config): 83 | self.config = config 84 | self.fn = eval(inflection.underscore(type(self).__name__)) 85 | 86 | def __call__(self, data): 87 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], data['index']) 88 | return data 89 | 90 | 91 | def flip_horizontally(image, mask, keypoints, yx_min, yx_max): 92 | assert len(image.shape) == 3 93 | image = cv2.flip(image, 1) 94 | mask = cv2.flip(mask, 1) 95 | width = image.shape[1] 96 | keypoints[:, :, 1] = width - keypoints[:, :, 1] 97 | temp = width - yx_min[:, 1] 98 | yx_min[:, 1] = width - yx_max[:, 1] 99 | yx_max[:, 1] = temp 100 | return image, mask, keypoints, yx_min, yx_max 101 | 102 | 103 | class RandomFlipHorizontally(object): 104 | def __init__(self, config): 105 | self.config = config 106 | name = inflection.underscore(type(self).__name__) 107 | self.prob = config.getfloat('augmentation', name) 108 | with open(os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset'))) + '.txt', 'r') as f: 109 | lines = (line.strip() for line in f) 110 | self.symmetric = [int(line) if line else i for i, line in enumerate(lines)] 111 | 112 | def __call__(self, data): 113 | if random.random() > self.prob: 114 | data['image'], data['mask'], keypoints, data['yx_min'], data['yx_max'] = flip_horizontally(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max']) 115 | assert keypoints.shape[1] == len(self.symmetric) 116 | keypoints = np.stack([[points[i] for i in self.symmetric] for points in keypoints]) 117 | data['keypoints'] = keypoints 118 | return data 119 | 120 | 121 | def get_transform(config, sequence): 122 | return transform.get_transform(config, sequence) 123 | -------------------------------------------------------------------------------- /transform/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import random 19 | 20 | import numpy as np 21 | import torchvision 22 | import inflection 23 | import skimage.exposure 24 | import cv2 25 | 26 | 27 | class BGR2RGB(object): 28 | def __call__(self, image): 29 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | 31 | 32 | class BGR2HSV(object): 33 | def __call__(self, image): 34 | return cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 35 | 36 | 37 | class HSV2RGB(object): 38 | def __call__(self, image): 39 | return cv2.cvtColor(image, cv2.COLOR_HSV2RGB) 40 | 41 | 42 | class RandomBlur(object): 43 | def __init__(self, config): 44 | name = inflection.underscore(type(self).__name__) 45 | self.adjust = tuple(map(int, config.get('augmentation', name).split())) 46 | 47 | def __call__(self, image): 48 | adjust = tuple(random.randint(1, adjust) for adjust in self.adjust) 49 | return cv2.blur(image, adjust) 50 | 51 | 52 | class RandomHue(object): 53 | def __init__(self, config): 54 | name = inflection.underscore(type(self).__name__) 55 | self.adjust = tuple(map(int, config.get('augmentation', name).split())) 56 | 57 | def __call__(self, hsv): 58 | h, s, v = cv2.split(hsv) 59 | adjust = random.randint(*self.adjust) 60 | h = h.astype(np.int) + adjust 61 | h = np.clip(h, 0, 179).astype(hsv.dtype) 62 | return cv2.merge((h, s, v)) 63 | 64 | 65 | class RandomSaturation(object): 66 | def __init__(self, config): 67 | name = inflection.underscore(type(self).__name__) 68 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 69 | 70 | def __call__(self, hsv): 71 | h, s, v = cv2.split(hsv) 72 | adjust = random.uniform(*self.adjust) 73 | s = s * adjust 74 | s = np.clip(s, 0, 255).astype(hsv.dtype) 75 | return cv2.merge((h, s, v)) 76 | 77 | 78 | class RandomBrightness(object): 79 | def __init__(self, config): 80 | name = inflection.underscore(type(self).__name__) 81 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 82 | 83 | def __call__(self, hsv): 84 | h, s, v = cv2.split(hsv) 85 | adjust = random.uniform(*self.adjust) 86 | v = v * adjust 87 | v = np.clip(v, 0, 255).astype(hsv.dtype) 88 | return cv2.merge((h, s, v)) 89 | 90 | 91 | class RandomGamma(object): 92 | def __init__(self, config): 93 | name = inflection.underscore(type(self).__name__) 94 | self.adjust = tuple(map(float, config.get('augmentation', name).split())) 95 | 96 | def __call__(self, image): 97 | adjust = random.uniform(*self.adjust) 98 | return skimage.exposure.adjust_gamma(image, adjust) 99 | 100 | 101 | class Normalize(torchvision.transforms.Normalize): 102 | def __init__(self, config): 103 | name = inflection.underscore(type(self).__name__) 104 | mean, std = tuple(map(float, config.get('transform', name).split())) 105 | torchvision.transforms.Normalize.__init__(self, (mean, mean, mean), (std, std, std)) 106 | -------------------------------------------------------------------------------- /transform/resize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/transform/resize/__init__.py -------------------------------------------------------------------------------- /transform/resize/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inflection 19 | import numpy as np 20 | import cv2 21 | 22 | 23 | def rescale(image, height, width): 24 | return cv2.resize(image, (width, height)) 25 | 26 | 27 | def rescale_scale(size, image_size): 28 | return image_size[0] / size[0], image_size[1] / size[1] 29 | 30 | 31 | class Rescale(object): 32 | def __init__(self): 33 | name = inflection.underscore(type(self).__name__) 34 | self.fn = eval(name) 35 | self.scale = eval(name + '_scale') 36 | 37 | def __call__(self, image, height, width): 38 | return self.fn(image, height, width) 39 | 40 | 41 | def fixed(image, height, width): 42 | _height, _width, _ = image.shape 43 | if _height / _width > height / width: 44 | scale = height / _height 45 | else: 46 | scale = width / _width 47 | m = np.eye(2, 3) 48 | m[0, 0] = scale 49 | m[1, 1] = scale 50 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC 51 | return cv2.warpAffine(image, m, (width, height), flags=flags) 52 | 53 | 54 | def fixed_scale(size, image_size): 55 | assert len(image_size) == 2 56 | _image_size = max(image_size) 57 | return _image_size / size[0], _image_size / size[1] 58 | 59 | 60 | class Fixed(object): 61 | def __init__(self): 62 | name = inflection.underscore(type(self).__name__) 63 | self.fn = eval(name) 64 | self.scale = eval(name + '_scale') 65 | 66 | def __call__(self, image, height, width): 67 | return self.fn(image, height, width) 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, config): 72 | name = config.get('data', inflection.underscore(type(self).__name__)) 73 | self.fn = eval(name) 74 | self.scale = eval(name + '_scale') 75 | 76 | def __call__(self, image, height, width): 77 | return self.fn(image, height, width) 78 | -------------------------------------------------------------------------------- /transform/resize/label.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import inspect 19 | 20 | import inflection 21 | import numpy as np 22 | import cv2 23 | 24 | 25 | def rescale(image, mask, keypoints, yx_min, yx_max, height, width): 26 | _height, _width = image.shape[:2] 27 | scale = np.array([height / _height, width / _width], np.float32) 28 | image = cv2.resize(image, (width, height)) 29 | mask = cv2.resize(mask, (width, height)) 30 | keypoints[:, :, :2] *= scale 31 | yx_min *= scale 32 | yx_max *= scale 33 | return image, mask, keypoints, yx_min, yx_max 34 | 35 | 36 | class Rescale(object): 37 | def __init__(self): 38 | self.fn = eval(inflection.underscore(type(self).__name__)) 39 | 40 | def __call__(self, data, height, width): 41 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], height, width) 42 | return data 43 | 44 | 45 | def padding(image, mask, keypoints, yx_min, yx_max, height, width): 46 | _height, _width, _ = image.shape 47 | if _height / _width > height / width: 48 | scale = height / _height 49 | else: 50 | scale = width / _width 51 | m = np.eye(2, 3) 52 | m[0, 0] = scale 53 | m[1, 1] = scale 54 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC 55 | image = cv2.warpAffine(image, m, (width, height), flags=flags) 56 | mask = cv2.warpAffine(mask, m, (width, height), flags=flags) 57 | return image, mask, keypoints, yx_min, yx_max 58 | 59 | 60 | class Padding(object): 61 | def __init__(self): 62 | self.fn = eval(inflection.underscore(type(self).__name__)) 63 | 64 | def __call__(self, data, height, width): 65 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], height, width) 66 | return data 67 | 68 | 69 | def resize(config, image, mask, keypoints, yx_min, yx_max, height, width): 70 | fn = eval(config.get('data', inspect.stack()[0][3])) 71 | return fn(image, mask, keypoints, yx_min, yx_max, height, width) 72 | 73 | 74 | class Resize(object): 75 | def __init__(self, config): 76 | self.config = config 77 | self.fn = eval(config.get('data', inflection.underscore(type(self).__name__))) 78 | 79 | def __call__(self, data, height, width): 80 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'], height, width) 81 | return data 82 | 83 | 84 | def change_aspect_ratio(range, height_src, width_src, height_dst, width_dst): 85 | assert range >= 0 86 | if width_src < height_src: 87 | width = min(range, width_src) 88 | height = width * height_dst / width_dst 89 | else: 90 | height = min(range, height_src) 91 | width = height * width_dst / height_dst 92 | return height, width 93 | 94 | 95 | def repair(yx_min, yx_max, size): 96 | move = np.clip(yx_max - size, 0, None) 97 | yx_min -= move 98 | yx_max -= move 99 | move = np.clip(-yx_min, 0, None) 100 | yx_min += move 101 | yx_max += move 102 | return yx_min, yx_max 103 | 104 | 105 | def random_crop(config, image, mask, keypoints, yx_min, yx_max, index, height, width): 106 | name = inspect.stack()[0][3] 107 | scale1, scale2 = tuple(map(float, config.get('augmentation', name).split())) 108 | assert 1 <= scale1 <= scale2, (scale1, scale2) 109 | dtype = keypoints.dtype 110 | size = np.array(image.shape[:2], dtype) 111 | _yx_min, _yx_max = yx_min[index], yx_max[index] 112 | _center = (_yx_min + _yx_max) / 2 113 | _size = np.array(change_aspect_ratio(np.max(_yx_max - _yx_min), *size, height, width), dtype) 114 | _size1, _size2 = _size * scale1 / 2, _size * scale2 / 2 115 | yx_min1, yx_max1 = _center - _size1, _center + _size1 116 | yx_min2, yx_max2 = _center - _size2, _center + _size2 117 | yx_min1, yx_max1 = repair(yx_min1, yx_max1, size) 118 | yx_min2, yx_max2 = repair(yx_min2, yx_max2, size) 119 | margin = np.random.rand(4).astype(dtype) * np.concatenate([yx_min1 - yx_min2, yx_max2 - yx_max1], 0) 120 | yx_min_crop = np.clip(yx_min2 + margin[:2], 0, None) 121 | yx_max_crop = np.clip(yx_max2 - margin[2:], None, size) 122 | _ymin, _xmin = tuple(map(int, yx_min_crop)) 123 | _ymax, _xmax = tuple(map(int, yx_max_crop)) 124 | image = image[_ymin:_ymax, _xmin:_xmax, :] 125 | mask = mask[_ymin:_ymax, _xmin:_xmax] 126 | keypoints[:, :, :2] -= yx_min_crop 127 | yx_min -= yx_min_crop 128 | yx_max -= yx_min_crop 129 | return rescale(image, mask, keypoints, yx_min, yx_max, height, width) 130 | 131 | 132 | class RandomCrop(object): 133 | def __init__(self, config): 134 | self.config = config 135 | self.fn = eval(inflection.underscore(type(self).__name__)) 136 | 137 | def __call__(self, data, height, width): 138 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], data['index'], height, width) 139 | return data 140 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import re 20 | import configparser 21 | import importlib 22 | import hashlib 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import torch.autograd 27 | from PIL import Image 28 | 29 | import pyopenpose 30 | 31 | 32 | class Compose(object): 33 | def __init__(self, transforms): 34 | self.transforms = transforms 35 | 36 | def __call__(self, img, yx_min, yx_max, cls): 37 | for t in self.transforms: 38 | img, yx_min, yx_max, cls = t(img, yx_min, yx_max, cls) 39 | return img, yx_min, yx_max, cls 40 | 41 | 42 | class RegexList(list): 43 | def __init__(self, l): 44 | for s in l: 45 | prog = re.compile(s) 46 | self.append(prog) 47 | 48 | def __call__(self, s): 49 | for prog in self: 50 | if prog.match(s): 51 | return True 52 | return False 53 | 54 | 55 | class DatasetMapper(object): 56 | def __init__(self, mapper): 57 | self.mapper = mapper 58 | 59 | def __call__(self, parts, dtype=np.int64): 60 | assert len(parts.shape) == 2 and parts.shape[-1] == 3 61 | result = np.zeros([len(self.mapper), 3], dtype=parts.dtype) 62 | for i, func in enumerate(self.mapper): 63 | result[i] = func(parts) 64 | return result 65 | 66 | 67 | def get_dataset_mappers(config): 68 | root = os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset'))) 69 | mappers = {} 70 | for dataset in os.listdir(root): 71 | path = os.path.join(root, dataset) 72 | if os.path.isfile(path): 73 | with open(path, 'r') as f: 74 | mapper = [eval(line.rstrip()) for line in f] 75 | mappers[dataset] = mapper 76 | sizes = set(map(lambda mapper: len(mapper), mappers.values())) 77 | assert len(sizes) == 1 78 | for dataset in mappers: 79 | mappers[dataset] = DatasetMapper(mappers[dataset]) 80 | return mappers, next(iter(sizes)) 81 | 82 | 83 | def get_limbs_index(config): 84 | dataset = os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset'))) 85 | limbs_index = np.loadtxt(dataset + '.tsv', dtype=np.int, delimiter='\t', ndmin=2) 86 | if len(limbs_index) > 0: 87 | assert pyopenpose.limbs_points(limbs_index) == get_dataset_mappers(config)[1] 88 | else: 89 | limbs_index = np.reshape(limbs_index, [0, 2]) 90 | return limbs_index 91 | 92 | 93 | def get_cache_dir(config): 94 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 95 | name = config.get('cache', 'name') 96 | dataset = os.path.basename(config.get('cache', 'dataset')) 97 | return os.path.join(root, name, dataset) 98 | 99 | 100 | def get_model_dir(config): 101 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 102 | name = config.get('model', 'name') 103 | dataset = os.path.basename(config.get('cache', 'dataset')) 104 | dnn = config.get('model', 'dnn') 105 | stages = hashlib.md5(' '.join(config.get('model', 'stages').split()).encode()).hexdigest() 106 | return os.path.join(root, name, dataset, dnn, stages) 107 | 108 | 109 | def get_eval_db(config): 110 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root'))) 111 | db = config.get('eval', 'db') 112 | return os.path.join(root, db) 113 | 114 | 115 | def get_category(config, cache_dir=None): 116 | path = os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))) if cache_dir is None else os.path.join(cache_dir, 'category') 117 | with open(path, 'r') as f: 118 | return [line.strip() for line in f] 119 | 120 | 121 | def get_anchors(config, dtype=np.float32): 122 | path = os.path.expanduser(os.path.expandvars(config.get('model', 'anchors'))) 123 | df = pd.read_csv(path, sep='\t', dtype=dtype) 124 | return df[['height', 'width']].values 125 | 126 | 127 | def parse_attr(s): 128 | m, n = s.rsplit('.', 1) 129 | m = importlib.import_module(m) 130 | return getattr(m, n) 131 | 132 | 133 | def load_config(config, paths): 134 | for path in paths: 135 | path = os.path.expanduser(os.path.expandvars(path)) 136 | assert os.path.exists(path) 137 | config.read(path) 138 | 139 | 140 | def modify_config(config, cmd): 141 | var, value = cmd.split('=', 1) 142 | section, option = var.split('/') 143 | if value: 144 | config.set(section, option, value) 145 | else: 146 | try: 147 | config.remove_option(section, option) 148 | except (configparser.NoSectionError, configparser.NoOptionError): 149 | pass 150 | 151 | 152 | def dense(var): 153 | return [torch.mean(torch.abs(x)) if torch.is_tensor(x) else np.abs(x) for x in var] 154 | 155 | 156 | def abs_mean(data, dtype=np.float32): 157 | assert isinstance(data, np.ndarray), type(data) 158 | return np.sum(np.abs(data)) / dtype(data.size) 159 | 160 | 161 | def image_size(path): 162 | with Image.open(path) as image: 163 | return image.size 164 | -------------------------------------------------------------------------------- /utils/cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import numpy as np 19 | 20 | 21 | def verify_coords(yx_min, yx_max, size): 22 | assert np.all(yx_min <= yx_max), 'yx_min <= yx_max' 23 | assert np.all(0 <= yx_min), '0 <= yx_min' 24 | assert np.all(0 <= yx_max), '0 <= yx_max' 25 | assert np.all(yx_min < size), 'yx_min < size' 26 | assert np.all(yx_max < size), 'yx_max < size' 27 | 28 | 29 | def fix_coords(yx_min, yx_max, size): 30 | assert np.all(yx_min <= yx_max) 31 | assert yx_min.dtype == yx_max.dtype 32 | coord_min = np.zeros([2], dtype=yx_min.dtype) 33 | coord_max = np.array(size, dtype=yx_min.dtype) - 1 34 | yx_min = np.minimum(np.maximum(yx_min, coord_min), coord_max) 35 | yx_max = np.minimum(np.maximum(yx_max, coord_min), coord_max) 36 | return yx_min, yx_max 37 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import pickle 20 | import random 21 | import copy 22 | 23 | import numpy as np 24 | import torch.utils.data 25 | import cv2 26 | 27 | import utils 28 | import pyopenpose 29 | 30 | 31 | def padding_labels(data, dim, labels='keypoints, yx_min, yx_max'.split(', ')): 32 | """ 33 | Padding labels into the same dimension (to form a batch). 34 | :author 申瑞珉 (Ruimin Shen) 35 | :param data: A dict contains the labels to be padded. 36 | :param dim: The target dimension. 37 | :param labels: The list of label names. 38 | :return: The padded label dict. 39 | """ 40 | pad = dim - len(data[labels[0]]) 41 | for key in labels: 42 | label = data[key] 43 | data[key] = np.pad(label, [(0, pad)] + [(0, 0)] * (len(label.shape) - 1), 'constant') 44 | return data 45 | 46 | 47 | def load_pickles(paths): 48 | data = [] 49 | for path in paths: 50 | with open(path, 'rb') as f: 51 | data += pickle.load(f) 52 | return data 53 | 54 | 55 | class Dataset(torch.utils.data.Dataset): 56 | def __init__(self, config, data, transform=lambda data: data, shuffle=False, dir=None): 57 | """ 58 | Load the cached data (.pkl) into memory. 59 | :author 申瑞珉 (Ruimin Shen) 60 | :param data: A list contains the data samples (dict). 61 | :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict. 62 | :param shuffle: Shuffle the loaded dataset. 63 | :param dir: The directory to store the exception data. 64 | """ 65 | self.config = config 66 | self.mask_ext = config.get('cache', 'mask_ext') 67 | self.data = data 68 | if shuffle: 69 | random.shuffle(self.data) 70 | self.transform = transform 71 | self.dir = dir 72 | 73 | def __len__(self): 74 | return len(self.data) 75 | 76 | def __getitem__(self, index): 77 | data = copy.deepcopy(self.data[index]) 78 | try: 79 | image = cv2.imread(data['path']) 80 | data['image'] = image 81 | data['size'] = np.array(image.shape[:2]) 82 | mask = cv2.imread(data['keypath'] + '.mask' + self.mask_ext, cv2.IMREAD_GRAYSCALE) 83 | assert image.shape[:2] == mask.shape, [image.shape[:2], mask.shape] 84 | data['mask'] = mask 85 | data['index'] = random.randint(0, len(data['keypoints']) - 1) 86 | data = self.transform(data) 87 | except: 88 | if self.dir is not None: 89 | os.makedirs(self.dir, exist_ok=True) 90 | name = self.__module__ + '.' + type(self).__name__ 91 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f: 92 | pickle.dump(data, f) 93 | raise 94 | return data 95 | 96 | 97 | class Collate(object): 98 | def __init__(self, config, resize, sizes, feature_sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None): 99 | """ 100 | Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch. 101 | :author 申瑞珉 (Ruimin Shen) 102 | :param resize: A function to resize the image and labels. 103 | :param sizes: The image sizes to be randomly choosed. 104 | :param feature_sizes: The feature sizes related to the image sizes. 105 | :param maintain: How many times a size to be maintained. 106 | :param transform_image: A function to transform the resized image. 107 | :param transform_tensor: A function to standardize a image into a tensor. 108 | :param dir: The directory to store the exception data. 109 | """ 110 | self.config = config 111 | self.resize = resize 112 | assert len(sizes) == len(feature_sizes) 113 | self.sizes = sizes 114 | self.feature_sizes = feature_sizes 115 | assert maintain > 0 116 | self.maintain = maintain 117 | self._maintain = maintain 118 | self.transform_image = transform_image 119 | self.transform_tensor = transform_tensor 120 | self.dir = dir 121 | self.sigma_parts = config.getfloat('label', 'sigma_parts') 122 | self.sigma_limbs = config.getfloat('label', 'sigma_limbs') 123 | self.limbs_index = utils.get_limbs_index(config) 124 | 125 | def __call__(self, batch): 126 | (height, width), (rows, cols) = self.next_size() 127 | dim = max(len(data['keypoints']) for data in batch) 128 | _batch = [] 129 | for data in batch: 130 | try: 131 | data = self.resize(data, height, width) 132 | data['image'] = self.transform_image(data['image']) 133 | data = padding_labels(data, dim) 134 | if self.transform_tensor is not None: 135 | data['tensor'] = self.transform_tensor(data['image']) 136 | data['mask'] = (cv2.resize(data['mask'], (cols, rows)) > 127).astype(np.uint8) 137 | data['parts'] = pyopenpose.label_parts(data['keypoints'], self.sigma_parts, height, width, rows, cols) 138 | data['limbs'] = pyopenpose.label_limbs(data['keypoints'], self.limbs_index, self.sigma_limbs, height, width, rows, cols) 139 | _batch.append(data) 140 | except: 141 | if self.dir is not None: 142 | os.makedirs(self.dir, exist_ok=True) 143 | name = self.__module__ + '.' + type(self).__name__ 144 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f: 145 | pickle.dump(data, f) 146 | raise 147 | return torch.utils.data.dataloader.default_collate(_batch) 148 | 149 | def next_size(self): 150 | if self._maintain < self.maintain: 151 | self._maintain += 1 152 | else: 153 | self._index = random.randint(0, len(self.sizes) - 1) 154 | self._maintain = 0 155 | return self.sizes[self._index], self.feature_sizes[self._index] 156 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import os 19 | import time 20 | import operator 21 | import logging 22 | 23 | import torch 24 | 25 | 26 | class Timer(object): 27 | def __init__(self, max, first=True): 28 | """ 29 | A simple function object to determine time event. 30 | :author 申瑞珉 (Ruimin Shen) 31 | :param max: Number of seconds to trigger a time event. 32 | :param first: Should a time event to be triggered at the first time. 33 | """ 34 | self.start = 0 if first else time.time() 35 | self.max = max 36 | 37 | def __call__(self): 38 | """ 39 | Return a boolean value to indicate if the time event is occurred. 40 | :author 申瑞珉 (Ruimin Shen) 41 | """ 42 | t = time.time() 43 | elapsed = t - self.start 44 | if elapsed > self.max: 45 | self.start = t 46 | return True 47 | else: 48 | return False 49 | 50 | 51 | def load_model(model_dir, step=None, ext='.pth', ext_epoch='.epoch', logger=logging.info): 52 | """ 53 | Load the latest checkpoint in a model directory. 54 | :author 申瑞珉 (Ruimin Shen) 55 | :param model_dir: The directory to store the model checkpoint files. 56 | :param step: If a integer value is given, the corresponding checkpoint will be loaded. Otherwise, the latest checkpoint (with the largest step value) will be loaded. 57 | :param ext: The extension of the model file. 58 | :param ext_epoch: The extension of the epoch file. 59 | :return: 60 | """ 61 | if step is None: 62 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(model_dir)) if n.isdigit() and e == ext] 63 | step, name = max(steps, key=operator.itemgetter(0)) 64 | else: 65 | name = str(step) 66 | prefix = os.path.join(model_dir, name) 67 | if logger is not None: 68 | logger('load %s.*' % prefix) 69 | try: 70 | with open(prefix + ext_epoch, 'r') as f: 71 | epoch = int(f.read()) 72 | except (FileNotFoundError, ValueError): 73 | epoch = None 74 | path = prefix + ext 75 | assert os.path.exists(path), path 76 | return path, step, epoch 77 | 78 | 79 | class Saver(object): 80 | def __init__(self, model_dir, keep, ext='.pth', ext_epoch='.epoch', logger=logging.info): 81 | """ 82 | Manage several latest checkpoints (with the largest step values) in a model directory. 83 | :author 申瑞珉 (Ruimin Shen) 84 | :param model_dir: The directory to store the model checkpoint files. 85 | :param keep: How many latest checkpoints to be maintained. 86 | :param ext: The extension of the model file. 87 | :param ext_epoch: The extension of the epoch file. 88 | """ 89 | self.model_dir = model_dir 90 | self.keep = keep 91 | self.ext = ext 92 | self.ext_epoch = ext_epoch 93 | self.logger = (lambda s: s) if logger is None else logger 94 | 95 | def __call__(self, obj, step, epoch=None): 96 | """ 97 | Save the PyTorch module. 98 | :author 申瑞珉 (Ruimin Shen) 99 | :param obj: The PyTorch module to be saved. 100 | :param step: Current step. 101 | :param epoch: Current epoch. 102 | """ 103 | os.makedirs(self.model_dir, exist_ok=True) 104 | prefix = os.path.join(self.model_dir, str(step)) 105 | torch.save(obj, prefix + self.ext) 106 | if epoch is not None: 107 | with open(prefix + self.ext_epoch, 'w') as f: 108 | f.write(str(epoch)) 109 | self.logger('model saved into %s.*' % prefix) 110 | self.tidy() 111 | return prefix 112 | 113 | def tidy(self): 114 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(self.model_dir)) if n.isdigit() and e == self.ext] 115 | if len(steps) > self.keep: 116 | steps = sorted(steps, key=operator.itemgetter(0)) 117 | remove = steps[:len(steps) - self.keep] 118 | for _, n in remove: 119 | path = os.path.join(self.model_dir, n) 120 | os.remove(path + self.ext) 121 | path_epoch = path + self.ext_epoch 122 | try: 123 | os.remove(path_epoch) 124 | except FileNotFoundError: 125 | self.logger(path_epoch + ' not found') 126 | logging.debug('tidy ' + path) 127 | 128 | 129 | def load_sizes(config): 130 | sizes = [s.split(',') for s in config.get('data', 'sizes').split()] 131 | return [(int(height), int(width)) for height, width in sizes] 132 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen) 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU Lesser General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | 18 | import logging 19 | import itertools 20 | import inspect 21 | 22 | import numpy as np 23 | import torch 24 | import matplotlib 25 | import matplotlib.cm 26 | import matplotlib.colors 27 | import matplotlib.pyplot as plt 28 | import humanize 29 | import graphviz 30 | import cv2 31 | 32 | 33 | def draw_mask(image, mask, threshold=128): 34 | _mask = cv2.resize(np.squeeze(mask), image.shape[1::-1], interpolation=cv2.INTER_NEAREST) 35 | return np.expand_dims(_mask >= threshold, -1) * image 36 | 37 | 38 | class DrawPoints(object): 39 | def __init__(self, limbs_index, colors=[], radius=5, thickness=2, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=0.5, z=1): 40 | self.limbs_index = limbs_index 41 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 42 | self._colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']] 43 | self.radius = radius 44 | self.thickness = thickness 45 | self.line_type = line_type 46 | self.shift = shift 47 | self.font_face = font_face 48 | self.font_scale = font_scale 49 | self.z = z 50 | 51 | def __call__(self, image, points, debug=False): 52 | if len(self.colors) >= 2: 53 | for i, point in enumerate(points): 54 | y, x, v = map(int, point) 55 | assert v >= 0 56 | if v > 0: 57 | text = str(i) 58 | color = self.colors[v - 1] 59 | _color = tuple(map(lambda c: np.float(np.bitwise_not(np.uint8(c))), color)) 60 | cv2.putText(image, text, (x - self.z, y), self.font_face, self.font_scale, _color) 61 | cv2.putText(image, text, (x + self.z, y), self.font_face, self.font_scale, _color) 62 | cv2.putText(image, text, (x, y - self.z), self.font_face, self.font_scale, _color) 63 | cv2.putText(image, text, (x, y + self.z), self.font_face, self.font_scale, _color) 64 | cv2.putText(image, text, (x, y), self.font_face, self.font_scale, color) 65 | if len(self.limbs_index) > 0: 66 | for color, (i1, i2) in zip(itertools.cycle(self._colors), self.limbs_index): 67 | y1, x1, v1 = points[i1].T 68 | y2, x2, v2 = points[i2].T 69 | if v1 > 0 and v2 > 0: 70 | cv2.line(image, (x1, y1), (x2, y2), color, thickness=self.thickness) 71 | else: 72 | for color, (y, x, v) in zip(itertools.cycle(self._colors), points): 73 | if v > 0: 74 | cv2.circle(image, (x, y), self.radius, color, thickness=-1) 75 | if debug: 76 | cv2.imshow('', image) 77 | cv2.waitKey(0) 78 | return image 79 | 80 | 81 | class DrawBBox(object): 82 | def __init__(self, category=None, colors=[], thickness=1, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1): 83 | self.category = category 84 | if colors: 85 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 86 | else: 87 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']] 88 | self.thickness = thickness 89 | self.line_type = line_type 90 | self.shift = shift 91 | self.font_face = font_face 92 | self.font_scale = font_scale 93 | 94 | def __call__(self, image, yx_min, yx_max, cls=None, colors=None, debug=False): 95 | colors = self.colors if colors is None else [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 96 | if cls is None: 97 | cls = [None] * len(yx_min) 98 | for color, (ymin, xmin), (ymax, xmax), cls in zip(itertools.cycle(colors), yx_min, yx_max, cls): 99 | try: 100 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness=self.thickness, lineType=self.line_type, shift=self.shift) 101 | if self.category is not None and cls is not None: 102 | cv2.putText(image, self.category[cls], (xmin, ymin), self.font_face, self.font_scale, color=color, thickness=self.thickness) 103 | except OverflowError as e: 104 | logging.warning(e, (xmin, ymin), (xmax, ymax)) 105 | if debug: 106 | cv2.imshow('', image) 107 | cv2.waitKey(0) 108 | return image 109 | 110 | 111 | class DrawFeature(object): 112 | def __init__(self, alpha=0.5, cmap=None): 113 | self.alpha = alpha 114 | self.cm = matplotlib.cm.get_cmap(cmap) 115 | 116 | def __call__(self, image, feature, debug=False): 117 | _feature = (feature * self.cm.N).astype(np.int) 118 | heatmap = self.cm(_feature)[:, :, :3] * 255 119 | heatmap = cv2.resize(heatmap, image.shape[1::-1], interpolation=cv2.INTER_NEAREST) 120 | canvas = (image * (1 - self.alpha) + heatmap * self.alpha).astype(np.uint8) 121 | if debug: 122 | cv2.imshow('max=%f, sum=%f' % (np.max(feature), np.sum(feature)), canvas) 123 | cv2.waitKey(0) 124 | return canvas 125 | 126 | 127 | class DrawCluster(object): 128 | def __init__(self, colors=[], thickness=2, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=0.5, z=1): 129 | if colors: 130 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors] 131 | else: 132 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']] 133 | self.thickness = thickness 134 | self.line_type = line_type 135 | self.shift = shift 136 | self.font_face = font_face 137 | self.font_scale = font_scale 138 | self.z = z 139 | 140 | def __call__(self, image, cluster, debug=False): 141 | for color, limb in zip(self.colors, cluster): 142 | (i1, y1, x1), (i2, y2, x2) = limb 143 | cv2.line(image, (x1, y1), (x2, y2), color, thickness=self.thickness) 144 | drawn = set() 145 | for (i, y, x) in limb: 146 | if i not in drawn: 147 | drawn.add(i) 148 | text = str(i) 149 | _color = tuple(map(lambda c: np.float(np.bitwise_not(np.uint8(c))), color)) 150 | cv2.putText(image, text, (x - self.z, y), self.font_face, self.font_scale, _color) 151 | cv2.putText(image, text, (x + self.z, y), self.font_face, self.font_scale, _color) 152 | cv2.putText(image, text, (x, y - self.z), self.font_face, self.font_scale, _color) 153 | cv2.putText(image, text, (x, y + self.z), self.font_face, self.font_scale, _color) 154 | cv2.putText(image, text, (x, y), self.font_face, self.font_scale, color) 155 | if debug: 156 | cv2.imshow('', image) 157 | cv2.waitKey(0) 158 | return image 159 | 160 | 161 | class Graph(object): 162 | def __init__(self, config, state_dict, cmap=None): 163 | self.dot = graphviz.Digraph(node_attr=dict(config.items('digraph_node_attr')), graph_attr=dict(config.items('digraph_graph_attr'))) 164 | self.dot.format = config.get('graph', 'format') 165 | self.state_dict = state_dict 166 | self.var_name = {t._cdata: k for k, t in state_dict.items()} 167 | self.seen = set() 168 | self.index = 0 169 | self.drawn = set() 170 | self.cm = matplotlib.cm.get_cmap(cmap) 171 | self.metric = eval(config.get('graph', 'metric')) 172 | metrics = [self.metric(t) for t in state_dict.values()] 173 | self.minmax = [min(metrics), max(metrics)] 174 | 175 | def __call__(self, node): 176 | if node not in self.seen: 177 | self.traverse_next(node) 178 | self.traverse_tensor(node) 179 | self.seen.add(node) 180 | self.index += 1 181 | 182 | def traverse_next(self, node): 183 | if hasattr(node, 'next_functions'): 184 | for n, _ in node.next_functions: 185 | if n is not None: 186 | self.__call__(n) 187 | self._draw_node_edge(node, n) 188 | self._draw_node(node) 189 | 190 | def traverse_tensor(self, node): 191 | tensors = [t for name, t in inspect.getmembers(node) if torch.is_tensor(t)] 192 | if hasattr(node, 'saved_tensors'): 193 | tensors += node.saved_tensors 194 | for tensor in tensors: 195 | name = self.var_name[tensor._cdata] 196 | self.drawn.add(name) 197 | self._draw_tensor(node, tensor) 198 | 199 | def _draw_node(self, node): 200 | if hasattr(node, 'variable'): 201 | tensor = node.variable.data 202 | name = self.var_name[tensor._cdata] 203 | label = '\n'.join(map(str, [ 204 | '%d: %s' % (self.index, name), 205 | list(tensor.size()), 206 | humanize.naturalsize(tensor.numpy().nbytes), 207 | ])) 208 | fillcolor, fontcolor = self._tensor_color(tensor) 209 | self.dot.node(str(id(node)), label, shape='note', fillcolor=fillcolor, fontcolor=fontcolor) 210 | self.drawn.add(name) 211 | else: 212 | self.dot.node(str(id(node)), '%d: %s' % (self.index, type(node).__name__), fillcolor='white') 213 | 214 | def _draw_node_edge(self, node, n): 215 | if hasattr(n, 'variable'): 216 | self.dot.edge(str(id(n)), str(id(node)), arrowhead='none', arrowtail='none') 217 | else: 218 | self.dot.edge(str(id(n)), str(id(node))) 219 | 220 | def _draw_tensor(self, node, tensor): 221 | name = self.var_name[tensor._cdata] 222 | label = '\n'.join(map(str, [ 223 | name, 224 | list(tensor.size()), 225 | humanize.naturalsize(tensor.numpy().nbytes), 226 | ])) 227 | fillcolor, fontcolor = self._tensor_color(tensor) 228 | self.dot.node(name, label, style='filled, rounded', fillcolor=fillcolor, fontcolor=fontcolor) 229 | self.dot.edge(name, str(id(node)), style='dashed', arrowhead='none', arrowtail='none') 230 | 231 | def _tensor_color(self, tensor): 232 | level = self._norm(self.metric(tensor)) 233 | fillcolor = self.cm(np.int(level * self.cm.N)) 234 | fontcolor = self.cm(self.cm.N if level < 0.5 else 0) 235 | return matplotlib.colors.to_hex(fillcolor), matplotlib.colors.to_hex(fontcolor) 236 | 237 | def _norm(self, metric): 238 | min, max = self.minmax 239 | assert min <= metric <= max, (metric, self.minmax) 240 | if min < max: 241 | return (metric - min) / (max - min) 242 | else: 243 | return metric 244 | --------------------------------------------------------------------------------