├── .gitignore ├── LICENSE ├── README.md ├── analysis.py ├── config.py ├── data.py ├── demo.py ├── layers.py ├── model.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | .vscode 4 | 5 | dataset 6 | checkpoints/ 7 | checkpoints* 8 | backup 9 | 10 | *.jpg 11 | *.png 12 | *.svg 13 | *.pdf 14 | *.tflite 15 | 16 | # training records 17 | train_record.json 18 | train_record_pretrain.json 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Tensorflow Implementation for BlazePose 2 | 3 | This is a third-party TensorFlow implementation for BlazePose. 4 | 5 | The original paper is "BlazePose: On-device Real-time Body Pose tracking" by Valentin Bazarevsky, Ivan Grishchenko, Karthik Raveendran, Tyler Zhu, Fan Zhang, and Matthias Grundmann. Available on [arXiv](https://arxiv.org/abs/2006.10204). 6 | 7 | Since I do not have the full settings provided by the original author. There might be something different from the original paper. Please forgive me if I write something wrong. 8 | 9 | Works are in process. The current version does not stand for the full functions. 10 | 11 | ## Requirements 12 | 13 | It is highly recommended to run this code on Ubuntu 20.04 with an Anaconda environment. Python 3.7.9 and 3.8.5 has been tested. CUDA version has been tested with `10.1` and `11.1`. 14 | 15 | ``` 16 | tensorflow >= 2.3 17 | numpy 18 | matplotlib 19 | scipy 20 | ``` 21 | 22 | ## Train (from random initialization) 23 | 24 | 1. Download LSP dataset. (If you already have, just skip this step) 25 | 26 | If you are the first time to run this code on Linux platform, the LSP dataset will be downloaded. 27 | 28 | However, if you are using Microsoft Windows 10, please download and unzip the dataset manually. 29 | 30 | 2. Pre-train the heatmap branch. 31 | 32 | Edit training settings in `config.py`. Set `train_mode = 0` and `continue_train = 0`. 33 | 34 | Then, run `python3 train.py`. 35 | 36 | 3. Fine-tune for the joint regression branch. 37 | 38 | Set `train_mode = 1`, `continue_train = 0` and `best_pre_train` with the num of epoch where the training loss drops but testing accuracy achieve the optimal. 39 | 40 | Then, run `python3 train.py`. 41 | 42 | ## Continue training 43 | 44 | If you have just suffered from an unexpected power off or press `Ctrl + C` to cut up training, and then you want to continue your work, follow the following steps: 45 | 46 | 1. Edit `config.py`, modify `continue_train` to the epoch where you want to start with. 47 | 48 | For continue pre-train, simply set the value of `continue_train`. 49 | 50 | For fine-tuning, set `train_mode = 1`, and `continue_train` to num of epoches where the training loss drops but testing accuracy achieve the optimal. 51 | 52 | 2. Run `python3 train.py`. 53 | 54 | 3. If you are running pre-train just now, after that, just set `train_mode = 1`, `continue_train = 0`, and `best_pre_train` with the num of epoches where the training loss drops but testing accuracy achieve the optimal, and run `python3 train.py`. 55 | 56 | ## Test 57 | 58 | 1. Edit `config.py`. 59 | 60 | If you want to see the visualized heatmaps, set `train_mode = 0`. 61 | 62 | For skeleton joint results, set `train_mode = 1`. 63 | 64 | 2. Set `epoch_to_test` to the epoch you would like to test. 65 | 66 | 3. If you set `train_mode = 0`, you should also set `vis_img_id` to select an image. 67 | 68 | 4. For `train_mode = 1`, evaluation mode should be set. 69 | 70 | Set `eval_mode = 1` if you want to get PCKh@0.5 score, or `eval_mode = 0` if you want to get the result images. 71 | 72 | 5. If you are the first time to set `train_mode = 1` and `eval_mode = 0`, open terminal: 73 | 74 | ```bash 75 | mkdir result 76 | ``` 77 | 78 | 6. Run `python3 test.py`. 79 | 80 | For `train_mode = 0`, you will see the heatmap. 81 | 82 | For `train_mode = 1` and `eval_mode = 0`, the tested images will be written in `result` dictionary. 83 | 84 | For `train_mode = 1` and `eval_mode = 1`, PCKh@0.5 scores of each joint and the average score will be shown. 85 | 86 | ## Online camera demo 87 | 88 | 1. Install YOLOv4 package using `pip3 install yolov4`. The trained model of YOLOv4 should also be downloaded. 89 | 90 | 2. Finish training on your dataset. 91 | 92 | 3. Set `train_mode = 1` and connect to a USB camera. 93 | 94 | 4. Run `python3 demo.py`. You should allow one or a few person(s) standing in front of the camera. 95 | 96 | ## TODOs 97 | 98 | - [x] Basic code for network model BlazePose. 99 | 100 | - [x] Implementation of Channel Attention layer. 101 | 102 | - [x] Functions 103 | 104 | - [x] Two-stage training (pre-train and fine-tune). 105 | 106 | - [x] Continue training from a custom epoch of checkpoint. 107 | 108 | - [x] Save the training record (loss and accuracy for training and validation set) to json file. 109 | 110 | - [x] More explicit training settings (for fine-tune and continue training). 111 | 112 | - [x] Calculate PCKh@0.5 scores. 113 | 114 | - [ ] Dataset and preprocess. 115 | 116 | - [x] LSP dataset train and validation. 117 | 118 | - [ ] LSPET dataset. 119 | 120 | - [ ] Custom dataset. 121 | 122 | - [ ] Implementation of pose tracking on video. 123 | 124 | - [x] Online camera demo. 125 | 126 | ## Reference 127 | 128 | If the original paper helps your research, you can cite this paper in the LaTex file with: 129 | 130 | ```tex 131 | @article{Bazarevsky2020BlazePoseOR, 132 | title={BlazePose: On-device Real-time Body Pose tracking}, 133 | author={Valentin Bazarevsky and I. Grishchenko and K. Raveendran and Tyler Lixuan Zhu and Fangfang Zhang and M. Grundmann}, 134 | journal={ArXiv}, 135 | year={2020}, 136 | volume={abs/2006.10204} 137 | } 138 | ``` 139 | 140 | ## Comments 141 | 142 | Please feel free to [submit an issue](https://github.com/jiang-du/BlazePose-tensorflow/issues) or [pull a request](https://github.com/jiang-du/BlazePose-tensorflow/pulls). 143 | -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | # save the training record 2 | import json 3 | import numpy as np 4 | from config import json_name 5 | 6 | def save_record(train_loss_results, train_accuracy_results, val_accuracy_results): 7 | train_record = dict() 8 | train_record["train_loss"] = list(np.float64(train_loss_results)) 9 | train_record["train_accuracy"] = list(np.float64(train_accuracy_results)) 10 | train_record["val_accuracy"] = list(np.float64(val_accuracy_results)) 11 | with open(json_name, 'w') as f: 12 | json.dump(train_record, f) 13 | return 0 14 | 15 | def load_record(): 16 | with open(json_name, 'r') as f: 17 | train_record = json.load(f) 18 | # convert list to numpy 19 | train_loss_results = np.float64(train_record["train_loss"]) 20 | train_accuracy_results = np.float64(train_record["train_accuracy"]) 21 | val_accuracy_results = np.float64(train_record["val_accuracy"]) 22 | return train_loss_results, train_accuracy_results, val_accuracy_results 23 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | num_joints = 14 # lsp dataset 2 | 3 | batch_size = 256 # 256 is best for RTX 3090 (per GPU) 4 | total_epoch = 500 5 | gpu_dynamic_memory = 0 6 | gaussian_sigma = 4 7 | 8 | # Train mode: 0-pre-train, 1-finetune 9 | train_mode = 1 10 | 11 | # Evaluation mode: 0-get result images, 1-get PCK score only 12 | eval_mode = 1 13 | 14 | show_batch_loss = 0 15 | continue_train = 0 # 0 for random initialize, >0 for num epoch 16 | 17 | if train_mode: 18 | best_pre_train = 434 # num of epoch where the training loss drops but testing accuracy achieve the optimal 19 | 20 | # for test only 21 | epoch_to_test = 199 22 | # for test the heatmap only 23 | vis_img_id = 1797 24 | 25 | json_name = "train_record.json" if train_mode else "train_record_pretrain.json" 26 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | #!~/miniconda3/envs/tf2/bin/python 2 | import os 3 | import platform 4 | import numpy as np 5 | import tensorflow as tf 6 | from scipy.io import loadmat 7 | from config import num_joints, batch_size, gaussian_sigma, gpu_dynamic_memory 8 | 9 | DataURL = "https://sam.johnson.io/research/lsp_dataset.zip" 10 | 11 | # guassian generation 12 | def getGaussianMap(joint = (16, 16), heat_size = 128, sigma = 2): 13 | # by default, the function returns a gaussian map with range [0, 1] of typr float32 14 | heatmap = np.zeros((heat_size, heat_size),dtype=np.float32) 15 | tmp_size = sigma * 3 16 | ul = [int(joint[0] - tmp_size), int(joint[1] - tmp_size)] 17 | br = [int(joint[0] + tmp_size + 1), int(joint[1] + tmp_size + 1)] 18 | size = 2 * tmp_size + 1 19 | x = np.arange(0, size, 1, np.float32) 20 | y = x[:, np.newaxis] 21 | x0 = y0 = size // 2 22 | g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * (sigma ** 2))) 23 | g.shape 24 | # usable gaussian range 25 | g_x = max(0, -ul[0]), min(br[0], heat_size) - ul[0] 26 | g_y = max(0, -ul[1]), min(br[1], heat_size) - ul[1] 27 | # image range 28 | img_x = max(0, ul[0]), min(br[0], heat_size) 29 | img_y = max(0, ul[1]), min(br[1], heat_size) 30 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 31 | """ 32 | heatmap *= 255 33 | heatmap = heatmap.astype(np.uint8) 34 | cv2.imshow("debug", heatmap) 35 | cv2.waitKey(0) 36 | cv2.destroyAllWindows() 37 | """ 38 | return heatmap 39 | 40 | if gpu_dynamic_memory: 41 | # Limit GPU memory usage if necessary 42 | gpus = tf.config.experimental.list_physical_devices('GPU') 43 | if gpus: 44 | try: 45 | # Currently, memory growth needs to be the same across GPUs 46 | for gpu in gpus: 47 | tf.config.experimental.set_memory_growth(gpu, True) 48 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 49 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 50 | except RuntimeError as e: 51 | # Memory growth must be set before GPUs have been initialized 52 | print(e) 53 | 54 | if not os.path.exists("./dataset"): 55 | os.system("mkdir dataset") 56 | 57 | if os.path.exists("./dataset/lsp/joints.mat"): 58 | print("Found lsp dataset.") 59 | else: 60 | if not os.path.isfile("./dataset/lsp_dataset.zip"): 61 | # try to download 62 | if os.system("wget " + DataURL): 63 | # abnormal when download with "wget" 64 | if platform.system() == 'Linux': 65 | # 下载失败,没有网,训练个锤子啊 66 | raise Exception("No Internet. Would you like to train with your hammer?") 67 | elif platform.system() == 'Windows': 68 | raise Exception('You should firstly install "wget" if you run on Windows.') 69 | else: 70 | raise Exception("Unsupported platform. Please run on Ubuntu 20.04.") 71 | else: 72 | print("Finish download lsp dataset.") 73 | 74 | # try to uncompress 75 | if platform.system() == 'Linux': 76 | if os.system("unzip dataset/lsp_dataset.zip -d dataset/lsp/"): 77 | raise Exception("Unzip Runtime Error. You can try 'rm ./dataset/lsp_dataset.zip' and run again.") 78 | else: 79 | print("Finish uncompress lsp dataset.") 80 | elif platform.system() == 'Windows': 81 | raise Exception('Please unzip files manually on windows.') 82 | else: 83 | raise Exception("Unsupported platform. Please run on Ubuntu 20.04.") 84 | print("Finish uncompress lsp dataset.") 85 | 86 | # read annotations 87 | annotations = loadmat("./dataset/lsp/joints.mat") 88 | label = annotations["joints"].swapaxes(0, 2) # shape (3, 14, 2000) -> (2000, 14, 3) 89 | 90 | # read images 91 | data = np.zeros([2000, 256, 256, 3]) 92 | heatmap_set = np.zeros((2000, 128, 128, num_joints), dtype=np.float32) 93 | print("Reading dataset...") 94 | for i in range(2000): 95 | FileName = "./dataset/lsp/images/im%04d.jpg" % (i + 1) 96 | img = tf.io.read_file(FileName) 97 | img = tf.image.decode_image(img) 98 | img_shape = img.shape 99 | # Attention here img_shape[0] is height and [1] is width 100 | label[i, :, 0] *= (256 / img_shape[1]) 101 | label[i, :, 1] *= (256 / img_shape[0]) 102 | data[i] = tf.image.resize(img, [256, 256]) 103 | # generate heatmap set 104 | for j in range(num_joints): 105 | _joint = (label[i, j, 0:2] // 2).astype(np.uint16) 106 | # print(_joint) 107 | heatmap_set[i, :, :, j] = getGaussianMap(joint = _joint, heat_size = 128, sigma = gaussian_sigma) 108 | # print status 109 | if not i%(2000//80): 110 | print(">", end='') 111 | 112 | # dataset 113 | print("\nGenerating training and testing data batches...") 114 | train_dataset = tf.data.Dataset.from_tensor_slices((data[0:1000], heatmap_set[0:1000])) 115 | test_dataset = tf.data.Dataset.from_tensor_slices((data[1000:-1], heatmap_set[1000:-1])) 116 | 117 | SHUFFLE_BUFFER_SIZE = 1000 118 | train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size) 119 | test_dataset = test_dataset.batch(batch_size) 120 | 121 | # Finetune 122 | finetune_train = tf.data.Dataset.from_tensor_slices((data[0:1000], label[0:1000])) 123 | finetune_validation = tf.data.Dataset.from_tensor_slices((data[1000:-1], label[1000:-1])) 124 | 125 | finetune_train = finetune_train.shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size) 126 | finetune_validation = finetune_validation.batch(batch_size) 127 | 128 | print("Done.") 129 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from model import BlazePose 4 | from config import total_epoch, train_mode 5 | 6 | model = BlazePose() 7 | 8 | checkpoint_path = "checkpoints/cp-{epoch:04d}.ckpt" 9 | checkpoint_dir = os.path.dirname(checkpoint_path) 10 | 11 | # model.evaluate(test_dataset) 12 | model.load_weights(checkpoint_path.format(epoch=199)) 13 | 14 | """ 15 | 0-Right ankle 16 | 1-Right knee 17 | 2-Right hip 18 | 3-Left hip 19 | 4-Left knee 20 | 5-Left ankle 21 | 6-Right wrist 22 | 7-Right elbow 23 | 8-Right shoulder 24 | 9-Left shoulder 25 | 10-Left elbow 26 | 11-Left wrist 27 | 12-Neck 28 | 13-Head top 29 | """ 30 | 31 | assert train_mode 32 | 33 | import cv2 34 | import numpy as np 35 | from yolov4.tf import YOLOv4 36 | yolo = YOLOv4() 37 | yolo.classes = "./../yolov4/coco.names" 38 | yolo.make_model() 39 | yolo.load_weights("./../yolov4/yolov4.weights", weights_type="yolo") 40 | cap = cv2.VideoCapture(0) 41 | while(1): 42 | ret, frame = cap.read() 43 | if not ret: 44 | break 45 | # ------ YOLO detection for the boxes ------ 46 | d = yolo.predict(frame) 47 | # len(d): num of objects 48 | # for d[i], dims 0 to 3 is position ranged [0, 1] -- center_x, center_y, w, h; dim 4 is class (0 for person); dim 5 is score. 49 | img_size = frame.shape # (480, 640, 3) 50 | 51 | # ------ get boxes ------ 52 | for bbox in d: 53 | if (bbox[5]>=0.4) and (bbox[4]==0): 54 | # high score person 55 | for i in (1, 3): 56 | bbox[i] *= img_size[0] 57 | for i in (0, 2): 58 | bbox[i] *= img_size[1] 59 | # box position 60 | c_x = int(bbox[0]) 61 | c_y = int(bbox[1]) 62 | half_w = int(bbox[2] / 2) 63 | half_h = int(bbox[3] / 2) 64 | top_left = (c_x - half_w, c_y - half_h) 65 | bottom_right = (c_x + half_w, c_y + half_h) 66 | 67 | # ------ draw the box ------ 68 | cv2.rectangle(frame, top_left, bottom_right, (255, 0, 0), 1) 69 | 70 | # ------ skeleton detection ------ 71 | img = frame[top_left[1]:bottom_right[1]][top_left[0]:bottom_right[0]] 72 | img = cv2.resize(frame, (256, 256)) 73 | y = model(tf.convert_to_tensor(np.expand_dims(img, axis=0), dtype=tf.float32)) 74 | skeleton = y[0].numpy() # x,y range in 256 75 | # normalize pose 76 | skeleton[:, 0] = skeleton[:, 0] * half_w / 128 + top_left[0] 77 | skeleton[:, 1] = skeleton[:, 1] * half_h / 128 + top_left[1] 78 | skeleton = skeleton.astype(np.int16) 79 | # draw the joints 80 | for i in range(14): 81 | cv2.circle(frame, center=tuple(skeleton[i][0:2]), radius=2, color=(0, 255, 0), thickness=2) 82 | # draw the lines 83 | for j in ((13, 12), (12, 8), (12, 9), (8, 7), (7, 6), (9, 10), (10, 11), (2, 3), (2, 1), (1, 0), (3, 4), (4, 5)): 84 | cv2.line(frame, tuple(skeleton[j[0]][0:2]), tuple(skeleton[j[1]][0:2]), color=(0, 0, 255), thickness=1) 85 | # solve the mid point of the hips 86 | cv2.line(frame, tuple(skeleton[12][0:2]), tuple(skeleton[2][0:2] // 2 + skeleton[3][0:2] // 2), color=(0, 0, 255), thickness=1) 87 | 88 | frame = cv2.resize(frame, (1280, 960)) 89 | cv2.imshow("Demo", frame) 90 | if cv2.waitKey(1) & 0xFF == ord('q'): 91 | break 92 | pass 93 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class ChannelPadding(tf.keras.layers.Layer): 4 | def __init__(self, channels): 5 | super(ChannelPadding, self).__init__() 6 | self.channels = channels 7 | 8 | def build(self, input_shapes): 9 | self.pad_shape = tf.constant([[0, 0], [0, 0], [0, 0], [0, self.channels - input_shapes[-1]]]) 10 | 11 | def call(self, input): 12 | return tf.pad(input, self.pad_shape) 13 | 14 | class BlazeBlock(tf.keras.Model): 15 | def __init__(self, block_num = 3, channel = 48, channel_padding = 1): 16 | super(BlazeBlock, self).__init__() 17 | # <----- downsample -----> 18 | self.downsample_a = tf.keras.models.Sequential([ 19 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=(2, 2), padding='same', activation=None), 20 | tf.keras.layers.Conv2D(filters=channel, kernel_size=1, activation=None) 21 | ]) 22 | if channel_padding: 23 | self.downsample_b = tf.keras.models.Sequential([ 24 | tf.keras.layers.MaxPool2D(pool_size=(2, 2)), 25 | # # 因为我实在是不会写channel padding的实现,所以这里用了个1x1的卷积来凑个数,嘤~ 26 | # tf.keras.layers.Conv2D(filters=channel, kernel_size=1, activation=None) 27 | # Update: 最终,还是自己写出来了,嘤~ 28 | ChannelPadding(channels=channel) 29 | ]) 30 | else: 31 | # channel number invariance 32 | self.downsample_b = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) 33 | # <----- separable convolution -----> 34 | self.conv = list() 35 | for i in range(block_num): 36 | self.conv.append(tf.keras.models.Sequential([ 37 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding='same', activation=None), 38 | tf.keras.layers.Conv2D(filters=channel, kernel_size=1, activation=None) 39 | ])) 40 | 41 | def call(self, x): 42 | x = tf.keras.activations.relu(self.downsample_a(x) + self.downsample_b(x)) 43 | for i in range(len(self.conv)): 44 | x = tf.keras.activations.relu(x + self.conv[i](x)) 45 | return x -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from layers import BlazeBlock 3 | from config import num_joints, train_mode 4 | 5 | class BlazePose(tf.keras.Model): 6 | def __init__(self): 7 | super(BlazePose, self).__init__() 8 | self.conv1 = tf.keras.layers.Conv2D( 9 | filters=24, kernel_size=3, strides=(2, 2), padding='same', activation='relu' 10 | ) 11 | 12 | # separable convolution (MobileNet) 13 | self.conv2_1 = tf.keras.models.Sequential([ 14 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding='same', activation=None), 15 | tf.keras.layers.Conv2D(filters=24, kernel_size=1, activation=None) 16 | ]) 17 | self.conv2_2 = tf.keras.models.Sequential([ 18 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding='same', activation=None), 19 | tf.keras.layers.Conv2D(filters=24, kernel_size=1, activation=None) 20 | ]) 21 | 22 | # ---------- Heatmap branch ---------- 23 | self.conv3 = BlazeBlock(block_num = 3, channel = 48) # input res: 128 24 | self.conv4 = BlazeBlock(block_num = 4, channel = 96) # input res: 64 25 | self.conv5 = BlazeBlock(block_num = 5, channel = 192) # input res: 32 26 | self.conv6 = BlazeBlock(block_num = 6, channel = 288) # input res: 16 27 | 28 | self.conv7a = tf.keras.models.Sequential([ 29 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 30 | tf.keras.layers.Conv2D(filters=48, kernel_size=1, activation="relu"), 31 | tf.keras.layers.UpSampling2D(size=(2, 2), interpolation="bilinear") 32 | ]) 33 | self.conv7b = tf.keras.models.Sequential([ 34 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 35 | tf.keras.layers.Conv2D(filters=48, kernel_size=1, activation="relu") 36 | ]) 37 | 38 | self.conv8a = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation="bilinear") 39 | self.conv8b = tf.keras.models.Sequential([ 40 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 41 | tf.keras.layers.Conv2D(filters=48, kernel_size=1, activation="relu") 42 | ]) 43 | 44 | self.conv9a = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation="bilinear") 45 | self.conv9b = tf.keras.models.Sequential([ 46 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 47 | tf.keras.layers.Conv2D(filters=48, kernel_size=1, activation="relu") 48 | ]) 49 | 50 | self.conv10a = tf.keras.models.Sequential([ 51 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 52 | tf.keras.layers.Conv2D(filters=8, kernel_size=1, activation="relu"), 53 | tf.keras.layers.UpSampling2D(size=(2, 2), interpolation="bilinear") 54 | ]) 55 | self.conv10b = tf.keras.models.Sequential([ 56 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 57 | tf.keras.layers.Conv2D(filters=8, kernel_size=1, activation="relu") 58 | ]) 59 | 60 | # the output layer for heatmap and offset 61 | self.conv11 = tf.keras.models.Sequential([ 62 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 63 | tf.keras.layers.Conv2D(filters=8, kernel_size=1, activation="relu"), 64 | # heatmap 65 | tf.keras.layers.Conv2D(filters=num_joints, kernel_size=3, padding="same", activation=None) 66 | ]) 67 | 68 | # ---------- Regression branch ---------- 69 | # shape = (1, 64, 64, 48) 70 | self.conv12a = BlazeBlock(block_num = 4, channel = 96) # input res: 64 71 | self.conv12b = tf.keras.models.Sequential([ 72 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 73 | tf.keras.layers.Conv2D(filters=96, kernel_size=1, activation="relu") 74 | ]) 75 | 76 | self.conv13a = BlazeBlock(block_num = 5, channel = 192) # input res: 32 77 | self.conv13b = tf.keras.models.Sequential([ 78 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 79 | tf.keras.layers.Conv2D(filters=192, kernel_size=1, activation="relu") 80 | ]) 81 | 82 | self.conv14a = BlazeBlock(block_num = 6, channel = 288) # input res: 16 83 | self.conv14b = tf.keras.models.Sequential([ 84 | tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding="same", activation=None), 85 | tf.keras.layers.Conv2D(filters=288, kernel_size=1, activation="relu") 86 | ]) 87 | 88 | self.conv15 = tf.keras.models.Sequential([ 89 | BlazeBlock(block_num = 7, channel = 288, channel_padding = 0), 90 | BlazeBlock(block_num = 7, channel = 288, channel_padding = 0) 91 | ]) 92 | 93 | self.conv16 = tf.keras.models.Sequential([ 94 | tf.keras.layers.GlobalAveragePooling2D(), 95 | # shape = (1, 1, 1, 288) 96 | tf.keras.layers.Dense(units=3*num_joints, activation=None), 97 | tf.keras.layers.Reshape((num_joints, 3)) 98 | ]) 99 | 100 | def call(self, x): 101 | # shape = (1, 256, 256, 3) 102 | x = self.conv1(x) 103 | # shape = (1, 128, 128, 24) 104 | x = x + self.conv2_1(x) # <-- skip connection 105 | x = tf.keras.activations.relu(x) 106 | # --> I don't know why the relu layer is put after skip connection? 107 | x = x + self.conv2_2(x) 108 | y0 = tf.keras.activations.relu(x) 109 | 110 | # shape = (1, 128, 128, 24) 111 | y1 = self.conv3(y0) 112 | y2 = self.conv4(y1) 113 | y3 = self.conv5(y2) 114 | y4 = self.conv6(y3) 115 | # shape = (1, 8, 8, 288) 116 | 117 | x = self.conv7a(y4) + self.conv7b(y3) 118 | x = self.conv8a(x) + self.conv8b(y2) 119 | # shape = (1, 32, 32, 96) 120 | x = self.conv9a(x) + self.conv9b(y1) 121 | # shape = (1, 64, 64, 48) 122 | y = self.conv10a(x) + self.conv10b(y0) 123 | # shape = (1, 128, 128, 8) 124 | heatmap = tf.keras.activations.sigmoid(self.conv11(y)) 125 | 126 | # ---------- regression branch ---------- 127 | x = self.conv12a(x) + self.conv12b(y2) 128 | # shape = (1, 32, 32, 96) 129 | x = self.conv13a(x) + self.conv13b(y3) 130 | # shape = (1, 16, 16, 192) 131 | x = self.conv14a(x) + self.conv14b(y4) 132 | # shape = (1, 8, 8, 288) 133 | x = self.conv15(x) 134 | # shape = (1, 2, 2, 288) 135 | joints = self.conv16(x) 136 | result = [heatmap, joints] 137 | return result[train_mode] # heatmap, joints 138 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | from model import BlazePose 5 | from config import total_epoch, train_mode, eval_mode, epoch_to_test 6 | from data import test_dataset, label, data 7 | 8 | def Eclidian2(a, b): 9 | # Calculate the square of Eclidian distance 10 | assert len(a)==len(b) 11 | summer = 0 12 | for i in range(len(a)): 13 | summer += (a[i] - b[i]) ** 2 14 | return summer 15 | 16 | model = BlazePose() 17 | # optimizer = tf.keras.optimizers.Adam() 18 | model.compile(optimizer=tf.keras.optimizers.Adam(), 19 | loss=tf.keras.losses.MeanSquaredError(), 20 | metrics=[tf.keras.metrics.MeanSquaredError()]) 21 | 22 | checkpoint_path = "checkpoints/cp-{epoch:04d}.ckpt" 23 | checkpoint_dir = os.path.dirname(checkpoint_path) 24 | 25 | # model.evaluate(test_dataset) 26 | model.load_weights(checkpoint_path.format(epoch=epoch_to_test)) 27 | 28 | """ 29 | 0-Right ankle 30 | 1-Right knee 31 | 2-Right hip 32 | 3-Left hip 33 | 4-Left knee 34 | 5-Left ankle 35 | 6-Right wrist 36 | 7-Right elbow 37 | 8-Right shoulder 38 | 9-Left shoulder 39 | 10-Left elbow 40 | 11-Left wrist 41 | 12-Neck 42 | 13-Head top 43 | """ 44 | 45 | if train_mode: 46 | y = np.zeros((2000, 14, 3)).astype(np.uint8) 47 | if 1: # for low profile GPU 48 | batch_size = 20 49 | for i in range(0, 2000, batch_size): 50 | if i + batch_size >= 2000: 51 | # last batch 52 | y[i : 2000] = model(data[i : i + batch_size]).numpy()#.astype(np.uint8) 53 | else: 54 | # other batches 55 | y[i : i + batch_size] = model(data[i : i + batch_size]).numpy()#.astype(np.uint8) 56 | print("=", end="") 57 | print(">") 58 | else: # for RTX 3090 59 | print("Start inference.") 60 | y[0:1000] = model(data[0:1000]).numpy()#.astype(np.uint8) 61 | print("Half.") 62 | y[1000:2000] = model(data[1000:2000]).numpy()#.astype(np.uint8) 63 | print("Complete.") 64 | 65 | if eval_mode: 66 | # calculate pckh score 67 | # print(label.shape) # (2000, 14, 3) 68 | y = y[:,:,0:2].astype(float) 69 | label = label[:,:,0:2].astype(float) 70 | score_j = np.zeros(14) 71 | pck_metric = 0.5 72 | for i in range(1000, 2000): 73 | # validation part 74 | pck_h = Eclidian2(label[i][12], label[i][13]) 75 | for j in range(14): 76 | pck_j = Eclidian2(y[i][j], label[i][j]) 77 | # pck_j <= pck_h * 0.5 --> True 78 | if pck_j <= pck_h * pck_metric: 79 | # True estimation 80 | score_j[j] += 1 81 | # convert to percentage 82 | score_j = score_j * 0.1 83 | score_avg = sum(score_j) / 14 84 | print(score_j) 85 | print("Average = %f%%" % score_avg) 86 | else: 87 | # show result images 88 | import cv2 89 | # generate result images 90 | for t in range(2000): 91 | skeleton = y[t] 92 | print(skeleton) 93 | img = data[t].astype(np.uint8) 94 | # draw the joints 95 | for i in range(14): 96 | cv2.circle(img, center=tuple(skeleton[i][0:2]), radius=2, color=(0, 255, 0), thickness=2) 97 | # draw the lines 98 | for j in ((13, 12), (12, 8), (12, 9), (8, 7), (7, 6), (9, 10), (10, 11), (2, 3), (2, 1), (1, 0), (3, 4), (4, 5)): 99 | cv2.line(img, tuple(skeleton[j[0]][0:2]), tuple(skeleton[j[1]][0:2]), color=(0, 0, 255), thickness=1) 100 | # solve the mid point of the hips 101 | cv2.line(img, tuple(skeleton[12][0:2]), tuple(skeleton[2][0:2] // 2 + skeleton[3][0:2] // 2), color=(0, 0, 255), thickness=1) 102 | 103 | cv2.imwrite("./result/lsp_%d.jpg"%t, img) 104 | cv2.imshow("test", img) 105 | cv2.waitKey(1) 106 | else: 107 | # visualize the dataset 108 | model.evaluate(test_dataset) 109 | 110 | # select an image to visualize 111 | from config import vis_img_id 112 | y = model.predict(data[vis_img_id : vis_img_id+1]) 113 | 114 | import matplotlib.pyplot as plt 115 | title_set = ["Right ankle", "Right knee", "Right hip", "Left hip", "Left knee", "Left ankle", "Right wrist", "Right elbow", "Right shoulder", "Left shoulder", "Left elbow", "Left wrist", "Neck", "Head top"] 116 | for t in range(1): 117 | plt.figure(figsize=(8,8), dpi=150) 118 | for i in range(14): 119 | plt.subplot(4, 4, i+1) 120 | plt.imshow(y[t, :, :, i]) 121 | plt.title(title_set[i]) 122 | plt.subplot(4, 4, 15) 123 | plt.imshow(data[vis_img_id].astype(np.uint8)) 124 | # plt.savefig("demo.png") 125 | plt.show() 126 | pass 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!~/miniconda3/envs/tf2/bin/python 2 | import os 3 | import tensorflow as tf 4 | import time 5 | from model import BlazePose 6 | from config import total_epoch, train_mode, continue_train, show_batch_loss 7 | from analysis import save_record, load_record 8 | 9 | if train_mode: 10 | from data import finetune_train as train_dataset 11 | from data import finetune_validation as test_dataset 12 | loss_func = tf.keras.losses.MeanSquaredError() 13 | else: 14 | from data import train_dataset, test_dataset 15 | loss_func = tf.keras.losses.BinaryCrossentropy() 16 | 17 | model = BlazePose() 18 | 19 | checkpoint_path = "checkpoints/cp-{epoch:04d}.ckpt" 20 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.001) 21 | 22 | def grad(model, inputs, targets): 23 | with tf.GradientTape() as tape: 24 | loss_value = loss_func(y_true=targets, y_pred=model(inputs)) 25 | return loss_value, tape.gradient(loss_value, model.trainable_variables) 26 | 27 | # continue train 28 | if continue_train > 0: 29 | model.load_weights(checkpoint_path.format(epoch=continue_train)) 30 | # continue recording 31 | train_loss_results, train_accuracy_results, val_accuracy_results = load_record() 32 | else: 33 | if train_mode: 34 | # start fine-tune 35 | from config import best_pre_train 36 | model.load_weights(checkpoint_path.format(epoch=best_pre_train)) 37 | 38 | # start from epoch 0 39 | # Initial for record of the training process 40 | train_loss_results = [] 41 | train_accuracy_results = [] 42 | val_accuracy_results = [] 43 | 44 | if train_mode: 45 | # finetune 46 | for layer in model.layers[0:16]: 47 | print(layer) 48 | layer.trainable = False 49 | else: 50 | # pre-train 51 | for layer in model.layers[16:24]: 52 | print(layer) 53 | layer.trainable = False 54 | 55 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), end=" Start train.\n") 56 | # validata initial loaded model 57 | val_accuracy = tf.keras.metrics.MeanSquaredError() 58 | for x, y in test_dataset: 59 | val_accuracy(y, model(x)) 60 | print("Initial Validation accuracy: {:.5%}".format(val_accuracy.result())) 61 | 62 | # make sure continue has any epoch to train 63 | assert(continue_train < total_epoch) 64 | 65 | for epoch in range(continue_train, total_epoch): 66 | epoch_loss_avg = tf.keras.metrics.Mean() 67 | epoch_accuracy = tf.keras.metrics.MeanSquaredError() 68 | val_accuracy = tf.keras.metrics.MeanSquaredError() 69 | 70 | # Training loop 71 | if show_batch_loss: 72 | batch_index = 0 73 | for x, y in train_dataset: 74 | # Optimize 75 | loss_value, grads = grad(model, x, y) 76 | optimizer.apply_gradients(zip(grads, model.trainable_variables)) 77 | 78 | # Add current batch loss 79 | epoch_loss_avg(loss_value) 80 | # Calculate error from Ground truth 81 | epoch_accuracy(y, model(x)) 82 | 83 | if show_batch_loss: 84 | print("Epoch {:03d}, Batch {:03d}: Train Loss: {:.3f}".format(epoch, 85 | batch_index, 86 | loss_value 87 | )) 88 | batch_index += 1 89 | 90 | # Record loss and accuracy 91 | train_loss_results.append(epoch_loss_avg.result()) 92 | train_accuracy_results.append(epoch_accuracy.result()) 93 | 94 | # Train loss at epoch 95 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 96 | print("Epoch {:03d}: Train Loss: {:.3f}, Accuracy: {:.5%}".format( 97 | epoch, 98 | epoch_loss_avg.result(), 99 | epoch_accuracy.result() 100 | )) 101 | 102 | if not((epoch + 1) % 5): 103 | # validata and save weight every 5 epochs 104 | for x, y in test_dataset: 105 | val_accuracy(y, model(x)) 106 | print("Epoch {:03d}, Validation accuracy: {:.5%}".format(epoch, val_accuracy.result())) 107 | model.save_weights(checkpoint_path.format(epoch=epoch)) 108 | val_accuracy_results.append(val_accuracy.result()) 109 | 110 | # save the training record at every validation epoch 111 | save_record(train_loss_results, train_accuracy_results, val_accuracy_results) 112 | 113 | model.summary() 114 | 115 | print("Finish training.") 116 | --------------------------------------------------------------------------------