├── .gitignore
├── LICENSE.md
├── README.md
├── cache.py
├── cache
├── __init__.py
├── coco.py
└── coco.tsv
├── config.ini
├── config
├── convert_caffe_torch
│ └── original_person18_19.tsv
├── convert_tf_torch
│ └── model.dnn.inception4.Inception4_down3_4
│ │ ├── Unet1.tsv
│ │ └── Unet2.tsv
├── dataset
│ ├── coco.tsv
│ ├── coco
│ │ └── cache.coco.cache
│ ├── hand20.tsv
│ ├── hand20
│ │ └── cache.hand_nyu.cache
│ ├── hand21.tsv
│ ├── hand21
│ │ └── cache.hand_nyu.cache
│ ├── hand_nyu.tsv
│ ├── hand_nyu
│ │ └── cache.hand_nyu.cache
│ ├── mpii.tsv
│ ├── mpii.txt
│ ├── mpii
│ │ └── cache.mpii.cache
│ ├── person13_12.tsv
│ ├── person13_12.txt
│ ├── person13_12
│ │ ├── cache.coco.cache
│ │ └── cache.mpii.cache
│ ├── person14_13.tsv
│ ├── person14_13.txt
│ ├── person14_13
│ │ └── cache.coco.cache
│ ├── person18.tsv
│ ├── person18.txt
│ ├── person18
│ │ └── cache.coco.cache
│ ├── person18_19.tsv
│ ├── person18_19.txt
│ └── person18_19
│ │ └── cache.coco.cache
├── inception_unet.ini
├── original_person18_19.ini
└── summary
│ └── histogram.txt
├── convert_caffe_torch.py
├── convert_onnx_caffe2.py
├── convert_tf_torch.py
├── convert_torch_onnx.py
├── demo_data.py
├── demo_keypoints.py
├── demo_label.py
├── donate_alipay.jpg
├── donate_mm.jpg
├── estimate.py
├── logging.yml
├── model
├── __init__.py
├── dnn
│ ├── __init__.py
│ ├── inception4.py
│ ├── mobilenet.py
│ ├── mobilenet2.py
│ ├── resnet.py
│ └── vgg.py
└── stages
│ ├── __init__.py
│ ├── openpose.py
│ └── unet.py
├── quick_start.sh
├── receptive_field_analyzer.py
├── requirements.txt
├── train.py
├── transform
├── __init__.py
├── augmentation.py
├── image.py
└── resize
│ ├── __init__.py
│ ├── image.py
│ └── label.py
└── utils
├── __init__.py
├── cache.py
├── data.py
├── train.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .project
2 | .pydevproject
3 | .settings/
4 | .idea/
5 | .cache/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # celery beat schedule file
85 | celerybeat-schedule
86 |
87 | # SageMath parsed files
88 | *.sage.py
89 |
90 | # Environments
91 | .env
92 | .venv
93 | env/
94 | venv/
95 | ENV/
96 | env.bak/
97 | venv.bak/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of the [OpenPose](https://arxiv.org/abs/1611.08050)
2 |
3 | The OpenPose is one of the most popular keypoint estimator, which uses two branches of feature map (is trained and enhanced via multiple stages) to estimate (via a [postprocess procedure](https://github.com/ruiminshen/pyopenpose)) the position of keypoints (via Gaussian heatmap) and the relationship between keypoints (called part affinity fields), respectively.
4 | This project adopts [PyTorch](http://pytorch.org/) as the developing framework to increase productivity, and utilize [ONNX](https://github.com/onnx/onnx) to convert models into [Caffe 2](https://caffe2.ai/) to benefit engineering deployment.
5 | If you are benefited from this project, a donation will be appreciated (via [PayPal](https://www.paypal.me/minimumshen), [微信支付](donate_mm.jpg) or [支付宝](donate_alipay.jpg)).
6 |
7 | ## Designs
8 |
9 | - Flexible configuration design.
10 | Program settings are configurable and can be modified (via **configure file overlaping** (-c/--config option) or **command editing** (-m/--modify option)) using command line argument.
11 |
12 | - Monitoring via [TensorBoard](https://github.com/tensorflow/tensorboard).
13 | Such as the loss values and the debugging images (such as IoU heatmap, ground truth and predict bounding boxes).
14 |
15 | - Parallel model training design.
16 | Different models are saved into different directories so that can be trained simultaneously.
17 |
18 | - Time-based output design.
19 | Running information (such as the model, the summaries (produced by TensorBoard), and the evaluation results) are saved periodically via a predefined time.
20 |
21 | - Checkpoint management.
22 | Several latest checkpoint files (.pth) are preserved in the model directory and the older ones are deleted.
23 |
24 | - NaN debug.
25 | When a NaN loss is detected, the running environment (data batch) and the model will be exported to analyze the reason.
26 |
27 | - Unified data cache design.
28 | Various dataset are converted into a unified data cache via a programmable (a series of Python lambda expressions, which means some points can be flexibly generated) configuration.
29 | Some plugins are already implemented. Such as [MS COCO](http://cocodataset.org/).
30 |
31 | - Arbitrarily replaceable model plugin design.
32 | The deep neural network (both the feature extraction network and the stage networks) can be easily replaced via configuration settings.
33 | Multiple models are already provided. Such as the oringal VGG like network, [Inception v4](https://arxiv.org/abs/1602.07261), [MobileNet v2](https://arxiv.org/abs/1801.04381) and [U-Net](https://arxiv.org/abs/1505.04597).
34 |
35 | - Extendable data preprocess plugin design.
36 | The original images (in different sizes) and labels are processed via a sequence of operations to form a training batch (images with the same size, and bounding boxes list are padded).
37 | Multiple preprocess plugins are already implemented. Such as
38 | augmentation operators to process images and labels (such as random rotate and random flip) simultaneously,
39 | operators to resize both images and labels into a fixed size in a batch (such as random crop),
40 | and operators to augment images without labels (such as random blur, random saturation and random brightness).
41 |
42 | ## Quick Start
43 |
44 | This project uses [Python 3](https://www.python.org/). To install the dependent libraries, make sure the [pyopenpose](https://github.com/ruiminshen/pyopenpose) is installed, and type the following command in a terminal.
45 |
46 | ```
47 | sudo pip3 install -r requirements.txt
48 | ```
49 |
50 | `quick_start.sh` contains the examples to perform detection and evaluation. Run this script.
51 | The COCO dataset is downloaded ([aria2](https://aria2.github.io/) is required) and cached, and the original pose model (18 parts and 19 limbs) is converted into PyTorch's format.
52 | If a webcam is present, the keypoint estimation demo will be shown.
53 | Finally, the training program is started.
54 |
--------------------------------------------------------------------------------
/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import importlib
24 | import pickle
25 | import random
26 | import shutil
27 | import yaml
28 |
29 | import utils
30 |
31 |
32 | def main():
33 | args = make_args()
34 | config = configparser.ConfigParser()
35 | utils.load_config(config, args.config)
36 | for cmd in args.modify:
37 | utils.modify_config(config, cmd)
38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
39 | logging.config.dictConfig(yaml.load(f))
40 | cache_dir = utils.get_cache_dir(config)
41 | os.makedirs(cache_dir, exist_ok=True)
42 | mappers, _ = utils.get_dataset_mappers(config)
43 | for phase in args.phase:
44 | path = os.path.join(cache_dir, phase) + '.pkl'
45 | logging.info('save cache file: ' + path)
46 | data = []
47 | for dataset in mappers:
48 | logging.info('load %s dataset' % dataset)
49 | module, func = dataset.rsplit('.', 1)
50 | module = importlib.import_module(module)
51 | func = getattr(module, func)
52 | data += func(config, path, mappers[dataset])
53 | if config.getboolean('cache', 'shuffle'):
54 | random.shuffle(data)
55 | with open(path, 'wb') as f:
56 | pickle.dump(data, f)
57 | logging.info('%s data are saved into %s' % (str(args.phase), cache_dir))
58 |
59 |
60 | def make_args():
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
63 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
64 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
65 | parser.add_argument('--logging', default='logging.yml', help='logging config')
66 | return parser.parse_args()
67 |
68 |
69 | if __name__ == '__main__':
70 | main()
71 |
--------------------------------------------------------------------------------
/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/cache/__init__.py
--------------------------------------------------------------------------------
/cache/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import logging
20 | import configparser
21 |
22 | import numpy as np
23 | import pandas as pd
24 | import tqdm
25 | import pycocotools.coco
26 | import pycocotools.mask
27 | from PIL import Image, ImageDraw
28 |
29 | import utils
30 | import utils.cache
31 |
32 |
33 | def draw_mask(segmentation, canvas, draw):
34 | pixels = canvas.load()
35 | if isinstance(segmentation, list):
36 | for polygon in segmentation:
37 | draw.polygon(polygon, fill=0)
38 | else:
39 | if isinstance(segmentation['counts'], list):
40 | rle = pycocotools.mask.frPyObjects([segmentation], canvas.size[1], canvas.size[0])
41 | else:
42 | rle = [segmentation]
43 | m = np.squeeze(pycocotools.mask.decode(rle))
44 | assert m.shape[:2] == canvas.size[::-1]
45 | for y, row in enumerate(m):
46 | for x, v in enumerate(row):
47 | if v:
48 | pixels[x, y] = 0
49 |
50 |
51 | def cache(config, path, mapper):
52 | name = __name__.split('.')[-1]
53 | cachedir = os.path.dirname(path)
54 | phase = os.path.splitext(os.path.basename(path))[0]
55 | phasedir = os.path.join(cachedir, phase)
56 | os.makedirs(phasedir, exist_ok=True)
57 | mask_ext = config.get('cache', 'mask_ext')
58 | data = []
59 | for i, row in pd.read_csv(os.path.splitext(__file__)[0] + '.tsv', sep='\t').iterrows():
60 | logging.info('loading data %d (%s)' % (i, ', '.join([k + '=' + str(v) for k, v in row.items()])))
61 | root = os.path.expanduser(os.path.expandvars(row['root']))
62 | year = str(row['year'])
63 | suffix = phase + year
64 | path = os.path.join(root, 'annotations', 'person_keypoints_%s.json' % suffix)
65 | if not os.path.exists(path):
66 | logging.warning(path + ' not exists')
67 | continue
68 | coco_kp = pycocotools.coco.COCO(path)
69 | skeleton = np.array(coco_kp.loadCats(1)[0]['skeleton']) - 1
70 | np.savetxt(os.path.join(os.path.dirname(cachedir), name + '.tsv'), skeleton, fmt='%d', delimiter='\t')
71 | imgIds = coco_kp.getImgIds()
72 | folder = os.path.join(root, suffix)
73 | imgs = coco_kp.loadImgs(imgIds)
74 | _imgs = list(filter(lambda img: os.path.exists(os.path.join(folder, img['file_name'])), imgs))
75 | if len(imgs) > len(_imgs):
76 | logging.warning('%d of %d images not exists' % (len(imgs) - len(_imgs), len(imgs)))
77 | for img in tqdm.tqdm(_imgs):
78 | # image
79 | path = os.path.join(folder, img['file_name'])
80 | width, height = img['width'], img['height']
81 | try:
82 | if config.getboolean('cache', 'verify'):
83 | if not np.all(np.equal(utils.image_size(path), [width, height])):
84 | logging.error('failed to verify shape of image ' + path)
85 | continue
86 | except configparser.NoOptionError:
87 | pass
88 | # keypoints
89 | annIds = coco_kp.getAnnIds(imgIds=img['id'], iscrowd=None)
90 | anns = coco_kp.loadAnns(annIds)
91 | keypoints = []
92 | bbox = []
93 | keypath = os.path.join(phasedir, __name__.split('.')[-1] + year, os.path.relpath(os.path.splitext(path)[0], root))
94 | os.makedirs(os.path.dirname(keypath), exist_ok=True)
95 | maskpath = keypath + '.mask' + mask_ext
96 | with Image.new('L', (width, height), 255) as canvas:
97 | draw = ImageDraw.Draw(canvas)
98 | for ann in anns:
99 | points = mapper(np.array(ann['keypoints']).reshape([-1, 3]))
100 | if np.any(points[:, 2] > 0):
101 | keypoints.append(points)
102 | bbox.append(ann['bbox'])
103 | else:
104 | draw_mask(ann['segmentation'], canvas, draw)
105 | if len(keypoints) <= 0:
106 | continue
107 | canvas.save(os.path.join(cachedir, maskpath))
108 | keypoints = np.array(keypoints, dtype=np.float32)
109 | keypoints = keypoints[:, :, [1, 0, 2]]
110 | bbox = np.array(bbox, dtype=np.float32)
111 | yx_min = bbox[:, 1::-1]
112 | size = bbox[:, -1:1:-1]
113 | yx_max = yx_min + size
114 | try:
115 | if config.getboolean('cache', 'dump'):
116 | np.save(keypath + '.keypoints.npy', keypoints)
117 | np.save(keypath + '.yx_min.npy', yx_min)
118 | np.save(keypath + '.yx_max.npy', yx_max)
119 | except configparser.NoOptionError:
120 | pass
121 | data.append(dict(
122 | path=path, keypath=keypath,
123 | keypoints=keypoints,
124 | yx_min=yx_min, yx_max=yx_max,
125 | ))
126 | logging.warning('%d of %d images are saved' % (len(data), len(_imgs)))
127 | return data
128 |
--------------------------------------------------------------------------------
/cache/coco.tsv:
--------------------------------------------------------------------------------
1 | root year
2 | ~/data/coco 2014
3 | ~/data/coco 2017
4 |
--------------------------------------------------------------------------------
/config.ini:
--------------------------------------------------------------------------------
1 | [config]
2 | root = ~/model/openpose-pytorch
3 |
4 | [image]
5 | # 368
6 | # 344
7 | # 320
8 | size = 320 320
9 |
10 | [cache]
11 | name = cache
12 | ; config/dataset/person18_19
13 | ; config/dataset/person14_13
14 | dataset = config/dataset/person14_13
15 | shuffle = 1
16 | mask_ext = .jpg
17 |
18 | [model]
19 | name = model
20 | ; model.dnn.vgg.person18_19
21 | ; model.dnn.resnet.resnet18
22 | ; model.dnn.inception3.Inception3
23 | ; model.dnn.inception4.Inception4
24 | ; model.dnn.inception4.Inception4_down3_4
25 | ; model.dnn.mobilenet.MobileNet
26 | ; model.dnn.mobilenet2.MobileNet2
27 | ; model.dnn.densenet.densenet121
28 | dnn = model.dnn.mobilenet2.MobileNet2
29 | # model.stages.openpose.Stage0 model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage
30 | # model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a
31 | # model.stages.unet.Unet2Sqz3 model.stages.unet.Unet2Sqz3
32 | stages = model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a
33 | pretrained = 0
34 |
35 | [batch_norm]
36 | enable = 0
37 | gamma = 1
38 | beta = 1
39 |
40 | [inception4]
41 | pretrained = imagenet
42 |
43 | [data]
44 | workers = 3
45 | sizes = 320,320
46 | maintain = 10
47 | shuffle = 0
48 | # fixed rescale
49 | resize = fixed
50 |
51 | [transform]
52 | ; transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally
53 | augmentation = transform.augmentation.RandomRotate transform.augmentation.RandomFlipHorizontally
54 | resize_train = transform.resize.label.RandomCrop
55 | resize_eval = transform.resize.label.Resize
56 | resize_test = transform.resize.image.Resize
57 | ; transform.image.RandomBlur transform.image.BGR2HSV transform.image.RandomHue transform.image.RandomSaturation transform.image.RandomBrightness transform.image.HSV2RGB transform.image.RandomGamma
58 | image_train = transform.image.BGR2RGB
59 | image_test = transform.image.BGR2RGB
60 | ; torchvision.transforms.ToTensor transform.image.Normalize
61 | tensor = torchvision.transforms.ToTensor transform.image.Normalize
62 | normalize = 0.5 1
63 |
64 | [augmentation]
65 | random_rotate = -40 40
66 | random_flip_horizontally = 0.5
67 | random_crop = 1.5 2
68 | random_blur = 5 5
69 | random_hue = 0 25
70 | random_saturation = 0.5 1.5
71 | random_brightness = 0.5 1.5
72 | random_gamma = 0.9 1.5
73 |
74 | [label]
75 | sigma_parts = 7
76 | sigma_limbs = 1
77 |
78 | [train]
79 | ; lambda params, lr: torch.optim.SGD(params, lr, momentum=2)
80 | ; lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8)
81 | ; lambda params, lr: torch.optim.RMSprop(params, lr, alpha=0.99, eps=1e-8)
82 | optimizer = lambda params, lr: torch.optim.Adam(params, lr, betas=(0.9, 0.999), eps=1e-8)
83 | ; lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
84 | ; lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)
85 | scheduler = lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)
86 | phase = train val
87 | cross_entropy = 1
88 | clip_ = 5
89 |
90 | [save]
91 | secs = 600
92 | keep = 5
93 |
94 | [draw_points]
95 | colors = r w
96 |
97 | [summary]
98 | scalar = 10
99 | image = 60
100 | histogram_ = 60
101 |
102 | [summary_scalar]
103 | loss_hparam = 0
104 |
105 | [summary_image]
106 | limit = 2
107 | data_keypoints = 1
108 | data_parts = 1
109 | data_limbs = 1
110 | estimate = 1
111 | output = parts limbs
112 | stage = -1
113 |
114 | [summary_histogram]
115 | parameters = config/summary/histogram.txt
116 |
117 | [hparam]
118 | parts = 1
119 | limbs = 1
120 |
121 | [estimate]
122 | interpolation = cubic
123 |
124 | [nms]
125 | threshold = 0.05
126 |
127 | [integration]
128 | step = 5
129 | step_limits = 5 25
130 | min_score = 0.05
131 | min_count = 9
132 |
133 | [cluster]
134 | min_score = 0.4
135 | min_count = 3
136 |
137 | [eval]
138 | phase = test
139 | secs = 12 * 60 * 60
140 | first = 0
141 | iou = 0.5
142 | db = eval.json
143 | mapper = config/eval.py
144 | debug = 0
145 | sort = timestamp
146 | metric07 = 1
147 |
148 | [graph]
149 | metric = lambda t: np.mean(utils.dense(t))
150 | format = svg
151 |
152 | [digraph_graph_attr]
153 | size = 12, 12
154 |
155 | [digraph_node_attr]
156 | style = filled
157 | shape = box
158 | align = left
159 | fontsize = 12
160 | ranksep = 0.1
161 | height = 0.2
162 |
--------------------------------------------------------------------------------
/config/convert_caffe_torch/original_person18_19.tsv:
--------------------------------------------------------------------------------
1 | dnn.features.0.weight conv1_1 lambda blobs: blobs[0]
2 | dnn.features.0.bias conv1_1 lambda blobs: blobs[1]
3 |
4 | dnn.features.2.weight conv1_2 lambda blobs: blobs[0]
5 | dnn.features.2.bias conv1_2 lambda blobs: blobs[1]
6 |
7 | dnn.features.5.weight conv2_1 lambda blobs: blobs[0]
8 | dnn.features.5.bias conv2_1 lambda blobs: blobs[1]
9 |
10 | dnn.features.7.weight conv2_2 lambda blobs: blobs[0]
11 | dnn.features.7.bias conv2_2 lambda blobs: blobs[1]
12 |
13 | dnn.features.10.weight conv3_1 lambda blobs: blobs[0]
14 | dnn.features.10.bias conv3_1 lambda blobs: blobs[1]
15 |
16 | dnn.features.12.weight conv3_2 lambda blobs: blobs[0]
17 | dnn.features.12.bias conv3_2 lambda blobs: blobs[1]
18 |
19 | dnn.features.14.weight conv3_3 lambda blobs: blobs[0]
20 | dnn.features.14.bias conv3_3 lambda blobs: blobs[1]
21 |
22 | dnn.features.16.weight conv3_4 lambda blobs: blobs[0]
23 | dnn.features.16.bias conv3_4 lambda blobs: blobs[1]
24 |
25 | dnn.features.19.weight conv4_1 lambda blobs: blobs[0]
26 | dnn.features.19.bias conv4_1 lambda blobs: blobs[1]
27 |
28 | dnn.features.21.weight conv4_2 lambda blobs: blobs[0]
29 | dnn.features.21.bias conv4_2 lambda blobs: blobs[1]
30 |
31 | dnn.features.23.weight conv4_3_CPM lambda blobs: blobs[0]
32 | dnn.features.23.bias conv4_3_CPM lambda blobs: blobs[1]
33 |
34 | dnn.features.25.weight conv4_4_CPM lambda blobs: blobs[0]
35 | dnn.features.25.bias conv4_4_CPM lambda blobs: blobs[1]
36 |
37 | stages.0.limbs.0.conv.weight conv5_1_CPM_L1 lambda blobs: blobs[0]
38 | stages.0.limbs.0.conv.bias conv5_1_CPM_L1 lambda blobs: blobs[1]
39 |
40 | stages.0.limbs.1.conv.weight conv5_2_CPM_L1 lambda blobs: blobs[0]
41 | stages.0.limbs.1.conv.bias conv5_2_CPM_L1 lambda blobs: blobs[1]
42 |
43 | stages.0.limbs.2.conv.weight conv5_3_CPM_L1 lambda blobs: blobs[0]
44 | stages.0.limbs.2.conv.bias conv5_3_CPM_L1 lambda blobs: blobs[1]
45 |
46 | stages.0.limbs.3.conv.weight conv5_4_CPM_L1 lambda blobs: blobs[0]
47 | stages.0.limbs.3.conv.bias conv5_4_CPM_L1 lambda blobs: blobs[1]
48 |
49 | stages.0.limbs.4.conv.weight conv5_5_CPM_L1 lambda blobs: blobs[0]
50 | stages.0.limbs.4.conv.bias conv5_5_CPM_L1 lambda blobs: blobs[1]
51 |
52 | stages.0.parts.0.conv.weight conv5_1_CPM_L2 lambda blobs: blobs[0]
53 | stages.0.parts.0.conv.bias conv5_1_CPM_L2 lambda blobs: blobs[1]
54 |
55 | stages.0.parts.1.conv.weight conv5_2_CPM_L2 lambda blobs: blobs[0]
56 | stages.0.parts.1.conv.bias conv5_2_CPM_L2 lambda blobs: blobs[1]
57 |
58 | stages.0.parts.2.conv.weight conv5_3_CPM_L2 lambda blobs: blobs[0]
59 | stages.0.parts.2.conv.bias conv5_3_CPM_L2 lambda blobs: blobs[1]
60 |
61 | stages.0.parts.3.conv.weight conv5_4_CPM_L2 lambda blobs: blobs[0]
62 | stages.0.parts.3.conv.bias conv5_4_CPM_L2 lambda blobs: blobs[1]
63 |
64 | stages.0.parts.4.conv.weight conv5_5_CPM_L2 lambda blobs: blobs[0]
65 | stages.0.parts.4.conv.bias conv5_5_CPM_L2 lambda blobs: blobs[1]
66 |
67 | stages.1.limbs.0.conv.weight Mconv1_stage2_L1 lambda blobs: blobs[0]
68 | stages.1.limbs.0.conv.bias Mconv1_stage2_L1 lambda blobs: blobs[1]
69 |
70 | stages.1.limbs.1.conv.weight Mconv2_stage2_L1 lambda blobs: blobs[0]
71 | stages.1.limbs.1.conv.bias Mconv2_stage2_L1 lambda blobs: blobs[1]
72 |
73 | stages.1.limbs.2.conv.weight Mconv3_stage2_L1 lambda blobs: blobs[0]
74 | stages.1.limbs.2.conv.bias Mconv3_stage2_L1 lambda blobs: blobs[1]
75 |
76 | stages.1.limbs.3.conv.weight Mconv4_stage2_L1 lambda blobs: blobs[0]
77 | stages.1.limbs.3.conv.bias Mconv4_stage2_L1 lambda blobs: blobs[1]
78 |
79 | stages.1.limbs.4.conv.weight Mconv5_stage2_L1 lambda blobs: blobs[0]
80 | stages.1.limbs.4.conv.bias Mconv5_stage2_L1 lambda blobs: blobs[1]
81 |
82 | stages.1.limbs.5.conv.weight Mconv6_stage2_L1 lambda blobs: blobs[0]
83 | stages.1.limbs.5.conv.bias Mconv6_stage2_L1 lambda blobs: blobs[1]
84 |
85 | stages.1.limbs.6.conv.weight Mconv7_stage2_L1 lambda blobs: blobs[0]
86 | stages.1.limbs.6.conv.bias Mconv7_stage2_L1 lambda blobs: blobs[1]
87 |
88 | stages.1.parts.0.conv.weight Mconv1_stage2_L2 lambda blobs: blobs[0]
89 | stages.1.parts.0.conv.bias Mconv1_stage2_L2 lambda blobs: blobs[1]
90 |
91 | stages.1.parts.1.conv.weight Mconv2_stage2_L2 lambda blobs: blobs[0]
92 | stages.1.parts.1.conv.bias Mconv2_stage2_L2 lambda blobs: blobs[1]
93 |
94 | stages.1.parts.2.conv.weight Mconv3_stage2_L2 lambda blobs: blobs[0]
95 | stages.1.parts.2.conv.bias Mconv3_stage2_L2 lambda blobs: blobs[1]
96 |
97 | stages.1.parts.3.conv.weight Mconv4_stage2_L2 lambda blobs: blobs[0]
98 | stages.1.parts.3.conv.bias Mconv4_stage2_L2 lambda blobs: blobs[1]
99 |
100 | stages.1.parts.4.conv.weight Mconv5_stage2_L2 lambda blobs: blobs[0]
101 | stages.1.parts.4.conv.bias Mconv5_stage2_L2 lambda blobs: blobs[1]
102 |
103 | stages.1.parts.5.conv.weight Mconv6_stage2_L2 lambda blobs: blobs[0]
104 | stages.1.parts.5.conv.bias Mconv6_stage2_L2 lambda blobs: blobs[1]
105 |
106 | stages.1.parts.6.conv.weight Mconv7_stage2_L2 lambda blobs: blobs[0]
107 | stages.1.parts.6.conv.bias Mconv7_stage2_L2 lambda blobs: blobs[1]
108 |
109 | stages.2.limbs.0.conv.weight Mconv1_stage3_L1 lambda blobs: blobs[0]
110 | stages.2.limbs.0.conv.bias Mconv1_stage3_L1 lambda blobs: blobs[1]
111 |
112 | stages.2.limbs.1.conv.weight Mconv2_stage3_L1 lambda blobs: blobs[0]
113 | stages.2.limbs.1.conv.bias Mconv2_stage3_L1 lambda blobs: blobs[1]
114 |
115 | stages.2.limbs.2.conv.weight Mconv3_stage3_L1 lambda blobs: blobs[0]
116 | stages.2.limbs.2.conv.bias Mconv3_stage3_L1 lambda blobs: blobs[1]
117 |
118 | stages.2.limbs.3.conv.weight Mconv4_stage3_L1 lambda blobs: blobs[0]
119 | stages.2.limbs.3.conv.bias Mconv4_stage3_L1 lambda blobs: blobs[1]
120 |
121 | stages.2.limbs.4.conv.weight Mconv5_stage3_L1 lambda blobs: blobs[0]
122 | stages.2.limbs.4.conv.bias Mconv5_stage3_L1 lambda blobs: blobs[1]
123 |
124 | stages.2.limbs.5.conv.weight Mconv6_stage3_L1 lambda blobs: blobs[0]
125 | stages.2.limbs.5.conv.bias Mconv6_stage3_L1 lambda blobs: blobs[1]
126 |
127 | stages.2.limbs.6.conv.weight Mconv7_stage3_L1 lambda blobs: blobs[0]
128 | stages.2.limbs.6.conv.bias Mconv7_stage3_L1 lambda blobs: blobs[1]
129 |
130 | stages.2.parts.0.conv.weight Mconv1_stage3_L2 lambda blobs: blobs[0]
131 | stages.2.parts.0.conv.bias Mconv1_stage3_L2 lambda blobs: blobs[1]
132 |
133 | stages.2.parts.1.conv.weight Mconv2_stage3_L2 lambda blobs: blobs[0]
134 | stages.2.parts.1.conv.bias Mconv2_stage3_L2 lambda blobs: blobs[1]
135 |
136 | stages.2.parts.2.conv.weight Mconv3_stage3_L2 lambda blobs: blobs[0]
137 | stages.2.parts.2.conv.bias Mconv3_stage3_L2 lambda blobs: blobs[1]
138 |
139 | stages.2.parts.3.conv.weight Mconv4_stage3_L2 lambda blobs: blobs[0]
140 | stages.2.parts.3.conv.bias Mconv4_stage3_L2 lambda blobs: blobs[1]
141 |
142 | stages.2.parts.4.conv.weight Mconv5_stage3_L2 lambda blobs: blobs[0]
143 | stages.2.parts.4.conv.bias Mconv5_stage3_L2 lambda blobs: blobs[1]
144 |
145 | stages.2.parts.5.conv.weight Mconv6_stage3_L2 lambda blobs: blobs[0]
146 | stages.2.parts.5.conv.bias Mconv6_stage3_L2 lambda blobs: blobs[1]
147 |
148 | stages.2.parts.6.conv.weight Mconv7_stage3_L2 lambda blobs: blobs[0]
149 | stages.2.parts.6.conv.bias Mconv7_stage3_L2 lambda blobs: blobs[1]
150 |
151 | stages.3.limbs.0.conv.weight Mconv1_stage4_L1 lambda blobs: blobs[0]
152 | stages.3.limbs.0.conv.bias Mconv1_stage4_L1 lambda blobs: blobs[1]
153 |
154 | stages.3.limbs.1.conv.weight Mconv2_stage4_L1 lambda blobs: blobs[0]
155 | stages.3.limbs.1.conv.bias Mconv2_stage4_L1 lambda blobs: blobs[1]
156 |
157 | stages.3.limbs.2.conv.weight Mconv3_stage4_L1 lambda blobs: blobs[0]
158 | stages.3.limbs.2.conv.bias Mconv3_stage4_L1 lambda blobs: blobs[1]
159 |
160 | stages.3.limbs.3.conv.weight Mconv4_stage4_L1 lambda blobs: blobs[0]
161 | stages.3.limbs.3.conv.bias Mconv4_stage4_L1 lambda blobs: blobs[1]
162 |
163 | stages.3.limbs.4.conv.weight Mconv5_stage4_L1 lambda blobs: blobs[0]
164 | stages.3.limbs.4.conv.bias Mconv5_stage4_L1 lambda blobs: blobs[1]
165 |
166 | stages.3.limbs.5.conv.weight Mconv6_stage4_L1 lambda blobs: blobs[0]
167 | stages.3.limbs.5.conv.bias Mconv6_stage4_L1 lambda blobs: blobs[1]
168 |
169 | stages.3.limbs.6.conv.weight Mconv7_stage4_L1 lambda blobs: blobs[0]
170 | stages.3.limbs.6.conv.bias Mconv7_stage4_L1 lambda blobs: blobs[1]
171 |
172 | stages.3.parts.0.conv.weight Mconv1_stage4_L2 lambda blobs: blobs[0]
173 | stages.3.parts.0.conv.bias Mconv1_stage4_L2 lambda blobs: blobs[1]
174 |
175 | stages.3.parts.1.conv.weight Mconv2_stage4_L2 lambda blobs: blobs[0]
176 | stages.3.parts.1.conv.bias Mconv2_stage4_L2 lambda blobs: blobs[1]
177 |
178 | stages.3.parts.2.conv.weight Mconv3_stage4_L2 lambda blobs: blobs[0]
179 | stages.3.parts.2.conv.bias Mconv3_stage4_L2 lambda blobs: blobs[1]
180 |
181 | stages.3.parts.3.conv.weight Mconv4_stage4_L2 lambda blobs: blobs[0]
182 | stages.3.parts.3.conv.bias Mconv4_stage4_L2 lambda blobs: blobs[1]
183 |
184 | stages.3.parts.4.conv.weight Mconv5_stage4_L2 lambda blobs: blobs[0]
185 | stages.3.parts.4.conv.bias Mconv5_stage4_L2 lambda blobs: blobs[1]
186 |
187 | stages.3.parts.5.conv.weight Mconv6_stage4_L2 lambda blobs: blobs[0]
188 | stages.3.parts.5.conv.bias Mconv6_stage4_L2 lambda blobs: blobs[1]
189 |
190 | stages.3.parts.6.conv.weight Mconv7_stage4_L2 lambda blobs: blobs[0]
191 | stages.3.parts.6.conv.bias Mconv7_stage4_L2 lambda blobs: blobs[1]
192 |
193 | stages.4.limbs.0.conv.weight Mconv1_stage5_L1 lambda blobs: blobs[0]
194 | stages.4.limbs.0.conv.bias Mconv1_stage5_L1 lambda blobs: blobs[1]
195 |
196 | stages.4.limbs.1.conv.weight Mconv2_stage5_L1 lambda blobs: blobs[0]
197 | stages.4.limbs.1.conv.bias Mconv2_stage5_L1 lambda blobs: blobs[1]
198 |
199 | stages.4.limbs.2.conv.weight Mconv3_stage5_L1 lambda blobs: blobs[0]
200 | stages.4.limbs.2.conv.bias Mconv3_stage5_L1 lambda blobs: blobs[1]
201 |
202 | stages.4.limbs.3.conv.weight Mconv4_stage5_L1 lambda blobs: blobs[0]
203 | stages.4.limbs.3.conv.bias Mconv4_stage5_L1 lambda blobs: blobs[1]
204 |
205 | stages.4.limbs.4.conv.weight Mconv5_stage5_L1 lambda blobs: blobs[0]
206 | stages.4.limbs.4.conv.bias Mconv5_stage5_L1 lambda blobs: blobs[1]
207 |
208 | stages.4.limbs.5.conv.weight Mconv6_stage5_L1 lambda blobs: blobs[0]
209 | stages.4.limbs.5.conv.bias Mconv6_stage5_L1 lambda blobs: blobs[1]
210 |
211 | stages.4.limbs.6.conv.weight Mconv7_stage5_L1 lambda blobs: blobs[0]
212 | stages.4.limbs.6.conv.bias Mconv7_stage5_L1 lambda blobs: blobs[1]
213 |
214 | stages.4.parts.0.conv.weight Mconv1_stage5_L2 lambda blobs: blobs[0]
215 | stages.4.parts.0.conv.bias Mconv1_stage5_L2 lambda blobs: blobs[1]
216 |
217 | stages.4.parts.1.conv.weight Mconv2_stage5_L2 lambda blobs: blobs[0]
218 | stages.4.parts.1.conv.bias Mconv2_stage5_L2 lambda blobs: blobs[1]
219 |
220 | stages.4.parts.2.conv.weight Mconv3_stage5_L2 lambda blobs: blobs[0]
221 | stages.4.parts.2.conv.bias Mconv3_stage5_L2 lambda blobs: blobs[1]
222 |
223 | stages.4.parts.3.conv.weight Mconv4_stage5_L2 lambda blobs: blobs[0]
224 | stages.4.parts.3.conv.bias Mconv4_stage5_L2 lambda blobs: blobs[1]
225 |
226 | stages.4.parts.4.conv.weight Mconv5_stage5_L2 lambda blobs: blobs[0]
227 | stages.4.parts.4.conv.bias Mconv5_stage5_L2 lambda blobs: blobs[1]
228 |
229 | stages.4.parts.5.conv.weight Mconv6_stage5_L2 lambda blobs: blobs[0]
230 | stages.4.parts.5.conv.bias Mconv6_stage5_L2 lambda blobs: blobs[1]
231 |
232 | stages.4.parts.6.conv.weight Mconv7_stage5_L2 lambda blobs: blobs[0]
233 | stages.4.parts.6.conv.bias Mconv7_stage5_L2 lambda blobs: blobs[1]
234 |
235 | stages.5.limbs.0.conv.weight Mconv1_stage6_L1 lambda blobs: blobs[0]
236 | stages.5.limbs.0.conv.bias Mconv1_stage6_L1 lambda blobs: blobs[1]
237 |
238 | stages.5.limbs.1.conv.weight Mconv2_stage6_L1 lambda blobs: blobs[0]
239 | stages.5.limbs.1.conv.bias Mconv2_stage6_L1 lambda blobs: blobs[1]
240 |
241 | stages.5.limbs.2.conv.weight Mconv3_stage6_L1 lambda blobs: blobs[0]
242 | stages.5.limbs.2.conv.bias Mconv3_stage6_L1 lambda blobs: blobs[1]
243 |
244 | stages.5.limbs.3.conv.weight Mconv4_stage6_L1 lambda blobs: blobs[0]
245 | stages.5.limbs.3.conv.bias Mconv4_stage6_L1 lambda blobs: blobs[1]
246 |
247 | stages.5.limbs.4.conv.weight Mconv5_stage6_L1 lambda blobs: blobs[0]
248 | stages.5.limbs.4.conv.bias Mconv5_stage6_L1 lambda blobs: blobs[1]
249 |
250 | stages.5.limbs.5.conv.weight Mconv6_stage6_L1 lambda blobs: blobs[0]
251 | stages.5.limbs.5.conv.bias Mconv6_stage6_L1 lambda blobs: blobs[1]
252 |
253 | stages.5.limbs.6.conv.weight Mconv7_stage6_L1 lambda blobs: blobs[0]
254 | stages.5.limbs.6.conv.bias Mconv7_stage6_L1 lambda blobs: blobs[1]
255 |
256 | stages.5.parts.0.conv.weight Mconv1_stage6_L2 lambda blobs: blobs[0]
257 | stages.5.parts.0.conv.bias Mconv1_stage6_L2 lambda blobs: blobs[1]
258 |
259 | stages.5.parts.1.conv.weight Mconv2_stage6_L2 lambda blobs: blobs[0]
260 | stages.5.parts.1.conv.bias Mconv2_stage6_L2 lambda blobs: blobs[1]
261 |
262 | stages.5.parts.2.conv.weight Mconv3_stage6_L2 lambda blobs: blobs[0]
263 | stages.5.parts.2.conv.bias Mconv3_stage6_L2 lambda blobs: blobs[1]
264 |
265 | stages.5.parts.3.conv.weight Mconv4_stage6_L2 lambda blobs: blobs[0]
266 | stages.5.parts.3.conv.bias Mconv4_stage6_L2 lambda blobs: blobs[1]
267 |
268 | stages.5.parts.4.conv.weight Mconv5_stage6_L2 lambda blobs: blobs[0]
269 | stages.5.parts.4.conv.bias Mconv5_stage6_L2 lambda blobs: blobs[1]
270 |
271 | stages.5.parts.5.conv.weight Mconv6_stage6_L2 lambda blobs: blobs[0]
272 | stages.5.parts.5.conv.bias Mconv6_stage6_L2 lambda blobs: blobs[1]
273 |
274 | stages.5.parts.6.conv.weight Mconv7_stage6_L2 lambda blobs: blobs[0]
275 | stages.5.parts.6.conv.bias Mconv7_stage6_L2 lambda blobs: blobs[1]
276 |
--------------------------------------------------------------------------------
/config/dataset/coco.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 0 2
3 | 1 3
4 | 2 4
5 | 0 5
6 | 0 6
7 | 5 7
8 | 6 8
9 | 7 9
10 | 8 10
11 | 0 11
12 | 0 12
13 | 11 13
14 | 12 14
15 | 13 15
16 | 14 16
17 |
--------------------------------------------------------------------------------
/config/dataset/coco/cache.coco.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[0]
2 | lambda parts: parts[1]
3 | lambda parts: parts[2]
4 | lambda parts: parts[3]
5 | lambda parts: parts[4]
6 | lambda parts: parts[5]
7 | lambda parts: parts[6]
8 | lambda parts: parts[7]
9 | lambda parts: parts[8]
10 | lambda parts: parts[9]
11 | lambda parts: parts[10]
12 | lambda parts: parts[11]
13 | lambda parts: parts[12]
14 | lambda parts: parts[13]
15 | lambda parts: parts[14]
16 | lambda parts: parts[15]
17 | lambda parts: parts[16]
18 |
--------------------------------------------------------------------------------
/config/dataset/hand20.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 1 2
3 | 2 3
4 | 0 4
5 | 4 5
6 | 5 6
7 | 6 7
8 | 0 8
9 | 8 9
10 | 9 10
11 | 10 11
12 | 0 12
13 | 12 13
14 | 13 14
15 | 14 15
16 | 0 16
17 | 16 17
18 | 17 18
19 | 18 19
--------------------------------------------------------------------------------
/config/dataset/hand20/cache.hand_nyu.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[29]
2 | lambda parts: parts[26]
3 | lambda parts: parts[25]
4 | lambda parts: parts[24]
5 | lambda parts: parts[22]
6 | lambda parts: parts[21]
7 | lambda parts: np.append((parts[20][:2] + parts[19][:2]) / 2, 1) if parts[20][2] > 0 and parts[19][2] > 0 else [0, 0, 0]
8 | lambda parts: parts[18]
9 | lambda parts: parts[16]
10 | lambda parts: parts[15]
11 | lambda parts: np.append((parts[14][:2] + parts[13][:2]) / 2, 1) if parts[14][2] > 0 and parts[13][2] > 0 else [0, 0, 0]
12 | lambda parts: parts[12]
13 | lambda parts: parts[10]
14 | lambda parts: parts[9]
15 | lambda parts: np.append((parts[8][:2] + parts[7][:2]) / 2, 1) if parts[8][2] > 0 and parts[7][2] > 0 else [0, 0, 0]
16 | lambda parts: parts[6]
17 | lambda parts: parts[4]
18 | lambda parts: parts[2]
19 | lambda parts: parts[1]
20 | lambda parts: parts[0]
21 |
--------------------------------------------------------------------------------
/config/dataset/hand21.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 1 2
3 | 2 3
4 | 3 4
5 | 0 5
6 | 5 6
7 | 6 7
8 | 7 8
9 | 0 9
10 | 9 10
11 | 10 11
12 | 11 12
13 | 0 13
14 | 13 14
15 | 14 15
16 | 15 16
17 | 0 17
18 | 17 18
19 | 18 19
20 | 19 20
--------------------------------------------------------------------------------
/config/dataset/hand21/cache.hand_nyu.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[29]
2 | lambda parts: parts[28]
3 | lambda parts: parts[26]
4 | lambda parts: parts[25]
5 | lambda parts: parts[24]
6 | lambda parts: parts[22]
7 | lambda parts: parts[21]
8 | lambda parts: np.append((parts[20][:2] + parts[19][:2]) / 2, 1) if parts[20][2] > 0 and parts[19][2] > 0 else [0, 0, 0]
9 | lambda parts: parts[18]
10 | lambda parts: parts[16]
11 | lambda parts: parts[15]
12 | lambda parts: np.append((parts[14][:2] + parts[13][:2]) / 2, 1) if parts[14][2] > 0 and parts[13][2] > 0 else [0, 0, 0]
13 | lambda parts: parts[12]
14 | lambda parts: parts[10]
15 | lambda parts: parts[9]
16 | lambda parts: np.append((parts[8][:2] + parts[7][:2]) / 2, 1) if parts[8][2] > 0 and parts[7][2] > 0 else [0, 0, 0]
17 | lambda parts: parts[6]
18 | lambda parts: parts[4]
19 | lambda parts: parts[2]
20 | lambda parts: parts[1]
21 | lambda parts: parts[0]
22 |
--------------------------------------------------------------------------------
/config/dataset/hand_nyu.tsv:
--------------------------------------------------------------------------------
1 | 29 34
2 | 34 33
3 | 33 5
4 | 5 4
5 | 4 3
6 | 3 2
7 | 2 1
8 | 1 0
9 | 34 32
10 | 32 11
11 | 11 10
12 | 10 9
13 | 9 8
14 | 8 7
15 | 7 6
16 | 32 17
17 | 17 16
18 | 16 15
19 | 15 14
20 | 14 13
21 | 13 12
22 | 34 23
23 | 23 22
24 | 22 21
25 | 21 20
26 | 20 19
27 | 19 18
28 | 29 28
29 | 28 27
30 | 27 26
31 | 26 25
32 | 25 24
33 | 29 30
34 | 29 31
35 | 29 35
36 |
--------------------------------------------------------------------------------
/config/dataset/hand_nyu/cache.hand_nyu.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[0]
2 | lambda parts: parts[1]
3 | lambda parts: parts[2]
4 | lambda parts: parts[3]
5 | lambda parts: parts[4]
6 | lambda parts: parts[5]
7 | lambda parts: parts[6]
8 | lambda parts: parts[7]
9 | lambda parts: parts[8]
10 | lambda parts: parts[9]
11 | lambda parts: parts[10]
12 | lambda parts: parts[11]
13 | lambda parts: parts[12]
14 | lambda parts: parts[13]
15 | lambda parts: parts[14]
16 | lambda parts: parts[15]
17 | lambda parts: parts[16]
18 | lambda parts: parts[17]
19 | lambda parts: parts[18]
20 | lambda parts: parts[19]
21 | lambda parts: parts[20]
22 | lambda parts: parts[21]
23 | lambda parts: parts[22]
24 | lambda parts: parts[23]
25 | lambda parts: parts[24]
26 | lambda parts: parts[25]
27 | lambda parts: parts[26]
28 | lambda parts: parts[27]
29 | lambda parts: parts[28]
30 | lambda parts: parts[29]
31 | lambda parts: parts[30]
32 | lambda parts: parts[31]
33 | lambda parts: parts[32]
34 | lambda parts: parts[33]
35 | lambda parts: parts[34]
36 | lambda parts: parts[35]
37 |
--------------------------------------------------------------------------------
/config/dataset/mpii.tsv:
--------------------------------------------------------------------------------
1 | 9 8
2 | 8 7
3 | 7 12
4 | 7 13
5 | 12 11
6 | 13 14
7 | 11 10
8 | 14 15
9 | 7 2
10 | 7 3
11 | 2 1
12 | 3 4
13 | 1 0
14 | 4 5
15 | 2 6
16 | 3 6
17 |
--------------------------------------------------------------------------------
/config/dataset/mpii.txt:
--------------------------------------------------------------------------------
1 | 5
2 | 4
3 | 3
4 | 2
5 | 1
6 | 0
7 |
8 |
9 |
10 |
11 | 15
12 | 14
13 | 13
14 | 12
15 | 11
16 | 10
17 |
--------------------------------------------------------------------------------
/config/dataset/mpii/cache.mpii.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[0]
2 | lambda parts: parts[1]
3 | lambda parts: parts[2]
4 | lambda parts: parts[3]
5 | lambda parts: parts[4]
6 | lambda parts: parts[5]
7 | lambda parts: parts[6]
8 | lambda parts: parts[7]
9 | lambda parts: parts[8]
10 | lambda parts: parts[9]
11 | lambda parts: parts[10]
12 | lambda parts: parts[11]
13 | lambda parts: parts[12]
14 | lambda parts: parts[13]
15 | lambda parts: parts[14]
16 | lambda parts: parts[15]
17 |
--------------------------------------------------------------------------------
/config/dataset/person13_12.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 0 2
3 | 1 3
4 | 2 4
5 | 3 5
6 | 4 6
7 | 0 7
8 | 0 8
9 | 7 9
10 | 8 10
11 | 9 11
12 | 10 12
13 |
--------------------------------------------------------------------------------
/config/dataset/person13_12.txt:
--------------------------------------------------------------------------------
1 |
2 | 2
3 | 1
4 | 4
5 | 3
6 | 6
7 | 5
8 | 8
9 | 7
10 | 10
11 | 9
12 | 12
13 | 11
14 |
--------------------------------------------------------------------------------
/config/dataset/person13_12/cache.coco.cache:
--------------------------------------------------------------------------------
1 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0]
2 | lambda parts: parts[6]
3 | lambda parts: parts[5]
4 | lambda parts: parts[8]
5 | lambda parts: parts[7]
6 | lambda parts: parts[10]
7 | lambda parts: parts[9]
8 | lambda parts: parts[12]
9 | lambda parts: parts[11]
10 | lambda parts: parts[14]
11 | lambda parts: parts[13]
12 | lambda parts: parts[16]
13 | lambda parts: parts[15]
14 |
--------------------------------------------------------------------------------
/config/dataset/person13_12/cache.mpii.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[7]
2 | lambda parts: parts[12]
3 | lambda parts: parts[13]
4 | lambda parts: parts[11]
5 | lambda parts: parts[14]
6 | lambda parts: parts[10]
7 | lambda parts: parts[15]
8 | lambda parts: parts[2]
9 | lambda parts: parts[3]
10 | lambda parts: parts[1]
11 | lambda parts: parts[4]
12 | lambda parts: parts[0]
13 | lambda parts: parts[5]
14 |
--------------------------------------------------------------------------------
/config/dataset/person14_13.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 0 2
3 | 1 3
4 | 2 4
5 | 3 5
6 | 4 6
7 | 0 7
8 | 0 8
9 | 7 9
10 | 8 10
11 | 9 11
12 | 10 12
13 | 0 13
14 |
--------------------------------------------------------------------------------
/config/dataset/person14_13.txt:
--------------------------------------------------------------------------------
1 |
2 | 2
3 | 1
4 | 4
5 | 3
6 | 6
7 | 5
8 | 8
9 | 7
10 | 10
11 | 9
12 | 12
13 | 11
14 |
15 |
--------------------------------------------------------------------------------
/config/dataset/person14_13/cache.coco.cache:
--------------------------------------------------------------------------------
1 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0]
2 | lambda parts: parts[6]
3 | lambda parts: parts[5]
4 | lambda parts: parts[8]
5 | lambda parts: parts[7]
6 | lambda parts: parts[10]
7 | lambda parts: parts[9]
8 | lambda parts: parts[12]
9 | lambda parts: parts[11]
10 | lambda parts: parts[14]
11 | lambda parts: parts[13]
12 | lambda parts: parts[16]
13 | lambda parts: parts[15]
14 | lambda parts: parts[0]
15 |
--------------------------------------------------------------------------------
/config/dataset/person18.tsv:
--------------------------------------------------------------------------------
1 | 0 1
2 | 0 2
3 | 1 2
4 | 1 3
5 | 2 4
6 | 0 7
7 | 7 5
8 | 7 6
9 | 5 8
10 | 6 9
11 | 8 10
12 | 9 11
13 | 7 12
14 | 7 13
15 | 12 14
16 | 13 15
17 | 14 16
18 | 15 17
19 |
--------------------------------------------------------------------------------
/config/dataset/person18.txt:
--------------------------------------------------------------------------------
1 |
2 | 2
3 | 1
4 | 4
5 | 3
6 | 6
7 | 5
8 |
9 | 9
10 | 8
11 | 11
12 | 10
13 | 13
14 | 12
15 | 15
16 | 14
17 | 17
18 | 16
19 |
--------------------------------------------------------------------------------
/config/dataset/person18/cache.coco.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[0]
2 | lambda parts: parts[1]
3 | lambda parts: parts[2]
4 | lambda parts: parts[3]
5 | lambda parts: parts[4]
6 | lambda parts: parts[5]
7 | lambda parts: parts[6]
8 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0]
9 | lambda parts: parts[7]
10 | lambda parts: parts[8]
11 | lambda parts: parts[9]
12 | lambda parts: parts[10]
13 | lambda parts: parts[11]
14 | lambda parts: parts[12]
15 | lambda parts: parts[13]
16 | lambda parts: parts[14]
17 | lambda parts: parts[15]
18 | lambda parts: parts[16]
19 |
--------------------------------------------------------------------------------
/config/dataset/person18_19.tsv:
--------------------------------------------------------------------------------
1 | 1 8
2 | 8 9
3 | 9 10
4 | 1 11
5 | 11 12
6 | 12 13
7 | 1 2
8 | 2 3
9 | 3 4
10 | 2 16
11 | 1 5
12 | 5 6
13 | 6 7
14 | 5 17
15 | 1 0
16 | 0 14
17 | 0 15
18 | 14 16
19 | 15 17
20 |
--------------------------------------------------------------------------------
/config/dataset/person18_19.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | 5
4 | 4
5 | 3
6 | 2
7 | 3
8 | 3
9 | 11
10 | 12
11 | 13
12 | 8
13 | 9
14 | 10
15 | 15
16 | 14
17 | 17
18 | 16
19 |
--------------------------------------------------------------------------------
/config/dataset/person18_19/cache.coco.cache:
--------------------------------------------------------------------------------
1 | lambda parts: parts[0]
2 | lambda parts: np.append((parts[5][:2] + parts[6][:2]) / 2, 1) if parts[5][2] > 0 and parts[6][2] > 0 else [0, 0, 0]
3 | lambda parts: parts[6]
4 | lambda parts: parts[8]
5 | lambda parts: parts[10]
6 | lambda parts: parts[5]
7 | lambda parts: parts[7]
8 | lambda parts: parts[9]
9 | lambda parts: parts[12]
10 | lambda parts: parts[14]
11 | lambda parts: parts[16]
12 | lambda parts: parts[11]
13 | lambda parts: parts[13]
14 | lambda parts: parts[15]
15 | lambda parts: parts[2]
16 | lambda parts: parts[1]
17 | lambda parts: parts[4]
18 | lambda parts: parts[3]
19 |
--------------------------------------------------------------------------------
/config/inception_unet.ini:
--------------------------------------------------------------------------------
1 | [image]
2 | size = 344 344
3 |
4 | [cache]
5 | dataset = config/dataset/person14_13
6 |
7 | [model]
8 | dnn = model.dnn.inception4.Inception4_down3_4
9 | stages = model.stages.unet.Unet1Sqz3 model.stages.unet.Unet1Sqz3_a
10 |
11 | [data]
12 | sizes = 344,344
13 |
14 | [nms]
15 | threshold = 0.05
16 |
17 | [integration]
18 | step = 5
19 | step_limits = 5 25
20 | min_score = 0.05
21 | min_count = 9
22 |
23 | [cluster]
24 | min_score = 0.4
25 | min_count = 3
26 |
--------------------------------------------------------------------------------
/config/original_person18_19.ini:
--------------------------------------------------------------------------------
1 | [image]
2 | size = 368 368
3 |
4 | [cache]
5 | dataset = config/dataset/person18_19
6 |
7 | [model]
8 | dnn = model.dnn.vgg.person18_19
9 | stages = model.stages.openpose.Stage0 model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage model.stages.openpose.Stage
10 |
11 | [data]
12 | sizes = 368,368
13 |
14 | [nms]
15 | threshold = 0.05
16 |
17 | [integration]
18 | step = 5
19 | step_limits = 5 25
20 | min_score = 0.05
21 | min_count = 9
22 |
23 | [cluster]
24 | min_score = 0.4
25 | min_count = 3
26 |
--------------------------------------------------------------------------------
/config/summary/histogram.txt:
--------------------------------------------------------------------------------
1 | .+\.bn\.weight$
2 |
--------------------------------------------------------------------------------
/convert_caffe_torch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import csv
24 | import hashlib
25 | import shutil
26 | import yaml
27 |
28 | import numpy as np
29 | import torch
30 | import torch.nn as nn
31 | import torch.autograd
32 | import caffe
33 |
34 | import utils
35 | import utils.train
36 | import model
37 |
38 |
39 | def load_mapper(path):
40 | with open(path, 'r') as f:
41 | lines = list(csv.reader(f, delimiter='\t'))
42 | mapper = {}
43 | for line in lines:
44 | if len(line) == 3:
45 | dst, src, transform = line
46 | transform = eval(transform)
47 | mapper[dst] = (src, transform)
48 | return mapper
49 |
50 |
51 | def main():
52 | args = make_args()
53 | config = configparser.ConfigParser()
54 | utils.load_config(config, args.config)
55 | for cmd in args.modify:
56 | utils.modify_config(config, cmd)
57 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
58 | logging.config.dictConfig(yaml.load(f))
59 | torch.manual_seed(args.seed)
60 | mapper = load_mapper(os.path.expandvars(os.path.expanduser(args.mapper)))
61 | model_dir = utils.get_model_dir(config)
62 | _, num_parts = utils.get_dataset_mappers(config)
63 | limbs_index = utils.get_limbs_index(config)
64 | height, width = tuple(map(int, config.get('image', 'size').split()))
65 | tensor = torch.randn(args.batch_size, 3, height, width)
66 | # PyTorch
67 | try:
68 | path, step, epoch = utils.train.load_model(model_dir)
69 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
70 | except (FileNotFoundError, ValueError):
71 | state_dict = {name: None for name in ('dnn', 'stages')}
72 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn'])
73 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn)
74 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels)
75 | channel_dict = model.channel_dict(num_parts, len(limbs_index))
76 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())])
77 | inference = model.Inference(config, dnn, stages)
78 | inference.eval()
79 | state_dict = inference.state_dict()
80 | # Caffe
81 | net = caffe.Net(os.path.expanduser(os.path.expandvars(args.prototxt)), os.path.expanduser(os.path.expandvars(args.caffemodel)), caffe.TEST)
82 | if args.debug:
83 | logging.info('Caffe variables')
84 | for name, blobs in net.params.items():
85 | for i, blob in enumerate(blobs):
86 | val = blob.data
87 | print('\t'.join(map(str, [
88 | '%s/%d' % (name, i),
89 | 'x'.join(map(str, val.shape)),
90 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
91 | ])))
92 | logging.info('Caffe features')
93 | input = net.blobs[args.input]
94 | input.reshape(*tensor.size())
95 | input.data[...] = tensor.numpy()
96 | net.forward()
97 | for name, blob in net.blobs.items():
98 | val = blob.data
99 | print('\t'.join(map(str, [
100 | name,
101 | 'x'.join(map(str, val.shape)),
102 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
103 | ])))
104 | # convert
105 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep'))
106 | try:
107 | for dst in state_dict:
108 | src, transform = mapper[dst]
109 | blobs = [b.data for b in net.params[src]]
110 | blob = transform(blobs)
111 | if isinstance(blob, np.ndarray):
112 | state_dict[dst] = torch.from_numpy(blob)
113 | else:
114 | state_dict[dst].fill_(blob)
115 | val = state_dict[dst].numpy()
116 | logging.info('\t'.join(list(map(str, (dst, src, val.shape, utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest())))))
117 | inference.load_state_dict(state_dict)
118 | if args.delete:
119 | logging.warning('delete model directory: ' + model_dir)
120 | shutil.rmtree(model_dir, ignore_errors=True)
121 | saver(dict(
122 | dnn=inference.dnn.state_dict(),
123 | stages=inference.stages.state_dict(),
124 | ), 0)
125 | finally:
126 | for stage, output in enumerate(inference(tensor)):
127 | for name, feature in output.items():
128 | val = feature.detach().numpy()
129 | print('\t'.join(map(str, [
130 | 'stage%d/%s' % (stage, name),
131 | 'x'.join(map(str, val.shape)),
132 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
133 | ])))
134 |
135 |
136 | def make_args():
137 | parser = argparse.ArgumentParser()
138 | parser.add_argument('mapper')
139 | parser.add_argument('prototxt')
140 | parser.add_argument('caffemodel')
141 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
142 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
143 | parser.add_argument('--logging', default='logging.yml', help='logging config')
144 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
145 | parser.add_argument('-i', '--input', default='image', help='input tensor name of Caffe')
146 | parser.add_argument('-d', '--delete', action='store_true', help='delete model')
147 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor')
148 | parser.add_argument('--debug', action='store_true')
149 | return parser.parse_args()
150 |
151 | if __name__ == '__main__':
152 | main()
153 |
--------------------------------------------------------------------------------
/convert_onnx_caffe2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import onnx
26 | import onnx_caffe2.backend
27 | import onnx_caffe2.helper
28 |
29 | import utils
30 |
31 |
32 | def main():
33 | args = make_args()
34 | config = configparser.ConfigParser()
35 | utils.load_config(config, args.config)
36 | for cmd in args.modify:
37 | utils.modify_config(config, cmd)
38 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
39 | logging.config.dictConfig(yaml.load(f))
40 | model_dir = utils.get_model_dir(config)
41 | model = onnx.load(model_dir + '.onnx')
42 | onnx.checker.check_model(model)
43 | init_net, predict_net = onnx_caffe2.backend.Caffe2Backend.onnx_graph_to_caffe2_net(model.graph, device='CPU')
44 | onnx_caffe2.helper.save_caffe2_net(init_net, os.path.join(model_dir, 'init_net.pb'))
45 | onnx_caffe2.helper.save_caffe2_net(predict_net, os.path.join(model_dir, 'predict_net.pb'), output_txt=True)
46 | logging.info(model_dir)
47 |
48 |
49 | def make_args():
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
52 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
53 | parser.add_argument('--logging', default='logging.yml', help='logging config')
54 | return parser.parse_args()
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/convert_tf_torch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import csv
24 | import hashlib
25 | import shutil
26 | import yaml
27 |
28 | import numpy as np
29 | import torch
30 | import torch.nn as nn
31 | import torch.autograd
32 | import tensorflow as tf
33 | from tensorflow.python.framework import ops
34 | from tensorboardX import SummaryWriter
35 |
36 | import utils
37 | import utils.train
38 | import model
39 |
40 |
41 | def load_mapper(path):
42 | with open(os.path.splitext(path)[0] + '.tsv', 'r') as f:
43 | lines = list(csv.reader(f, delimiter='\t'))
44 | mapper = {}
45 | for line in lines:
46 | if line:
47 | if len(line) < 3:
48 | line += [''] * (3 - len(line))
49 | dst, src, _converter = line
50 | converter = eval(_converter) if _converter else lambda val: val
51 | mapper[dst] = (src, converter)
52 | return mapper
53 |
54 |
55 | def main():
56 | args = make_args()
57 | config = configparser.ConfigParser()
58 | utils.load_config(config, args.config)
59 | for cmd in args.modify:
60 | utils.modify_config(config, cmd)
61 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
62 | logging.config.dictConfig(yaml.load(f))
63 | torch.manual_seed(args.seed)
64 | mapper = load_mapper(os.path.expandvars(os.path.expanduser(args.mapper)))
65 | model_dir = utils.get_model_dir(config)
66 | _, num_parts = utils.get_dataset_mappers(config)
67 | limbs_index = utils.get_limbs_index(config)
68 | height, width = tuple(map(int, config.get('image', 'size').split()))
69 | tensor = torch.randn(args.batch_size, 3, height, width)
70 | # PyTorch
71 | try:
72 | path, step, epoch = utils.train.load_model(model_dir)
73 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
74 | except (FileNotFoundError, ValueError):
75 | state_dict = {name: None for name in ('dnn', 'stages')}
76 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn'])
77 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn)
78 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels)
79 | channel_dict = model.channel_dict(num_parts, len(limbs_index))
80 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())])
81 | inference = model.Inference(config, dnn, stages)
82 | inference.eval()
83 | state_dict = inference.state_dict()
84 | # TensorFlow
85 | with open(os.path.expanduser(os.path.expandvars(args.path)), 'rb') as f:
86 | graph_def = tf.GraphDef()
87 | graph_def.ParseFromString(f.read())
88 | image = ops.convert_to_tensor(np.transpose(tensor.cpu().numpy(), [0, 2, 3, 1]), name='image')
89 | tf.import_graph_def(graph_def, input_map={'image:0': image})
90 | saver = utils.train.Saver(model_dir, config.getint('save', 'keep'))
91 | with tf.Session(config=tf.ConfigProto(
92 | device_count={'CPU': 1, 'GPU': 0},
93 | allow_soft_placement=True,
94 | log_device_placement=False
95 | )) as sess:
96 | try:
97 | for dst in state_dict:
98 | src, converter = mapper[dst]
99 | if src.isdigit():
100 | state_dict[dst].fill_(float(src))
101 | else:
102 | op = sess.graph.get_operation_by_name(src)
103 | t = op.values()[0]
104 | v = sess.run(t)
105 | state_dict[dst] = torch.from_numpy(converter(v))
106 | val = state_dict[dst].numpy()
107 | print('\t'.join(list(map(str, (dst, src, val.shape, utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest())))))
108 | inference.load_state_dict(state_dict)
109 | if args.delete:
110 | logging.warning('delete model directory: ' + model_dir)
111 | shutil.rmtree(model_dir, ignore_errors=True)
112 | saver(dict(
113 | dnn=inference.dnn.state_dict(),
114 | stages=inference.stages.state_dict(),
115 | ), 0)
116 | finally:
117 | if args.debug:
118 | for op in sess.graph.get_operations():
119 | if op.values():
120 | logging.info(op.values()[0])
121 | for name in args.debug:
122 | t = sess.graph.get_tensor_by_name(name + ':0')
123 | val = sess.run(t)
124 | val = np.transpose(val, [0, 3, 1, 2])
125 | print('\t'.join(map(str, [
126 | name,
127 | 'x'.join(map(str, val.shape)),
128 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
129 | ])))
130 | val = dnn(tensor).detach().numpy()
131 | print('\t'.join(map(str, [
132 | 'x'.join(map(str, val.shape)),
133 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
134 | ])))
135 | for stage, output in enumerate(inference(tensor)):
136 | for name, feature in output.items():
137 | val = feature.detach().numpy()
138 | print('\t'.join(map(str, [
139 | 'stage%d/%s' % (stage, name),
140 | 'x'.join(map(str, val.shape)),
141 | utils.abs_mean(val), hashlib.md5(val.tostring()).hexdigest(),
142 | ])))
143 | forward = inference.forward
144 | inference.forward = lambda self, *x: list(forward(self, *x)[-1].values())
145 | with SummaryWriter(model_dir) as writer:
146 | writer.add_graph(inference, (tensor,))
147 |
148 |
149 | def make_args():
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('path')
152 | parser.add_argument('mapper')
153 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
154 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
155 | parser.add_argument('--logging', default='logging.yml', help='logging config')
156 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
157 | parser.add_argument('-d', '--delete', action='store_true', help='delete model')
158 | parser.add_argument('-s', '--seed', default=0, type=int, help='a seed to create a random image tensor')
159 | parser.add_argument('--debug', nargs='+')
160 | return parser.parse_args()
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
--------------------------------------------------------------------------------
/convert_torch_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import torch.nn as nn
26 | import torch.autograd
27 | import torch.cuda
28 | import torch.optim
29 | import torch.utils.data
30 | import torch.onnx
31 | import humanize
32 |
33 | import utils.train
34 | import model
35 |
36 |
37 | def main():
38 | args = make_args()
39 | config = configparser.ConfigParser()
40 | utils.load_config(config, args.config)
41 | for cmd in args.modify:
42 | utils.modify_config(config, cmd)
43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
44 | logging.config.dictConfig(yaml.load(f))
45 | height, width = tuple(map(int, config.get('image', 'size').split()))
46 | model_dir = utils.get_model_dir(config)
47 | _, num_parts = utils.get_dataset_mappers(config)
48 | limbs_index = utils.get_limbs_index(config)
49 | path, step, epoch = utils.train.load_model(model_dir)
50 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
51 | config_channels_dnn = model.ConfigChannels(config, state_dict['dnn'])
52 | dnn = utils.parse_attr(config.get('model', 'dnn'))(config_channels_dnn)
53 | config_channels_stages = model.ConfigChannels(config, state_dict['stages'], config_channels_dnn.channels)
54 | channel_dict = model.channel_dict(num_parts, len(limbs_index))
55 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(config.get('model', 'stages').split())])
56 | dnn.load_state_dict(config_channels_dnn.state_dict)
57 | stages.load_state_dict(config_channels_stages.state_dict)
58 | inference = model.Inference(config, dnn, stages)
59 | inference.eval()
60 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in inference.state_dict().values())))
61 | image = torch.randn(args.batch_size, 3, height, width)
62 | path = model_dir + '.onnx'
63 | logging.info('save ' + path)
64 | forward = inference.forward
65 | inference.forward = lambda self, *x: [[output[name] for name in 'parts, limbs'.split(', ')] for output in forward(self, *x)]
66 | torch.onnx.export(inference, image, path, export_params=True, verbose=args.verbose)
67 |
68 |
69 | def make_args():
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
72 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
73 | parser.add_argument('-b', '--batch_size', default=1, type=int, help='batch size')
74 | parser.add_argument('-v', '--verbose', action='store_true')
75 | parser.add_argument('--logging', default='logging.yml', help='logging config')
76 | return parser.parse_args()
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------
/demo_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import multiprocessing
24 | import yaml
25 |
26 | import numpy as np
27 | import torch.utils.data
28 | import matplotlib.pyplot as plt
29 |
30 | import utils.data
31 | import utils.train
32 | import utils.visualize
33 | import transform.augmentation
34 | import model
35 |
36 |
37 | def main():
38 | args = make_args()
39 | config = configparser.ConfigParser()
40 | utils.load_config(config, args.config)
41 | for cmd in args.modify:
42 | utils.modify_config(config, cmd)
43 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
44 | logging.config.dictConfig(yaml.load(f))
45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46 | cache_dir = utils.get_cache_dir(config)
47 | _, num_parts = utils.get_dataset_mappers(config)
48 | limbs_index = utils.get_limbs_index(config)
49 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device)
50 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split())
51 | _draw_points = utils.visualize.DrawPoints(limbs_index, thickness=1)
52 | draw_bbox = utils.visualize.DrawBBox()
53 | batch_size = args.rows * args.cols
54 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase]
55 | dataset = utils.data.Dataset(
56 | config,
57 | utils.data.load_pickles(paths),
58 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()),
59 | shuffle=config.getboolean('data', 'shuffle'),
60 | )
61 | logging.info('num_examples=%d' % len(dataset))
62 | try:
63 | workers = config.getint('data', 'workers')
64 | except configparser.NoOptionError:
65 | workers = multiprocessing.cpu_count()
66 | sizes = utils.train.load_sizes(config)
67 | feature_sizes = [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:] for size in sizes]
68 | collate_fn = utils.data.Collate(
69 | config,
70 | transform.parse_transform(config, config.get('transform', 'resize_train')),
71 | sizes, feature_sizes,
72 | maintain=config.getint('data', 'maintain'),
73 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()),
74 | )
75 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn)
76 | for data in loader:
77 | path, size, image, mask, keypoints, yx_min, yx_max, index = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, mask, keypoints, yx_min, yx_max, index'.split(', ')))
78 | fig, axes = plt.subplots(args.rows, args.cols)
79 | axes = axes.flat if batch_size > 1 else [axes]
80 | for ax, path, size, image, mask, keypoints, yx_min, yx_max, index in zip(*[axes, path, size, image, mask, keypoints, yx_min, yx_max, index]):
81 | logging.info(path + ': ' + 'x'.join(map(str, size)))
82 | image = utils.visualize.draw_mask(image, mask, 1)
83 | size = yx_max - yx_min
84 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)])
85 | keypoints, yx_min, yx_max = (a[target] for a in (keypoints, yx_min, yx_max))
86 | for i, points in enumerate(keypoints):
87 | if i == index:
88 | image = draw_points(image, points)
89 | else:
90 | image = _draw_points(image, points)
91 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int))
92 | ax.imshow(image)
93 | ax.set_xticks([])
94 | ax.set_yticks([])
95 | fig.tight_layout()
96 | mng = plt.get_current_fig_manager()
97 | mng.resize(*mng.window.maxsize())
98 | plt.show()
99 |
100 |
101 | def make_args():
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
104 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
105 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
106 | parser.add_argument('--rows', default=3, type=int)
107 | parser.add_argument('--cols', default=3, type=int)
108 | parser.add_argument('--logging', default='logging.yml', help='logging config')
109 | return parser.parse_args()
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/demo_keypoints.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import yaml
24 |
25 | import numpy as np
26 | import scipy.misc
27 | import matplotlib.pyplot as plt
28 |
29 | import utils.data
30 | import utils.visualize
31 |
32 |
33 | def main():
34 | args = make_args()
35 | config = configparser.ConfigParser()
36 | utils.load_config(config, args.config)
37 | for cmd in args.modify:
38 | utils.modify_config(config, cmd)
39 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
40 | logging.config.dictConfig(yaml.load(f))
41 | cache_dir = utils.get_cache_dir(config)
42 | _, num_parts = utils.get_dataset_mappers(config)
43 | limbs_index = utils.get_limbs_index(config)
44 | mask_ext = config.get('cache', 'mask_ext')
45 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase]
46 | dataset = utils.data.Dataset(config, utils.data.load_pickles(paths))
47 | logging.info('num_examples=%d' % len(dataset))
48 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split())
49 | draw_bbox = utils.visualize.DrawBBox(config)
50 | for data in dataset:
51 | path, keypath, keypoints, yx_min, yx_max = (data[key] for key in 'path, keypath, keypoints, yx_min, yx_max'.split(', '))
52 | image = scipy.misc.imread(path, mode='RGB')
53 | fig = plt.figure()
54 | ax = fig.gca()
55 | maskpath = keypath + '.mask' + mask_ext
56 | mask = scipy.misc.imread(maskpath)
57 | image = utils.visualize.draw_mask(image, mask)
58 | for points in keypoints:
59 | image = draw_points(image, points)
60 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int))
61 | ax.imshow(image)
62 | ax.set_xlim([0, image.shape[1] - 1])
63 | ax.set_ylim([image.shape[0] - 1, 0])
64 | ax.set_xticks([])
65 | ax.set_yticks([])
66 | mng = plt.get_current_fig_manager()
67 | mng.resize(*mng.window.maxsize())
68 | plt.show()
69 |
70 |
71 | def make_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
74 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
75 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
76 | parser.add_argument('--logging', default='logging.yml', help='logging config')
77 | return parser.parse_args()
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------
/demo_label.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import sys
19 | import os
20 | import argparse
21 | import configparser
22 | import logging
23 | import logging.config
24 | import multiprocessing
25 | import yaml
26 |
27 | import numpy as np
28 | import torch.utils.data
29 | from PyQt5 import QtCore, QtWidgets
30 | import matplotlib.pyplot as plt
31 | import matplotlib.backends.backend_qt5agg as qtagg
32 | import humanize
33 | import cv2
34 |
35 | import model
36 | import utils.data
37 | import utils.train
38 | import utils.visualize
39 | import transform.augmentation
40 |
41 |
42 | class Visualizer(QtWidgets.QDialog):
43 | def __init__(self, name, image, feature, alpha=0.5):
44 | super(Visualizer, self).__init__()
45 | self.name = name
46 | self.image = image
47 | self.feature = feature
48 | self.draw_feature = utils.visualize.DrawFeature(alpha)
49 |
50 | layout = QtWidgets.QVBoxLayout(self)
51 | fig = plt.Figure()
52 | self.ax = fig.gca()
53 | self.canvas = qtagg.FigureCanvasQTAgg(fig)
54 | layout.addWidget(self.canvas)
55 | toolbar = qtagg.NavigationToolbar2QT(self.canvas, self)
56 | layout.addWidget(toolbar)
57 | self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal, self)
58 | self.slider.setRange(0, feature.shape[0] - 1)
59 | layout.addWidget(self.slider)
60 | self.slider.valueChanged[int].connect(self.on_progress)
61 |
62 | self.ax.imshow(self.image)
63 | self.ax.set_xticks([])
64 | self.ax.set_yticks([])
65 | self.on_progress(0)
66 |
67 | def on_progress(self, index):
68 | try:
69 | self.last.remove()
70 | except AttributeError:
71 | pass
72 | image = np.copy(self.image)
73 | feature = self.feature[index, :, :]
74 | image = self.draw_feature(image, feature)
75 | self.last = self.ax.imshow(image)
76 | self.canvas.draw()
77 | plt.draw()
78 | self.setWindowTitle('%s %d/%d' % (self.name, index + 1, self.feature.shape[0]))
79 |
80 |
81 | def main():
82 | args = make_args()
83 | config = configparser.ConfigParser()
84 | utils.load_config(config, args.config)
85 | for cmd in args.modify:
86 | utils.modify_config(config, cmd)
87 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
88 | logging.config.dictConfig(yaml.load(f))
89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90 | cache_dir = utils.get_cache_dir(config)
91 | _, num_parts = utils.get_dataset_mappers(config)
92 | limbs_index = utils.get_limbs_index(config)
93 | dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config)).to(device)
94 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in dnn.state_dict().values())))
95 | size = tuple(map(int, config.get('image', 'size').split()))
96 | draw_points = utils.visualize.DrawPoints(limbs_index, colors=config.get('draw_points', 'colors').split())
97 | _draw_points = utils.visualize.DrawPoints(limbs_index, thickness=1)
98 | draw_bbox = utils.visualize.DrawBBox()
99 | paths = [os.path.join(cache_dir, phase + '.pkl') for phase in args.phase]
100 | dataset = utils.data.Dataset(
101 | config,
102 | utils.data.load_pickles(paths),
103 | transform=transform.augmentation.get_transform(config, config.get('transform', 'augmentation').split()),
104 | shuffle=config.getboolean('data', 'shuffle'),
105 | )
106 | logging.info('num_examples=%d' % len(dataset))
107 | try:
108 | workers = config.getint('data', 'workers')
109 | except configparser.NoOptionError:
110 | workers = multiprocessing.cpu_count()
111 | collate_fn = utils.data.Collate(
112 | config,
113 | transform.parse_transform(config, config.get('transform', 'resize_train')),
114 | [size], [dnn(torch.randn(1, 3, *size).to(device)).size()[-2:]],
115 | maintain=config.getint('data', 'maintain'),
116 | transform_image=transform.get_transform(config, config.get('transform', 'image_train').split()),
117 | )
118 | loader = torch.utils.data.DataLoader(dataset, shuffle=True, num_workers=workers, collate_fn=collate_fn)
119 | for data in loader:
120 | path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index = (t.numpy() if hasattr(t, 'numpy') else t for t in (data[key] for key in 'path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index'.split(', ')))
121 | for path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index in zip(*[path, size, image, mask, keypoints, yx_min, yx_max, parts, limbs, index]):
122 | logging.info(path + ': ' + 'x'.join(map(str, size)))
123 | image = utils.visualize.draw_mask(image, mask, 1)
124 | size = yx_max - yx_min
125 | target = np.logical_and(*[np.squeeze(a, -1) > 0 for a in np.split(size, size.shape[-1], -1)])
126 | keypoints, yx_min, yx_max = (a[target] for a in (keypoints, yx_min, yx_max))
127 | for i, points in enumerate(keypoints):
128 | if i == index:
129 | image = draw_points(image, points)
130 | else:
131 | image = _draw_points(image, points)
132 | image = draw_bbox(image, yx_min.astype(np.int), yx_max.astype(np.int))
133 | dialog = Visualizer('parts', image, parts)
134 | dialog.exec()
135 | dialog = Visualizer('limbs', image, limbs)
136 | dialog.exec()
137 |
138 |
139 | def make_args():
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
142 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
143 | parser.add_argument('-p', '--phase', nargs='+', default=['train', 'val', 'test'])
144 | parser.add_argument('--logging', default='logging.yml', help='logging config')
145 | return parser.parse_args()
146 |
147 | if __name__ == '__main__':
148 | app = QtWidgets.QApplication(sys.argv)
149 | main()
150 | sys.exit(app.exec_())
151 |
--------------------------------------------------------------------------------
/donate_alipay.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/donate_alipay.jpg
--------------------------------------------------------------------------------
/donate_mm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/donate_mm.jpg
--------------------------------------------------------------------------------
/estimate.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import argparse
19 | import configparser
20 | import logging
21 | import logging.config
22 | import os
23 | import time
24 | import re
25 | import yaml
26 |
27 | import numpy as np
28 | import torch.autograd
29 | import torch.cuda
30 | import torch.optim
31 | import torch.utils.data
32 | import torch.nn as nn
33 | try:
34 | from caffe2.proto import caffe2_pb2
35 | from caffe2.python import workspace
36 | except ImportError:
37 | pass
38 | import humanize
39 | import pybenchmark
40 | import cv2
41 |
42 | import transform
43 | import model
44 | import utils.train
45 | import utils.visualize
46 | import pyopenpose
47 |
48 |
49 | class Estimate(object):
50 | def __init__(self, args, config):
51 | self.args = args
52 | self.config = config
53 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54 | self.cache_dir = utils.get_cache_dir(config)
55 | self.model_dir = utils.get_model_dir(config)
56 | _, self.num_parts = utils.get_dataset_mappers(config)
57 | self.limbs_index = utils.get_limbs_index(config)
58 | if args.debug is None:
59 | self.draw_cluster = utils.visualize.DrawCluster(colors=args.colors, thickness=args.thickness)
60 | else:
61 | self.draw_feature = utils.visualize.DrawFeature()
62 | s = re.search('(-?[0-9]+)([a-z]+)(-?[0-9]+)', args.debug)
63 | stage = int(s.group(1))
64 | name = s.group(2)
65 | channel = int(s.group(3))
66 | self.get_feature = lambda outputs: outputs[stage][name][0][channel]
67 | self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
68 | if args.caffe:
69 | init_net = caffe2_pb2.NetDef()
70 | with open(os.path.join(self.model_dir, 'init_net.pb'), 'rb') as f:
71 | init_net.ParseFromString(f.read())
72 | predict_net = caffe2_pb2.NetDef()
73 | with open(os.path.join(self.model_dir, 'predict_net.pb'), 'rb') as f:
74 | predict_net.ParseFromString(f.read())
75 | p = workspace.Predictor(init_net, predict_net)
76 | self.inference = lambda tensor: [{'parts': torch.from_numpy(parts), 'limbs': torch.from_numpy(limbs)} for parts, limbs in zip(*[iter(p.run([tensor.detach().cpu().numpy()]))] * 2)]
77 | else:
78 | self.step, self.epoch, self.dnn, self.stages = self.load()
79 | self.inference = model.Inference(config, self.dnn, self.stages)
80 | self.inference.eval()
81 | if torch.cuda.is_available():
82 | self.inference.cuda()
83 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values())))
84 | self.cap = self.create_cap()
85 | self.keys = set(args.keys)
86 | self.resize = transform.parse_transform(config, config.get('transform', 'resize_test'))
87 | self.transform_image = transform.get_transform(config, config.get('transform', 'image_test').split())
88 | self.transform_tensor = transform.get_transform(config, config.get('transform', 'tensor').split())
89 |
90 | def __del__(self):
91 | cv2.destroyAllWindows()
92 | try:
93 | self.writer.release()
94 | except AttributeError:
95 | pass
96 | self.cap.release()
97 |
98 | def load(self):
99 | path, step, epoch = utils.train.load_model(self.model_dir)
100 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
101 | config_channels_dnn = model.ConfigChannels(self.config, state_dict['dnn'])
102 | dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels_dnn)
103 | config_channels_stages = model.ConfigChannels(self.config, state_dict['stages'], config_channels_dnn.channels)
104 | channel_dict = model.channel_dict(self.num_parts, len(self.limbs_index))
105 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(self.config.get('model', 'stages').split())])
106 | dnn.load_state_dict(config_channels_dnn.state_dict)
107 | stages.load_state_dict(config_channels_stages.state_dict)
108 | return step, epoch, dnn, stages
109 |
110 | def create_cap(self):
111 | try:
112 | cap = int(self.args.input)
113 | except ValueError:
114 | cap = os.path.expanduser(os.path.expandvars(self.args.input))
115 | assert os.path.exists(cap)
116 | return cv2.VideoCapture(cap)
117 |
118 | def create_writer(self, height, width):
119 | fps = self.cap.get(cv2.CAP_PROP_FPS)
120 | logging.info('cap fps=%f' % fps)
121 | path = os.path.expanduser(os.path.expandvars(self.args.output))
122 | if self.args.fourcc:
123 | fourcc = cv2.VideoWriter_fourcc(*self.args.fourcc.upper())
124 | else:
125 | fourcc = int(self.cap.get(cv2.CAP_PROP_FOURCC))
126 | os.makedirs(os.path.dirname(path), exist_ok=True)
127 | return cv2.VideoWriter(path, fourcc, fps, (width, height))
128 |
129 | def get_image(self):
130 | ret, image_bgr = self.cap.read()
131 | if self.args.crop:
132 | image_bgr = image_bgr[self.crop_ymin:self.crop_ymax, self.crop_xmin:self.crop_xmax]
133 | return image_bgr
134 |
135 | def __call__(self):
136 | image_bgr = self.get_image()
137 | image_resized = self.resize(image_bgr, self.height, self.width)
138 | image = self.transform_image(image_resized)
139 | tensor = self.transform_tensor(image)
140 | tensor = tensor.unsqueeze(0).to(self.device)
141 | outputs = pybenchmark.profile('inference')(self.inference)(tensor)
142 | if hasattr(self, 'draw_cluster'):
143 | output = outputs[-1]
144 | parts, limbs = (output[name][0] for name in 'parts, limbs'.split(', '))
145 | parts = parts[:-1]
146 | parts, limbs = (t.detach().cpu().numpy() for t in (parts, limbs))
147 | try:
148 | interpolation = getattr(cv2, 'INTER_' + self.config.get('estimate', 'interpolation').upper())
149 | parts, limbs = (np.stack([cv2.resize(feature, (self.width, self.height), interpolation=interpolation) for feature in a]) for a in (parts, limbs))
150 | except configparser.NoOptionError:
151 | pass
152 | clusters = pyopenpose.estimate(
153 | parts, limbs,
154 | self.limbs_index,
155 | self.config.getfloat('nms', 'threshold'),
156 | self.config.getfloat('integration', 'step'), tuple(map(int, self.config.get('integration', 'step_limits').split())), self.config.getfloat('integration', 'min_score'), self.config.getint('integration', 'min_count'),
157 | self.config.getfloat('cluster', 'min_score'), self.config.getint('cluster', 'min_count'),
158 | )
159 | scale_y, scale_x = self.resize.scale(parts.shape[-2:], image_bgr.shape[:2])
160 | image_result = image_bgr.copy()
161 | for cluster in clusters:
162 | cluster = [((i1, int(y1 * scale_y), int(x1 * scale_x)), (i2, int(y2 * scale_y), int(x2 * scale_x))) for (i1, y1, x1), (i2, y2, x2) in cluster]
163 | image_result = self.draw_cluster(image_result, cluster)
164 | else:
165 | image_result = image_resized.copy()
166 | feature = self.get_feature(outputs).detach().cpu().numpy()
167 | image_result = self.draw_feature(image_result, feature)
168 | if self.args.output:
169 | if not hasattr(self, 'writer'):
170 | self.writer = self.create_writer(*image_result.shape[:2])
171 | self.writer.write(image_result)
172 | else:
173 | cv2.imshow('estimate', image_result)
174 | if cv2.waitKey(0 if self.args.pause else 1) in self.keys:
175 | root = os.path.join(self.model_dir, 'snapshot')
176 | os.makedirs(root, exist_ok=True)
177 | path = os.path.join(root, time.strftime(self.args.format))
178 | cv2.imwrite(path, image_bgr)
179 | logging.warning('image dumped into ' + path)
180 |
181 |
182 | def main():
183 | args = make_args()
184 | config = configparser.ConfigParser()
185 | utils.load_config(config, args.config)
186 | for cmd in args.modify:
187 | utils.modify_config(config, cmd)
188 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
189 | logging.config.dictConfig(yaml.load(f))
190 | estimate = Estimate(args, config)
191 | try:
192 | with torch.no_grad():
193 | while estimate.cap.isOpened():
194 | estimate()
195 | except KeyboardInterrupt:
196 | logging.warning('interrupted')
197 | finally:
198 | logging.info(pybenchmark.stats)
199 |
200 |
201 | def make_args():
202 | parser = argparse.ArgumentParser()
203 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
204 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
205 | parser.add_argument('-i', '--input', default=-1)
206 | parser.add_argument('-k', '--keys', nargs='+', type=int, default=[ord(' ')], help='keys to dump images')
207 | parser.add_argument('-o', '--output', help='output video file')
208 | parser.add_argument('-f', '--format', default='%Y-%m-%d_%H-%M-%S.jpg', help='dump file name format')
209 | parser.add_argument('--crop', nargs='+', type=float, default=[], help='ymin ymax xmin xmax')
210 | parser.add_argument('--pause', action='store_true')
211 | parser.add_argument('--fourcc', default='XVID', help='4-character code of codec used to compress the frames, such as XVID, MJPG')
212 | parser.add_argument('--thickness', default=3, type=int)
213 | parser.add_argument('--colors', nargs='+', default=[])
214 | parser.add_argument('-d', '--debug')
215 | parser.add_argument('--caffe', action='store_true')
216 | parser.add_argument('--logging', default='logging.yml', help='logging config')
217 | return parser.parse_args()
218 |
219 |
220 | if __name__ == '__main__':
221 | main()
222 |
--------------------------------------------------------------------------------
/logging.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | formatters:
3 | simple:
4 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
5 | handlers:
6 | console:
7 | class: logging.StreamHandler
8 | level: INFO
9 | formatter: simple
10 | stream: ext://sys.stderr
11 | root:
12 | level: INFO
13 | handlers: [console]
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import collections
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.autograd
24 |
25 |
26 | class ConfigChannels(object):
27 | def __init__(self, config, state_dict=None, channels=3):
28 | self.config = config
29 | self.state_dict = state_dict
30 | self.channels = channels
31 |
32 | def __call__(self, default, name, fn=lambda var: var.size(0)):
33 | if self.state_dict is None:
34 | self.channels = default
35 | else:
36 | var = self.state_dict[name]
37 | self.channels = fn(var)
38 | if self.channels != default:
39 | logging.warning('%s: change number of output channels from %d to %d' % (name, default, self.channels))
40 | return self.channels
41 |
42 |
43 | def channel_dict(num_parts, num_limbs):
44 | return collections.OrderedDict([
45 | ('parts', num_parts + 1),
46 | ('limbs', num_limbs * 2),
47 | ])
48 |
49 |
50 | class Inference(nn.Module):
51 | def __init__(self, config, dnn, stages):
52 | nn.Module.__init__(self)
53 | self.config = config
54 | self.dnn = dnn
55 | self.stages = stages
56 |
57 | def forward(self, x):
58 | x = self.dnn(x)
59 | outputs = []
60 | output = {}
61 | for stage in self.stages:
62 | output = stage(x, **output)
63 | outputs.append(output)
64 | return outputs
65 |
66 |
67 | class Loss(object):
68 | def __init__(self, config, data, limbs_index, height, width):
69 | self.config = config
70 | self.data = data
71 | self.limbs_index = limbs_index
72 | self.height = height
73 | self.width = width
74 |
75 | def __call__(self, **kwargs):
76 | mask = self.data['mask'].float()
77 | batch_size, rows, cols = mask.size()
78 | mask = mask.view(batch_size, 1, rows, cols)
79 | data = {name: self.data[name] for name in kwargs}
80 | return {name: self.loss(mask, data[name], feature) for name, feature in kwargs.items()}
81 |
82 | def loss(self, mask, label, feature):
83 | return torch.mean(mask * (feature - label) ** 2)
84 |
--------------------------------------------------------------------------------
/model/dnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/model/dnn/__init__.py
--------------------------------------------------------------------------------
/model/dnn/inception4.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import configparser
20 | import collections.abc
21 |
22 | import torch
23 | import torch.nn as nn
24 | from pretrainedmodels.models.inceptionv4 import pretrained_settings
25 |
26 |
27 | class Conv2d(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=True, act=True):
29 | nn.Module.__init__(self)
30 | if isinstance(padding, bool):
31 | if isinstance(kernel_size, collections.abc.Iterable):
32 | padding = [(kernel_size - 1) // 2 for kernel_size in kernel_size] if padding else 0
33 | else:
34 | padding = (kernel_size - 1) // 2 if padding else 0
35 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn)
36 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) if bn else lambda x: x
37 | self.act = nn.ReLU(inplace=True) if act else lambda x: x
38 |
39 | def forward(self, x):
40 | x = self.conv(x)
41 | x = self.bn(x)
42 | x = self.act(x)
43 | return x
44 |
45 |
46 | class Mixed_3a(nn.Module):
47 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
48 | nn.Module.__init__(self)
49 | channels = config_channels.channels
50 | self.maxpool = nn.MaxPool2d(3, stride=2)
51 | self.conv = Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.conv.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn)
52 | config_channels.channels = channels + self.conv.conv.weight.size(0)
53 |
54 | def forward(self, x):
55 | x0 = self.maxpool(x)
56 | x1 = self.conv(x)
57 | out = torch.cat((x0, x1), 1)
58 | return out
59 |
60 |
61 | class Mixed_4a(nn.Module):
62 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
63 | nn.Module.__init__(self)
64 | # branch0
65 | channels = config_channels.channels
66 | branch = []
67 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
68 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, bn=bn))
69 | self.branch0 = nn.Sequential(*branch)
70 | # branch1
71 | config_channels.channels = channels
72 | branch = []
73 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
74 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn))
75 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn))
76 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(3, 3), stride=1, bn=bn))
77 | self.branch1 = nn.Sequential(*branch)
78 | # output
79 | config_channels.channels = self.branch0[-1].conv.weight.size(0) + self.branch1[-1].conv.weight.size(0)
80 |
81 | def forward(self, x):
82 | x0 = self.branch0(x)
83 | x1 = self.branch1(x)
84 | out = torch.cat((x0, x1), 1)
85 | return out
86 |
87 |
88 | class Mixed_5a(nn.Module):
89 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
90 | nn.Module.__init__(self)
91 | channels = config_channels.channels
92 | self.conv = Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.conv.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn)
93 | self.maxpool = nn.MaxPool2d(3, stride=2)
94 | config_channels.channels = self.conv.conv.weight.size(0) + channels
95 |
96 | def forward(self, x):
97 | x0 = self.conv(x)
98 | x1 = self.maxpool(x)
99 | out = torch.cat((x0, x1), 1)
100 | return out
101 |
102 |
103 | class Inception_A(nn.Module):
104 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
105 | nn.Module.__init__(self)
106 | channels = config_channels.channels
107 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn)
108 | # branch1
109 | config_channels.channels = channels
110 | branch = []
111 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
112 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn))
113 | self.branch1 = nn.Sequential(*branch)
114 | # branch2
115 | config_channels.channels = channels
116 | branch = []
117 | branch.append(Conv2d(config_channels.channels, config_channels(int(64 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
118 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn))
119 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn))
120 | self.branch2 = nn.Sequential(*branch)
121 | #branch3
122 | config_channels.channels = channels
123 | branch = []
124 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False))
125 | branch.append(Conv2d(config_channels.channels, config_channels(int(96 * ratio), '%s.branch3.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
126 | self.branch3 = nn.Sequential(*branch)
127 | # output
128 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + self.branch2[-1].conv.weight.size(0) + self.branch3[-1].conv.weight.size(0)
129 |
130 | def forward(self, x):
131 | x0 = self.branch0(x)
132 | x1 = self.branch1(x)
133 | x2 = self.branch2(x)
134 | x3 = self.branch3(x)
135 | out = torch.cat((x0, x1, x2, x3), 1)
136 | return out
137 |
138 |
139 | class Reduction_A(nn.Module):
140 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
141 | nn.Module.__init__(self)
142 | channels = config_channels.channels
143 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=3, stride=2, bn=bn)
144 | # branch1
145 | config_channels.channels = channels
146 | branch = []
147 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
148 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=1, padding=1, bn=bn))
149 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn))
150 | self.branch1 = nn.Sequential(*branch)
151 |
152 | self.branch2 = nn.MaxPool2d(3, stride=2)
153 | # output
154 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + channels
155 |
156 | def forward(self, x):
157 | x0 = self.branch0(x)
158 | x1 = self.branch1(x)
159 | x2 = self.branch2(x)
160 | out = torch.cat((x0, x1, x2), 1)
161 | return out
162 |
163 |
164 | class Inception_B(nn.Module):
165 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
166 | nn.Module.__init__(self)
167 | channels = config_channels.channels
168 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn)
169 | # branch1
170 | config_channels.channels = channels
171 | branch = []
172 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
173 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn))
174 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn))
175 | self.branch1 = nn.Sequential(*branch)
176 | # branch2
177 | config_channels.channels = channels
178 | branch = []
179 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
180 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn))
181 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn))
182 | branch.append(Conv2d(config_channels.channels, config_channels(int(224 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn))
183 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch2.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn))
184 | self.branch2 = nn.Sequential(*branch)
185 | # branch3
186 | config_channels.channels = channels
187 | branch = []
188 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False))
189 | branch.append(Conv2d(config_channels.channels, config_channels(int(128 * ratio), '%s.branch3.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
190 | self.branch3 = nn.Sequential(*branch)
191 | # output
192 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + self.branch2[-1].conv.weight.size(0) + self.branch3[-1].conv.weight.size(0)
193 |
194 | def forward(self, x):
195 | x0 = self.branch0(x)
196 | x1 = self.branch1(x)
197 | x2 = self.branch2(x)
198 | x3 = self.branch3(x)
199 | out = torch.cat((x0, x1, x2, x3), 1)
200 | return out
201 |
202 |
203 | class Reduction_B(nn.Module):
204 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
205 | nn.Module.__init__(self)
206 | # branch0
207 | channels = config_channels.channels
208 | branch = []
209 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
210 | branch.append(Conv2d(config_channels.channels, config_channels(int(192 * ratio), '%s.branch0.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn))
211 | self.branch0 = nn.Sequential(*branch)
212 | # branch1
213 | config_channels.channels = channels
214 | branch = []
215 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=1, stride=1, bn=bn))
216 | branch.append(Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(1, 7), stride=1, padding=(0, 3), bn=bn))
217 | branch.append(Conv2d(config_channels.channels, config_channels(int(320 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=(7, 1), stride=1, padding=(3, 0), bn=bn))
218 | branch.append(Conv2d(config_channels.channels, config_channels(int(320 * ratio), '%s.branch1.%d.conv.weight' % (prefix, len(branch))), kernel_size=3, stride=2, bn=bn))
219 | self.branch1 = nn.Sequential(*branch)
220 | self.branch2 = nn.MaxPool2d(3, stride=2)
221 | # output
222 | config_channels.channels = self.branch0[-1].conv.weight.size(0) + self.branch1[-1].conv.weight.size(0) + channels
223 |
224 | def forward(self, x):
225 | x0 = self.branch0(x)
226 | x1 = self.branch1(x)
227 | x2 = self.branch2(x)
228 | out = torch.cat((x0, x1, x2), 1)
229 | return out
230 |
231 |
232 | class Inception_C(nn.Module):
233 | def __init__(self, config_channels, prefix, bn=True, ratio=1):
234 | nn.Module.__init__(self)
235 | channels = config_channels.channels
236 | self.branch0 = Conv2d(config_channels.channels, config_channels(int(256 * ratio), '%s.branch0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn)
237 | # branch1
238 | config_channels.channels = channels
239 | self.branch1_0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch1_0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn)
240 | _channels = config_channels.channels
241 | self.branch1_1a = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch1_1a.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn)
242 | self.branch1_1b = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch1_1b.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn)
243 | # branch2
244 | config_channels.channels = channels
245 | self.branch2_0 = Conv2d(config_channels.channels, config_channels(int(384 * ratio), '%s.branch2_0.conv.weight' % prefix), kernel_size=1, stride=1, bn=bn)
246 | self.branch2_1 = Conv2d(config_channels.channels, config_channels(int(448 * ratio), '%s.branch2_1.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn)
247 | self.branch2_2 = Conv2d(config_channels.channels, config_channels(int(512 * ratio), '%s.branch2_2.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn)
248 | _channels = config_channels.channels
249 | self.branch2_3a = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch2_3a.conv.weight' % prefix), kernel_size=(1, 3), stride=1, padding=(0, 1), bn=bn)
250 | self.branch2_3b = Conv2d(_channels, config_channels(int(256 * ratio), '%s.branch2_3b.conv.weight' % prefix), kernel_size=(3, 1), stride=1, padding=(1, 0), bn=bn)
251 | # branch3
252 | config_channels.channels = channels
253 | branch = []
254 | branch.append(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False))
255 | branch.append(Conv2d(config_channels.channels, int(256 * ratio), kernel_size=1, stride=1, bn=bn))
256 | self.branch3 = nn.Sequential(*branch)
257 | # output
258 | config_channels.channels = self.branch0.conv.weight.size(0) + self.branch1_1a.conv.weight.size(0) + self.branch1_1b.conv.weight.size(0) + self.branch2_3a.conv.weight.size(0) + self.branch2_3b.conv.weight.size(0) + self.branch3[-1].conv.weight.size(0)
259 |
260 | def forward(self, x):
261 | x0 = self.branch0(x)
262 |
263 | x1_0 = self.branch1_0(x)
264 | x1_1a = self.branch1_1a(x1_0)
265 | x1_1b = self.branch1_1b(x1_0)
266 | x1 = torch.cat((x1_1a, x1_1b), 1)
267 |
268 | x2_0 = self.branch2_0(x)
269 | x2_1 = self.branch2_1(x2_0)
270 | x2_2 = self.branch2_2(x2_1)
271 | x2_3a = self.branch2_3a(x2_2)
272 | x2_3b = self.branch2_3b(x2_2)
273 | x2 = torch.cat((x2_3a, x2_3b), 1)
274 |
275 | x3 = self.branch3(x)
276 |
277 | out = torch.cat((x0, x1, x2, x3), 1)
278 | return out
279 |
280 |
281 | class Inception4(nn.Module):
282 | def __init__(self, config_channels, ratio=1):
283 | nn.Module.__init__(self)
284 | features = []
285 | bn = config_channels.config.getboolean('batch_norm', 'enable')
286 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=2, bn=bn))
287 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, bn=bn))
288 | features.append(Conv2d(config_channels.channels, config_channels(64, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, padding=1, bn=bn))
289 | features.append(Mixed_3a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
290 | features.append(Mixed_4a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
291 | features.append(Mixed_5a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
292 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
293 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
294 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
295 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
296 | features.append(Reduction_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) # Mixed_6a
297 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
298 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
299 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
300 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
301 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
302 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
303 | features.append(Inception_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
304 | features.append(Reduction_B(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio)) # Mixed_7a
305 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
306 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
307 | features.append(Inception_C(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
308 | self.features = nn.Sequential(*features)
309 | self.init(config_channels)
310 |
311 | def init(self, config_channels):
312 | try:
313 | gamma = config_channels.config.getboolean('batch_norm', 'gamma')
314 | except (configparser.NoSectionError, configparser.NoOptionError):
315 | gamma = True
316 | try:
317 | beta = config_channels.config.getboolean('batch_norm', 'beta')
318 | except (configparser.NoSectionError, configparser.NoOptionError):
319 | beta = True
320 | for m in self.modules():
321 | if isinstance(m, nn.Conv2d):
322 | m.weight = nn.init.kaiming_normal_(m.weight)
323 | elif isinstance(m, nn.BatchNorm2d):
324 | m.weight.fill_(1)
325 | m.bias.zero_()
326 | m.weight.requires_grad = gamma
327 | m.bias.requires_grad = beta
328 | try:
329 | if config_channels.config.getboolean('model', 'pretrained'):
330 | settings = pretrained_settings['inceptionv4'][config_channels.config.get('inception4', 'pretrained')]
331 | logging.info('use pretrained model: ' + str(settings))
332 | state_dict = self.state_dict()
333 | for key, value in torch.utils.model_zoo.load_url(settings['url']).items():
334 | if key in state_dict:
335 | state_dict[key] = value
336 | self.load_state_dict(state_dict)
337 | except (configparser.NoSectionError, configparser.NoOptionError):
338 | pass
339 |
340 | def forward(self, x):
341 | return self.features(x)
342 |
343 | def scope(self, name):
344 | return '.'.join(name.split('.')[:-2])
345 |
346 |
347 | class Inception4_down3_4(Inception4):
348 | def __init__(self, config_channels, ratio=1 / 4):
349 | nn.Module.__init__(self)
350 | features = []
351 | bn = config_channels.config.getboolean('batch_norm', 'enable')
352 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=2, bn=bn))
353 | features.append(Conv2d(config_channels.channels, config_channels(32, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, bn=bn))
354 | features.append(Conv2d(config_channels.channels, config_channels(64, 'features.%d.conv.weight' % len(features)), kernel_size=3, stride=1, padding=1, bn=bn))
355 | features.append(Mixed_3a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
356 | features.append(Mixed_4a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
357 | features.append(Mixed_5a(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
358 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
359 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
360 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
361 | features.append(Inception_A(config_channels, 'features.%d' % len(features), bn=bn, ratio=ratio))
362 | self.features = nn.Sequential(*features)
363 | self.init(config_channels)
364 |
--------------------------------------------------------------------------------
/model/dnn/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import collections
19 |
20 | import torch.nn as nn
21 |
22 | import model
23 |
24 |
25 | def conv_bn(in_channels, out_channels, stride):
26 | return nn.Sequential(collections.OrderedDict([
27 | ('conv', nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)),
28 | ('bn', nn.BatchNorm2d(out_channels)),
29 | ('act', nn.ReLU(inplace=True)),
30 | ]))
31 |
32 |
33 | def conv_dw(in_channels, stride):
34 | return nn.Sequential(collections.OrderedDict([
35 | ('conv', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)),
36 | ('bn', nn.BatchNorm2d(in_channels)),
37 | ('act', nn.ReLU(inplace=True)),
38 | ]))
39 |
40 |
41 | def conv_pw(in_channels, out_channels):
42 | return nn.Sequential(collections.OrderedDict([
43 | ('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)),
44 | ('bn', nn.BatchNorm2d(out_channels)),
45 | ('act', nn.ReLU(inplace=True)),
46 | ]))
47 |
48 |
49 | def conv_unit(in_channels, out_channels, stride):
50 | return nn.Sequential(collections.OrderedDict([
51 | ('dw', conv_dw(in_channels, stride)),
52 | ('pw', conv_pw(in_channels, out_channels)),
53 | ]))
54 |
55 |
56 | class MobileNet(nn.Module):
57 | def __init__(self, config_channels):
58 | nn.Module.__init__(self)
59 | layers = []
60 | layers.append(conv_bn(config_channels.channels, config_channels(32, 'layers.%d.conv.weight' % len(layers)), 2))
61 | layers.append(conv_unit(config_channels.channels, config_channels(64, 'layers.%d.pw.conv.weight' % len(layers)), 1))
62 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 2))
63 | layers.append(conv_unit(config_channels.channels, config_channels(128, 'layers.%d.pw.conv.weight' % len(layers)), 1))
64 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 2))
65 | layers.append(conv_unit(config_channels.channels, config_channels(256, 'layers.%d.pw.conv.weight' % len(layers)), 1))
66 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 2))
67 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
68 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
69 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
70 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
71 | layers.append(conv_unit(config_channels.channels, config_channels(512, 'layers.%d.pw.conv.weight' % len(layers)), 1))
72 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 2))
73 | layers.append(conv_unit(config_channels.channels, config_channels(1024, 'layers.%d.pw.conv.weight' % len(layers)), 1))
74 | self.layers = nn.Sequential(*layers)
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | m.weight = nn.init.kaiming_normal_(m.weight)
79 | elif isinstance(m, nn.BatchNorm2d):
80 | m.weight.fill_(1)
81 | m.bias.zero_()
82 |
83 | def forward(self, x):
84 | return self.layers(x)
85 |
--------------------------------------------------------------------------------
/model/dnn/mobilenet2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import torch.nn as nn
19 | import math
20 |
21 |
22 | def conv_bn(inp, oup, stride, dilation=1):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False, dilation=dilation),
25 | nn.BatchNorm2d(oup),
26 | nn.ReLU(inplace=True)
27 | )
28 |
29 |
30 | def conv_1x1_bn(inp, oup):
31 | return nn.Sequential(
32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33 | nn.BatchNorm2d(oup),
34 | nn.ReLU(inplace=True)
35 | )
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride, expand_ratio):
40 | super(InvertedResidual, self).__init__()
41 | self.stride = stride
42 | assert stride in [1, 2]
43 |
44 | self.use_res_connect = self.stride == 1 and inp == oup
45 |
46 | self.conv = nn.Sequential(
47 | # pw
48 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
49 | nn.BatchNorm2d(inp * expand_ratio),
50 | nn.ReLU(inplace=True),
51 | # dw
52 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
53 | nn.BatchNorm2d(inp * expand_ratio),
54 | nn.ReLU(inplace=True),
55 | # pw-linear
56 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
57 | nn.BatchNorm2d(oup),
58 | )
59 |
60 | def forward(self, x):
61 | if self.use_res_connect:
62 | return x + self.conv(x)
63 | else:
64 | return self.conv(x)
65 |
66 |
67 | class MobileNet2(nn.Module):
68 | def __init__(self, config_channels, input_size=224, last_channel=320, width_mult=1., dilation=1, ratio=1):
69 | nn.Module.__init__(self)
70 | # setting of inverted residual blocks
71 | self.interverted_residual_setting = [
72 | # t, c, n, s
73 | [1, int(16 * ratio), 1, 1],
74 | [6, int(24 * ratio), 2, 2],
75 | [6, int(32 * ratio), 3, 2],
76 | [6, int(64 * ratio), 4, 1], # stride 2->1
77 | [6, int(96 * ratio), 3, 1],
78 | [6, int(160 * ratio), 3, 1], # stride 2->1
79 | [6, int(320 * ratio), 1, 1],
80 | ]
81 |
82 | # building first layer
83 | assert input_size % 32 == 0
84 | input_channel = int(32 * width_mult)
85 | if last_channel is None:
86 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
87 | else:
88 | self.last_channel = int(last_channel * ratio)
89 | self.features = [conv_bn(3, input_channel, 2)]
90 | # building inverted residual blocks
91 | for t, c, n, s in self.interverted_residual_setting:
92 | output_channel = int(c * width_mult)
93 | for i in range(n):
94 | if i == 0:
95 | self.features.append(InvertedResidual(input_channel, output_channel, s, t))
96 | else:
97 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t))
98 | input_channel = output_channel
99 | # building last several layers
100 | self.features.append(conv_bn(input_channel, self.last_channel, 1, dilation=dilation))
101 | #self.features.append(nn.AvgPool2d(input_size/32))
102 | config_channels.channels = self.last_channel # temp
103 |
104 | # make it nn.Sequential
105 | self.features = nn.Sequential(*self.features)
106 |
107 | self._initialize_weights()
108 |
109 | def forward(self, x):
110 | return self.features(x)
111 |
112 | def _initialize_weights(self):
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | if m.bias is not None:
118 | m.bias.data.zero_()
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1) # PyTorch's bug
121 | m.bias.data.zero_() # PyTorch's bug
122 | elif isinstance(m, nn.Linear):
123 | m.weight.normal_(0, 0.01)
124 | m.bias.zero_()
125 |
126 |
127 | class MobileNet2Dilate2(MobileNet2):
128 | def __init__(self, config_channels):
129 | MobileNet2.__init__(self, config_channels, dilation=2)
130 |
131 |
132 | class MobileNet2Dilate4(MobileNet2):
133 | def __init__(self, config_channels):
134 | MobileNet2.__init__(self, config_channels, dilation=4)
135 |
136 |
137 | class MobileNet2Half(MobileNet2):
138 | def __init__(self, config_channels):
139 | MobileNet2.__init__(self, config_channels, ratio=1 / 2)
140 |
141 |
142 | class MobileNet2Quarter(MobileNet2):
143 | def __init__(self, config_channels):
144 | MobileNet2.__init__(self, config_channels, ratio=1 / 4)
145 |
--------------------------------------------------------------------------------
/model/dnn/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import re
20 |
21 | import torch.nn as nn
22 | import torch.utils.model_zoo as model_zoo
23 | import torchvision.models.resnet as _model
24 | from torchvision.models.resnet import conv3x3
25 |
26 | import model
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | def __init__(self, config_channels, prefix, channels, stride=1):
31 | nn.Module.__init__(self)
32 | channels_in = config_channels.channels
33 | self.conv1 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), stride)
34 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix))
37 | self.bn2 = nn.BatchNorm2d(config_channels.channels)
38 | if stride > 1 or channels_in != config_channels.channels:
39 | downsample = []
40 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False))
41 | downsample.append(nn.BatchNorm2d(config_channels.channels))
42 | self.downsample = nn.Sequential(*downsample)
43 | else:
44 | self.downsample = None
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 |
65 | class Bottleneck(nn.Module):
66 | def __init__(self, config_channels, prefix, channels, stride=1):
67 | nn.Module.__init__(self)
68 | channels_in = config_channels.channels
69 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv1.weight' % prefix), kernel_size=1, bias=False)
70 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
71 | self.conv2 = nn.Conv2d(config_channels.channels, config_channels(channels, '%s.conv2.weight' % prefix), kernel_size=3, stride=stride, padding=1, bias=False)
72 | self.bn2 = nn.BatchNorm2d(config_channels.channels)
73 | self.conv3 = nn.Conv2d(config_channels.channels, config_channels(channels * 4, '%s.conv3.weight' % prefix), kernel_size=1, bias=False)
74 | self.bn3 = nn.BatchNorm2d(config_channels.channels)
75 | self.relu = nn.ReLU(inplace=True)
76 | if stride > 1 or channels_in != config_channels.channels:
77 | downsample = []
78 | downsample.append(nn.Conv2d(channels_in, config_channels.channels, kernel_size=1, stride=stride, bias=False))
79 | downsample.append(nn.BatchNorm2d(config_channels.channels))
80 | self.downsample = nn.Sequential(*downsample)
81 | else:
82 | self.downsample = None
83 |
84 | def forward(self, x):
85 | residual = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv3(out)
96 | out = self.bn3(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class ResNet(_model.ResNet):
108 | def __init__(self, config_channels, anchors, num_cls, block, layers):
109 | nn.Module.__init__(self)
110 | self.conv1 = nn.Conv2d(config_channels.channels, config_channels(64, 'conv1.weight'), kernel_size=7, stride=2, padding=3, bias=False)
111 | self.bn1 = nn.BatchNorm2d(config_channels.channels)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
114 | self.layer1 = self._make_layer(config_channels, 'layer1', block, 64, layers[0])
115 | self.layer2 = self._make_layer(config_channels, 'layer2', block, 128, layers[1], stride=2)
116 | self.layer3 = self._make_layer(config_channels, 'layer3', block, 256, layers[2], stride=2)
117 | self.layer4 = self._make_layer(config_channels, 'layer4', block, 512, layers[3], stride=2)
118 |
119 | for m in self.modules():
120 | if isinstance(m, nn.Conv2d):
121 | m.weight = nn.init.kaiming_normal_(m.weight)
122 | elif isinstance(m, nn.BatchNorm2d):
123 | m.weight.fill_(1)
124 | m.bias.zero_()
125 |
126 | def _make_layer(self, config_channels, prefix, block, channels, blocks, stride=1):
127 | layers = []
128 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels, stride))
129 | for i in range(1, blocks):
130 | layers.append(block(config_channels, '%s.%d' % (prefix, len(layers)), channels))
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | x = self.conv1(x)
135 | x = self.bn1(x)
136 | x = self.relu(x)
137 | x = self.maxpool(x)
138 |
139 | x = self.layer1(x)
140 | x = self.layer2(x)
141 | x = self.layer3(x)
142 | x = self.layer4(x)
143 |
144 | return x
145 |
146 | def scope(self, name):
147 | comp = name.split('.')[:-1]
148 | try:
149 | comp[-1] = re.search('[(conv)|(bn)](\d+)', comp[-1]).group(1)
150 | except AttributeError:
151 | if len(comp) > 1:
152 | if comp[-2] == 'downsample':
153 | comp = comp[:-1]
154 | else:
155 | assert False, name
156 | else:
157 | assert comp[-1] == 'conv', name
158 | return '.'.join(comp)
159 |
160 |
161 | def resnet18(config_channels, **kwargs):
162 | model = ResNet(config_channels, BasicBlock, [2, 2, 2, 2], **kwargs)
163 | if config_channels.config.getboolean('model', 'pretrained'):
164 | url = _model.model_urls['resnet18']
165 | logging.info('use pretrained model: ' + url)
166 | state_dict = model.state_dict()
167 | for key, value in model_zoo.load_url(url).items():
168 | if key in state_dict:
169 | state_dict[key] = value
170 | model.load_state_dict(state_dict)
171 | return model
172 |
173 |
174 | def resnet34(config_channels, **kwargs):
175 | model = ResNet(config_channels, BasicBlock, [3, 4, 6, 3], **kwargs)
176 | if config_channels.config.getboolean('model', 'pretrained'):
177 | url = _model.model_urls['resnet34']
178 | logging.info('use pretrained model: ' + url)
179 | state_dict = model.state_dict()
180 | for key, value in model_zoo.load_url(url).items():
181 | if key in state_dict:
182 | state_dict[key] = value
183 | model.load_state_dict(state_dict)
184 | return model
185 |
186 |
187 | def resnet50(config_channels, **kwargs):
188 | model = ResNet(config_channels, Bottleneck, [3, 4, 6, 3], **kwargs)
189 | if config_channels.config.getboolean('model', 'pretrained'):
190 | url = _model.model_urls['resnet50']
191 | logging.info('use pretrained model: ' + url)
192 | state_dict = model.state_dict()
193 | for key, value in model_zoo.load_url(url).items():
194 | if key in state_dict:
195 | state_dict[key] = value
196 | model.load_state_dict(state_dict)
197 | return model
198 |
199 |
200 | def resnet101(config_channels, **kwargs):
201 | model = ResNet(config_channels, Bottleneck, [3, 4, 23, 3], **kwargs)
202 | if config_channels.config.getboolean('model', 'pretrained'):
203 | url = _model.model_urls['resnet101']
204 | logging.info('use pretrained model: ' + url)
205 | state_dict = model.state_dict()
206 | for key, value in model_zoo.load_url(url).items():
207 | if key in state_dict:
208 | state_dict[key] = value
209 | model.load_state_dict(state_dict)
210 | return model
211 |
212 |
213 | def resnet152(config_channels, **kwargs):
214 | model = ResNet(config_channels, Bottleneck, [3, 8, 36, 3], **kwargs)
215 | if config_channels.config.getboolean('model', 'pretrained'):
216 | url = _model.model_urls['resnet152']
217 | logging.info('use pretrained model: ' + url)
218 | state_dict = model.state_dict()
219 | for key, value in model_zoo.load_url(url).items():
220 | if key in state_dict:
221 | state_dict[key] = value
222 | model.load_state_dict(state_dict)
223 | return model
224 |
--------------------------------------------------------------------------------
/model/dnn/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 |
20 | import torch.nn as nn
21 | import torch.utils.model_zoo as model_zoo
22 | import torchvision.models.vgg as _model
23 | from torchvision.models.vgg import model_urls, cfg
24 |
25 | import model
26 |
27 |
28 | class VGG(_model.VGG):
29 | def __init__(self, config_channels, features):
30 | nn.Module.__init__(self)
31 | self.features = features
32 | self._initialize_weights()
33 |
34 | def forward(self, x):
35 | return self.features(x)
36 |
37 |
38 | def make_layers(config_channels, cfg, batch_norm=False):
39 | features = []
40 | for v in cfg:
41 | if v == 'M':
42 | features += [nn.MaxPool2d(kernel_size=2, stride=2)]
43 | else:
44 | conv2d = nn.Conv2d(config_channels.channels, config_channels(v, 'features.%d.weight' % len(features)), kernel_size=3, padding=1)
45 | if batch_norm:
46 | features += [conv2d, nn.BatchNorm2d(config_channels.channels), nn.ReLU(inplace=True)]
47 | else:
48 | features += [conv2d, nn.ReLU(inplace=True)]
49 | return nn.Sequential(*features)
50 |
51 |
52 | def vgg11(config_channels):
53 | model = VGG(config_channels, make_layers(config_channels, cfg['A']))
54 | if config_channels.config.getboolean('model', 'pretrained'):
55 | url = model_urls['vgg11']
56 | logging.info('use pretrained model: ' + url)
57 | state_dict = model.state_dict()
58 | for key, value in model_zoo.load_url(url).items():
59 | if key in state_dict:
60 | state_dict[key] = value
61 | model.load_state_dict(state_dict)
62 | return model
63 |
64 |
65 | def vgg11_bn(config_channels):
66 | model = VGG(config_channels, make_layers(config_channels, cfg['A'], batch_norm=True))
67 | if config_channels.config.getboolean('model', 'pretrained'):
68 | url = model_urls['vgg11_bn']
69 | logging.info('use pretrained model: ' + url)
70 | state_dict = model.state_dict()
71 | for key, value in model_zoo.load_url(url).items():
72 | if key in state_dict:
73 | state_dict[key] = value
74 | model.load_state_dict(state_dict)
75 | return model
76 |
77 |
78 | def vgg13(config_channels):
79 | model = VGG(config_channels, make_layers(config_channels, cfg['B']))
80 | if config_channels.config.getboolean('model', 'pretrained'):
81 | url = model_urls['vgg13']
82 | logging.info('use pretrained model: ' + url)
83 | state_dict = model.state_dict()
84 | for key, value in model_zoo.load_url(url).items():
85 | if key in state_dict:
86 | state_dict[key] = value
87 | model.load_state_dict(state_dict)
88 | return model
89 |
90 |
91 | def vgg13_bn(config_channels):
92 | model = VGG(config_channels, make_layers(config_channels, cfg['B'], batch_norm=True))
93 | if config_channels.config.getboolean('model', 'pretrained'):
94 | url = model_urls['vgg13_bn']
95 | logging.info('use pretrained model: ' + url)
96 | state_dict = model.state_dict()
97 | for key, value in model_zoo.load_url(url).items():
98 | if key in state_dict:
99 | state_dict[key] = value
100 | model.load_state_dict(state_dict)
101 | return model
102 |
103 |
104 | def vgg16(config_channels):
105 | model = VGG(config_channels, make_layers(config_channels, cfg['D']))
106 | if config_channels.config.getboolean('model', 'pretrained'):
107 | url = model_urls['vgg16']
108 | logging.info('use pretrained model: ' + url)
109 | state_dict = model.state_dict()
110 | for key, value in model_zoo.load_url(url).items():
111 | if key in state_dict:
112 | state_dict[key] = value
113 | model.load_state_dict(state_dict)
114 | return model
115 |
116 |
117 | def vgg16_bn(config_channels):
118 | model = VGG(config_channels, make_layers(config_channels, cfg['D'], batch_norm=True))
119 | if config_channels.config.getboolean('model', 'pretrained'):
120 | url = model_urls['vgg16_bn']
121 | logging.info('use pretrained model: ' + url)
122 | state_dict = model.state_dict()
123 | for key, value in model_zoo.load_url(url).items():
124 | if key in state_dict:
125 | state_dict[key] = value
126 | model.load_state_dict(state_dict)
127 | return model
128 |
129 |
130 | def vgg19(config_channels):
131 | model = VGG(config_channels, make_layers(config_channels, cfg['E']))
132 | if config_channels.config.getboolean('model', 'pretrained'):
133 | url = model_urls['vgg19']
134 | logging.info('use pretrained model: ' + url)
135 | state_dict = model.state_dict()
136 | for key, value in model_zoo.load_url(url).items():
137 | if key in state_dict:
138 | state_dict[key] = value
139 | model.load_state_dict(state_dict)
140 | return model
141 |
142 |
143 | def vgg19_bn(config_channels):
144 | model = VGG(config_channels, make_layers(config_channels, cfg['E'], batch_norm=True))
145 | if config_channels.config.getboolean('model', 'pretrained'):
146 | url = model_urls['vgg19_bn']
147 | logging.info('use pretrained model: ' + url)
148 | state_dict = model.state_dict()
149 | for key, value in model_zoo.load_url(url).items():
150 | if key in state_dict:
151 | state_dict[key] = value
152 | model.load_state_dict(state_dict)
153 | return model
154 |
155 |
156 | def person18_19(config_channels):
157 | cfg = [
158 | 64, 64, 'M',
159 | 128, 128, 'M',
160 | 256, 256, 256, 256, 'M',
161 | 512, 512,
162 | 256, 128,
163 | ]
164 | return VGG(config_channels, make_layers(config_channels, cfg))
165 |
166 |
167 | def hand21(config_channels):
168 | cfg = [
169 | 64, 64, 'M',
170 | 128, 128, 'M',
171 | 256, 256, 256, 256, 'M',
172 | 512, 512, 512, 512,
173 | 512, 512, 128,
174 | ]
175 | return VGG(config_channels, make_layers(config_channels, cfg))
176 |
--------------------------------------------------------------------------------
/model/stages/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/model/stages/__init__.py
--------------------------------------------------------------------------------
/model/stages/openpose.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import collections.abc
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 |
24 | class Conv2d(nn.Module):
25 | def __init__(self, in_channels, out_channels, kernel_size, padding=True, stride=1, bn=False, act=True):
26 | nn.Module.__init__(self)
27 | if isinstance(padding, bool):
28 | if isinstance(kernel_size, collections.abc.Iterable):
29 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0
30 | else:
31 | padding = (kernel_size - 1) // 2 if padding else 0
32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn)
33 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x
34 | self.act = nn.ReLU(inplace=True) if act else lambda x: x
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | x = self.bn(x)
39 | x = self.act(x)
40 | return x
41 |
42 |
43 | class Stage0(nn.Module):
44 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix):
45 | nn.Module.__init__(self)
46 | channels_stage = config_channels.channels
47 | for name, channels in channel_dict.items():
48 | config_channels.channels = channels_stage
49 | branch = []
50 | for _ in range(3):
51 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 3))
52 | branch.append(Conv2d(config_channels.channels, config_channels(512, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 1))
53 | branch.append(Conv2d(config_channels.channels, channels, 1, act=False))
54 | setattr(self, name, nn.Sequential(*branch))
55 | config_channels.channels = channels_dnn + sum(branch[-1].conv.weight.size(0) for branch in self._modules.values())
56 | self.init()
57 |
58 | def init(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.Conv2d):
61 | m.weight = nn.init.xavier_normal_(m.weight)
62 | elif isinstance(m, nn.BatchNorm2d):
63 | m.weight.fill_(1)
64 | m.bias.zero_()
65 |
66 | def forward(self, x, **kwargs):
67 | return {name: var(x) for name, var in self._modules.items()}
68 |
69 |
70 | class Stage(nn.Module):
71 | def __init__(self, config_channels, channels, channels_dnn, prefix):
72 | nn.Module.__init__(self)
73 | channels_stage = config_channels.channels
74 | for name, _channels in channels.items():
75 | config_channels.channels = channels_stage
76 | branch = []
77 | for _ in range(5):
78 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 7))
79 | branch.append(Conv2d(config_channels.channels, config_channels(128, '%s.%s.%d.conv.weight' % (prefix, name, len(branch))), 1))
80 | branch.append(Conv2d(config_channels.channels, _channels, 1, act=False))
81 | setattr(self, name, nn.Sequential(*branch))
82 | config_channels.channels = channels_dnn + sum(branch[-1].conv.weight.size(0) for branch in self._modules.values())
83 | self.init()
84 |
85 | def init(self):
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | m.weight = nn.init.xavier_normal_(m.weight)
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.fill_(1)
91 | m.bias.zero_()
92 |
93 | def forward(self, x, **kwargs):
94 | x = torch.cat([kwargs[name] for name in ('limbs', 'parts')] + [x], 1)
95 | return {name: var(x) for name, var in self._modules.items()}
96 |
--------------------------------------------------------------------------------
/model/stages/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import collections.abc
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 |
24 | class Conv2d(nn.Module):
25 | def __init__(self, in_channels, out_channels, kernel_size, padding=True, stride=1, bn=False, act=True):
26 | nn.Module.__init__(self)
27 | if isinstance(padding, bool):
28 | if isinstance(kernel_size, collections.abc.Iterable):
29 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0
30 | else:
31 | padding = (kernel_size - 1) // 2 if padding else 0
32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn)
33 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x
34 | self.act = nn.ReLU(inplace=True) if act else lambda x: x
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | x = self.bn(x)
39 | x = self.act(x)
40 | return x
41 |
42 |
43 | class ConvTranspose2d(nn.Module):
44 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, bn=False, act=True):
45 | nn.Module.__init__(self)
46 | if isinstance(padding, bool):
47 | if isinstance(kernel_size, collections.abc.Iterable):
48 | padding = tuple((kernel_size - 1) // 2 for kernel_size in kernel_size) if padding else 0
49 | else:
50 | padding = (kernel_size - 1) // 2 if padding else 0
51 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=not bn)
52 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.01) if bn else lambda x: x
53 | self.act = nn.ReLU(inplace=True) if act else lambda x: x
54 |
55 | def forward(self, x):
56 | x = self.conv(x)
57 | x = self.bn(x)
58 | x = self.act(x)
59 | return x
60 |
61 |
62 | class Downsample(nn.Module):
63 | def __init__(self, config_channels, channels, prefix, kernel_sizes, pooling):
64 | nn.Module.__init__(self)
65 | self.seq = nn.Sequential(*[Conv2d(config_channels.channels, config_channels(channels, '%s.seq.%d.conv.weight' % (prefix, index)), kernel_size) for index, kernel_size in enumerate(kernel_sizes)])
66 | self.downsample = nn.MaxPool2d(kernel_size=pooling)
67 |
68 | def forward(self, x):
69 | feature = self.seq(x)
70 | return self.downsample(feature), feature
71 |
72 |
73 | class Upsample(nn.Module):
74 | def __init__(self, config_channels, channels, channels_min, prefix, sample, kernel_sizes, ratio=1):
75 | nn.Module.__init__(self)
76 | self.upsample = ConvTranspose2d(config_channels.channels, config_channels(channels, '%s.upsample.conv.weight' % prefix, fn=lambda var: var.size(1)), kernel_size=sample, stride=sample)
77 | config_channels.channels += channels # concat
78 |
79 | seq = []
80 | if ratio < 1:
81 | seq.append(Conv2d(config_channels.channels, config_channels(max(int(config_channels.channels * ratio), channels_min), '%s.seq.%d.conv.weight' % (prefix, len(seq))), 1))
82 | for kernel_size in kernel_sizes:
83 | seq.append(Conv2d(config_channels.channels, config_channels(channels, '%s.seq.%d.conv.weight' % (prefix, len(seq))), kernel_size))
84 | self.seq = nn.Sequential(*seq)
85 |
86 | def forward(self, x, feature):
87 | x = self.upsample(x)
88 | x = torch.cat([x, feature], 1)
89 | return self.seq(x)
90 |
91 |
92 | class Branch(nn.Module):
93 | def __init__(self, config_channels, channels, prefix, multiply, ratio, kernel_sizes, sample):
94 | nn.Module.__init__(self)
95 | _channels = channels
96 | self.down = []
97 | for index, m in enumerate(multiply):
98 | name = 'down%d' % index
99 | block = Downsample(config_channels, _channels, '%s.%s' % (prefix, name), kernel_sizes, pooling=sample)
100 | setattr(self, name, block)
101 | self.down.append(block)
102 | _channels = int(_channels * m)
103 | self.top = nn.Sequential(*[Conv2d(config_channels.channels, config_channels(_channels, '%s.top.%d.conv.weight' % (prefix, index)), kernel_size) for index, kernel_size in enumerate(kernel_sizes)])
104 |
105 | self.up = []
106 | for index, block in enumerate(self.down[::-1]):
107 | name = 'up%d' % index
108 | block = Upsample(config_channels, block.seq[-1].conv.weight.size(0), channels, '%s.%s' % (prefix, name), sample, kernel_sizes, ratio)
109 | setattr(self, name, block)
110 | self.up.append(block)
111 | self.out = Conv2d(config_channels.channels, channels, 1, act=False)
112 |
113 | def forward(self, x):
114 | features = []
115 | for block in self.down:
116 | x, feature = block(x)
117 | features.append(feature)
118 | x = self.top(x)
119 |
120 | for block, feature in zip(self.up, features[::-1]):
121 | x = block(x, feature)
122 | return self.out(x)
123 |
124 |
125 | class Unet(nn.Module):
126 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2, 2], ratio=1, kernel_sizes=[3], sample=2):
127 | nn.Module.__init__(self)
128 | channels_stage = config_channels.channels
129 | for name, channels in channel_dict.items():
130 | config_channels.channels = channels_stage
131 | branch = Branch(config_channels, channels, '%s.%s' % (prefix, name), multiply, ratio, kernel_sizes, sample)
132 | setattr(self, name, branch)
133 | config_channels.channels = channels_dnn + sum(branch.out.conv.weight.size(0) for branch in self._modules.values())
134 |
135 | def forward(self, x, **kwargs):
136 | if kwargs:
137 | x = torch.cat([kwargs[name] for name in ('parts', 'limbs') if name in kwargs] + [x], 1)
138 | return {name: branch(x) for name, branch in self._modules.items()}
139 |
140 |
141 | class Unet1Sqz3(Unet):
142 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix):
143 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2], ratio=1 / 3)
144 |
145 |
146 | class Unet1Sqz3_a(Unet):
147 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix):
148 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[1.5], ratio=1 / 3)
149 |
150 |
151 | class Unet2Sqz3(Unet):
152 | def __init__(self, config_channels, channel_dict, channels_dnn, prefix):
153 | Unet.__init__(self, config_channels, channel_dict, channels_dnn, prefix, multiply=[2, 2], ratio=1 / 3)
154 |
--------------------------------------------------------------------------------
/quick_start.sh:
--------------------------------------------------------------------------------
1 | echo download COCO dataset
2 | LINKS="
3 | http://images.cocodataset.org/zips/train2014.zip
4 | http://images.cocodataset.org/zips/val2014.zip
5 | http://images.cocodataset.org/annotations/annotations_trainval2014.zip
6 | http://images.cocodataset.org/zips/train2017.zip
7 | http://images.cocodataset.org/zips/val2017.zip
8 | http://images.cocodataset.org/annotations/annotations_trainval2017.zip
9 | "
10 | ROOT=~/data/coco
11 | for LINK in $LINKS
12 | do
13 | aria2c --auto-file-renaming=false -d $ROOT $LINK
14 | unzip -n $ROOT/$(basename $LINK) -d $ROOT
15 | done
16 | rm $ROOT/val2014/COCO_val2014_000000320612.jpg
17 |
18 | echo cache data
19 | python3 cache.py -c config.ini config/original_person18_19.ini -m cache/name=cache_original
20 |
21 | echo download and cache the original model
22 | ROOT=~/model/openpose/pose/coco
23 | aria2c --auto-file-renaming=false -d $ROOT https://raw.githubusercontent.com/CMU-Perceptual-Computing-Lab/openpose/master/models/pose/coco/pose_deploy_linevec.prototxt
24 | aria2c --auto-file-renaming=false -d $ROOT http://posefs1.perception.cs.cmu.edu/OpenPose/models/pose/coco/pose_iter_440000.caffemodel
25 | python3 convert_caffe_torch.py config/convert_caffe_torch/original_person18_19.tsv $ROOT/pose_deploy_linevec.prototxt $ROOT/pose_iter_440000.caffemodel -c config.ini config/original_person18_19.ini -m model/name=model_original -d
26 |
27 | echo demo keypoint estimation via a webcam
28 | python3 estimate.py -c config.ini config/original_person18_19.ini -m model/name=model_original
29 |
30 | echo training
31 | python3 train.py -c config.ini config/original_person18_19.ini -m cache/name=cache_original model/name=model_original
--------------------------------------------------------------------------------
/receptive_field_analyzer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import argparse
20 | import configparser
21 | import logging
22 | import logging.config
23 | import multiprocessing
24 | import yaml
25 |
26 | import numpy as np
27 | import scipy.misc
28 | import torch.autograd
29 | import torch.cuda
30 | import torch.optim
31 | import torch.utils.data
32 | import torch.nn as nn
33 | import tqdm
34 | import humanize
35 |
36 | import model
37 | import utils.data
38 | import utils.train
39 | import utils.visualize
40 |
41 |
42 | class Dataset(torch.utils.data.Dataset):
43 | def __init__(self, height, width):
44 | self.points = np.array([(i, j) for i in range(height) for j in range(width)])
45 |
46 | def __len__(self):
47 | return len(self.points)
48 |
49 | def __getitem__(self, index):
50 | return self.points[index]
51 |
52 |
53 | class Analyzer(object):
54 | def __init__(self, args, config):
55 | self.args = args
56 | self.config = config
57 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58 | self.model_dir = utils.get_model_dir(config)
59 | _, self.num_parts = utils.get_dataset_mappers(config)
60 | self.limbs_index = utils.get_limbs_index(config)
61 | self.step, self.epoch, self.dnn, self.stages = self.load()
62 | self.inference = model.Inference(self.config, self.dnn, self.stages)
63 | self.inference.eval()
64 | logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values())))
65 | if torch.cuda.is_available():
66 | self.inference.cuda()
67 | self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
68 | t = torch.zeros(1, 3, self.height, self.width).to(self.device)
69 | output = self.dnn(t)
70 | _, _, self.rows, self.cols = output.size()
71 | self.i, self.j = self.rows // 2, self.cols // 2
72 | self.output = output[:, :, self.i, self.j]
73 | dataset = Dataset(self.height, self.width)
74 | try:
75 | workers = self.config.getint('data', 'workers')
76 | except configparser.NoOptionError:
77 | workers = multiprocessing.cpu_count()
78 | self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers)
79 |
80 | def __call__(self):
81 | changed = np.zeros([self.height, self.width], np.bool)
82 | for yx in tqdm.tqdm(self.loader):
83 | batch_size = yx.size(0)
84 | tensor = torch.zeros(batch_size, 3, self.height, self.width)
85 | for i, _yx in enumerate(torch.unbind(yx)):
86 | y, x = torch.unbind(_yx)
87 | tensor[i, :, y, x] = 1
88 | tensor = tensor.to(self.device)
89 | output = self.dnn(tensor)
90 | output = output[:, :, self.i, self.j]
91 | cmp = output == self.output
92 | cmp = torch.prod(cmp, -1)
93 | for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)):
94 | y, x = torch.unbind(_yx)
95 | changed[y, x] = c
96 | return changed
97 |
98 | def load(self):
99 | try:
100 | path, step, epoch = utils.train.load_model(self.model_dir)
101 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
102 | except (FileNotFoundError, ValueError):
103 | step, epoch = 0, 0
104 | state_dict = {name: None for name in ('dnn', 'stages')}
105 | config_channels_dnn = model.ConfigChannels(self.config, state_dict['dnn'])
106 | dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels_dnn)
107 | config_channels_stages = model.ConfigChannels(self.config, state_dict['stages'], config_channels_dnn.channels)
108 | channel_dict = model.channel_dict(self.num_parts, len(self.limbs_index))
109 | stages = nn.Sequential(*[utils.parse_attr(s)(config_channels_stages, channel_dict, config_channels_dnn.channels, str(i)) for i, s in enumerate(self.config.get('model', 'stages').split())])
110 | return step, epoch, dnn, stages
111 |
112 |
113 | def main():
114 | args = make_args()
115 | config = configparser.ConfigParser()
116 | utils.load_config(config, args.config)
117 | for cmd in args.modify:
118 | utils.modify_config(config, cmd)
119 | with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
120 | logging.config.dictConfig(yaml.load(f))
121 | analyzer = Analyzer(args, config)
122 | changed = analyzer()
123 | os.makedirs(analyzer.model_dir, exist_ok=True)
124 | path = os.path.join(analyzer.model_dir, args.filename)
125 | scipy.misc.imsave(path, (~changed).astype(np.uint8) * 255)
126 | logging.info(path)
127 |
128 |
129 | def make_args():
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
132 | parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
133 | parser.add_argument('-b', '--batch_size', default=16, type=int, help='batch size')
134 | parser.add_argument('-n', '--filename', default='receptive_field.jpg')
135 | parser.add_argument('--logging', default='logging.yml', help='logging config')
136 | return parser.parse_args()
137 |
138 |
139 | if __name__ == '__main__':
140 | main()
141 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | pybenchmark
3 | graphviz
4 | torch>=0.4.0
5 | pandas
6 | onnx
7 | onnx_caffe2
8 | pretrainedmodels
9 | torchvision
10 | matplotlib
11 | filelock
12 | scikit_image
13 | inflection
14 | numpy
15 | humanize
16 | Pillow
17 | PyQt5
18 | scipy
19 | skimage
20 | tensorboardX>=1.2
21 | tensorflow
22 | PyYAML
23 | pycocotools
24 |
--------------------------------------------------------------------------------
/transform/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 |
20 | import torchvision
21 |
22 | import utils
23 |
24 |
25 | def parse_transform(config, method):
26 | if isinstance(method, str):
27 | attr = utils.parse_attr(method)
28 | sig = inspect.signature(attr)
29 | if len(sig.parameters) == 1:
30 | return attr(config)
31 | else:
32 | return attr()
33 | else:
34 | return method
35 |
36 |
37 | def get_transform(config, sequence, compose=torchvision.transforms.Compose):
38 | return compose([parse_transform(config, method) for method in sequence])
39 |
--------------------------------------------------------------------------------
/transform/augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import inspect
20 | import random
21 |
22 | import inflection
23 | import numpy as np
24 | import cv2
25 |
26 | import transform
27 |
28 |
29 | class Rotator(object):
30 | def __init__(self, y, x, height, width, angle):
31 | """
32 | A efficient tool to rotate multiple images in the same size.
33 | :author 申瑞珉 (Ruimin Shen)
34 | :param y: The y coordinate of rotation point.
35 | :param x: The x coordinate of rotation point.
36 | :param height: Image height.
37 | :param width: Image width.
38 | :param angle: Rotate angle.
39 | """
40 | self._mat = cv2.getRotationMatrix2D((x, y), angle, 1.0)
41 | r = np.abs(self._mat[0, :2])
42 | _height, _width = np.inner(r, [height, width]), np.inner(r, [width, height])
43 | fix_y, fix_x = _height / 2 - y, _width / 2 - x
44 | self._mat[:, 2] += [fix_x, fix_y]
45 | self._size = int(_width), int(_height)
46 |
47 | def __call__(self, image, flags=cv2.INTER_LINEAR, fill=None):
48 | if fill is None:
49 | fill = np.random.rand(3) * 256
50 | return cv2.warpAffine(image, self._mat, self._size, flags=flags, borderMode=cv2.BORDER_CONSTANT, borderValue=fill)
51 |
52 | def _rotate_points(self, points):
53 | _points = np.pad(points, [(0, 0), (0, 1)], 'constant')
54 | _points[:, 2] = 1
55 | _points = np.dot(self._mat, _points.T)
56 | return _points.T.astype(points.dtype)
57 |
58 | def rotate_points(self, points):
59 | return self._rotate_points(points[:, ::-1])[:, ::-1]
60 |
61 |
62 | def random_rotate(config, image, mask, keypoints, yx_min, yx_max, index):
63 | name = inspect.stack()[0][3]
64 | angle = random.uniform(*tuple(map(float, config.get('augmentation', name).split())))
65 | height, width = image.shape[:2]
66 | p1, p2 = np.copy(yx_min), np.copy(yx_max)
67 | p1[:, 0] = yx_max[:, 0]
68 | p2[:, 0] = yx_min[:, 0]
69 | points = np.concatenate([yx_min, yx_max, p1, p2], 0)
70 | rotator = Rotator(*((yx_min[index] + yx_max[index]) / 2), height, width, angle)
71 | image = rotator(image, fill=0)
72 | mask = rotator(mask, fill=0)
73 | keypoints[:, :, :2] = np.reshape(rotator.rotate_points(np.reshape(keypoints, [-1, 3])[:, :2]), [len(keypoints), -1, 2])
74 | points = rotator.rotate_points(points)
75 | bbox_points = np.reshape(points, [4, -1, 2])
76 | yx_min = np.apply_along_axis(lambda points: np.min(points, 0), 0, bbox_points)
77 | yx_max = np.apply_along_axis(lambda points: np.max(points, 0), 0, bbox_points)
78 | return image, mask, keypoints, yx_min, yx_max
79 |
80 |
81 | class RandomRotate(object):
82 | def __init__(self, config):
83 | self.config = config
84 | self.fn = eval(inflection.underscore(type(self).__name__))
85 |
86 | def __call__(self, data):
87 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], data['index'])
88 | return data
89 |
90 |
91 | def flip_horizontally(image, mask, keypoints, yx_min, yx_max):
92 | assert len(image.shape) == 3
93 | image = cv2.flip(image, 1)
94 | mask = cv2.flip(mask, 1)
95 | width = image.shape[1]
96 | keypoints[:, :, 1] = width - keypoints[:, :, 1]
97 | temp = width - yx_min[:, 1]
98 | yx_min[:, 1] = width - yx_max[:, 1]
99 | yx_max[:, 1] = temp
100 | return image, mask, keypoints, yx_min, yx_max
101 |
102 |
103 | class RandomFlipHorizontally(object):
104 | def __init__(self, config):
105 | self.config = config
106 | name = inflection.underscore(type(self).__name__)
107 | self.prob = config.getfloat('augmentation', name)
108 | with open(os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset'))) + '.txt', 'r') as f:
109 | lines = (line.strip() for line in f)
110 | self.symmetric = [int(line) if line else i for i, line in enumerate(lines)]
111 |
112 | def __call__(self, data):
113 | if random.random() > self.prob:
114 | data['image'], data['mask'], keypoints, data['yx_min'], data['yx_max'] = flip_horizontally(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'])
115 | assert keypoints.shape[1] == len(self.symmetric)
116 | keypoints = np.stack([[points[i] for i in self.symmetric] for points in keypoints])
117 | data['keypoints'] = keypoints
118 | return data
119 |
120 |
121 | def get_transform(config, sequence):
122 | return transform.get_transform(config, sequence)
123 |
--------------------------------------------------------------------------------
/transform/image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import random
19 |
20 | import numpy as np
21 | import torchvision
22 | import inflection
23 | import skimage.exposure
24 | import cv2
25 |
26 |
27 | class BGR2RGB(object):
28 | def __call__(self, image):
29 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30 |
31 |
32 | class BGR2HSV(object):
33 | def __call__(self, image):
34 | return cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
35 |
36 |
37 | class HSV2RGB(object):
38 | def __call__(self, image):
39 | return cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
40 |
41 |
42 | class RandomBlur(object):
43 | def __init__(self, config):
44 | name = inflection.underscore(type(self).__name__)
45 | self.adjust = tuple(map(int, config.get('augmentation', name).split()))
46 |
47 | def __call__(self, image):
48 | adjust = tuple(random.randint(1, adjust) for adjust in self.adjust)
49 | return cv2.blur(image, adjust)
50 |
51 |
52 | class RandomHue(object):
53 | def __init__(self, config):
54 | name = inflection.underscore(type(self).__name__)
55 | self.adjust = tuple(map(int, config.get('augmentation', name).split()))
56 |
57 | def __call__(self, hsv):
58 | h, s, v = cv2.split(hsv)
59 | adjust = random.randint(*self.adjust)
60 | h = h.astype(np.int) + adjust
61 | h = np.clip(h, 0, 179).astype(hsv.dtype)
62 | return cv2.merge((h, s, v))
63 |
64 |
65 | class RandomSaturation(object):
66 | def __init__(self, config):
67 | name = inflection.underscore(type(self).__name__)
68 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
69 |
70 | def __call__(self, hsv):
71 | h, s, v = cv2.split(hsv)
72 | adjust = random.uniform(*self.adjust)
73 | s = s * adjust
74 | s = np.clip(s, 0, 255).astype(hsv.dtype)
75 | return cv2.merge((h, s, v))
76 |
77 |
78 | class RandomBrightness(object):
79 | def __init__(self, config):
80 | name = inflection.underscore(type(self).__name__)
81 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
82 |
83 | def __call__(self, hsv):
84 | h, s, v = cv2.split(hsv)
85 | adjust = random.uniform(*self.adjust)
86 | v = v * adjust
87 | v = np.clip(v, 0, 255).astype(hsv.dtype)
88 | return cv2.merge((h, s, v))
89 |
90 |
91 | class RandomGamma(object):
92 | def __init__(self, config):
93 | name = inflection.underscore(type(self).__name__)
94 | self.adjust = tuple(map(float, config.get('augmentation', name).split()))
95 |
96 | def __call__(self, image):
97 | adjust = random.uniform(*self.adjust)
98 | return skimage.exposure.adjust_gamma(image, adjust)
99 |
100 |
101 | class Normalize(torchvision.transforms.Normalize):
102 | def __init__(self, config):
103 | name = inflection.underscore(type(self).__name__)
104 | mean, std = tuple(map(float, config.get('transform', name).split()))
105 | torchvision.transforms.Normalize.__init__(self, (mean, mean, mean), (std, std, std))
106 |
--------------------------------------------------------------------------------
/transform/resize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruiminshen/openpose-pytorch/f850084194ddccc6d401d5b11f61facc20ec2b75/transform/resize/__init__.py
--------------------------------------------------------------------------------
/transform/resize/image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inflection
19 | import numpy as np
20 | import cv2
21 |
22 |
23 | def rescale(image, height, width):
24 | return cv2.resize(image, (width, height))
25 |
26 |
27 | def rescale_scale(size, image_size):
28 | return image_size[0] / size[0], image_size[1] / size[1]
29 |
30 |
31 | class Rescale(object):
32 | def __init__(self):
33 | name = inflection.underscore(type(self).__name__)
34 | self.fn = eval(name)
35 | self.scale = eval(name + '_scale')
36 |
37 | def __call__(self, image, height, width):
38 | return self.fn(image, height, width)
39 |
40 |
41 | def fixed(image, height, width):
42 | _height, _width, _ = image.shape
43 | if _height / _width > height / width:
44 | scale = height / _height
45 | else:
46 | scale = width / _width
47 | m = np.eye(2, 3)
48 | m[0, 0] = scale
49 | m[1, 1] = scale
50 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC
51 | return cv2.warpAffine(image, m, (width, height), flags=flags)
52 |
53 |
54 | def fixed_scale(size, image_size):
55 | assert len(image_size) == 2
56 | _image_size = max(image_size)
57 | return _image_size / size[0], _image_size / size[1]
58 |
59 |
60 | class Fixed(object):
61 | def __init__(self):
62 | name = inflection.underscore(type(self).__name__)
63 | self.fn = eval(name)
64 | self.scale = eval(name + '_scale')
65 |
66 | def __call__(self, image, height, width):
67 | return self.fn(image, height, width)
68 |
69 |
70 | class Resize(object):
71 | def __init__(self, config):
72 | name = config.get('data', inflection.underscore(type(self).__name__))
73 | self.fn = eval(name)
74 | self.scale = eval(name + '_scale')
75 |
76 | def __call__(self, image, height, width):
77 | return self.fn(image, height, width)
78 |
--------------------------------------------------------------------------------
/transform/resize/label.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import inspect
19 |
20 | import inflection
21 | import numpy as np
22 | import cv2
23 |
24 |
25 | def rescale(image, mask, keypoints, yx_min, yx_max, height, width):
26 | _height, _width = image.shape[:2]
27 | scale = np.array([height / _height, width / _width], np.float32)
28 | image = cv2.resize(image, (width, height))
29 | mask = cv2.resize(mask, (width, height))
30 | keypoints[:, :, :2] *= scale
31 | yx_min *= scale
32 | yx_max *= scale
33 | return image, mask, keypoints, yx_min, yx_max
34 |
35 |
36 | class Rescale(object):
37 | def __init__(self):
38 | self.fn = eval(inflection.underscore(type(self).__name__))
39 |
40 | def __call__(self, data, height, width):
41 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], height, width)
42 | return data
43 |
44 |
45 | def padding(image, mask, keypoints, yx_min, yx_max, height, width):
46 | _height, _width, _ = image.shape
47 | if _height / _width > height / width:
48 | scale = height / _height
49 | else:
50 | scale = width / _width
51 | m = np.eye(2, 3)
52 | m[0, 0] = scale
53 | m[1, 1] = scale
54 | flags = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC
55 | image = cv2.warpAffine(image, m, (width, height), flags=flags)
56 | mask = cv2.warpAffine(mask, m, (width, height), flags=flags)
57 | return image, mask, keypoints, yx_min, yx_max
58 |
59 |
60 | class Padding(object):
61 | def __init__(self):
62 | self.fn = eval(inflection.underscore(type(self).__name__))
63 |
64 | def __call__(self, data, height, width):
65 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], height, width)
66 | return data
67 |
68 |
69 | def resize(config, image, mask, keypoints, yx_min, yx_max, height, width):
70 | fn = eval(config.get('data', inspect.stack()[0][3]))
71 | return fn(image, mask, keypoints, yx_min, yx_max, height, width)
72 |
73 |
74 | class Resize(object):
75 | def __init__(self, config):
76 | self.config = config
77 | self.fn = eval(config.get('data', inflection.underscore(type(self).__name__)))
78 |
79 | def __call__(self, data, height, width):
80 | data['image'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['yx_min'], data['yx_max'], height, width)
81 | return data
82 |
83 |
84 | def change_aspect_ratio(range, height_src, width_src, height_dst, width_dst):
85 | assert range >= 0
86 | if width_src < height_src:
87 | width = min(range, width_src)
88 | height = width * height_dst / width_dst
89 | else:
90 | height = min(range, height_src)
91 | width = height * width_dst / height_dst
92 | return height, width
93 |
94 |
95 | def repair(yx_min, yx_max, size):
96 | move = np.clip(yx_max - size, 0, None)
97 | yx_min -= move
98 | yx_max -= move
99 | move = np.clip(-yx_min, 0, None)
100 | yx_min += move
101 | yx_max += move
102 | return yx_min, yx_max
103 |
104 |
105 | def random_crop(config, image, mask, keypoints, yx_min, yx_max, index, height, width):
106 | name = inspect.stack()[0][3]
107 | scale1, scale2 = tuple(map(float, config.get('augmentation', name).split()))
108 | assert 1 <= scale1 <= scale2, (scale1, scale2)
109 | dtype = keypoints.dtype
110 | size = np.array(image.shape[:2], dtype)
111 | _yx_min, _yx_max = yx_min[index], yx_max[index]
112 | _center = (_yx_min + _yx_max) / 2
113 | _size = np.array(change_aspect_ratio(np.max(_yx_max - _yx_min), *size, height, width), dtype)
114 | _size1, _size2 = _size * scale1 / 2, _size * scale2 / 2
115 | yx_min1, yx_max1 = _center - _size1, _center + _size1
116 | yx_min2, yx_max2 = _center - _size2, _center + _size2
117 | yx_min1, yx_max1 = repair(yx_min1, yx_max1, size)
118 | yx_min2, yx_max2 = repair(yx_min2, yx_max2, size)
119 | margin = np.random.rand(4).astype(dtype) * np.concatenate([yx_min1 - yx_min2, yx_max2 - yx_max1], 0)
120 | yx_min_crop = np.clip(yx_min2 + margin[:2], 0, None)
121 | yx_max_crop = np.clip(yx_max2 - margin[2:], None, size)
122 | _ymin, _xmin = tuple(map(int, yx_min_crop))
123 | _ymax, _xmax = tuple(map(int, yx_max_crop))
124 | image = image[_ymin:_ymax, _xmin:_xmax, :]
125 | mask = mask[_ymin:_ymax, _xmin:_xmax]
126 | keypoints[:, :, :2] -= yx_min_crop
127 | yx_min -= yx_min_crop
128 | yx_max -= yx_min_crop
129 | return rescale(image, mask, keypoints, yx_min, yx_max, height, width)
130 |
131 |
132 | class RandomCrop(object):
133 | def __init__(self, config):
134 | self.config = config
135 | self.fn = eval(inflection.underscore(type(self).__name__))
136 |
137 | def __call__(self, data, height, width):
138 | data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'] = self.fn(self.config, data['image'], data['mask'], data['keypoints'], data['yx_min'], data['yx_max'], data['index'], height, width)
139 | return data
140 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import re
20 | import configparser
21 | import importlib
22 | import hashlib
23 |
24 | import numpy as np
25 | import pandas as pd
26 | import torch.autograd
27 | from PIL import Image
28 |
29 | import pyopenpose
30 |
31 |
32 | class Compose(object):
33 | def __init__(self, transforms):
34 | self.transforms = transforms
35 |
36 | def __call__(self, img, yx_min, yx_max, cls):
37 | for t in self.transforms:
38 | img, yx_min, yx_max, cls = t(img, yx_min, yx_max, cls)
39 | return img, yx_min, yx_max, cls
40 |
41 |
42 | class RegexList(list):
43 | def __init__(self, l):
44 | for s in l:
45 | prog = re.compile(s)
46 | self.append(prog)
47 |
48 | def __call__(self, s):
49 | for prog in self:
50 | if prog.match(s):
51 | return True
52 | return False
53 |
54 |
55 | class DatasetMapper(object):
56 | def __init__(self, mapper):
57 | self.mapper = mapper
58 |
59 | def __call__(self, parts, dtype=np.int64):
60 | assert len(parts.shape) == 2 and parts.shape[-1] == 3
61 | result = np.zeros([len(self.mapper), 3], dtype=parts.dtype)
62 | for i, func in enumerate(self.mapper):
63 | result[i] = func(parts)
64 | return result
65 |
66 |
67 | def get_dataset_mappers(config):
68 | root = os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset')))
69 | mappers = {}
70 | for dataset in os.listdir(root):
71 | path = os.path.join(root, dataset)
72 | if os.path.isfile(path):
73 | with open(path, 'r') as f:
74 | mapper = [eval(line.rstrip()) for line in f]
75 | mappers[dataset] = mapper
76 | sizes = set(map(lambda mapper: len(mapper), mappers.values()))
77 | assert len(sizes) == 1
78 | for dataset in mappers:
79 | mappers[dataset] = DatasetMapper(mappers[dataset])
80 | return mappers, next(iter(sizes))
81 |
82 |
83 | def get_limbs_index(config):
84 | dataset = os.path.expanduser(os.path.expandvars(config.get('cache', 'dataset')))
85 | limbs_index = np.loadtxt(dataset + '.tsv', dtype=np.int, delimiter='\t', ndmin=2)
86 | if len(limbs_index) > 0:
87 | assert pyopenpose.limbs_points(limbs_index) == get_dataset_mappers(config)[1]
88 | else:
89 | limbs_index = np.reshape(limbs_index, [0, 2])
90 | return limbs_index
91 |
92 |
93 | def get_cache_dir(config):
94 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
95 | name = config.get('cache', 'name')
96 | dataset = os.path.basename(config.get('cache', 'dataset'))
97 | return os.path.join(root, name, dataset)
98 |
99 |
100 | def get_model_dir(config):
101 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
102 | name = config.get('model', 'name')
103 | dataset = os.path.basename(config.get('cache', 'dataset'))
104 | dnn = config.get('model', 'dnn')
105 | stages = hashlib.md5(' '.join(config.get('model', 'stages').split()).encode()).hexdigest()
106 | return os.path.join(root, name, dataset, dnn, stages)
107 |
108 |
109 | def get_eval_db(config):
110 | root = os.path.expanduser(os.path.expandvars(config.get('config', 'root')))
111 | db = config.get('eval', 'db')
112 | return os.path.join(root, db)
113 |
114 |
115 | def get_category(config, cache_dir=None):
116 | path = os.path.expanduser(os.path.expandvars(config.get('cache', 'category'))) if cache_dir is None else os.path.join(cache_dir, 'category')
117 | with open(path, 'r') as f:
118 | return [line.strip() for line in f]
119 |
120 |
121 | def get_anchors(config, dtype=np.float32):
122 | path = os.path.expanduser(os.path.expandvars(config.get('model', 'anchors')))
123 | df = pd.read_csv(path, sep='\t', dtype=dtype)
124 | return df[['height', 'width']].values
125 |
126 |
127 | def parse_attr(s):
128 | m, n = s.rsplit('.', 1)
129 | m = importlib.import_module(m)
130 | return getattr(m, n)
131 |
132 |
133 | def load_config(config, paths):
134 | for path in paths:
135 | path = os.path.expanduser(os.path.expandvars(path))
136 | assert os.path.exists(path)
137 | config.read(path)
138 |
139 |
140 | def modify_config(config, cmd):
141 | var, value = cmd.split('=', 1)
142 | section, option = var.split('/')
143 | if value:
144 | config.set(section, option, value)
145 | else:
146 | try:
147 | config.remove_option(section, option)
148 | except (configparser.NoSectionError, configparser.NoOptionError):
149 | pass
150 |
151 |
152 | def dense(var):
153 | return [torch.mean(torch.abs(x)) if torch.is_tensor(x) else np.abs(x) for x in var]
154 |
155 |
156 | def abs_mean(data, dtype=np.float32):
157 | assert isinstance(data, np.ndarray), type(data)
158 | return np.sum(np.abs(data)) / dtype(data.size)
159 |
160 |
161 | def image_size(path):
162 | with Image.open(path) as image:
163 | return image.size
164 |
--------------------------------------------------------------------------------
/utils/cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import numpy as np
19 |
20 |
21 | def verify_coords(yx_min, yx_max, size):
22 | assert np.all(yx_min <= yx_max), 'yx_min <= yx_max'
23 | assert np.all(0 <= yx_min), '0 <= yx_min'
24 | assert np.all(0 <= yx_max), '0 <= yx_max'
25 | assert np.all(yx_min < size), 'yx_min < size'
26 | assert np.all(yx_max < size), 'yx_max < size'
27 |
28 |
29 | def fix_coords(yx_min, yx_max, size):
30 | assert np.all(yx_min <= yx_max)
31 | assert yx_min.dtype == yx_max.dtype
32 | coord_min = np.zeros([2], dtype=yx_min.dtype)
33 | coord_max = np.array(size, dtype=yx_min.dtype) - 1
34 | yx_min = np.minimum(np.maximum(yx_min, coord_min), coord_max)
35 | yx_max = np.minimum(np.maximum(yx_max, coord_min), coord_max)
36 | return yx_min, yx_max
37 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import pickle
20 | import random
21 | import copy
22 |
23 | import numpy as np
24 | import torch.utils.data
25 | import cv2
26 |
27 | import utils
28 | import pyopenpose
29 |
30 |
31 | def padding_labels(data, dim, labels='keypoints, yx_min, yx_max'.split(', ')):
32 | """
33 | Padding labels into the same dimension (to form a batch).
34 | :author 申瑞珉 (Ruimin Shen)
35 | :param data: A dict contains the labels to be padded.
36 | :param dim: The target dimension.
37 | :param labels: The list of label names.
38 | :return: The padded label dict.
39 | """
40 | pad = dim - len(data[labels[0]])
41 | for key in labels:
42 | label = data[key]
43 | data[key] = np.pad(label, [(0, pad)] + [(0, 0)] * (len(label.shape) - 1), 'constant')
44 | return data
45 |
46 |
47 | def load_pickles(paths):
48 | data = []
49 | for path in paths:
50 | with open(path, 'rb') as f:
51 | data += pickle.load(f)
52 | return data
53 |
54 |
55 | class Dataset(torch.utils.data.Dataset):
56 | def __init__(self, config, data, transform=lambda data: data, shuffle=False, dir=None):
57 | """
58 | Load the cached data (.pkl) into memory.
59 | :author 申瑞珉 (Ruimin Shen)
60 | :param data: A list contains the data samples (dict).
61 | :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict.
62 | :param shuffle: Shuffle the loaded dataset.
63 | :param dir: The directory to store the exception data.
64 | """
65 | self.config = config
66 | self.mask_ext = config.get('cache', 'mask_ext')
67 | self.data = data
68 | if shuffle:
69 | random.shuffle(self.data)
70 | self.transform = transform
71 | self.dir = dir
72 |
73 | def __len__(self):
74 | return len(self.data)
75 |
76 | def __getitem__(self, index):
77 | data = copy.deepcopy(self.data[index])
78 | try:
79 | image = cv2.imread(data['path'])
80 | data['image'] = image
81 | data['size'] = np.array(image.shape[:2])
82 | mask = cv2.imread(data['keypath'] + '.mask' + self.mask_ext, cv2.IMREAD_GRAYSCALE)
83 | assert image.shape[:2] == mask.shape, [image.shape[:2], mask.shape]
84 | data['mask'] = mask
85 | data['index'] = random.randint(0, len(data['keypoints']) - 1)
86 | data = self.transform(data)
87 | except:
88 | if self.dir is not None:
89 | os.makedirs(self.dir, exist_ok=True)
90 | name = self.__module__ + '.' + type(self).__name__
91 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
92 | pickle.dump(data, f)
93 | raise
94 | return data
95 |
96 |
97 | class Collate(object):
98 | def __init__(self, config, resize, sizes, feature_sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None):
99 | """
100 | Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch.
101 | :author 申瑞珉 (Ruimin Shen)
102 | :param resize: A function to resize the image and labels.
103 | :param sizes: The image sizes to be randomly choosed.
104 | :param feature_sizes: The feature sizes related to the image sizes.
105 | :param maintain: How many times a size to be maintained.
106 | :param transform_image: A function to transform the resized image.
107 | :param transform_tensor: A function to standardize a image into a tensor.
108 | :param dir: The directory to store the exception data.
109 | """
110 | self.config = config
111 | self.resize = resize
112 | assert len(sizes) == len(feature_sizes)
113 | self.sizes = sizes
114 | self.feature_sizes = feature_sizes
115 | assert maintain > 0
116 | self.maintain = maintain
117 | self._maintain = maintain
118 | self.transform_image = transform_image
119 | self.transform_tensor = transform_tensor
120 | self.dir = dir
121 | self.sigma_parts = config.getfloat('label', 'sigma_parts')
122 | self.sigma_limbs = config.getfloat('label', 'sigma_limbs')
123 | self.limbs_index = utils.get_limbs_index(config)
124 |
125 | def __call__(self, batch):
126 | (height, width), (rows, cols) = self.next_size()
127 | dim = max(len(data['keypoints']) for data in batch)
128 | _batch = []
129 | for data in batch:
130 | try:
131 | data = self.resize(data, height, width)
132 | data['image'] = self.transform_image(data['image'])
133 | data = padding_labels(data, dim)
134 | if self.transform_tensor is not None:
135 | data['tensor'] = self.transform_tensor(data['image'])
136 | data['mask'] = (cv2.resize(data['mask'], (cols, rows)) > 127).astype(np.uint8)
137 | data['parts'] = pyopenpose.label_parts(data['keypoints'], self.sigma_parts, height, width, rows, cols)
138 | data['limbs'] = pyopenpose.label_limbs(data['keypoints'], self.limbs_index, self.sigma_limbs, height, width, rows, cols)
139 | _batch.append(data)
140 | except:
141 | if self.dir is not None:
142 | os.makedirs(self.dir, exist_ok=True)
143 | name = self.__module__ + '.' + type(self).__name__
144 | with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
145 | pickle.dump(data, f)
146 | raise
147 | return torch.utils.data.dataloader.default_collate(_batch)
148 |
149 | def next_size(self):
150 | if self._maintain < self.maintain:
151 | self._maintain += 1
152 | else:
153 | self._index = random.randint(0, len(self.sizes) - 1)
154 | self._maintain = 0
155 | return self.sizes[self._index], self.feature_sizes[self._index]
156 |
--------------------------------------------------------------------------------
/utils/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import os
19 | import time
20 | import operator
21 | import logging
22 |
23 | import torch
24 |
25 |
26 | class Timer(object):
27 | def __init__(self, max, first=True):
28 | """
29 | A simple function object to determine time event.
30 | :author 申瑞珉 (Ruimin Shen)
31 | :param max: Number of seconds to trigger a time event.
32 | :param first: Should a time event to be triggered at the first time.
33 | """
34 | self.start = 0 if first else time.time()
35 | self.max = max
36 |
37 | def __call__(self):
38 | """
39 | Return a boolean value to indicate if the time event is occurred.
40 | :author 申瑞珉 (Ruimin Shen)
41 | """
42 | t = time.time()
43 | elapsed = t - self.start
44 | if elapsed > self.max:
45 | self.start = t
46 | return True
47 | else:
48 | return False
49 |
50 |
51 | def load_model(model_dir, step=None, ext='.pth', ext_epoch='.epoch', logger=logging.info):
52 | """
53 | Load the latest checkpoint in a model directory.
54 | :author 申瑞珉 (Ruimin Shen)
55 | :param model_dir: The directory to store the model checkpoint files.
56 | :param step: If a integer value is given, the corresponding checkpoint will be loaded. Otherwise, the latest checkpoint (with the largest step value) will be loaded.
57 | :param ext: The extension of the model file.
58 | :param ext_epoch: The extension of the epoch file.
59 | :return:
60 | """
61 | if step is None:
62 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(model_dir)) if n.isdigit() and e == ext]
63 | step, name = max(steps, key=operator.itemgetter(0))
64 | else:
65 | name = str(step)
66 | prefix = os.path.join(model_dir, name)
67 | if logger is not None:
68 | logger('load %s.*' % prefix)
69 | try:
70 | with open(prefix + ext_epoch, 'r') as f:
71 | epoch = int(f.read())
72 | except (FileNotFoundError, ValueError):
73 | epoch = None
74 | path = prefix + ext
75 | assert os.path.exists(path), path
76 | return path, step, epoch
77 |
78 |
79 | class Saver(object):
80 | def __init__(self, model_dir, keep, ext='.pth', ext_epoch='.epoch', logger=logging.info):
81 | """
82 | Manage several latest checkpoints (with the largest step values) in a model directory.
83 | :author 申瑞珉 (Ruimin Shen)
84 | :param model_dir: The directory to store the model checkpoint files.
85 | :param keep: How many latest checkpoints to be maintained.
86 | :param ext: The extension of the model file.
87 | :param ext_epoch: The extension of the epoch file.
88 | """
89 | self.model_dir = model_dir
90 | self.keep = keep
91 | self.ext = ext
92 | self.ext_epoch = ext_epoch
93 | self.logger = (lambda s: s) if logger is None else logger
94 |
95 | def __call__(self, obj, step, epoch=None):
96 | """
97 | Save the PyTorch module.
98 | :author 申瑞珉 (Ruimin Shen)
99 | :param obj: The PyTorch module to be saved.
100 | :param step: Current step.
101 | :param epoch: Current epoch.
102 | """
103 | os.makedirs(self.model_dir, exist_ok=True)
104 | prefix = os.path.join(self.model_dir, str(step))
105 | torch.save(obj, prefix + self.ext)
106 | if epoch is not None:
107 | with open(prefix + self.ext_epoch, 'w') as f:
108 | f.write(str(epoch))
109 | self.logger('model saved into %s.*' % prefix)
110 | self.tidy()
111 | return prefix
112 |
113 | def tidy(self):
114 | steps = [(int(n), n) for n, e in map(os.path.splitext, os.listdir(self.model_dir)) if n.isdigit() and e == self.ext]
115 | if len(steps) > self.keep:
116 | steps = sorted(steps, key=operator.itemgetter(0))
117 | remove = steps[:len(steps) - self.keep]
118 | for _, n in remove:
119 | path = os.path.join(self.model_dir, n)
120 | os.remove(path + self.ext)
121 | path_epoch = path + self.ext_epoch
122 | try:
123 | os.remove(path_epoch)
124 | except FileNotFoundError:
125 | self.logger(path_epoch + ' not found')
126 | logging.debug('tidy ' + path)
127 |
128 |
129 | def load_sizes(config):
130 | sizes = [s.split(',') for s in config.get('data', 'sizes').split()]
131 | return [(int(height), int(width)) for height, width in sizes]
132 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017, 申瑞珉 (Ruimin Shen)
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU Lesser General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
18 | import logging
19 | import itertools
20 | import inspect
21 |
22 | import numpy as np
23 | import torch
24 | import matplotlib
25 | import matplotlib.cm
26 | import matplotlib.colors
27 | import matplotlib.pyplot as plt
28 | import humanize
29 | import graphviz
30 | import cv2
31 |
32 |
33 | def draw_mask(image, mask, threshold=128):
34 | _mask = cv2.resize(np.squeeze(mask), image.shape[1::-1], interpolation=cv2.INTER_NEAREST)
35 | return np.expand_dims(_mask >= threshold, -1) * image
36 |
37 |
38 | class DrawPoints(object):
39 | def __init__(self, limbs_index, colors=[], radius=5, thickness=2, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=0.5, z=1):
40 | self.limbs_index = limbs_index
41 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
42 | self._colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']]
43 | self.radius = radius
44 | self.thickness = thickness
45 | self.line_type = line_type
46 | self.shift = shift
47 | self.font_face = font_face
48 | self.font_scale = font_scale
49 | self.z = z
50 |
51 | def __call__(self, image, points, debug=False):
52 | if len(self.colors) >= 2:
53 | for i, point in enumerate(points):
54 | y, x, v = map(int, point)
55 | assert v >= 0
56 | if v > 0:
57 | text = str(i)
58 | color = self.colors[v - 1]
59 | _color = tuple(map(lambda c: np.float(np.bitwise_not(np.uint8(c))), color))
60 | cv2.putText(image, text, (x - self.z, y), self.font_face, self.font_scale, _color)
61 | cv2.putText(image, text, (x + self.z, y), self.font_face, self.font_scale, _color)
62 | cv2.putText(image, text, (x, y - self.z), self.font_face, self.font_scale, _color)
63 | cv2.putText(image, text, (x, y + self.z), self.font_face, self.font_scale, _color)
64 | cv2.putText(image, text, (x, y), self.font_face, self.font_scale, color)
65 | if len(self.limbs_index) > 0:
66 | for color, (i1, i2) in zip(itertools.cycle(self._colors), self.limbs_index):
67 | y1, x1, v1 = points[i1].T
68 | y2, x2, v2 = points[i2].T
69 | if v1 > 0 and v2 > 0:
70 | cv2.line(image, (x1, y1), (x2, y2), color, thickness=self.thickness)
71 | else:
72 | for color, (y, x, v) in zip(itertools.cycle(self._colors), points):
73 | if v > 0:
74 | cv2.circle(image, (x, y), self.radius, color, thickness=-1)
75 | if debug:
76 | cv2.imshow('', image)
77 | cv2.waitKey(0)
78 | return image
79 |
80 |
81 | class DrawBBox(object):
82 | def __init__(self, category=None, colors=[], thickness=1, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1):
83 | self.category = category
84 | if colors:
85 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
86 | else:
87 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']]
88 | self.thickness = thickness
89 | self.line_type = line_type
90 | self.shift = shift
91 | self.font_face = font_face
92 | self.font_scale = font_scale
93 |
94 | def __call__(self, image, yx_min, yx_max, cls=None, colors=None, debug=False):
95 | colors = self.colors if colors is None else [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
96 | if cls is None:
97 | cls = [None] * len(yx_min)
98 | for color, (ymin, xmin), (ymax, xmax), cls in zip(itertools.cycle(colors), yx_min, yx_max, cls):
99 | try:
100 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness=self.thickness, lineType=self.line_type, shift=self.shift)
101 | if self.category is not None and cls is not None:
102 | cv2.putText(image, self.category[cls], (xmin, ymin), self.font_face, self.font_scale, color=color, thickness=self.thickness)
103 | except OverflowError as e:
104 | logging.warning(e, (xmin, ymin), (xmax, ymax))
105 | if debug:
106 | cv2.imshow('', image)
107 | cv2.waitKey(0)
108 | return image
109 |
110 |
111 | class DrawFeature(object):
112 | def __init__(self, alpha=0.5, cmap=None):
113 | self.alpha = alpha
114 | self.cm = matplotlib.cm.get_cmap(cmap)
115 |
116 | def __call__(self, image, feature, debug=False):
117 | _feature = (feature * self.cm.N).astype(np.int)
118 | heatmap = self.cm(_feature)[:, :, :3] * 255
119 | heatmap = cv2.resize(heatmap, image.shape[1::-1], interpolation=cv2.INTER_NEAREST)
120 | canvas = (image * (1 - self.alpha) + heatmap * self.alpha).astype(np.uint8)
121 | if debug:
122 | cv2.imshow('max=%f, sum=%f' % (np.max(feature), np.sum(feature)), canvas)
123 | cv2.waitKey(0)
124 | return canvas
125 |
126 |
127 | class DrawCluster(object):
128 | def __init__(self, colors=[], thickness=2, line_type=cv2.LINE_8, shift=0, font_face=cv2.FONT_HERSHEY_SIMPLEX, font_scale=0.5, z=1):
129 | if colors:
130 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(c)[::-1])) for c in colors]
131 | else:
132 | self.colors = [tuple(map(lambda c: c * 255, matplotlib.colors.colorConverter.to_rgb(prop['color'])[::-1])) for prop in plt.rcParams['axes.prop_cycle']]
133 | self.thickness = thickness
134 | self.line_type = line_type
135 | self.shift = shift
136 | self.font_face = font_face
137 | self.font_scale = font_scale
138 | self.z = z
139 |
140 | def __call__(self, image, cluster, debug=False):
141 | for color, limb in zip(self.colors, cluster):
142 | (i1, y1, x1), (i2, y2, x2) = limb
143 | cv2.line(image, (x1, y1), (x2, y2), color, thickness=self.thickness)
144 | drawn = set()
145 | for (i, y, x) in limb:
146 | if i not in drawn:
147 | drawn.add(i)
148 | text = str(i)
149 | _color = tuple(map(lambda c: np.float(np.bitwise_not(np.uint8(c))), color))
150 | cv2.putText(image, text, (x - self.z, y), self.font_face, self.font_scale, _color)
151 | cv2.putText(image, text, (x + self.z, y), self.font_face, self.font_scale, _color)
152 | cv2.putText(image, text, (x, y - self.z), self.font_face, self.font_scale, _color)
153 | cv2.putText(image, text, (x, y + self.z), self.font_face, self.font_scale, _color)
154 | cv2.putText(image, text, (x, y), self.font_face, self.font_scale, color)
155 | if debug:
156 | cv2.imshow('', image)
157 | cv2.waitKey(0)
158 | return image
159 |
160 |
161 | class Graph(object):
162 | def __init__(self, config, state_dict, cmap=None):
163 | self.dot = graphviz.Digraph(node_attr=dict(config.items('digraph_node_attr')), graph_attr=dict(config.items('digraph_graph_attr')))
164 | self.dot.format = config.get('graph', 'format')
165 | self.state_dict = state_dict
166 | self.var_name = {t._cdata: k for k, t in state_dict.items()}
167 | self.seen = set()
168 | self.index = 0
169 | self.drawn = set()
170 | self.cm = matplotlib.cm.get_cmap(cmap)
171 | self.metric = eval(config.get('graph', 'metric'))
172 | metrics = [self.metric(t) for t in state_dict.values()]
173 | self.minmax = [min(metrics), max(metrics)]
174 |
175 | def __call__(self, node):
176 | if node not in self.seen:
177 | self.traverse_next(node)
178 | self.traverse_tensor(node)
179 | self.seen.add(node)
180 | self.index += 1
181 |
182 | def traverse_next(self, node):
183 | if hasattr(node, 'next_functions'):
184 | for n, _ in node.next_functions:
185 | if n is not None:
186 | self.__call__(n)
187 | self._draw_node_edge(node, n)
188 | self._draw_node(node)
189 |
190 | def traverse_tensor(self, node):
191 | tensors = [t for name, t in inspect.getmembers(node) if torch.is_tensor(t)]
192 | if hasattr(node, 'saved_tensors'):
193 | tensors += node.saved_tensors
194 | for tensor in tensors:
195 | name = self.var_name[tensor._cdata]
196 | self.drawn.add(name)
197 | self._draw_tensor(node, tensor)
198 |
199 | def _draw_node(self, node):
200 | if hasattr(node, 'variable'):
201 | tensor = node.variable.data
202 | name = self.var_name[tensor._cdata]
203 | label = '\n'.join(map(str, [
204 | '%d: %s' % (self.index, name),
205 | list(tensor.size()),
206 | humanize.naturalsize(tensor.numpy().nbytes),
207 | ]))
208 | fillcolor, fontcolor = self._tensor_color(tensor)
209 | self.dot.node(str(id(node)), label, shape='note', fillcolor=fillcolor, fontcolor=fontcolor)
210 | self.drawn.add(name)
211 | else:
212 | self.dot.node(str(id(node)), '%d: %s' % (self.index, type(node).__name__), fillcolor='white')
213 |
214 | def _draw_node_edge(self, node, n):
215 | if hasattr(n, 'variable'):
216 | self.dot.edge(str(id(n)), str(id(node)), arrowhead='none', arrowtail='none')
217 | else:
218 | self.dot.edge(str(id(n)), str(id(node)))
219 |
220 | def _draw_tensor(self, node, tensor):
221 | name = self.var_name[tensor._cdata]
222 | label = '\n'.join(map(str, [
223 | name,
224 | list(tensor.size()),
225 | humanize.naturalsize(tensor.numpy().nbytes),
226 | ]))
227 | fillcolor, fontcolor = self._tensor_color(tensor)
228 | self.dot.node(name, label, style='filled, rounded', fillcolor=fillcolor, fontcolor=fontcolor)
229 | self.dot.edge(name, str(id(node)), style='dashed', arrowhead='none', arrowtail='none')
230 |
231 | def _tensor_color(self, tensor):
232 | level = self._norm(self.metric(tensor))
233 | fillcolor = self.cm(np.int(level * self.cm.N))
234 | fontcolor = self.cm(self.cm.N if level < 0.5 else 0)
235 | return matplotlib.colors.to_hex(fillcolor), matplotlib.colors.to_hex(fontcolor)
236 |
237 | def _norm(self, metric):
238 | min, max = self.minmax
239 | assert min <= metric <= max, (metric, self.minmax)
240 | if min < max:
241 | return (metric - min) / (max - min)
242 | else:
243 | return metric
244 |
--------------------------------------------------------------------------------