├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── common.py
├── datum_pb2.py
├── images
├── person2.jpg
└── person3.jpg
├── inference.py
├── models
└── numpy
│ └── openpose_coco.npy
├── network_base.py
├── network_cmu.py
├── network_dsconv.py
├── network_mobilenet.py
├── pose_augment.py
├── pose_dataset.py
└── requirements.txt
/.gitattributes:
--------------------------------------------------------------------------------
1 | models/numpy/* filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tf-openpose
2 |
3 | Openpose from CMU implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU.**
4 |
5 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose
6 |
7 | **Features**
8 |
9 | [x] CMU's original network architecture and weights.
10 |
11 | [] Post processing from network output.
12 |
13 | [] Faster network variants using mobilenet, lcnn architecture.
14 |
15 | [] ROS Support.
16 |
17 | ## Install
18 |
19 | You need dependencies below.
20 |
21 | - python3
22 |
23 | - tensorflow 1.3
24 |
25 | - opencv 3
26 |
27 | - protobuf
28 |
29 | ## Models
30 |
31 | ### Inference Time
32 |
33 | | Dataset | Model | Description | Inference Time
1 core cpu |
34 | |---------|--------------------------|------------------------------------------------------------------------------------------|---------------:|
35 | | Coco | cmu | CMU's original version. Same network, same weights. | 3.65s / img |
36 | | Coco | dsconv | Same as the cmu version except for the **depthwise separable convolution** of mobilenet. | 0.44s / img |
37 | | Coco | mobilenet | | | |
38 | | Coco | lcnn | | | |
39 |
40 |
41 | ## Training
42 |
43 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation.
44 |
45 | This includes
46 |
47 | - scale : 0.7 ~ 1.3
48 |
49 | - rotation : -40 ~ 40 degrees
50 |
51 | - flip
52 |
53 | - cropping
54 |
55 | See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train
56 |
57 |
58 |
59 | ## References
60 |
61 | ### OpenPose
62 |
63 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose
64 |
65 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation
66 |
67 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train
68 |
69 | ### Mobilenet
70 |
71 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md
72 |
73 | ### Libraries
74 |
75 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from enum import Enum
3 | import math
4 | import logging
5 |
6 | import numpy as np
7 | import itertools
8 | import tensorflow as tf
9 | from scipy.ndimage.filters import maximum_filter
10 |
11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
12 |
13 |
14 | class CocoPart(Enum):
15 | Nose = 0
16 | Neck = 1
17 | RShoulder = 2
18 | RElbow = 3
19 | RWrist = 4
20 | LShoulder = 5
21 | LElbow = 6
22 | LWrist = 7
23 | RHip = 8
24 | RKnee = 9
25 | RAnkle = 10
26 | LHip = 11
27 | LKnee = 12
28 | LAnkle = 13
29 | REye = 14
30 | LEye = 15
31 | REar = 16
32 | LEar = 17
33 | Background = 18
34 |
35 | CocoPairs = [
36 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11),
37 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17)
38 | ] # = 19
39 | CocoPairsRender = CocoPairs[:-2]
40 | CocoPairsNetwork = [
41 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5),
42 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27)
43 | ] # = 19
44 |
45 | NMS_Threshold = 0.05
46 | InterMinAbove_Threshold = 6
47 | Inter_Threashold = 0.05
48 | Min_Subset_Cnt = 3
49 | Min_Subset_Score = 0.4
50 | Max_Human = 96
51 |
52 |
53 | def non_max_suppression_tf(np_input, window_size=3, threshold=NMS_Threshold):
54 | # input: B x W x H x C
55 | under_threshold_indices = np_input < threshold
56 | np_input[under_threshold_indices] = 0
57 | np_input = np_input.reshape([1, 30, 40, 1])
58 | pooled = tf.nn.max_pool(np_input, ksize=[1, window_size, window_size, 1], strides=[1, 1, 1, 1], padding='SAME')
59 | output = tf.where(tf.equal(np_input, pooled), np_input, tf.zeros_like(np_input))
60 | # NOTE: if input has negative values, the suppressed values can be higher than original
61 | return output.eval().reshape([30, 40]) # output: B X W X H x C
62 |
63 |
64 | def non_max_suppression_scipy(np_input, window_size=3, threshold=NMS_Threshold):
65 | under_threshold_indices = np_input < threshold
66 | np_input[under_threshold_indices] = 0
67 | return np_input*(np_input == maximum_filter(np_input, footprint=np.ones((window_size, window_size))))
68 |
69 |
70 | non_max_suppression = non_max_suppression_scipy
71 |
72 |
73 | def estimate_pose(heatMat, pafMat):
74 | logging.debug('nms')
75 | coords = []
76 | for plain in heatMat:
77 | nms = non_max_suppression(plain, 3, NMS_Threshold)
78 | coords.append(np.where(nms >= NMS_Threshold))
79 |
80 | logging.debug('estimate_pose1')
81 | connection_all = []
82 | for (idx1, idx2), (paf_x_idx, paf_y_idx) in zip(CocoPairs, CocoPairsNetwork):
83 | connection = estimate_pose_pair(coords, idx1, idx2, pafMat[paf_x_idx], pafMat[paf_y_idx])
84 | connection_all.extend(connection)
85 |
86 | logging.debug('estimate_pose2, connection=%d' % len(connection_all))
87 | connection_by_human = dict()
88 | for idx, c in enumerate(connection_all):
89 | connection_by_human['human_%d' % idx] = [c]
90 |
91 | no_merge_cache = defaultdict(list)
92 | while True:
93 | is_merged = False
94 | for k1, k2 in itertools.combinations(connection_by_human.keys(), 2):
95 | if k1 == k2:
96 | continue
97 | if k2 in no_merge_cache[k1]:
98 | continue
99 | for c1, c2 in itertools.product(connection_by_human[k1], connection_by_human[k2]):
100 | if len(set(c1['uPartIdx']) & set(c2['uPartIdx'])) > 0:
101 | is_merged = True
102 | connection_by_human[k1].extend(connection_by_human[k2])
103 | connection_by_human.pop(k2)
104 | break
105 | if is_merged:
106 | no_merge_cache.pop(k1, None)
107 | break
108 | else:
109 | no_merge_cache[k1].append(k2)
110 |
111 | if not is_merged:
112 | break
113 |
114 | logging.debug('estimate_pose3')
115 |
116 | # reject by subset count
117 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if len(v) >= Min_Subset_Cnt}
118 |
119 | # reject by subset max score
120 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if max([ii['score'] for ii in v]) >= Min_Subset_Score}
121 |
122 | logging.debug('estimate_pose4')
123 | return connection_by_human
124 |
125 |
126 | def estimate_pose_pair(coords, partIdx1, partIdx2, pafMatX, pafMatY):
127 | connection_temp = []
128 | peak_coord1, peak_coord2 = coords[partIdx1], coords[partIdx2]
129 |
130 | cnt = 0
131 | for idx1, (y1, x1) in enumerate(zip(peak_coord1[0], peak_coord1[1])):
132 | for idx2, (y2, x2) in enumerate(zip(peak_coord2[0], peak_coord2[1])):
133 | score, count = get_score(x1, y1, x2, y2, pafMatX, pafMatY)
134 | cnt += 1
135 | if count < InterMinAbove_Threshold or score <= 0.0:
136 | continue
137 | connection_temp.append({
138 | 'score': score,
139 | 'c1': (x1, y1),
140 | 'c2': (x2, y2),
141 | 'idx': (idx1, idx2),
142 | 'partIdx': (partIdx1, partIdx2),
143 | 'uPartIdx': ('{}-{}-{}'.format(x1, y1, partIdx1), '{}-{}-{}'.format(x2, y2, partIdx2))
144 | })
145 |
146 | connection = []
147 | used_idx1, used_idx2 = [], []
148 | for candidate in sorted(connection_temp, key=lambda x: x['score'], reverse=True):
149 | # check not connected
150 | if candidate['idx'][0] in used_idx1 or candidate['idx'][1] in used_idx2:
151 | continue
152 | connection.append(candidate)
153 | used_idx1.append(candidate['idx'][0])
154 | used_idx2.append(candidate['idx'][1])
155 |
156 | return connection
157 |
158 |
159 | def get_score(x1, y1, x2, y2, pafMatX, pafMatY):
160 | __num_inter = 10
161 | __num_inter_f = float(__num_inter)
162 | dx, dy = x2 - x1, y2 - y1
163 | normVec = math.sqrt(dx ** 2 + dy ** 2)
164 |
165 | if normVec < 1e-6:
166 | return 0.0, 0
167 |
168 | vx, vy = dx / normVec, dy / normVec
169 |
170 | xs = np.arange(x1, x2, dx / __num_inter_f) if x1 != x2 else [x1] * __num_inter
171 | ys = np.arange(y1, y2, dy / __num_inter_f) if y1 != y2 else [y1] * __num_inter
172 | pafXs = np.zeros(__num_inter)
173 | pafYs = np.zeros(__num_inter)
174 | for idx, (mx, my) in enumerate(zip(xs, ys)):
175 | mx, my = int(mx + 0.5), int(my + 0.5)
176 | mx, my = max(mx, 0), max(my, 0)
177 |
178 | pafXs[idx] = pafMatX[my][mx]
179 | pafYs[idx] = pafMatY[my][mx]
180 |
181 | local_scores = pafXs * vx + pafYs * vy
182 | thidxs = local_scores > Inter_Threashold
183 |
184 | return sum(local_scores*thidxs), sum(thidxs)
185 |
--------------------------------------------------------------------------------
/datum_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: datum.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='datum.proto',
20 | package='',
21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _DATUM = _descriptor.Descriptor(
29 | name='Datum',
30 | full_name='Datum',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='channels', full_name='Datum.channels', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='height', full_name='Datum.height', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='width', full_name='Datum.width', index=2,
51 | number=3, type=5, cpp_type=1, label=1,
52 | has_default_value=False, default_value=0,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='data', full_name='Datum.data', index=3,
58 | number=4, type=12, cpp_type=9, label=1,
59 | has_default_value=False, default_value=_b(""),
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='label', full_name='Datum.label', index=4,
65 | number=5, type=5, cpp_type=1, label=1,
66 | has_default_value=False, default_value=0,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | _descriptor.FieldDescriptor(
71 | name='float_data', full_name='Datum.float_data', index=5,
72 | number=6, type=2, cpp_type=6, label=3,
73 | has_default_value=False, default_value=[],
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None),
77 | _descriptor.FieldDescriptor(
78 | name='encoded', full_name='Datum.encoded', index=6,
79 | number=7, type=8, cpp_type=7, label=1,
80 | has_default_value=True, default_value=False,
81 | message_type=None, enum_type=None, containing_type=None,
82 | is_extension=False, extension_scope=None,
83 | options=None),
84 | ],
85 | extensions=[
86 | ],
87 | nested_types=[],
88 | enum_types=[
89 | ],
90 | options=None,
91 | is_extendable=False,
92 | extension_ranges=[],
93 | oneofs=[
94 | ],
95 | serialized_start=16,
96 | serialized_end=145,
97 | )
98 |
99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
100 |
101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
102 | DESCRIPTOR = _DATUM,
103 | __module__ = 'datum_pb2'
104 | # @@protoc_insertion_point(class_scope:Datum)
105 | ))
106 | _sym_db.RegisterMessage(Datum)
107 |
108 |
109 | # @@protoc_insertion_point(module_scope)
110 |
--------------------------------------------------------------------------------
/images/person2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/images/person2.jpg
--------------------------------------------------------------------------------
/images/person3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/images/person3.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import tensorflow as tf
3 | import cv2
4 | import numpy as np
5 | import time
6 | import logging
7 | import argparse
8 |
9 | from tensorflow.python.client import timeline
10 |
11 | from network_cmu import CmuNetwork
12 | from common import estimate_pose, CocoPairsRender
13 |
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
15 |
16 | config = tf.ConfigProto()
17 | config.gpu_options.allocator_type = 'BFC'
18 | config.gpu_options.per_process_gpu_memory_fraction = 0.95
19 | config.gpu_options.allow_growth = True
20 |
21 |
22 | if __name__ == '__main__':
23 | parser = argparse.ArgumentParser(description='Tensorflow Openpose Inference')
24 | parser.add_argument('--imgpath', type=str, default='./images/person3.jpg')
25 | parser.add_argument('--input-width', type=int, default=320)
26 | parser.add_argument('--input-height', type=int, default=240)
27 | parser.add_argument('--stage-level', type=int, default=6)
28 | parser.add_argument('--model', type=str, default='kakao', help='cmu(original) / kakao(faster version)')
29 | args = parser.parse_args()
30 |
31 | input_node = tf.placeholder(tf.float32, shape=(None, args.input_height, args.input_width, 3), name='image')
32 |
33 | with tf.Session(config=config) as sess:
34 | if args.model == 'kakao':
35 | net = KakaoNetwork({'image': input_node}, trainable=False)
36 | net.load('./models/numpy/fastopenpose_coco_v170729.npy', sess)
37 | elif args.model == 'cmu':
38 | net = CmuNetwork({'image': input_node}, trainable=False)
39 | net.load('./models/numpy/openpose_coco.npy', sess)
40 | else:
41 | raise Exception('Invalid Mode.')
42 |
43 | logging.debug('read image+')
44 | image = cv2.imread(args.imgpath)
45 | image = cv2.resize(image, (args.input_width, args.input_height))
46 | image = image.astype(float)
47 | image -= 128.0
48 | image /= 128.0
49 |
50 | vec = sess.run(net.get_output(name='concat_stage7'), feed_dict={'image:0': [image]})
51 |
52 | logging.debug('inference+')
53 | a = time.time()
54 | pafMat, heatMat = sess.run(
55 | [
56 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)),
57 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level))
58 | ], feed_dict={'image:0': [image]}
59 | )
60 | logging.info('inference- elapsed_time={}'.format(time.time() - a))
61 | a = time.time()
62 | pafMat, heatMat = sess.run(
63 | [
64 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)),
65 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level))
66 | ], feed_dict={'image:0': [image]}
67 | )
68 | logging.info('inference- elapsed_time={}'.format(time.time() - a))
69 | a = time.time()
70 |
71 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
72 | run_metadata = tf.RunMetadata()
73 | pafMat, heatMat = sess.run(
74 | [
75 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)),
76 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level))
77 | ], feed_dict={'image:0': [image]}, options=run_options, run_metadata=run_metadata
78 | )
79 | logging.info('inference- elapsed_time={}'.format(time.time() - a))
80 |
81 | tl = timeline.Timeline(run_metadata.step_stats)
82 | ctf = tl.generate_chrome_trace_format()
83 | with open('timeline.json', 'w') as f:
84 | f.write(ctf)
85 |
86 | heatMat, pafMat = heatMat[0], pafMat[0]
87 | heatMat = np.rollaxis(heatMat, 2, 0)
88 | pafMat = np.rollaxis(pafMat, 2, 0)
89 |
90 | logging.info('pickle data')
91 | with open('heatmat.pickle', 'wb') as pickle_file:
92 | pickle.dump(heatMat, pickle_file, pickle.HIGHEST_PROTOCOL)
93 | with open('pafmat.pickle', 'wb') as pickle_file:
94 | pickle.dump(pafMat, pickle_file, pickle.HIGHEST_PROTOCOL)
95 |
96 | logging.info('pose+')
97 | a = time.time()
98 | humans = estimate_pose(heatMat, pafMat)
99 | logging.info('pose- elapsed_time={}'.format(time.time() - a))
100 |
101 | # display
102 | image = cv2.imread(args.imgpath)
103 | image_h, image_w = image.shape[:2]
104 | heat_h, heat_w = heatMat[0].shape[:2]
105 | for _, human in humans.items():
106 | for part in human:
107 | if part['partIdx'] not in CocoPairsRender:
108 | continue
109 | center1 = (int((part['c1'][0] + 0.5) * image_w / heat_w), int((part['c1'][1] + 0.5) * image_h / heat_h))
110 | center2 = (int((part['c2'][0] + 0.5) * image_w / heat_w), int((part['c2'][1] + 0.5) * image_h / heat_h))
111 | cv2.circle(image, center1, 2, (255, 0, 0), thickness=3, lineType=8, shift=0)
112 | cv2.circle(image, center2, 2, (255, 0, 0), thickness=3, lineType=8, shift=0)
113 | image = cv2.line(image, center1, center2, (255, 0, 0), 1)
114 | cv2.imshow('result', image)
115 | cv2.waitKey(0)
116 |
--------------------------------------------------------------------------------
/models/numpy/openpose_coco.npy:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7dc09d77ffa32a3b5a495d2f5bda75fa60941a90847c800600142437ef606a84
3 | size 209255398
4 |
--------------------------------------------------------------------------------
/network_base.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import tensorflow as tf
4 | import tensorflow.contrib.slim as slim
5 |
6 | DEFAULT_PADDING = 'SAME'
7 |
8 |
9 | def layer(op):
10 | '''
11 | Decorator for composable network layers.
12 | '''
13 |
14 | def layer_decorated(self, *args, **kwargs):
15 | # Automatically set a name if not provided.
16 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
17 | # Figure out the layer inputs.
18 | if len(self.terminals) == 0:
19 | raise RuntimeError('No input variables found for layer %s.' % name)
20 | elif len(self.terminals) == 1:
21 | layer_input = self.terminals[0]
22 | else:
23 | layer_input = list(self.terminals)
24 | # Perform the operation and get the output.
25 | layer_output = op(self, layer_input, *args, **kwargs)
26 | # Add to layer LUT.
27 | self.layers[name] = layer_output
28 | # This output is now the input for the next layer.
29 | self.feed(layer_output)
30 | # Return self for chained calls.
31 | return self
32 |
33 | return layer_decorated
34 |
35 |
36 | class BaseNetwork(object):
37 |
38 | def __init__(self, inputs, trainable=True):
39 | # The input nodes for this network
40 | self.inputs = inputs
41 | # The current list of terminal nodes
42 | self.terminals = []
43 | # Mapping from layer names to layers
44 | self.layers = dict(inputs)
45 | self.tensor_before_relu = dict()
46 | # If true, the resulting variables are set as trainable
47 | self.trainable = trainable
48 | # Switch variable for dropout
49 | self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
50 | shape=[],
51 | name='use_dropout')
52 | self.setup()
53 |
54 | def setup(self):
55 | '''Construct the network. '''
56 | raise NotImplementedError('Must be implemented by the subclass.')
57 |
58 | def load(self, data_path, session, ignore_missing=False):
59 | '''
60 | Load network weights.
61 | data_path: The path to the numpy-serialized network weights
62 | session: The current TensorFlow session
63 | ignore_missing: If true, serialized weights for missing layers are ignored.
64 | '''
65 | data_dict = np.load(data_path, encoding='bytes').item()
66 | for op_name in data_dict:
67 | if isinstance(data_dict[op_name], np.ndarray):
68 | if 'RMSProp' in op_name:
69 | continue
70 | with tf.variable_scope('', reuse=True):
71 | var = tf.get_variable(op_name.replace(':0', ''))
72 | try:
73 | session.run(var.assign(data_dict[op_name]))
74 | except Exception as e:
75 | print(op_name)
76 | print(e)
77 | sys.exit(-1)
78 | else:
79 | with tf.variable_scope(op_name, reuse=True):
80 | for param_name, data in data_dict[op_name].items():
81 | try:
82 | var = tf.get_variable(param_name.decode("utf-8"))
83 | session.run(var.assign(data))
84 | except ValueError as e:
85 | print(e)
86 | if not ignore_missing:
87 | raise
88 |
89 | def feed(self, *args):
90 | '''Set the input(s) for the next operation by replacing the terminal nodes.
91 | The arguments can be either layer names or the actual layers.
92 | '''
93 | assert len(args) != 0
94 | self.terminals = []
95 | for fed_layer in args:
96 | try:
97 | is_str = isinstance(fed_layer, basestring)
98 | except NameError:
99 | is_str = isinstance(fed_layer, str)
100 | if is_str:
101 | try:
102 | fed_layer = self.layers[fed_layer]
103 | except KeyError:
104 | raise KeyError('Unknown layer name fed: %s' % fed_layer)
105 | self.terminals.append(fed_layer)
106 | return self
107 |
108 | def get_output(self, name=None):
109 | '''Returns the current network output.'''
110 | if not name:
111 | return self.terminals[-1]
112 | else:
113 | return self.layers[name]
114 |
115 | def get_tensor(self, name):
116 | if 'conv' not in name:
117 | return self.get_output(name)
118 | return self.tensor_before_relu[name]
119 |
120 | def get_unique_name(self, prefix):
121 | '''Returns an index-suffixed unique name for the given prefix.
122 | This is used for auto-generating layer names based on the type-prefix.
123 | '''
124 | ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
125 | return '%s_%d' % (prefix, ident)
126 |
127 | def make_var(self, name, shape, trainable=True):
128 | '''Creates a new TensorFlow variable.'''
129 | return tf.get_variable(name, shape, trainable=self.trainable & trainable, initializer=tf.contrib.layers.xavier_initializer())
130 |
131 | def validate_padding(self, padding):
132 | '''Verifies that the padding is one of the supported ones.'''
133 | assert padding in ('SAME', 'VALID')
134 |
135 | @layer
136 | def separable_conv(self, input, k_h, k_w, c_o, stride, name, relu=True, depth_multiplier=1.0):
137 | # with slim.arg_scope([slim.batch_norm], fused=True):
138 | # skip pointwise by setting num_outputs=None
139 | output = slim.separable_convolution2d(input,
140 | num_outputs=None,
141 | stride=stride,
142 | trainable=self.trainable,
143 | depth_multiplier=depth_multiplier,
144 | kernel_size=[k_h, k_w],
145 | activation_fn=None,
146 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
147 | weights_initializer=slim.initializers.xavier_initializer(),
148 | biases_initializer=slim.init_ops.zeros_initializer(),
149 | # normalizer_fn=slim.batch_norm,
150 | scope=name + '/dsc')
151 |
152 | # bn = slim.batch_norm(bn, scope=name + '/dw_batch_norm')
153 | output = slim.convolution2d(output,
154 | c_o,
155 | stride=1,
156 | kernel_size=[1, 1],
157 | activation_fn=None,
158 | # activation_fn=tf.nn.relu if relu else None,
159 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
160 | weights_initializer=slim.initializers.xavier_initializer(),
161 | biases_initializer=slim.init_ops.zeros_initializer(),
162 | # normalizer_fn=slim.batch_norm,
163 | trainable=self.trainable,
164 | scope=name + '/pointwise_conv')
165 | # output = slim.batch_norm(pointwise_conv, scope=name + '/pw_batch_norm')
166 |
167 | self.tensor_before_relu[name] = output
168 |
169 | if relu:
170 | output = tf.nn.relu(output, name=name + '/relu')
171 | return output
172 |
173 | @layer
174 | def conv(self,
175 | input,
176 | k_h,
177 | k_w,
178 | c_o,
179 | s_h,
180 | s_w,
181 | name,
182 | relu=True,
183 | padding=DEFAULT_PADDING,
184 | group=1,
185 | trainable=True,
186 | biased=True):
187 | # Verify that the padding is acceptable
188 | self.validate_padding(padding)
189 | # Get the number of channels in the input
190 | c_i = int(input.get_shape()[-1])
191 | # Verify that the grouping parameter is valid
192 | assert c_i % group == 0
193 | assert c_o % group == 0
194 | # Convolution for a given input and kernel
195 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
196 | with tf.variable_scope(name) as scope:
197 | kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o], trainable=self.trainable & trainable)
198 | if group == 1:
199 | # This is the common-case. Convolve the input without any further complications.
200 | output = convolve(input, kernel)
201 | else:
202 | # Split the input into groups and then convolve each of them independently
203 | input_groups = tf.split(3, group, input)
204 | kernel_groups = tf.split(3, group, kernel)
205 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
206 | # Concatenate the groups
207 | output = tf.concat(3, output_groups)
208 | # Add the biases
209 | if biased:
210 | biases = self.make_var('biases', [c_o], trainable=self.trainable & trainable)
211 | output = tf.nn.bias_add(output, biases)
212 |
213 | self.tensor_before_relu[name] = output
214 | if relu:
215 | # ReLU non-linearity
216 | output = tf.nn.relu(output, name=scope.name)
217 | return output
218 |
219 | @layer
220 | def relu(self, input, name):
221 | return tf.nn.relu(input, name=name)
222 |
223 | @layer
224 | def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
225 | self.validate_padding(padding)
226 | return tf.nn.max_pool(input,
227 | ksize=[1, k_h, k_w, 1],
228 | strides=[1, s_h, s_w, 1],
229 | padding=padding,
230 | name=name)
231 |
232 | @layer
233 | def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
234 | self.validate_padding(padding)
235 | return tf.nn.avg_pool(input,
236 | ksize=[1, k_h, k_w, 1],
237 | strides=[1, s_h, s_w, 1],
238 | padding=padding,
239 | name=name)
240 |
241 | @layer
242 | def lrn(self, input, radius, alpha, beta, name, bias=1.0):
243 | return tf.nn.local_response_normalization(input,
244 | depth_radius=radius,
245 | alpha=alpha,
246 | beta=beta,
247 | bias=bias,
248 | name=name)
249 |
250 | @layer
251 | def concat(self, inputs, axis, name):
252 | return tf.concat(axis=axis, values=inputs, name=name)
253 |
254 | @layer
255 | def add(self, inputs, name):
256 | return tf.add_n(inputs, name=name)
257 |
258 | @layer
259 | def fc(self, input, num_out, name, relu=True):
260 | with tf.variable_scope(name) as scope:
261 | input_shape = input.get_shape()
262 | if input_shape.ndims == 4:
263 | # The input is spatial. Vectorize it first.
264 | dim = 1
265 | for d in input_shape[1:].as_list():
266 | dim *= d
267 | feed_in = tf.reshape(input, [-1, dim])
268 | else:
269 | feed_in, dim = (input, input_shape[-1].value)
270 | weights = self.make_var('weights', shape=[dim, num_out])
271 | biases = self.make_var('biases', [num_out])
272 | op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
273 | fc = op(feed_in, weights, biases, name=scope.name)
274 | return fc
275 |
276 | @layer
277 | def softmax(self, input, name):
278 | input_shape = map(lambda v: v.value, input.get_shape())
279 | if len(input_shape) > 2:
280 | # For certain models (like NiN), the singleton spatial dimensions
281 | # need to be explicitly squeezed, since they're not broadcast-able
282 | # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
283 | if input_shape[1] == 1 and input_shape[2] == 1:
284 | input = tf.squeeze(input, squeeze_dims=[1, 2])
285 | else:
286 | raise ValueError('Rank 2 tensor input expected for softmax!')
287 | return tf.nn.softmax(input, name=name)
288 |
289 | @layer
290 | def batch_normalization(self, input, name, scale_offset=True, relu=False):
291 | # NOTE: Currently, only inference is supported
292 | with tf.variable_scope(name) as scope:
293 | shape = [input.get_shape()[-1]]
294 | if scale_offset:
295 | scale = self.make_var('scale', shape=shape)
296 | offset = self.make_var('offset', shape=shape)
297 | else:
298 | scale, offset = (None, None)
299 | output = tf.nn.batch_normalization(
300 | input,
301 | mean=self.make_var('mean', shape=shape),
302 | variance=self.make_var('variance', shape=shape),
303 | offset=offset,
304 | scale=scale,
305 | # TODO: This is the default Caffe batch norm eps
306 | # Get the actual eps from parameters
307 | variance_epsilon=1e-5,
308 | name=name)
309 | if relu:
310 | output = tf.nn.relu(output)
311 | return output
312 |
313 | @layer
314 | def dropout(self, input, keep_prob, name):
315 | keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
316 | return tf.nn.dropout(input, keep, name=name)
317 |
--------------------------------------------------------------------------------
/network_cmu.py:
--------------------------------------------------------------------------------
1 | import network_base
2 |
3 |
4 | class CmuNetwork(network_base.BaseNetwork):
5 | def setup(self):
6 | (self.feed('image')
7 | .conv(3, 3, 64, 1, 1, name='conv1_1')
8 | .conv(3, 3, 64, 1, 1, name='conv1_2')
9 | .max_pool(2, 2, 2, 2, name='pool1_stage1')
10 | .conv(3, 3, 128, 1, 1, name='conv2_1')
11 | .conv(3, 3, 128, 1, 1, name='conv2_2')
12 | .max_pool(2, 2, 2, 2, name='pool2_stage1')
13 | .conv(3, 3, 256, 1, 1, name='conv3_1')
14 | .conv(3, 3, 256, 1, 1, name='conv3_2')
15 | .conv(3, 3, 256, 1, 1, name='conv3_3')
16 | .conv(3, 3, 256, 1, 1, name='conv3_4')
17 | .max_pool(2, 2, 2, 2, name='pool3_stage1')
18 | .conv(3, 3, 512, 1, 1, name='conv4_1')
19 | .conv(3, 3, 512, 1, 1, name='conv4_2')
20 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM')
21 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # *****
22 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1')
23 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1')
24 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1')
25 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1')
26 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1'))
27 |
28 | (self.feed('conv4_4_CPM')
29 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2')
30 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2')
31 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2')
32 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2')
33 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2'))
34 |
35 | (self.feed('conv5_5_CPM_L1',
36 | 'conv5_5_CPM_L2',
37 | 'conv4_4_CPM')
38 | .concat(3, name='concat_stage2')
39 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1')
40 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1')
41 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1')
42 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1')
43 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1')
44 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1')
45 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1'))
46 |
47 | (self.feed('concat_stage2')
48 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2')
49 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2')
50 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2')
51 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2')
52 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2')
53 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2')
54 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2'))
55 |
56 | (self.feed('Mconv7_stage2_L1',
57 | 'Mconv7_stage2_L2',
58 | 'conv4_4_CPM')
59 | .concat(3, name='concat_stage3')
60 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1')
61 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1')
62 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1')
63 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1')
64 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1')
65 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1')
66 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1'))
67 |
68 | (self.feed('concat_stage3')
69 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2')
70 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2')
71 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2')
72 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2')
73 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2')
74 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2')
75 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2'))
76 |
77 | (self.feed('Mconv7_stage3_L1',
78 | 'Mconv7_stage3_L2',
79 | 'conv4_4_CPM')
80 | .concat(3, name='concat_stage4')
81 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1')
82 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1')
83 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1')
84 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1')
85 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1')
86 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1')
87 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1'))
88 |
89 | (self.feed('concat_stage4')
90 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2')
91 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2')
92 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2')
93 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2')
94 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2')
95 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2')
96 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2'))
97 |
98 | (self.feed('Mconv7_stage4_L1',
99 | 'Mconv7_stage4_L2',
100 | 'conv4_4_CPM')
101 | .concat(3, name='concat_stage5')
102 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1')
103 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1')
104 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1')
105 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1')
106 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1')
107 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1')
108 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1'))
109 |
110 | (self.feed('concat_stage5')
111 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2')
112 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2')
113 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2')
114 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2')
115 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2')
116 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2')
117 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2'))
118 |
119 | (self.feed('Mconv7_stage5_L1',
120 | 'Mconv7_stage5_L2',
121 | 'conv4_4_CPM')
122 | .concat(3, name='concat_stage6')
123 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1')
124 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1')
125 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1')
126 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1')
127 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1')
128 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1')
129 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1'))
130 |
131 | (self.feed('concat_stage6')
132 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2')
133 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2')
134 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2')
135 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2')
136 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2')
137 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2')
138 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2'))
139 |
140 | (self.feed('Mconv7_stage6_L2',
141 | 'Mconv7_stage6_L1')
142 | .concat(3, name='concat_stage7'))
143 |
--------------------------------------------------------------------------------
/network_dsconv.py:
--------------------------------------------------------------------------------
1 | import network_base
2 | import tensorflow as tf
3 |
4 |
5 | class KakaoNetwork(network_base.BaseNetwork):
6 | def __init__(self, inputs, trainable=True, conv_width=1.0):
7 | self.conv_width = conv_width
8 | network_base.BaseNetwork.__init__(self, inputs, trainable)
9 |
10 | def setup(self):
11 | (self.feed('image')
12 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False)
13 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO
14 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2')
15 | # .max_pool(2, 2, 2, 2, name='pool1_stage1')
16 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1')
17 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2')
18 | # .max_pool(2, 2, 2, 2, name='pool2_stage1')
19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1')
20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2')
21 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3')
22 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4')
23 | # .max_pool(2, 2, 2, 2, name='pool3_stage1')
24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1')
25 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2')
26 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM')
27 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM')
28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1')
29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1')
30 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1')
31 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1')
32 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1'))
33 |
34 | (self.feed('conv4_4_CPM')
35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2')
36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2')
37 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2')
38 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2')
39 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2'))
40 |
41 | (self.feed('conv5_5_CPM_L1',
42 | 'conv5_5_CPM_L2',
43 | 'conv4_4_CPM')
44 | .concat(3, name='concat_stage2')
45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1')
46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1')
47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1')
48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1')
49 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1')
50 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1')
51 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1'))
52 |
53 | (self.feed('concat_stage2')
54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2')
55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2')
56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2')
57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2')
58 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2')
59 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2')
60 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2'))
61 |
62 | (self.feed('Mconv7_stage2_L1',
63 | 'Mconv7_stage2_L2',
64 | 'conv4_4_CPM')
65 | .concat(3, name='concat_stage3')
66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1')
67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1')
68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1')
69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1')
70 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1')
71 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1')
72 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1'))
73 |
74 | (self.feed('concat_stage3')
75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2')
76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2')
77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2')
78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2')
79 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2')
80 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2')
81 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2'))
82 |
83 | (self.feed('Mconv7_stage3_L1',
84 | 'Mconv7_stage3_L2',
85 | 'conv4_4_CPM')
86 | .concat(3, name='concat_stage4')
87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1')
88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1')
89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1')
90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1')
91 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1')
92 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1')
93 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1'))
94 |
95 | (self.feed('concat_stage4')
96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2')
97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2')
98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2')
99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2')
100 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2')
101 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2')
102 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2'))
103 |
104 | (self.feed('Mconv7_stage4_L1',
105 | 'Mconv7_stage4_L2',
106 | 'conv4_4_CPM')
107 | .concat(3, name='concat_stage5')
108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1')
109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1')
110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1')
111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1')
112 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1')
113 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1')
114 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1'))
115 |
116 | (self.feed('concat_stage5')
117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2')
118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2')
119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2')
120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2')
121 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2')
122 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2')
123 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2'))
124 |
125 | (self.feed('Mconv7_stage5_L1',
126 | 'Mconv7_stage5_L2',
127 | 'conv4_4_CPM')
128 | .concat(3, name='concat_stage6')
129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1')
130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1')
131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1')
132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1')
133 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1')
134 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1')
135 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1'))
136 |
137 | (self.feed('concat_stage6')
138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2')
139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2')
140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2')
141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2')
142 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2')
143 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2')
144 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2'))
145 |
146 | (self.feed('Mconv7_stage6_L2',
147 | 'Mconv7_stage6_L1')
148 | .concat(3, name='concat_stage7'))
149 |
--------------------------------------------------------------------------------
/network_mobilenet.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/network_mobilenet.py
--------------------------------------------------------------------------------
/pose_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid
8 |
9 | from common import CocoPart
10 |
11 |
12 | def pose_resize_shortestedge_fixed(meta):
13 | return pose_resize_shortestedge(meta, 368)
14 |
15 |
16 | def pose_resize_shortestedge_random(meta):
17 | target_size = int(400 * random.uniform(0.5, 1.1))
18 | return pose_resize_shortestedge(meta, target_size)
19 |
20 |
21 | def pose_resize_shortestedge(meta, target_size):
22 | img = meta.img
23 |
24 | # adjust image
25 | scale = target_size * 1.0 / min(meta.height, meta.width)
26 | if meta.height < meta.width:
27 | newh, neww = target_size, int(scale * meta.width + 0.5)
28 | else:
29 | newh, neww = int(scale * meta.height + 0.5), target_size
30 |
31 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
32 |
33 | pw = ph = 0
34 | if neww < 370 or newh < 370:
35 | pw = max(0, (370 - neww) // 2)
36 | ph = max(0, (370 - newh) // 2)
37 | dst = cv2.copyMakeBorder(dst, ph, ph, pw, pw, cv2.BORDER_CONSTANT, value=(255, 255, 255))
38 |
39 | # adjust meta data
40 | adjust_joint_list = []
41 | for joint in meta.joint_list:
42 | adjust_joint = []
43 | for point in joint:
44 | if point[0] <= 0 or point[1] <= 0:
45 | adjust_joint.append((-1, -1))
46 | continue
47 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph))
48 | adjust_joint_list.append(adjust_joint)
49 |
50 | meta.joint_list = adjust_joint_list
51 | meta.width, meta.height = neww + pw * 2, newh + ph * 2
52 | meta.img = dst
53 | return meta
54 |
55 |
56 | def pose_crop_center(meta):
57 | target_size = (368, 368)
58 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0
59 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0
60 |
61 | return pose_crop(meta, x, y, target_size[0], target_size[1])
62 |
63 |
64 | def pose_crop_random(meta):
65 | target_size = (368, 368)
66 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0
67 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0
68 |
69 | return pose_crop(meta, x, y, target_size[0], target_size[1])
70 |
71 |
72 | def pose_crop(meta, x, y, w, h):
73 | # adjust image
74 | target_size = (w, h)
75 |
76 | img = meta.img
77 | resized = img[y:y+target_size[1], x:x+target_size[0], :]
78 |
79 | # adjust meta data
80 | adjust_joint_list = []
81 | for joint in meta.joint_list:
82 | adjust_joint = []
83 | for point in joint:
84 | if point[0] <= 0 or point[1] <= 0:
85 | adjust_joint.append((-1, -1))
86 | continue
87 | new_x, new_y = point[0] - x, point[1] - y
88 | if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]:
89 | adjust_joint.append((-1, -1))
90 | continue
91 | adjust_joint.append((new_x, new_y))
92 | adjust_joint_list.append(adjust_joint)
93 |
94 | meta.joint_list = adjust_joint_list
95 | meta.width, meta.height = target_size
96 | meta.img = resized
97 | return meta
98 |
99 |
100 | def pose_flip(meta):
101 | r = random.uniform(0, 1.0)
102 | if r > 0.5:
103 | return meta
104 |
105 | img = meta.img
106 | img = cv2.flip(img, 1)
107 |
108 | # flip meta
109 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist,
110 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle,
111 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background]
112 | adjust_joint_list = []
113 | for joint in meta.joint_list:
114 | adjust_joint = []
115 | for cocopart in flip_list:
116 | point = joint[cocopart.value]
117 | if point[0] <= 0 or point[1] <= 0:
118 | adjust_joint.append((-1, -1))
119 | continue
120 | adjust_joint.append((meta.width - point[0], point[1]))
121 | adjust_joint_list.append(adjust_joint)
122 |
123 | meta.joint_list = adjust_joint_list
124 |
125 | meta.img = img
126 | return meta
127 |
128 |
129 | def pose_rotation(meta):
130 | deg = random.uniform(-40.0, 40.0)
131 | img = meta.img
132 |
133 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5)
134 | rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1)
135 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT)
136 | if img.ndim == 3 and ret.ndim == 2:
137 | ret = ret[:, :, np.newaxis]
138 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
139 | neww = min(neww, ret.shape[1])
140 | newh = min(newh, ret.shape[0])
141 | newx = int(center[0] - neww * 0.5)
142 | newy = int(center[1] - newh * 0.5)
143 | # print(ret.shape, deg, newx, newy, neww, newh)
144 | img = ret[newy:newy + newh, newx:newx + neww]
145 |
146 | # adjust meta data
147 | adjust_joint_list = []
148 | for joint in meta.joint_list:
149 | adjust_joint = []
150 | for point in joint:
151 | if point[0] <= 0 or point[1] <= 0:
152 | adjust_joint.append((-1, -1))
153 | continue
154 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg)
155 | adjust_joint.append((x, y))
156 | adjust_joint_list.append(adjust_joint)
157 |
158 | meta.joint_list = adjust_joint_list
159 | meta.width, meta.height = neww, newh
160 | meta.img = img
161 |
162 | return meta
163 |
164 |
165 | def _rotate_coord(shape, newxy, point, angle):
166 | angle = -1 * angle / 180.0 * math.pi
167 |
168 | ox, oy = shape
169 | px, py = point
170 |
171 | ox /= 2
172 | oy /= 2
173 |
174 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
175 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
176 |
177 | new_x, new_y = newxy
178 |
179 | qx += ox - new_x
180 | qy += oy - new_y
181 |
182 | return int(qx + 0.5), int(qy + 0.5)
183 |
184 |
185 | def pose_to_img(meta_l):
186 | return [meta_l[0].img, meta_l[0].get_heatmap(target_size=(92, 92)), meta_l[0].get_vectormap(target_size=(92, 92))]
187 |
--------------------------------------------------------------------------------
/pose_dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import struct
3 | import cv2
4 |
5 | import lmdb
6 | import logging
7 |
8 | import multiprocessing
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | from tensorpack import imgaug
13 | from tensorpack.dataflow.image import MapDataComponent, AugmentImageComponent
14 | from tensorpack.dataflow.common import BatchData, MapData
15 | from tensorpack.dataflow.prefetch import PrefetchData
16 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated
17 |
18 | from datum_pb2 import Datum
19 | from pose_augment import pose_flip, pose_rotation, pose_to_img, pose_crop_random, \
20 | pose_resize_shortestedge_random, pose_resize_shortestedge_fixed, pose_crop_center
21 |
22 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s')
23 |
24 |
25 | class CocoMetadata:
26 | # __coco_parts = 57
27 | __coco_parts = 19
28 | __coco_vecs = list(zip(
29 | [2, 9, 10, 2, 12, 13, 2, 3, 4, 3, 2, 6, 7, 6, 2, 1, 1, 15, 16],
30 | [9, 10, 11, 12, 13, 14, 3, 4, 5, 17, 6, 7, 8, 18, 1, 15, 16, 17, 18]
31 | ))
32 |
33 | @staticmethod
34 | def parse_float(four_np):
35 | assert len(four_np) == 4
36 | return struct.unpack(' th:
135 | continue
136 | heatmap[plane_idx][y][x] += math.exp(-exp)
137 | heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)
138 |
139 | def get_vectormap(self, target_size=None):
140 | vectormap = np.zeros((CocoMetadata.__coco_parts*2, self.height, self.width))
141 | countmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width))
142 | for joints in self.joint_list:
143 | for plane_idx, (j_idx1, j_idx2) in enumerate(CocoMetadata.__coco_vecs):
144 | j_idx1 -= 1
145 | j_idx2 -= 1
146 |
147 | center_from = joints[j_idx1]
148 | center_to = joints[j_idx2]
149 |
150 | if center_from[0] < 0 or center_from[1] < 0 or center_to[0] < 0 or center_to[1] < 0:
151 | continue
152 |
153 | CocoMetadata.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to)
154 |
155 | vectormap = vectormap.transpose((1, 2, 0))
156 | nonzeros = np.nonzero(countmap)
157 | for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]):
158 | if countmap[p][y][x] <= 0:
159 | continue
160 | vectormap[y][x][p*2+0] /= countmap[p][y][x]
161 | vectormap[y][x][p*2+1] /= countmap[p][y][x]
162 |
163 | if target_size:
164 | vectormap = cv2.resize(vectormap, target_size, interpolation=cv2.INTER_AREA)
165 |
166 | return vectormap
167 |
168 | @staticmethod
169 | def put_vectormap(vectormap, countmap, plane_idx, center_from, center_to, threshold=4):
170 | _, height, width = vectormap.shape[:3]
171 |
172 | vec_x = center_to[0] - center_from[0]
173 | vec_y = center_to[1] - center_from[1]
174 |
175 | min_x = max(0, int(min(center_from[0], center_to[0]) - threshold))
176 | min_y = max(0, int(min(center_from[1], center_to[1]) - threshold))
177 |
178 | max_x = min(width, int(max(center_from[0], center_to[0]) + threshold))
179 | max_y = min(height, int(max(center_from[1], center_to[1]) + threshold))
180 |
181 | norm = math.sqrt(vec_x ** 2 + vec_y ** 2)
182 | vec_x /= norm
183 | vec_y /= norm
184 |
185 | for y in range(min_y, max_y):
186 | for x in range(min_x, max_x):
187 | bec_x = x - center_from[0]
188 | bec_y = y - center_from[1]
189 | dist = abs(bec_x * vec_y - bec_y * vec_x)
190 |
191 | if dist > threshold:
192 | continue
193 |
194 | countmap[plane_idx][y][x] += 1
195 |
196 | vectormap[plane_idx*2+0][y][x] = vec_x
197 | vectormap[plane_idx*2+1][y][x] = vec_y
198 |
199 |
200 | class CocoPoseLMDB(RNGDataFlow):
201 | __valid_i = 2745
202 | __max_key = 121745
203 |
204 | @staticmethod
205 | def display_image(inp, heatmap, vectmap):
206 | fig = plt.figure()
207 | a = fig.add_subplot(2, 2, 1)
208 | a.set_title('Image')
209 | plt.imshow(CocoPoseLMDB.get_bgimg(inp))
210 |
211 | a = fig.add_subplot(2, 2, 2)
212 | a.set_title('Heatmap')
213 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5)
214 | tmp = np.amax(heatmap, axis=2)
215 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5)
216 | plt.colorbar()
217 |
218 | tmp2 = vectmap.transpose((2, 0, 1))
219 | tmp2_odd = np.amax(tmp2[::2, :, :], axis=0)
220 | tmp2_even = np.amax(tmp2[1::2, :, :], axis=0)
221 |
222 | a = fig.add_subplot(2, 2, 3)
223 | a.set_title('Vectormap-x')
224 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
225 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5)
226 | plt.colorbar()
227 |
228 | a = fig.add_subplot(2, 2, 4)
229 | a.set_title('Vectormap-y')
230 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
231 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5)
232 | plt.colorbar()
233 |
234 | plt.show()
235 |
236 | @staticmethod
237 | def get_bgimg(inp, target_size=None):
238 | if target_size:
239 | inp = cv2.resize(inp, target_size, interpolation = cv2.INTER_AREA)
240 | inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)
241 | return inp
242 |
243 | def __init__(self, path, is_train=True):
244 | self.is_train = is_train
245 | self.env = lmdb.open(path, map_size=int(1e12), readonly=True)
246 | self.txn = self.env.begin(buffers=True)
247 | pass
248 |
249 | def size(self):
250 | if self.is_train:
251 | return CocoPoseLMDB.__max_key - CocoPoseLMDB.__valid_i
252 | else:
253 | return CocoPoseLMDB.__valid_i
254 |
255 | def get_data(self):
256 | idxs = np.arange(self.size())
257 | if self.is_train:
258 | idxs += CocoPoseLMDB.__valid_i
259 | self.rng.shuffle(idxs)
260 | else:
261 | pass
262 |
263 | for idx in idxs:
264 | datum = Datum()
265 | s = self.txn.get(('%07d' % idx).encode('utf-8'))
266 | datum.ParseFromString(s)
267 | data = np.fromstring(datum.data.tobytes(), dtype=np.uint8).reshape(datum.channels, datum.height, datum.width)
268 | img = data[:3].transpose((1, 2, 0))
269 |
270 | meta = CocoMetadata(img, data[3], 4.0)
271 |
272 | yield [meta]
273 |
274 |
275 | def get_dataflow(is_train):
276 | ds = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train)
277 | if is_train:
278 | ds = MapDataComponent(ds, pose_rotation)
279 | ds = MapDataComponent(ds, pose_flip)
280 | ds = MapDataComponent(ds, pose_resize_shortestedge_random)
281 | ds = MapDataComponent(ds, pose_crop_random)
282 | ds = MapData(ds, pose_to_img)
283 | augs = [
284 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([
285 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
286 | imgaug.RandomOrderAug([
287 | imgaug.BrightnessScale((0.8, 1.2), clip=False),
288 | imgaug.Contrast((0.8, 1.2), clip=False),
289 | # imgaug.Saturation(0.4, rgb=True),
290 | ]),
291 | ]), 0.7),
292 | ]
293 | ds = AugmentImageComponent(ds, augs)
294 | else:
295 | ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
296 | ds = MapDataComponent(ds, pose_crop_center)
297 | ds = MapData(ds, pose_to_img)
298 |
299 | return ds
300 |
301 |
302 | def get_dataflow_batch(is_train, batchsize):
303 | ds = get_dataflow(is_train)
304 | ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())
305 | ds = BatchData(ds, batchsize)
306 | ds = PrefetchData(ds, 10, 4)
307 |
308 | return ds
309 |
310 |
311 | if __name__ == '__main__':
312 | df = get_dataflow(False)
313 |
314 | df.reset_state()
315 | for dp in df.get_data():
316 | CocoPoseLMDB.display_image(dp[0], dp[1], dp[2])
317 | pass
318 |
319 | logging.info('done')
320 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | lmdb
3 | https://github.com/ppwwyyxx/tensorpack
--------------------------------------------------------------------------------