├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── common.py ├── datum_pb2.py ├── images ├── person2.jpg └── person3.jpg ├── inference.py ├── models └── numpy │ └── openpose_coco.npy ├── network_base.py ├── network_cmu.py ├── network_dsconv.py ├── network_mobilenet.py ├── pose_augment.py ├── pose_dataset.py └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | models/numpy/* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-openpose 2 | 3 | Openpose from CMU implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU.** 4 | 5 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose 6 | 7 | **Features** 8 | 9 | [x] CMU's original network architecture and weights. 10 | 11 | [] Post processing from network output. 12 | 13 | [] Faster network variants using mobilenet, lcnn architecture. 14 | 15 | [] ROS Support. 16 | 17 | ## Install 18 | 19 | You need dependencies below. 20 | 21 | - python3 22 | 23 | - tensorflow 1.3 24 | 25 | - opencv 3 26 | 27 | - protobuf 28 | 29 | ## Models 30 | 31 | ### Inference Time 32 | 33 | | Dataset | Model | Description | Inference Time
1 core cpu | 34 | |---------|--------------------------|------------------------------------------------------------------------------------------|---------------:| 35 | | Coco | cmu | CMU's original version. Same network, same weights. | 3.65s / img | 36 | | Coco | dsconv | Same as the cmu version except for the **depthwise separable convolution** of mobilenet. | 0.44s / img | 37 | | Coco | mobilenet | | | | 38 | | Coco | lcnn | | | | 39 | 40 | 41 | ## Training 42 | 43 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation. 44 | 45 | This includes 46 | 47 | - scale : 0.7 ~ 1.3 48 | 49 | - rotation : -40 ~ 40 degrees 50 | 51 | - flip 52 | 53 | - cropping 54 | 55 | See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 56 | 57 | 58 | 59 | ## References 60 | 61 | ### OpenPose 62 | 63 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose 64 | 65 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation 66 | 67 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 68 | 69 | ### Mobilenet 70 | 71 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 72 | 73 | ### Libraries 74 | 75 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import Enum 3 | import math 4 | import logging 5 | 6 | import numpy as np 7 | import itertools 8 | import tensorflow as tf 9 | from scipy.ndimage.filters import maximum_filter 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 12 | 13 | 14 | class CocoPart(Enum): 15 | Nose = 0 16 | Neck = 1 17 | RShoulder = 2 18 | RElbow = 3 19 | RWrist = 4 20 | LShoulder = 5 21 | LElbow = 6 22 | LWrist = 7 23 | RHip = 8 24 | RKnee = 9 25 | RAnkle = 10 26 | LHip = 11 27 | LKnee = 12 28 | LAnkle = 13 29 | REye = 14 30 | LEye = 15 31 | REar = 16 32 | LEar = 17 33 | Background = 18 34 | 35 | CocoPairs = [ 36 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11), 37 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17) 38 | ] # = 19 39 | CocoPairsRender = CocoPairs[:-2] 40 | CocoPairsNetwork = [ 41 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5), 42 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27) 43 | ] # = 19 44 | 45 | NMS_Threshold = 0.05 46 | InterMinAbove_Threshold = 6 47 | Inter_Threashold = 0.05 48 | Min_Subset_Cnt = 3 49 | Min_Subset_Score = 0.4 50 | Max_Human = 96 51 | 52 | 53 | def non_max_suppression_tf(np_input, window_size=3, threshold=NMS_Threshold): 54 | # input: B x W x H x C 55 | under_threshold_indices = np_input < threshold 56 | np_input[under_threshold_indices] = 0 57 | np_input = np_input.reshape([1, 30, 40, 1]) 58 | pooled = tf.nn.max_pool(np_input, ksize=[1, window_size, window_size, 1], strides=[1, 1, 1, 1], padding='SAME') 59 | output = tf.where(tf.equal(np_input, pooled), np_input, tf.zeros_like(np_input)) 60 | # NOTE: if input has negative values, the suppressed values can be higher than original 61 | return output.eval().reshape([30, 40]) # output: B X W X H x C 62 | 63 | 64 | def non_max_suppression_scipy(np_input, window_size=3, threshold=NMS_Threshold): 65 | under_threshold_indices = np_input < threshold 66 | np_input[under_threshold_indices] = 0 67 | return np_input*(np_input == maximum_filter(np_input, footprint=np.ones((window_size, window_size)))) 68 | 69 | 70 | non_max_suppression = non_max_suppression_scipy 71 | 72 | 73 | def estimate_pose(heatMat, pafMat): 74 | logging.debug('nms') 75 | coords = [] 76 | for plain in heatMat: 77 | nms = non_max_suppression(plain, 3, NMS_Threshold) 78 | coords.append(np.where(nms >= NMS_Threshold)) 79 | 80 | logging.debug('estimate_pose1') 81 | connection_all = [] 82 | for (idx1, idx2), (paf_x_idx, paf_y_idx) in zip(CocoPairs, CocoPairsNetwork): 83 | connection = estimate_pose_pair(coords, idx1, idx2, pafMat[paf_x_idx], pafMat[paf_y_idx]) 84 | connection_all.extend(connection) 85 | 86 | logging.debug('estimate_pose2, connection=%d' % len(connection_all)) 87 | connection_by_human = dict() 88 | for idx, c in enumerate(connection_all): 89 | connection_by_human['human_%d' % idx] = [c] 90 | 91 | no_merge_cache = defaultdict(list) 92 | while True: 93 | is_merged = False 94 | for k1, k2 in itertools.combinations(connection_by_human.keys(), 2): 95 | if k1 == k2: 96 | continue 97 | if k2 in no_merge_cache[k1]: 98 | continue 99 | for c1, c2 in itertools.product(connection_by_human[k1], connection_by_human[k2]): 100 | if len(set(c1['uPartIdx']) & set(c2['uPartIdx'])) > 0: 101 | is_merged = True 102 | connection_by_human[k1].extend(connection_by_human[k2]) 103 | connection_by_human.pop(k2) 104 | break 105 | if is_merged: 106 | no_merge_cache.pop(k1, None) 107 | break 108 | else: 109 | no_merge_cache[k1].append(k2) 110 | 111 | if not is_merged: 112 | break 113 | 114 | logging.debug('estimate_pose3') 115 | 116 | # reject by subset count 117 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if len(v) >= Min_Subset_Cnt} 118 | 119 | # reject by subset max score 120 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if max([ii['score'] for ii in v]) >= Min_Subset_Score} 121 | 122 | logging.debug('estimate_pose4') 123 | return connection_by_human 124 | 125 | 126 | def estimate_pose_pair(coords, partIdx1, partIdx2, pafMatX, pafMatY): 127 | connection_temp = [] 128 | peak_coord1, peak_coord2 = coords[partIdx1], coords[partIdx2] 129 | 130 | cnt = 0 131 | for idx1, (y1, x1) in enumerate(zip(peak_coord1[0], peak_coord1[1])): 132 | for idx2, (y2, x2) in enumerate(zip(peak_coord2[0], peak_coord2[1])): 133 | score, count = get_score(x1, y1, x2, y2, pafMatX, pafMatY) 134 | cnt += 1 135 | if count < InterMinAbove_Threshold or score <= 0.0: 136 | continue 137 | connection_temp.append({ 138 | 'score': score, 139 | 'c1': (x1, y1), 140 | 'c2': (x2, y2), 141 | 'idx': (idx1, idx2), 142 | 'partIdx': (partIdx1, partIdx2), 143 | 'uPartIdx': ('{}-{}-{}'.format(x1, y1, partIdx1), '{}-{}-{}'.format(x2, y2, partIdx2)) 144 | }) 145 | 146 | connection = [] 147 | used_idx1, used_idx2 = [], [] 148 | for candidate in sorted(connection_temp, key=lambda x: x['score'], reverse=True): 149 | # check not connected 150 | if candidate['idx'][0] in used_idx1 or candidate['idx'][1] in used_idx2: 151 | continue 152 | connection.append(candidate) 153 | used_idx1.append(candidate['idx'][0]) 154 | used_idx2.append(candidate['idx'][1]) 155 | 156 | return connection 157 | 158 | 159 | def get_score(x1, y1, x2, y2, pafMatX, pafMatY): 160 | __num_inter = 10 161 | __num_inter_f = float(__num_inter) 162 | dx, dy = x2 - x1, y2 - y1 163 | normVec = math.sqrt(dx ** 2 + dy ** 2) 164 | 165 | if normVec < 1e-6: 166 | return 0.0, 0 167 | 168 | vx, vy = dx / normVec, dy / normVec 169 | 170 | xs = np.arange(x1, x2, dx / __num_inter_f) if x1 != x2 else [x1] * __num_inter 171 | ys = np.arange(y1, y2, dy / __num_inter_f) if y1 != y2 else [y1] * __num_inter 172 | pafXs = np.zeros(__num_inter) 173 | pafYs = np.zeros(__num_inter) 174 | for idx, (mx, my) in enumerate(zip(xs, ys)): 175 | mx, my = int(mx + 0.5), int(my + 0.5) 176 | mx, my = max(mx, 0), max(my, 0) 177 | 178 | pafXs[idx] = pafMatX[my][mx] 179 | pafYs[idx] = pafMatY[my][mx] 180 | 181 | local_scores = pafXs * vx + pafYs * vy 182 | thidxs = local_scores > Inter_Threashold 183 | 184 | return sum(local_scores*thidxs), sum(thidxs) 185 | -------------------------------------------------------------------------------- /datum_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: datum.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='datum.proto', 20 | package='', 21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _DATUM = _descriptor.Descriptor( 29 | name='Datum', 30 | full_name='Datum', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='channels', full_name='Datum.channels', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='Datum.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='width', full_name='Datum.width', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='data', full_name='Datum.data', index=3, 58 | number=4, type=12, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b(""), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='label', full_name='Datum.label', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='float_data', full_name='Datum.float_data', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='encoded', full_name='Datum.encoded', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=16, 96 | serialized_end=145, 97 | ) 98 | 99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM 100 | 101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict( 102 | DESCRIPTOR = _DATUM, 103 | __module__ = 'datum_pb2' 104 | # @@protoc_insertion_point(class_scope:Datum) 105 | )) 106 | _sym_db.RegisterMessage(Datum) 107 | 108 | 109 | # @@protoc_insertion_point(module_scope) 110 | -------------------------------------------------------------------------------- /images/person2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/images/person2.jpg -------------------------------------------------------------------------------- /images/person3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/images/person3.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tensorflow as tf 3 | import cv2 4 | import numpy as np 5 | import time 6 | import logging 7 | import argparse 8 | 9 | from tensorflow.python.client import timeline 10 | 11 | from network_cmu import CmuNetwork 12 | from common import estimate_pose, CocoPairsRender 13 | 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 15 | 16 | config = tf.ConfigProto() 17 | config.gpu_options.allocator_type = 'BFC' 18 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 19 | config.gpu_options.allow_growth = True 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='Tensorflow Openpose Inference') 24 | parser.add_argument('--imgpath', type=str, default='./images/person3.jpg') 25 | parser.add_argument('--input-width', type=int, default=320) 26 | parser.add_argument('--input-height', type=int, default=240) 27 | parser.add_argument('--stage-level', type=int, default=6) 28 | parser.add_argument('--model', type=str, default='kakao', help='cmu(original) / kakao(faster version)') 29 | args = parser.parse_args() 30 | 31 | input_node = tf.placeholder(tf.float32, shape=(None, args.input_height, args.input_width, 3), name='image') 32 | 33 | with tf.Session(config=config) as sess: 34 | if args.model == 'kakao': 35 | net = KakaoNetwork({'image': input_node}, trainable=False) 36 | net.load('./models/numpy/fastopenpose_coco_v170729.npy', sess) 37 | elif args.model == 'cmu': 38 | net = CmuNetwork({'image': input_node}, trainable=False) 39 | net.load('./models/numpy/openpose_coco.npy', sess) 40 | else: 41 | raise Exception('Invalid Mode.') 42 | 43 | logging.debug('read image+') 44 | image = cv2.imread(args.imgpath) 45 | image = cv2.resize(image, (args.input_width, args.input_height)) 46 | image = image.astype(float) 47 | image -= 128.0 48 | image /= 128.0 49 | 50 | vec = sess.run(net.get_output(name='concat_stage7'), feed_dict={'image:0': [image]}) 51 | 52 | logging.debug('inference+') 53 | a = time.time() 54 | pafMat, heatMat = sess.run( 55 | [ 56 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)), 57 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level)) 58 | ], feed_dict={'image:0': [image]} 59 | ) 60 | logging.info('inference- elapsed_time={}'.format(time.time() - a)) 61 | a = time.time() 62 | pafMat, heatMat = sess.run( 63 | [ 64 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)), 65 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level)) 66 | ], feed_dict={'image:0': [image]} 67 | ) 68 | logging.info('inference- elapsed_time={}'.format(time.time() - a)) 69 | a = time.time() 70 | 71 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 72 | run_metadata = tf.RunMetadata() 73 | pafMat, heatMat = sess.run( 74 | [ 75 | net.get_output(name='Mconv7_stage{}_L1'.format(args.stage_level)), 76 | net.get_output(name='Mconv7_stage{}_L2'.format(args.stage_level)) 77 | ], feed_dict={'image:0': [image]}, options=run_options, run_metadata=run_metadata 78 | ) 79 | logging.info('inference- elapsed_time={}'.format(time.time() - a)) 80 | 81 | tl = timeline.Timeline(run_metadata.step_stats) 82 | ctf = tl.generate_chrome_trace_format() 83 | with open('timeline.json', 'w') as f: 84 | f.write(ctf) 85 | 86 | heatMat, pafMat = heatMat[0], pafMat[0] 87 | heatMat = np.rollaxis(heatMat, 2, 0) 88 | pafMat = np.rollaxis(pafMat, 2, 0) 89 | 90 | logging.info('pickle data') 91 | with open('heatmat.pickle', 'wb') as pickle_file: 92 | pickle.dump(heatMat, pickle_file, pickle.HIGHEST_PROTOCOL) 93 | with open('pafmat.pickle', 'wb') as pickle_file: 94 | pickle.dump(pafMat, pickle_file, pickle.HIGHEST_PROTOCOL) 95 | 96 | logging.info('pose+') 97 | a = time.time() 98 | humans = estimate_pose(heatMat, pafMat) 99 | logging.info('pose- elapsed_time={}'.format(time.time() - a)) 100 | 101 | # display 102 | image = cv2.imread(args.imgpath) 103 | image_h, image_w = image.shape[:2] 104 | heat_h, heat_w = heatMat[0].shape[:2] 105 | for _, human in humans.items(): 106 | for part in human: 107 | if part['partIdx'] not in CocoPairsRender: 108 | continue 109 | center1 = (int((part['c1'][0] + 0.5) * image_w / heat_w), int((part['c1'][1] + 0.5) * image_h / heat_h)) 110 | center2 = (int((part['c2'][0] + 0.5) * image_w / heat_w), int((part['c2'][1] + 0.5) * image_h / heat_h)) 111 | cv2.circle(image, center1, 2, (255, 0, 0), thickness=3, lineType=8, shift=0) 112 | cv2.circle(image, center2, 2, (255, 0, 0), thickness=3, lineType=8, shift=0) 113 | image = cv2.line(image, center1, center2, (255, 0, 0), 1) 114 | cv2.imshow('result', image) 115 | cv2.waitKey(0) 116 | -------------------------------------------------------------------------------- /models/numpy/openpose_coco.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7dc09d77ffa32a3b5a495d2f5bda75fa60941a90847c800600142437ef606a84 3 | size 209255398 4 | -------------------------------------------------------------------------------- /network_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | DEFAULT_PADDING = 'SAME' 7 | 8 | 9 | def layer(op): 10 | ''' 11 | Decorator for composable network layers. 12 | ''' 13 | 14 | def layer_decorated(self, *args, **kwargs): 15 | # Automatically set a name if not provided. 16 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__)) 17 | # Figure out the layer inputs. 18 | if len(self.terminals) == 0: 19 | raise RuntimeError('No input variables found for layer %s.' % name) 20 | elif len(self.terminals) == 1: 21 | layer_input = self.terminals[0] 22 | else: 23 | layer_input = list(self.terminals) 24 | # Perform the operation and get the output. 25 | layer_output = op(self, layer_input, *args, **kwargs) 26 | # Add to layer LUT. 27 | self.layers[name] = layer_output 28 | # This output is now the input for the next layer. 29 | self.feed(layer_output) 30 | # Return self for chained calls. 31 | return self 32 | 33 | return layer_decorated 34 | 35 | 36 | class BaseNetwork(object): 37 | 38 | def __init__(self, inputs, trainable=True): 39 | # The input nodes for this network 40 | self.inputs = inputs 41 | # The current list of terminal nodes 42 | self.terminals = [] 43 | # Mapping from layer names to layers 44 | self.layers = dict(inputs) 45 | self.tensor_before_relu = dict() 46 | # If true, the resulting variables are set as trainable 47 | self.trainable = trainable 48 | # Switch variable for dropout 49 | self.use_dropout = tf.placeholder_with_default(tf.constant(1.0), 50 | shape=[], 51 | name='use_dropout') 52 | self.setup() 53 | 54 | def setup(self): 55 | '''Construct the network. ''' 56 | raise NotImplementedError('Must be implemented by the subclass.') 57 | 58 | def load(self, data_path, session, ignore_missing=False): 59 | ''' 60 | Load network weights. 61 | data_path: The path to the numpy-serialized network weights 62 | session: The current TensorFlow session 63 | ignore_missing: If true, serialized weights for missing layers are ignored. 64 | ''' 65 | data_dict = np.load(data_path, encoding='bytes').item() 66 | for op_name in data_dict: 67 | if isinstance(data_dict[op_name], np.ndarray): 68 | if 'RMSProp' in op_name: 69 | continue 70 | with tf.variable_scope('', reuse=True): 71 | var = tf.get_variable(op_name.replace(':0', '')) 72 | try: 73 | session.run(var.assign(data_dict[op_name])) 74 | except Exception as e: 75 | print(op_name) 76 | print(e) 77 | sys.exit(-1) 78 | else: 79 | with tf.variable_scope(op_name, reuse=True): 80 | for param_name, data in data_dict[op_name].items(): 81 | try: 82 | var = tf.get_variable(param_name.decode("utf-8")) 83 | session.run(var.assign(data)) 84 | except ValueError as e: 85 | print(e) 86 | if not ignore_missing: 87 | raise 88 | 89 | def feed(self, *args): 90 | '''Set the input(s) for the next operation by replacing the terminal nodes. 91 | The arguments can be either layer names or the actual layers. 92 | ''' 93 | assert len(args) != 0 94 | self.terminals = [] 95 | for fed_layer in args: 96 | try: 97 | is_str = isinstance(fed_layer, basestring) 98 | except NameError: 99 | is_str = isinstance(fed_layer, str) 100 | if is_str: 101 | try: 102 | fed_layer = self.layers[fed_layer] 103 | except KeyError: 104 | raise KeyError('Unknown layer name fed: %s' % fed_layer) 105 | self.terminals.append(fed_layer) 106 | return self 107 | 108 | def get_output(self, name=None): 109 | '''Returns the current network output.''' 110 | if not name: 111 | return self.terminals[-1] 112 | else: 113 | return self.layers[name] 114 | 115 | def get_tensor(self, name): 116 | if 'conv' not in name: 117 | return self.get_output(name) 118 | return self.tensor_before_relu[name] 119 | 120 | def get_unique_name(self, prefix): 121 | '''Returns an index-suffixed unique name for the given prefix. 122 | This is used for auto-generating layer names based on the type-prefix. 123 | ''' 124 | ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1 125 | return '%s_%d' % (prefix, ident) 126 | 127 | def make_var(self, name, shape, trainable=True): 128 | '''Creates a new TensorFlow variable.''' 129 | return tf.get_variable(name, shape, trainable=self.trainable & trainable, initializer=tf.contrib.layers.xavier_initializer()) 130 | 131 | def validate_padding(self, padding): 132 | '''Verifies that the padding is one of the supported ones.''' 133 | assert padding in ('SAME', 'VALID') 134 | 135 | @layer 136 | def separable_conv(self, input, k_h, k_w, c_o, stride, name, relu=True, depth_multiplier=1.0): 137 | # with slim.arg_scope([slim.batch_norm], fused=True): 138 | # skip pointwise by setting num_outputs=None 139 | output = slim.separable_convolution2d(input, 140 | num_outputs=None, 141 | stride=stride, 142 | trainable=self.trainable, 143 | depth_multiplier=depth_multiplier, 144 | kernel_size=[k_h, k_w], 145 | activation_fn=None, 146 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.0001), 147 | weights_initializer=slim.initializers.xavier_initializer(), 148 | biases_initializer=slim.init_ops.zeros_initializer(), 149 | # normalizer_fn=slim.batch_norm, 150 | scope=name + '/dsc') 151 | 152 | # bn = slim.batch_norm(bn, scope=name + '/dw_batch_norm') 153 | output = slim.convolution2d(output, 154 | c_o, 155 | stride=1, 156 | kernel_size=[1, 1], 157 | activation_fn=None, 158 | # activation_fn=tf.nn.relu if relu else None, 159 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.0001), 160 | weights_initializer=slim.initializers.xavier_initializer(), 161 | biases_initializer=slim.init_ops.zeros_initializer(), 162 | # normalizer_fn=slim.batch_norm, 163 | trainable=self.trainable, 164 | scope=name + '/pointwise_conv') 165 | # output = slim.batch_norm(pointwise_conv, scope=name + '/pw_batch_norm') 166 | 167 | self.tensor_before_relu[name] = output 168 | 169 | if relu: 170 | output = tf.nn.relu(output, name=name + '/relu') 171 | return output 172 | 173 | @layer 174 | def conv(self, 175 | input, 176 | k_h, 177 | k_w, 178 | c_o, 179 | s_h, 180 | s_w, 181 | name, 182 | relu=True, 183 | padding=DEFAULT_PADDING, 184 | group=1, 185 | trainable=True, 186 | biased=True): 187 | # Verify that the padding is acceptable 188 | self.validate_padding(padding) 189 | # Get the number of channels in the input 190 | c_i = int(input.get_shape()[-1]) 191 | # Verify that the grouping parameter is valid 192 | assert c_i % group == 0 193 | assert c_o % group == 0 194 | # Convolution for a given input and kernel 195 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 196 | with tf.variable_scope(name) as scope: 197 | kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o], trainable=self.trainable & trainable) 198 | if group == 1: 199 | # This is the common-case. Convolve the input without any further complications. 200 | output = convolve(input, kernel) 201 | else: 202 | # Split the input into groups and then convolve each of them independently 203 | input_groups = tf.split(3, group, input) 204 | kernel_groups = tf.split(3, group, kernel) 205 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] 206 | # Concatenate the groups 207 | output = tf.concat(3, output_groups) 208 | # Add the biases 209 | if biased: 210 | biases = self.make_var('biases', [c_o], trainable=self.trainable & trainable) 211 | output = tf.nn.bias_add(output, biases) 212 | 213 | self.tensor_before_relu[name] = output 214 | if relu: 215 | # ReLU non-linearity 216 | output = tf.nn.relu(output, name=scope.name) 217 | return output 218 | 219 | @layer 220 | def relu(self, input, name): 221 | return tf.nn.relu(input, name=name) 222 | 223 | @layer 224 | def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 225 | self.validate_padding(padding) 226 | return tf.nn.max_pool(input, 227 | ksize=[1, k_h, k_w, 1], 228 | strides=[1, s_h, s_w, 1], 229 | padding=padding, 230 | name=name) 231 | 232 | @layer 233 | def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 234 | self.validate_padding(padding) 235 | return tf.nn.avg_pool(input, 236 | ksize=[1, k_h, k_w, 1], 237 | strides=[1, s_h, s_w, 1], 238 | padding=padding, 239 | name=name) 240 | 241 | @layer 242 | def lrn(self, input, radius, alpha, beta, name, bias=1.0): 243 | return tf.nn.local_response_normalization(input, 244 | depth_radius=radius, 245 | alpha=alpha, 246 | beta=beta, 247 | bias=bias, 248 | name=name) 249 | 250 | @layer 251 | def concat(self, inputs, axis, name): 252 | return tf.concat(axis=axis, values=inputs, name=name) 253 | 254 | @layer 255 | def add(self, inputs, name): 256 | return tf.add_n(inputs, name=name) 257 | 258 | @layer 259 | def fc(self, input, num_out, name, relu=True): 260 | with tf.variable_scope(name) as scope: 261 | input_shape = input.get_shape() 262 | if input_shape.ndims == 4: 263 | # The input is spatial. Vectorize it first. 264 | dim = 1 265 | for d in input_shape[1:].as_list(): 266 | dim *= d 267 | feed_in = tf.reshape(input, [-1, dim]) 268 | else: 269 | feed_in, dim = (input, input_shape[-1].value) 270 | weights = self.make_var('weights', shape=[dim, num_out]) 271 | biases = self.make_var('biases', [num_out]) 272 | op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b 273 | fc = op(feed_in, weights, biases, name=scope.name) 274 | return fc 275 | 276 | @layer 277 | def softmax(self, input, name): 278 | input_shape = map(lambda v: v.value, input.get_shape()) 279 | if len(input_shape) > 2: 280 | # For certain models (like NiN), the singleton spatial dimensions 281 | # need to be explicitly squeezed, since they're not broadcast-able 282 | # in TensorFlow's NHWC ordering (unlike Caffe's NCHW). 283 | if input_shape[1] == 1 and input_shape[2] == 1: 284 | input = tf.squeeze(input, squeeze_dims=[1, 2]) 285 | else: 286 | raise ValueError('Rank 2 tensor input expected for softmax!') 287 | return tf.nn.softmax(input, name=name) 288 | 289 | @layer 290 | def batch_normalization(self, input, name, scale_offset=True, relu=False): 291 | # NOTE: Currently, only inference is supported 292 | with tf.variable_scope(name) as scope: 293 | shape = [input.get_shape()[-1]] 294 | if scale_offset: 295 | scale = self.make_var('scale', shape=shape) 296 | offset = self.make_var('offset', shape=shape) 297 | else: 298 | scale, offset = (None, None) 299 | output = tf.nn.batch_normalization( 300 | input, 301 | mean=self.make_var('mean', shape=shape), 302 | variance=self.make_var('variance', shape=shape), 303 | offset=offset, 304 | scale=scale, 305 | # TODO: This is the default Caffe batch norm eps 306 | # Get the actual eps from parameters 307 | variance_epsilon=1e-5, 308 | name=name) 309 | if relu: 310 | output = tf.nn.relu(output) 311 | return output 312 | 313 | @layer 314 | def dropout(self, input, keep_prob, name): 315 | keep = 1 - self.use_dropout + (self.use_dropout * keep_prob) 316 | return tf.nn.dropout(input, keep, name=name) 317 | -------------------------------------------------------------------------------- /network_cmu.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | 3 | 4 | class CmuNetwork(network_base.BaseNetwork): 5 | def setup(self): 6 | (self.feed('image') 7 | .conv(3, 3, 64, 1, 1, name='conv1_1') 8 | .conv(3, 3, 64, 1, 1, name='conv1_2') 9 | .max_pool(2, 2, 2, 2, name='pool1_stage1') 10 | .conv(3, 3, 128, 1, 1, name='conv2_1') 11 | .conv(3, 3, 128, 1, 1, name='conv2_2') 12 | .max_pool(2, 2, 2, 2, name='pool2_stage1') 13 | .conv(3, 3, 256, 1, 1, name='conv3_1') 14 | .conv(3, 3, 256, 1, 1, name='conv3_2') 15 | .conv(3, 3, 256, 1, 1, name='conv3_3') 16 | .conv(3, 3, 256, 1, 1, name='conv3_4') 17 | .max_pool(2, 2, 2, 2, name='pool3_stage1') 18 | .conv(3, 3, 512, 1, 1, name='conv4_1') 19 | .conv(3, 3, 512, 1, 1, name='conv4_2') 20 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM') 21 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # ***** 22 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1') 23 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1') 24 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1') 25 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 26 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 27 | 28 | (self.feed('conv4_4_CPM') 29 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2') 30 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2') 31 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2') 32 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 33 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 34 | 35 | (self.feed('conv5_5_CPM_L1', 36 | 'conv5_5_CPM_L2', 37 | 'conv4_4_CPM') 38 | .concat(3, name='concat_stage2') 39 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1') 40 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1') 41 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1') 42 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1') 43 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1') 44 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 45 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 46 | 47 | (self.feed('concat_stage2') 48 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2') 49 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2') 50 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2') 51 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2') 52 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2') 53 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 54 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 55 | 56 | (self.feed('Mconv7_stage2_L1', 57 | 'Mconv7_stage2_L2', 58 | 'conv4_4_CPM') 59 | .concat(3, name='concat_stage3') 60 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1') 61 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1') 62 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1') 63 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1') 64 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1') 65 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 66 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 67 | 68 | (self.feed('concat_stage3') 69 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2') 70 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2') 71 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2') 72 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2') 73 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2') 74 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 75 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 76 | 77 | (self.feed('Mconv7_stage3_L1', 78 | 'Mconv7_stage3_L2', 79 | 'conv4_4_CPM') 80 | .concat(3, name='concat_stage4') 81 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1') 82 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1') 83 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1') 84 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1') 85 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1') 86 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 87 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 88 | 89 | (self.feed('concat_stage4') 90 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2') 91 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2') 92 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2') 93 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2') 94 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2') 95 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 96 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 97 | 98 | (self.feed('Mconv7_stage4_L1', 99 | 'Mconv7_stage4_L2', 100 | 'conv4_4_CPM') 101 | .concat(3, name='concat_stage5') 102 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1') 103 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1') 104 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1') 105 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1') 106 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1') 107 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 108 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 109 | 110 | (self.feed('concat_stage5') 111 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2') 112 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2') 113 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2') 114 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2') 115 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2') 116 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 117 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 118 | 119 | (self.feed('Mconv7_stage5_L1', 120 | 'Mconv7_stage5_L2', 121 | 'conv4_4_CPM') 122 | .concat(3, name='concat_stage6') 123 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1') 124 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1') 125 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1') 126 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1') 127 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1') 128 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 129 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 130 | 131 | (self.feed('concat_stage6') 132 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2') 133 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2') 134 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2') 135 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2') 136 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2') 137 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 138 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 139 | 140 | (self.feed('Mconv7_stage6_L2', 141 | 'Mconv7_stage6_L1') 142 | .concat(3, name='concat_stage7')) 143 | -------------------------------------------------------------------------------- /network_dsconv.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | import tensorflow as tf 3 | 4 | 5 | class KakaoNetwork(network_base.BaseNetwork): 6 | def __init__(self, inputs, trainable=True, conv_width=1.0): 7 | self.conv_width = conv_width 8 | network_base.BaseNetwork.__init__(self, inputs, trainable) 9 | 10 | def setup(self): 11 | (self.feed('image') 12 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) 13 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO 14 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2') 15 | # .max_pool(2, 2, 2, 2, name='pool1_stage1') 16 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1') 17 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2') 18 | # .max_pool(2, 2, 2, 2, name='pool2_stage1') 19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1') 20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2') 21 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3') 22 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4') 23 | # .max_pool(2, 2, 2, 2, name='pool3_stage1') 24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1') 25 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2') 26 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM') 27 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM') 28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1') 29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1') 30 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1') 31 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 32 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 33 | 34 | (self.feed('conv4_4_CPM') 35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2') 36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2') 37 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2') 38 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 39 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 40 | 41 | (self.feed('conv5_5_CPM_L1', 42 | 'conv5_5_CPM_L2', 43 | 'conv4_4_CPM') 44 | .concat(3, name='concat_stage2') 45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1') 46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1') 47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1') 48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1') 49 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1') 50 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 51 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 52 | 53 | (self.feed('concat_stage2') 54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2') 55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2') 56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2') 57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2') 58 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2') 59 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 60 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 61 | 62 | (self.feed('Mconv7_stage2_L1', 63 | 'Mconv7_stage2_L2', 64 | 'conv4_4_CPM') 65 | .concat(3, name='concat_stage3') 66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1') 67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1') 68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1') 69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1') 70 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1') 71 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 72 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 73 | 74 | (self.feed('concat_stage3') 75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2') 76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2') 77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2') 78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2') 79 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2') 80 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 81 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 82 | 83 | (self.feed('Mconv7_stage3_L1', 84 | 'Mconv7_stage3_L2', 85 | 'conv4_4_CPM') 86 | .concat(3, name='concat_stage4') 87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1') 88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1') 89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1') 90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1') 91 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1') 92 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 93 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 94 | 95 | (self.feed('concat_stage4') 96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2') 97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2') 98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2') 99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2') 100 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2') 101 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 102 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 103 | 104 | (self.feed('Mconv7_stage4_L1', 105 | 'Mconv7_stage4_L2', 106 | 'conv4_4_CPM') 107 | .concat(3, name='concat_stage5') 108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1') 109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1') 110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1') 111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1') 112 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1') 113 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 114 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 115 | 116 | (self.feed('concat_stage5') 117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2') 118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2') 119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2') 120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2') 121 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2') 122 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 123 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 124 | 125 | (self.feed('Mconv7_stage5_L1', 126 | 'Mconv7_stage5_L2', 127 | 'conv4_4_CPM') 128 | .concat(3, name='concat_stage6') 129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1') 130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1') 131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1') 132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1') 133 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1') 134 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 135 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 136 | 137 | (self.feed('concat_stage6') 138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2') 139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2') 140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2') 141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2') 142 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2') 143 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 144 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 145 | 146 | (self.feed('Mconv7_stage6_L2', 147 | 'Mconv7_stage6_L1') 148 | .concat(3, name='concat_stage7')) 149 | -------------------------------------------------------------------------------- /network_mobilenet.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zumbalamambo/tf-openpose/4ddee7f516dd95c949dbd9f7f783d064b7b477e2/network_mobilenet.py -------------------------------------------------------------------------------- /pose_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid 8 | 9 | from common import CocoPart 10 | 11 | 12 | def pose_resize_shortestedge_fixed(meta): 13 | return pose_resize_shortestedge(meta, 368) 14 | 15 | 16 | def pose_resize_shortestedge_random(meta): 17 | target_size = int(400 * random.uniform(0.5, 1.1)) 18 | return pose_resize_shortestedge(meta, target_size) 19 | 20 | 21 | def pose_resize_shortestedge(meta, target_size): 22 | img = meta.img 23 | 24 | # adjust image 25 | scale = target_size * 1.0 / min(meta.height, meta.width) 26 | if meta.height < meta.width: 27 | newh, neww = target_size, int(scale * meta.width + 0.5) 28 | else: 29 | newh, neww = int(scale * meta.height + 0.5), target_size 30 | 31 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) 32 | 33 | pw = ph = 0 34 | if neww < 370 or newh < 370: 35 | pw = max(0, (370 - neww) // 2) 36 | ph = max(0, (370 - newh) // 2) 37 | dst = cv2.copyMakeBorder(dst, ph, ph, pw, pw, cv2.BORDER_CONSTANT, value=(255, 255, 255)) 38 | 39 | # adjust meta data 40 | adjust_joint_list = [] 41 | for joint in meta.joint_list: 42 | adjust_joint = [] 43 | for point in joint: 44 | if point[0] <= 0 or point[1] <= 0: 45 | adjust_joint.append((-1, -1)) 46 | continue 47 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph)) 48 | adjust_joint_list.append(adjust_joint) 49 | 50 | meta.joint_list = adjust_joint_list 51 | meta.width, meta.height = neww + pw * 2, newh + ph * 2 52 | meta.img = dst 53 | return meta 54 | 55 | 56 | def pose_crop_center(meta): 57 | target_size = (368, 368) 58 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0 59 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0 60 | 61 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 62 | 63 | 64 | def pose_crop_random(meta): 65 | target_size = (368, 368) 66 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0 67 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0 68 | 69 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 70 | 71 | 72 | def pose_crop(meta, x, y, w, h): 73 | # adjust image 74 | target_size = (w, h) 75 | 76 | img = meta.img 77 | resized = img[y:y+target_size[1], x:x+target_size[0], :] 78 | 79 | # adjust meta data 80 | adjust_joint_list = [] 81 | for joint in meta.joint_list: 82 | adjust_joint = [] 83 | for point in joint: 84 | if point[0] <= 0 or point[1] <= 0: 85 | adjust_joint.append((-1, -1)) 86 | continue 87 | new_x, new_y = point[0] - x, point[1] - y 88 | if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]: 89 | adjust_joint.append((-1, -1)) 90 | continue 91 | adjust_joint.append((new_x, new_y)) 92 | adjust_joint_list.append(adjust_joint) 93 | 94 | meta.joint_list = adjust_joint_list 95 | meta.width, meta.height = target_size 96 | meta.img = resized 97 | return meta 98 | 99 | 100 | def pose_flip(meta): 101 | r = random.uniform(0, 1.0) 102 | if r > 0.5: 103 | return meta 104 | 105 | img = meta.img 106 | img = cv2.flip(img, 1) 107 | 108 | # flip meta 109 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist, 110 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle, 111 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background] 112 | adjust_joint_list = [] 113 | for joint in meta.joint_list: 114 | adjust_joint = [] 115 | for cocopart in flip_list: 116 | point = joint[cocopart.value] 117 | if point[0] <= 0 or point[1] <= 0: 118 | adjust_joint.append((-1, -1)) 119 | continue 120 | adjust_joint.append((meta.width - point[0], point[1])) 121 | adjust_joint_list.append(adjust_joint) 122 | 123 | meta.joint_list = adjust_joint_list 124 | 125 | meta.img = img 126 | return meta 127 | 128 | 129 | def pose_rotation(meta): 130 | deg = random.uniform(-40.0, 40.0) 131 | img = meta.img 132 | 133 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) 134 | rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1) 135 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT) 136 | if img.ndim == 3 and ret.ndim == 2: 137 | ret = ret[:, :, np.newaxis] 138 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) 139 | neww = min(neww, ret.shape[1]) 140 | newh = min(newh, ret.shape[0]) 141 | newx = int(center[0] - neww * 0.5) 142 | newy = int(center[1] - newh * 0.5) 143 | # print(ret.shape, deg, newx, newy, neww, newh) 144 | img = ret[newy:newy + newh, newx:newx + neww] 145 | 146 | # adjust meta data 147 | adjust_joint_list = [] 148 | for joint in meta.joint_list: 149 | adjust_joint = [] 150 | for point in joint: 151 | if point[0] <= 0 or point[1] <= 0: 152 | adjust_joint.append((-1, -1)) 153 | continue 154 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg) 155 | adjust_joint.append((x, y)) 156 | adjust_joint_list.append(adjust_joint) 157 | 158 | meta.joint_list = adjust_joint_list 159 | meta.width, meta.height = neww, newh 160 | meta.img = img 161 | 162 | return meta 163 | 164 | 165 | def _rotate_coord(shape, newxy, point, angle): 166 | angle = -1 * angle / 180.0 * math.pi 167 | 168 | ox, oy = shape 169 | px, py = point 170 | 171 | ox /= 2 172 | oy /= 2 173 | 174 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) 175 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) 176 | 177 | new_x, new_y = newxy 178 | 179 | qx += ox - new_x 180 | qy += oy - new_y 181 | 182 | return int(qx + 0.5), int(qy + 0.5) 183 | 184 | 185 | def pose_to_img(meta_l): 186 | return [meta_l[0].img, meta_l[0].get_heatmap(target_size=(92, 92)), meta_l[0].get_vectormap(target_size=(92, 92))] 187 | -------------------------------------------------------------------------------- /pose_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import struct 3 | import cv2 4 | 5 | import lmdb 6 | import logging 7 | 8 | import multiprocessing 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from tensorpack import imgaug 13 | from tensorpack.dataflow.image import MapDataComponent, AugmentImageComponent 14 | from tensorpack.dataflow.common import BatchData, MapData 15 | from tensorpack.dataflow.prefetch import PrefetchData 16 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated 17 | 18 | from datum_pb2 import Datum 19 | from pose_augment import pose_flip, pose_rotation, pose_to_img, pose_crop_random, \ 20 | pose_resize_shortestedge_random, pose_resize_shortestedge_fixed, pose_crop_center 21 | 22 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 23 | 24 | 25 | class CocoMetadata: 26 | # __coco_parts = 57 27 | __coco_parts = 19 28 | __coco_vecs = list(zip( 29 | [2, 9, 10, 2, 12, 13, 2, 3, 4, 3, 2, 6, 7, 6, 2, 1, 1, 15, 16], 30 | [9, 10, 11, 12, 13, 14, 3, 4, 5, 17, 6, 7, 8, 18, 1, 15, 16, 17, 18] 31 | )) 32 | 33 | @staticmethod 34 | def parse_float(four_np): 35 | assert len(four_np) == 4 36 | return struct.unpack(' th: 135 | continue 136 | heatmap[plane_idx][y][x] += math.exp(-exp) 137 | heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0) 138 | 139 | def get_vectormap(self, target_size=None): 140 | vectormap = np.zeros((CocoMetadata.__coco_parts*2, self.height, self.width)) 141 | countmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width)) 142 | for joints in self.joint_list: 143 | for plane_idx, (j_idx1, j_idx2) in enumerate(CocoMetadata.__coco_vecs): 144 | j_idx1 -= 1 145 | j_idx2 -= 1 146 | 147 | center_from = joints[j_idx1] 148 | center_to = joints[j_idx2] 149 | 150 | if center_from[0] < 0 or center_from[1] < 0 or center_to[0] < 0 or center_to[1] < 0: 151 | continue 152 | 153 | CocoMetadata.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to) 154 | 155 | vectormap = vectormap.transpose((1, 2, 0)) 156 | nonzeros = np.nonzero(countmap) 157 | for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]): 158 | if countmap[p][y][x] <= 0: 159 | continue 160 | vectormap[y][x][p*2+0] /= countmap[p][y][x] 161 | vectormap[y][x][p*2+1] /= countmap[p][y][x] 162 | 163 | if target_size: 164 | vectormap = cv2.resize(vectormap, target_size, interpolation=cv2.INTER_AREA) 165 | 166 | return vectormap 167 | 168 | @staticmethod 169 | def put_vectormap(vectormap, countmap, plane_idx, center_from, center_to, threshold=4): 170 | _, height, width = vectormap.shape[:3] 171 | 172 | vec_x = center_to[0] - center_from[0] 173 | vec_y = center_to[1] - center_from[1] 174 | 175 | min_x = max(0, int(min(center_from[0], center_to[0]) - threshold)) 176 | min_y = max(0, int(min(center_from[1], center_to[1]) - threshold)) 177 | 178 | max_x = min(width, int(max(center_from[0], center_to[0]) + threshold)) 179 | max_y = min(height, int(max(center_from[1], center_to[1]) + threshold)) 180 | 181 | norm = math.sqrt(vec_x ** 2 + vec_y ** 2) 182 | vec_x /= norm 183 | vec_y /= norm 184 | 185 | for y in range(min_y, max_y): 186 | for x in range(min_x, max_x): 187 | bec_x = x - center_from[0] 188 | bec_y = y - center_from[1] 189 | dist = abs(bec_x * vec_y - bec_y * vec_x) 190 | 191 | if dist > threshold: 192 | continue 193 | 194 | countmap[plane_idx][y][x] += 1 195 | 196 | vectormap[plane_idx*2+0][y][x] = vec_x 197 | vectormap[plane_idx*2+1][y][x] = vec_y 198 | 199 | 200 | class CocoPoseLMDB(RNGDataFlow): 201 | __valid_i = 2745 202 | __max_key = 121745 203 | 204 | @staticmethod 205 | def display_image(inp, heatmap, vectmap): 206 | fig = plt.figure() 207 | a = fig.add_subplot(2, 2, 1) 208 | a.set_title('Image') 209 | plt.imshow(CocoPoseLMDB.get_bgimg(inp)) 210 | 211 | a = fig.add_subplot(2, 2, 2) 212 | a.set_title('Heatmap') 213 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5) 214 | tmp = np.amax(heatmap, axis=2) 215 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 216 | plt.colorbar() 217 | 218 | tmp2 = vectmap.transpose((2, 0, 1)) 219 | tmp2_odd = np.amax(tmp2[::2, :, :], axis=0) 220 | tmp2_even = np.amax(tmp2[1::2, :, :], axis=0) 221 | 222 | a = fig.add_subplot(2, 2, 3) 223 | a.set_title('Vectormap-x') 224 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 225 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5) 226 | plt.colorbar() 227 | 228 | a = fig.add_subplot(2, 2, 4) 229 | a.set_title('Vectormap-y') 230 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 231 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5) 232 | plt.colorbar() 233 | 234 | plt.show() 235 | 236 | @staticmethod 237 | def get_bgimg(inp, target_size=None): 238 | if target_size: 239 | inp = cv2.resize(inp, target_size, interpolation = cv2.INTER_AREA) 240 | inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB) 241 | return inp 242 | 243 | def __init__(self, path, is_train=True): 244 | self.is_train = is_train 245 | self.env = lmdb.open(path, map_size=int(1e12), readonly=True) 246 | self.txn = self.env.begin(buffers=True) 247 | pass 248 | 249 | def size(self): 250 | if self.is_train: 251 | return CocoPoseLMDB.__max_key - CocoPoseLMDB.__valid_i 252 | else: 253 | return CocoPoseLMDB.__valid_i 254 | 255 | def get_data(self): 256 | idxs = np.arange(self.size()) 257 | if self.is_train: 258 | idxs += CocoPoseLMDB.__valid_i 259 | self.rng.shuffle(idxs) 260 | else: 261 | pass 262 | 263 | for idx in idxs: 264 | datum = Datum() 265 | s = self.txn.get(('%07d' % idx).encode('utf-8')) 266 | datum.ParseFromString(s) 267 | data = np.fromstring(datum.data.tobytes(), dtype=np.uint8).reshape(datum.channels, datum.height, datum.width) 268 | img = data[:3].transpose((1, 2, 0)) 269 | 270 | meta = CocoMetadata(img, data[3], 4.0) 271 | 272 | yield [meta] 273 | 274 | 275 | def get_dataflow(is_train): 276 | ds = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train) 277 | if is_train: 278 | ds = MapDataComponent(ds, pose_rotation) 279 | ds = MapDataComponent(ds, pose_flip) 280 | ds = MapDataComponent(ds, pose_resize_shortestedge_random) 281 | ds = MapDataComponent(ds, pose_crop_random) 282 | ds = MapData(ds, pose_to_img) 283 | augs = [ 284 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([ 285 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 286 | imgaug.RandomOrderAug([ 287 | imgaug.BrightnessScale((0.8, 1.2), clip=False), 288 | imgaug.Contrast((0.8, 1.2), clip=False), 289 | # imgaug.Saturation(0.4, rgb=True), 290 | ]), 291 | ]), 0.7), 292 | ] 293 | ds = AugmentImageComponent(ds, augs) 294 | else: 295 | ds = MapDataComponent(ds, pose_resize_shortestedge_fixed) 296 | ds = MapDataComponent(ds, pose_crop_center) 297 | ds = MapData(ds, pose_to_img) 298 | 299 | return ds 300 | 301 | 302 | def get_dataflow_batch(is_train, batchsize): 303 | ds = get_dataflow(is_train) 304 | ds = PrefetchData(ds, 1000, multiprocessing.cpu_count()) 305 | ds = BatchData(ds, batchsize) 306 | ds = PrefetchData(ds, 10, 4) 307 | 308 | return ds 309 | 310 | 311 | if __name__ == '__main__': 312 | df = get_dataflow(False) 313 | 314 | df.reset_state() 315 | for dp in df.get_data(): 316 | CocoPoseLMDB.display_image(dp[0], dp[1], dp[2]) 317 | pass 318 | 319 | logging.info('done') 320 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | lmdb 3 | https://github.com/ppwwyyxx/tensorpack --------------------------------------------------------------------------------