├── .idea ├── .gitignore ├── Minimal-Hand-pytorch.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── LICENSE ├── README.md ├── ShapeNet.md ├── __pycache__ └── config.cpython-37.pyc ├── aik_pose.py ├── assets ├── DEMO2.gif ├── demo.gif └── results.png ├── config.py ├── create_data.py ├── datasets ├── SIK1M.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── dexter_object.cpython-37.pyc │ ├── egodexter.cpython-37.pyc │ ├── ganerated_hands.cpython-37.pyc │ ├── hand143_panopticdb.cpython-37.pyc │ ├── hand_labels.cpython-37.pyc │ ├── handataset.cpython-37.pyc │ ├── rhd.cpython-37.pyc │ └── stb.cpython-37.pyc ├── dexter_object.py ├── egodexter.py ├── ganerated_hands.py ├── hand143_panopticdb.py ├── hand_labels.py ├── handataset.py ├── rhd.py └── stb.py ├── demo.py ├── demo_dl.py ├── dl_shape_estimate.py ├── environment.yml ├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── detloss.cpython-37.pyc ├── detloss.py └── shape_loss.py ├── manopth └── rotproj.py ├── model ├── detnet │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── detnet.cpython-37.pyc │ └── detnet.py ├── helper │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── resnet_helper.cpython-37.pyc │ └── resnet_helper.py └── shape_net.py ├── op_pso.py ├── optimize_shape.py ├── plot.py ├── train_detnet.py ├── train_shape_net.py └── utils ├── AIK.py ├── LM.py ├── LM_new.py ├── __init__.py ├── __pycache__ ├── AIK.cpython-37.pyc ├── LM.cpython-37.pyc ├── __init__.cpython-37.pyc ├── align.cpython-37.pyc ├── bone.cpython-37.pyc ├── func.cpython-37.pyc ├── handutils.cpython-37.pyc ├── heatmaputils.cpython-37.pyc ├── imgutils.cpython-37.pyc ├── misc.cpython-37.pyc └── vis.cpython-37.pyc ├── align.py ├── bone.py ├── eval ├── __pycache__ │ ├── evalutils.cpython-37.pyc │ └── zimeval.cpython-37.pyc ├── evalutils.py └── zimeval.py ├── func.py ├── handutils.py ├── heatmaputils.py ├── imgutils.py ├── misc.py ├── smoother.py └── vis.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/Minimal-Hand-pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 29 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hao Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal Hand Pytorch 2 | 3 | **Unofficial** PyTorch reimplementation of [minimal-hand](https://calciferzh.github.io/files/zhou2020monocular.pdf) (CVPR2020). 4 | 5 | ![demo](assets/demo.gif) 6 | ![demo](assets/DEMO2.gif) 7 | 8 | you can also find in youtube or bilibili 9 | 10 | 14 | 15 | 16 | 17 | 18 | This project reimplement following components : 19 | 1. Training (DetNet) and Evaluation Code 20 | 1. Shape Estimation 21 | 1. Pose Estimation: Instead of IKNet in original paper, an [analytical inverse kinematics method](https://arxiv.org/abs/2011.14672) is used. 22 | 23 | 24 | 25 | Offical project link: 26 | [\[minimal-hand\]](https://github.com/CalciferZh/minimal-hand) 27 | 28 | ## Update 29 | ### 30 | * 2021/08/22 many guys may get errors when creating environment from .yaml file, u may refer to [here](https://github.com/MengHao666/Minimal-Hand-pytorch/issues/29#issue-976328196) 31 | * 2021/03/09 update about `utils/LM.py`, **time cost drop from 12s/item to 1.57s/item** 32 | 33 | * 2021/03/12 update about `utils/LM.py`, **time cost drop from 1.57s/item to 0.27s/item** 34 | 35 | * 2021/03/17 realtime perfomance is achieved when using PSO to estimate shape, coming soon 36 | 37 | * 2021/03/20 Add PSO to estimate shape. ~~AUC is decreased by about 0.01 on STB and RHD datasets, and increased a little on EO and do datasets.~~ Modifiy utlis/vis.py to improve realtime perfomance 38 | 39 | * 2021/03/24 Fixed some errors in calculating AUC. Update the 3D PCK AUC Diffenence. 40 | 41 | * 2021/06/14 A new method to estimate shape parameters by using fully connected neural network is added. This is finished by @maitetsu as part of his undergraduate graduation project. Please refer to [ShapeNet.md](./ShapeNet.md) for details. Thanks to @kishan1823 and @EEWenbinWu for pointing out the mistake. There are a little differences between the manopth I used and the official manopth. More details see [issues 11](https://github.com/MengHao666/Minimal-Hand-pytorch/issues/11). manopth/rotproj.py is the modified rotproj.py. **This could achieve much faster real-time performance!** 42 | 43 | 44 | ## Usage 45 | 46 | - Retrieve the code 47 | ```sh 48 | git clone https://github.com/MengHao666/Minimal-Hand-pytorch 49 | cd Minimal-Hand-pytorch 50 | ``` 51 | 52 | - Create and activate the virtual environment with python dependencies 53 | ``` 54 | conda env create --file=environment.yml 55 | conda activate minimal-hand-torch 56 | ``` 57 | 58 | ### Prepare MANO hand model 59 | 1. Download MANO model from [here](https://mano.is.tue.mpg.de/) and unzip it. 60 | 1. Create an account by clicking *Sign Up* and provide your information 61 | 1. Download Models and Code (the downloaded file should have the format mano_v*_*.zip). Note that all code and data from this download falls under the [MANO license](http://mano.is.tue.mpg.de/license). 62 | 1. unzip and copy the content of the *models* folder into the `mano` folder 63 | 64 | 1. Your structure should look like this: 65 | 66 | ``` 67 | Minimal-Hand-pytorch/ 68 | mano/ 69 | models/ 70 | webuser/ 71 | ``` 72 | 73 | ## Download and Prepare datasets 74 | 75 | ### Training dataset 76 | * CMU HandDB [part1](http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand143_panopticdb.tar) ; [part2](http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip) 77 | * [Rendered Handpose Dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html) 78 | * [GANerated Hands Dataset](https://handtracker.mpi-inf.mpg.de/projects/GANeratedHands/GANeratedDataset.htm) 79 | 80 | ### Evaluation dataset 81 | * [STB Dataset](https://github.com/zhjwustc/icip17_stereo_hand_pose_dataset),or u can find it[here](https://bhpan.buaa.edu.cn:443/link/55321872BA66E9205C91BA30D9FADC8F): 82 | 83 | STB_supp: for license reason, download link could be found in [bihand](https://github.com/lixiny/bihand ) 84 | 85 | * [DexterObjectDataset](https://handtracker.mpi-inf.mpg.de/projects/RealtimeHO/dexter+object.htm) ; 86 | 87 | DO_supp: [Google Drive](https://drive.google.com/file/d/1uhqJGfVJs_-Yviqj9Y2Ngo7NPt5hUihl/view?usp=sharing) or 88 | [Baidu Pan](https://pan.baidu.com/s/1ckfLnaBQUfZJG3IDvMo09Q) (`s892`) 89 | * [EgoDexterDataset](http://handtracker.mpi-inf.mpg.de/projects/OccludedHands/EgoDexter.htm) 90 | 91 | EO_supp: [Google Drive](https://drive.google.com/file/d/1WRHVTp7ZmryE41xN2Yhp-qet0ddeOim4/view?usp=sharing) or 92 | [Baidu Pan](https://pan.baidu.com/s/1sK4Nfvs6og-eXJGwDQCHlQ) (`axkm`) 93 | 94 | 95 | 96 | ### Processing 97 | - Create a data directory, extract all above datasets or additional materials in it 98 | 99 | Now your `data` folder structure should like this: 100 | ``` 101 | data/ 102 | 103 | CMU/ 104 | hand143_panopticdb/ 105 | datasets/ 106 | ... 107 | hand_labels/ 108 | datasets/ 109 | ... 110 | 111 | RHD/ 112 | RHD_published_v2/ 113 | evaluation/ 114 | training/ 115 | view_sample.py 116 | ... 117 | 118 | GANeratedHands_Release/ 119 | data/ 120 | ... 121 | 122 | STB/ 123 | images/ 124 | B1Counting/ 125 | SK_color_0.png 126 | SK_depth_0.png 127 | SK_depth_seg_0.png <-- merged from STB_supp 128 | ... 129 | ... 130 | labels/ 131 | B1Counting_BB.mat 132 | ... 133 | 134 | dexter+object/ 135 | calibration/ 136 | bbox_dexter+object.csv 137 | DO_pred_2d.npy 138 | data/ 139 | Grasp1/ 140 | annotations/ 141 | Grasp13D.txt 142 | my_Grasp13D.txt 143 | ... 144 | ... 145 | Grasp2/ 146 | annotations/ 147 | Grasp23D.txt 148 | my_Grasp23D.txt 149 | ... 150 | ... 151 | Occlusion/ 152 | annotations/ 153 | Occlusion3D.txt 154 | my_Occlusion3D.txt 155 | ... 156 | ... 157 | Pinch/ 158 | annotations/ 159 | Pinch3D.txt 160 | my_Pinch3D.txt 161 | ... 162 | ... 163 | Rigid/ 164 | annotations/ 165 | Rigid3D.txt 166 | my_Rigid3D.txt 167 | ... 168 | ... 169 | Rotate/ 170 | annotations/ 171 | Rotate3D.txt 172 | my_Rotate3D.txt 173 | ... 174 | ... 175 | 176 | 177 | EgoDexter/ 178 | preview/ 179 | data/ 180 | Desk/ 181 | annotation.txt_3D.txt 182 | my_annotation.txt_3D.txt 183 | ... 184 | Fruits/ 185 | annotation.txt_3D.txt 186 | my_annotation.txt_3D.txt 187 | ... 188 | Kitchen/ 189 | annotation.txt_3D.txt 190 | my_annotation.txt_3D.txt 191 | ... 192 | Rotunda/ 193 | annotation.txt_3D.txt 194 | my_annotation.txt_3D.txt 195 | ... 196 | 197 | ``` 198 | 199 | ### Note 200 | - **All code and data from these download falls under their own licenses.** 201 | - DO represents "dexter+object" dataset; EO represents "EgoDexter" dataset 202 | - `DO_supp` and `EO_supp` are modified from original ones. 203 | - DO_pred_2d.npy are 2D predictions from 2D part of DetNet. 204 | - some labels of DO and EO is obviously wrong (u could find some examples with original labels from [dexter_object.py](datasets/dexter_object.py) or [egodexter.py](datasets/egodexter.py)), when projected into image plane, thus should be omitted. 205 | Here come `my_{}3D.txt` and `my_annotation.txt_3D.txt`. 206 | 207 | ## Download my Results 208 | 209 | - my_results: [Google Drive](https://drive.google.com/file/d/1e6aG4ZSOB6Ri_1TjXI9N-1r7MtwmjA6w/view?usp=sharing) or 210 | [Baidu Pan](https://pan.baidu.com/s/1Hh0ZU8p04prFVSp9bQm_IA) (`2rv7`) 211 | - extract it in project folder 212 | - **The parameters used in the real-time demo can be found [google_drive](https://drive.google.com/file/d/1fug29PBMo1Cb2DwAtX7f2E_yLHjDBmiM/view?usp=sharing) or [baidu](https://pan.baidu.com/s/1gr3xSkLuvsveSQ7nW1taSA) (un06). It is trained with loss of [Hand-BMC-pytorch](https://github.com/MengHao666/Hand-BMC-pytorch) together!!!** 213 | 214 | 215 |

realtime demo with PSO-based shape estimation

216 | 217 | ``` 218 | python demo.py 219 | ``` 220 | 221 |

realtime demo with learing-based shape estimation

222 | 223 | ``` 224 | python demo_dl.py 225 | ``` 226 | 227 | ## DetNet Training and Evaluation 228 | 229 | Run the training code 230 | ``` 231 | python train_detnet.py --data_root data/ 232 | ``` 233 | 234 | 235 | Run the evaluation code 236 | ``` 237 | python train_detnet.py --data_root data/ --datasets_test testset_name_to_test --evaluate --evaluate_id checkpoints_id_to_load 238 | ``` 239 | or use my results 240 | ``` 241 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "rhd" --evaluate --evaluate_id 106 242 | 243 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "stb" --evaluate --evaluate_id 71 244 | 245 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "do" --evaluate --evaluate_id 68 246 | 247 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "eo" --evaluate --evaluate_id 101 248 | ``` 249 | 250 | ## Shape Estimation with LM algorithm 251 | 252 | Run the shape optimization code. This can be very time consuming when the weight parameter is quite small. 253 | ``` 254 | python optimize_shape.py --weight 1e-5 255 | ``` 256 | or use my results 257 | ``` 258 | python optimize_shape.py --path my_results/out_testset/ 259 | ``` 260 | 261 | ## Pose Estimation 262 | 263 | Run the following code which uses a analytical inverse kinematics method. 264 | ``` 265 | python aik_pose.py 266 | ``` 267 | or use my results 268 | ``` 269 | python aik_pose.py --path my_results/out_testset/ 270 | ``` 271 | 272 | 273 | ### Detnet training and evaluation curve 274 | Run the following code to see my results 275 | ``` 276 | python plot.py --path my_results/out_loss_auc 277 | ``` 278 | 279 | (AUC means 3D PCK, and ACC_HM means 2D PCK) 280 | ![teaser](assets/results.png) 281 | 282 | ### 3D PCK AUC Diffenence 283 | 284 | \* means this project 285 | 286 | | Dataset | DetNet(paper) | DetNet(*) | DetNet+IKNet(paper) | DetNet+LM+AIK(*) | DetNet+PSO+AIK(*) | DetNet+DL+AIK(*) | 287 | | :-----: | :-----------: | :-------: | :-----------------: | :--------------: | :-------------: | :-------------: | 288 | | **RHD** | - | 0.9339 | 0.856 | 0.9301 | 0.9310 | 0.9272 | 289 | | **STB** | 0.891 | 0.8744 | 0.898 | 0.8647 | 0.8671 | 0.8624 | 290 | | **DO** | 0.923 | 0.9378 | 0.948 | 0.9392 | 0.9342 | 0.9400 | 291 | | **EO** | 0.804 | 0.9270 | 0.811 | 0.9288 | 0.9277 | 0.9365 | 292 | 293 | 294 | 295 | ### Note 296 | 297 | - Adjusting training parameters carefully, longer training time, more complicated network or **[Biomechanical Constraint Losses](https://github.com/MengHao666/Hand-BMC-pytorch)** could further boost accuracy. 298 | - As there is no official open source of original paper, above comparison is a little rough. 299 | 300 | ## Citation 301 | 302 | This is the **unofficial** pytorch reimplementation of the paper "Monocular Real-time Hand Shape and Motion Capture using Multi-modal Data" (CVPR 2020). 303 | 304 | 305 | 306 | If you find the project helpful, please star this project and cite them: 307 | ``` 308 | @inproceedings{zhou2020monocular, 309 | title={Monocular Real-time Hand Shape and Motion Capture using Multi-modal Data}, 310 | author={Zhou, Yuxiao and Habermann, Marc and Xu, Weipeng and Habibie, Ikhsanul and Theobalt, Christian and Xu, Feng}, 311 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 312 | pages={0--0}, 313 | year={2020} 314 | } 315 | ``` 316 | 317 | ## Acknowledgement 318 | 319 | - Code of Mano Pytorch Layer was adapted from [manopth](https://github.com/hassony2/manopth). 320 | 321 | - Code for evaluating the hand PCK and AUC in `utils/eval/zimeval.py` was adapted from [hand3d](https://github.com/lmb-freiburg/hand3d). 322 | 323 | - Part code of data augmentation, dataset parsing and utils were adapted from [bihand](https://github.com/lixiny/bihand) and [3D-Hand-Pose-Estimation](https://github.com/OlgaChernytska/3D-Hand-Pose-Estimation). 324 | 325 | - Code of network model was adapted from [Minimal-Hand](https://github.com/lingtengqiu/Minimal-Hand). 326 | 327 | - @Mrsirovo for the starter code of the `utils/LM.py` , @maitetsu update it later. 328 | 329 | - @maitetsu for the starter code of the `utils/AIK.py`,the implementation of PSO and deep-learing method for shape estimation. 330 | -------------------------------------------------------------------------------- /ShapeNet.md: -------------------------------------------------------------------------------- 1 | ## ShapeNet 2 | 3 | #### train ShapeNet 4 | 5 | ShapeNet is a model that uses fully connected neural network to estimate shape parameters. It's faster than pso. Code of ShapeNet are adapted from [bihand](https://github.com/lixiny/bihand). If you want train your ShapeNet, just run these. 6 | 7 | ```python 8 | # create training set 9 | python create_data.py 10 | # train the model 11 | python train_shape_net.py 12 | ``` 13 | As for the pre-trained model or the training set generated by me, you can download from [google drive](https://drive.google.com/drive/folders/1OEM8Q8SIJjXkembz3wlp6UiXFfpzAsLu?usp=sharing) or [bhpan](https://bhpan.buaa.edu.cn:443/link/88CD866C9EB3F30906D57571678FFE6D). Put the data_bone.npy and data_shape.npy in ROOT_DIR/data. Put the ckp_siknet_synth_41.pth.tar in ROOT_DIR/checkpoints. 14 | 15 | ``` 16 | ROOT_DIR 17 | ... 18 | |--data 19 | |--data_bone.npy 20 | |--data_shape.npy 21 | ... 22 | |--checkpoints 23 | |--ckp_siknet_synth_41.pth.tar 24 | ... 25 | ``` 26 | 27 | 28 | 29 | #### training set 30 | 31 | The training set is generated by MANO. More details see create_data.py. 32 | 33 | 1. First sample shape parameters from normal distribution N(0,3) 34 | 35 | 2. Calculate the relative bone length corresponding to the shape parameter. 36 | 37 | #### loss 38 | 39 | The loss of ShapeNet consists of two parts, one is the error of relative bone length, the other is the regularization loss of shape parameters. 40 | 41 | $$ 42 | L = \lambda_1 ||\hat{x} - x||_2^2 + \lambda_2 || \hat{y}||_2^2 \\ 43 | $$ 44 | 45 | ​ $x$ is the input relative bone length. 46 | ​ $\hat{y}$ is the output shape parameters of ShapeNet. 47 | ​ $\hat{x}$ is the relative bone length corresponding to the shape parameter. 48 | 49 | ​ In my opinion, shape parameters include not only relative bone length information, but also absolute bone length information. It is impossible to guarantee that a relative bone length only corresponds to one shape parameter, which is necessary for neural networks. Therefore, the loss function does not directly calculate the error between the shape parameter and the label of it. 50 | 51 | #### AUC 52 | 53 | AUC of ShapeNet can refer [README.md](./README.md). You can get higher AUC, if you change " beta = torch.tanh(beta) " which is the line 85 of model/shape_net.py to "beta = 3*torch.tanh(beta) ". This will make the output range of ShapeNet bigger and get higher AUC. According to the experiment, the finger will be thinner. 54 | 55 | I didn't adjust the parameters carefully. Maybe you can get better results if you adjust parameters. 56 | 57 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /aik_pose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from manopth import demo 6 | from manopth import manolayer 7 | from tqdm import tqdm 8 | 9 | from utils import AIK, align, vis 10 | from utils.eval.zimeval import EvalUtil 11 | 12 | 13 | def recon_eval(op_shapes, pre_j3ds, gt_j3ds, visual, key): 14 | pose0 = torch.eye(3).repeat(1, 16, 1, 1) 15 | mano = manolayer.ManoLayer(flat_hand_mean=True, 16 | side="right", 17 | mano_root='mano/models', 18 | use_pca=False, 19 | root_rot_mode='rotmat', 20 | joint_rot_mode='rotmat') 21 | 22 | j3d_recons = [] 23 | evaluator = EvalUtil() 24 | for i in tqdm(range(pre_j3ds.shape[0])): 25 | j3d_pre = pre_j3ds[i] 26 | 27 | op_shape = torch.tensor(op_shapes[i]).float().unsqueeze(0) 28 | _, j3d_p0_ops = mano(pose0, op_shape) 29 | template = j3d_p0_ops.cpu().numpy().squeeze() / 1000.0 # template, m 30 | 31 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(j3d_pre[9] - j3d_pre[0]) 32 | j3d_pre_process = j3d_pre * ratio # template, m 33 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0] 34 | 35 | pose_R = AIK.adaptive_IK(template, j3d_pre_process) 36 | pose_R = torch.from_numpy(pose_R).float() 37 | 38 | # reconstruction 39 | hand_verts, j3d_recon = mano(pose_R, op_shape.float()) 40 | 41 | # visualization 42 | if visual: 43 | demo.display_hand({ 44 | 'verts': hand_verts.cpu(), 45 | 'joints': j3d_recon.cpu() 46 | }, 47 | mano_faces=mano.th_faces) 48 | 49 | j3d_recon = j3d_recon.cpu().numpy().squeeze() / 1000. 50 | j3d_recons.append(j3d_recon) 51 | 52 | # visualization 53 | if visual: 54 | vis.multi_plot3d([j3d_recon, j3d_pre_process], title=["recon", "pre"]) 55 | j3d_recons = np.array(j3d_recons) 56 | gt_joint, j3d_recon_align_gt = align.global_align(gt_j3ds, j3d_recons, key=key) 57 | 58 | for targj, predj_a in zip(gt_joint, j3d_recon_align_gt): 59 | evaluator.feed(targj * 1000.0, predj_a * 1000.0) 60 | 61 | ( 62 | _1, _2, _3, 63 | auc_all, 64 | pck_curve_all, 65 | thresholds 66 | ) = evaluator.get_measures( 67 | 20, 50, 15 68 | ) 69 | print("Reconstruction AUC all of {}_test_set is : {}".format(key, auc_all)) 70 | 71 | 72 | def main(args): 73 | path = args.path 74 | for key_i in args.dataset: 75 | print("load {}'s joint 3D".format(key_i)) 76 | _path = "{}/{}_dl.npy".format(path, key_i) 77 | print('load {}'.format(_path)) 78 | op_shapes = np.load(_path) 79 | pre_j3ds = np.load("{}/{}_pre_joints.npy".format(path, key_i)) 80 | gt_j3ds = np.load("{}/{}_gt_joints.npy".format(path, key_i)) 81 | recon_eval(op_shapes, pre_j3ds, gt_j3ds, args.visualize, key_i) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser( 86 | description=' get pose params. of mano model ') 87 | 88 | parser.add_argument( 89 | '-ds', 90 | "--dataset", 91 | nargs="+", 92 | default=['rhd', 'stb', 'do', 'eo'], 93 | type=list, 94 | help="sub datasets, should be listed in: [stb|rhd|do|eo]" 95 | ) 96 | 97 | parser.add_argument( 98 | '-p', 99 | "--path", 100 | default="out_testset", 101 | type=str, 102 | help="path" 103 | ) 104 | 105 | parser.add_argument( 106 | '-vis', 107 | '--visualize', 108 | action='store_true', 109 | help='visualize reconstruction result', 110 | default=False 111 | ) 112 | 113 | main(parser.parse_args()) 114 | -------------------------------------------------------------------------------- /assets/DEMO2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/DEMO2.gif -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/demo.gif -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/results.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | DEPTH_RANGE = 3.0 2 | DEPTH_MIN = -1.5 3 | 4 | stb_joints = [ 5 | 'loc_bn_palm_L', 6 | 'loc_bn_pinky_L_01', 7 | 'loc_bn_pinky_L_02', 8 | 'loc_bn_pinky_L_03', 9 | 'loc_bn_pinky_L_04', 10 | 'loc_bn_ring_L_01', 11 | 'loc_bn_ring_L_02', 12 | 'loc_bn_ring_L_03', 13 | 'loc_bn_ring_L_04', 14 | 'loc_bn_mid_L_01', 15 | 'loc_bn_mid_L_02', 16 | 'loc_bn_mid_L_03', 17 | 'loc_bn_mid_L_04', 18 | 'loc_bn_index_L_01', 19 | 'loc_bn_index_L_02', 20 | 'loc_bn_index_L_03', 21 | 'loc_bn_index_L_04', 22 | 'loc_bn_thumb_L_01', 23 | 'loc_bn_thumb_L_02', 24 | 'loc_bn_thumb_L_03', 25 | 'loc_bn_thumb_L_04', 26 | ] 27 | 28 | rhd_joints = [ 29 | 'loc_bn_palm_L', 30 | 'loc_bn_thumb_L_04', 31 | 'loc_bn_thumb_L_03', 32 | 'loc_bn_thumb_L_02', 33 | 'loc_bn_thumb_L_01', 34 | 'loc_bn_index_L_04', 35 | 'loc_bn_index_L_03', 36 | 'loc_bn_index_L_02', 37 | 'loc_bn_index_L_01', 38 | 'loc_bn_mid_L_04', 39 | 'loc_bn_mid_L_03', 40 | 'loc_bn_mid_L_02', 41 | 'loc_bn_mid_L_01', 42 | 'loc_bn_ring_L_04', 43 | 'loc_bn_ring_L_03', 44 | 'loc_bn_ring_L_02', 45 | 'loc_bn_ring_L_01', 46 | 'loc_bn_pinky_L_04', 47 | 'loc_bn_pinky_L_03', 48 | 'loc_bn_pinky_L_02', 49 | 'loc_bn_pinky_L_01' 50 | ] 51 | 52 | snap_joint_names = [ 53 | 'loc_bn_palm_L', 54 | 'loc_bn_thumb_L_01', 55 | 'loc_bn_thumb_L_02', 56 | 'loc_bn_thumb_L_03', 57 | 'loc_bn_thumb_L_04', 58 | 'loc_bn_index_L_01', 59 | 'loc_bn_index_L_02', 60 | 'loc_bn_index_L_03', 61 | 'loc_bn_index_L_04', 62 | 'loc_bn_mid_L_01', 63 | 'loc_bn_mid_L_02', 64 | 'loc_bn_mid_L_03', 65 | 'loc_bn_mid_L_04', 66 | 'loc_bn_ring_L_01', 67 | 'loc_bn_ring_L_02', 68 | 'loc_bn_ring_L_03', 69 | 'loc_bn_ring_L_04', 70 | 'loc_bn_pinky_L_01', 71 | 'loc_bn_pinky_L_02', 72 | 'loc_bn_pinky_L_03', 73 | 'loc_bn_pinky_L_04' 74 | ] 75 | 76 | SNAP_BONES = [ 77 | (0, 1, 2, 3, 4), 78 | (0, 5, 6, 7, 8), 79 | (0, 9, 10, 11, 12), 80 | (0, 13, 14, 15, 16), 81 | (0, 17, 18, 19, 20) 82 | ] 83 | 84 | SNAP_PARENT = [ 85 | 0, # 0's parent 86 | 0, # 1's parent 87 | 1, 88 | 2, 89 | 3, 90 | 0, # 5's parent 91 | 5, 92 | 6, 93 | 7, 94 | 0, # 9's parent 95 | 9, 96 | 10, 97 | 11, 98 | 0, # 13's parent 99 | 13, 100 | 14, 101 | 15, 102 | 0, # 17's parent 103 | 17, 104 | 18, 105 | 19, 106 | ] 107 | 108 | JOINT_COLORS = ( 109 | (216, 31, 53), 110 | (214, 208, 0), 111 | (136, 72, 152), 112 | (126, 199, 216), 113 | (0, 0, 230), 114 | ) 115 | 116 | DEFAULT_CACHE_DIR = 'datasets/data/.cache' 117 | 118 | USEFUL_BONE = [1, 2, 3, 119 | 5, 6, 7, 120 | 9, 10, 11, 121 | 13, 14, 15, 122 | 17, 18, 19] 123 | 124 | kinematic_tree = [2, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 19, 20] 125 | 126 | ID2ROT = { 127 | 2: 13, 3: 14, 4: 15, 128 | 6: 1, 7: 2, 8: 3, 129 | 10: 4, 11: 5, 12: 6, 130 | 14: 10, 15: 11, 16: 12, 131 | 18: 7, 19: 8, 20: 9, 132 | } -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from manopth import manolayer 3 | import os 4 | 5 | 6 | class DataSet: 7 | def __init__(self, device=torch.device('cpu'), _mano_root='mano/models'): 8 | args = {'flat_hand_mean': True, 'root_rot_mode': 'axisang', 9 | 'ncomps': 45, 'mano_root': _mano_root, 10 | 'no_pca': True, 'joint_rot_mode': 'axisang', 'side': 'right'} 11 | self.mano = manolayer.ManoLayer(flat_hand_mean=args['flat_hand_mean'], 12 | side=args['side'], 13 | mano_root=args['mano_root'], 14 | ncomps=args['ncomps'], 15 | use_pca=not args['no_pca'], 16 | root_rot_mode=args['root_rot_mode'], 17 | joint_rot_mode=args['joint_rot_mode'] 18 | ).to(device) 19 | self.device = device 20 | 21 | def new_cal_ref_bone(self, _shape): 22 | parent_index = [0, 23 | 0, 1, 2, 24 | 0, 4, 5, 25 | 0, 7, 8, 26 | 0, 10, 11, 27 | 0, 13, 14 28 | ] 29 | index = [0, 30 | 1, 2, 3, # index 31 | 4, 5, 6, # middle 32 | 7, 8, 9, # pinky 33 | 10, 11, 12, # ring 34 | 13, 14, 15] # thumb 35 | reoder_index = [ 36 | 13, 14, 15, 37 | 1, 2, 3, 38 | 4, 5, 6, 39 | 10, 11, 12, 40 | 7, 8, 9] 41 | shape = _shape.clone().detach() 42 | th_v_shaped = torch.matmul(self.mano.th_shapedirs, 43 | shape.transpose(1, 0)).permute(2, 0, 1) \ 44 | + self.mano.th_v_template 45 | th_j = torch.matmul(self.mano.th_J_regressor, th_v_shaped) 46 | temp1 = th_j.clone().detach() 47 | temp2 = th_j.clone().detach()[:, parent_index, :] 48 | result = temp1 - temp2 49 | ref_len = th_j[:, [4], :] - th_j[:, [0], :] 50 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True) 51 | result = torch.norm(result, dim=-1, keepdim=True) 52 | result = result / ref_len 53 | return torch.squeeze(result, dim=-1)[:, reoder_index] 54 | 55 | def sample(self): 56 | shape = 3 * torch.randn((1, 10)) 57 | result = self.new_cal_ref_bone(shape) 58 | return (result, shape) 59 | 60 | def batch_sample(self, batch_size): 61 | shape = 3 * torch.randn((batch_size, 10)) 62 | result = self.new_cal_ref_bone(shape) 63 | return (result, shape) 64 | 65 | @staticmethod 66 | def cal_ref_bone(_Jtr): 67 | parent_index = [0, 68 | 0, 1, 2, 3, 69 | 0, 5, 6, 7, 70 | 0, 9, 10, 8, 71 | 0, 13, 14, 15, 72 | 0, 17, 18, 19 73 | ] 74 | index = [1, 2, 3, 75 | 5, 6, 7, 76 | 9, 10, 11, 77 | 13, 14, 15, 78 | 17, 18, 19] 79 | temp1 = _Jtr.clone().detach() 80 | temp2 = _Jtr.clone().detach()[:, parent_index, :] 81 | result = temp1 - temp2 82 | result = result[:, index, :] 83 | ref_len = _Jtr[:, [9], :] - _Jtr[:, [0], :] 84 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True) 85 | result = torch.norm(result, dim=-1, keepdim=True) 86 | # result = result / ref_len 87 | return torch.squeeze(result, dim=-1) 88 | 89 | 90 | if __name__ == '__main__': 91 | dataset = DataSet() 92 | import numpy as np 93 | import tqdm 94 | 95 | Total_Num = 1000000 96 | NUM = 10000 97 | data_bone = np.zeros((Total_Num, 15)) 98 | data_shape = np.zeros((Total_Num, 10)) 99 | for i in tqdm.tqdm(range(Total_Num // NUM)): 100 | t1 = i * NUM 101 | t2 = t1 + NUM 102 | temp_1, temp_2 = dataset.batch_sample(NUM) 103 | data_bone[t1:t2] = temp_1 104 | data_shape[t1:t2] = temp_2 105 | print(t1, t2) 106 | 107 | save_dir = 'data' 108 | if os.path.exists(save_dir): 109 | pass 110 | else: 111 | os.mkdir(save_dir) 112 | np.save(os.path.join(save_dir, 'data_bone.npy'), data_bone) 113 | np.save(os.path.join(save_dir, 'data_shape.npy'), data_shape) 114 | print('*' * 10, 'test', '*' * 10) 115 | data_bone = np.load(os.path.join(save_dir, 'data_bone.npy')) 116 | data_shape = np.load(os.path.join(save_dir, 'data_shape.npy')) 117 | test_flag = 1 118 | for i in tqdm.tqdm(range(Total_Num // NUM)): 119 | t1 = i * NUM 120 | t2 = t1 + NUM 121 | test_shape = data_shape[t1:t2] 122 | test_shape = torch.tensor(test_shape, dtype=torch.float) 123 | test_bone = data_bone[t1:t2] 124 | temp_1 = dataset.new_cal_ref_bone(test_shape) 125 | flag = np.allclose(temp_1, test_bone) 126 | flag = int(flag) 127 | test_flag = test_flag * flag 128 | print(test_flag) 129 | -------------------------------------------------------------------------------- /datasets/SIK1M.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import torch 4 | import os 5 | from torch.utils import data 6 | from termcolor import colored, cprint 7 | import numpy as np 8 | 9 | sik1m_inst = 0 10 | 11 | 12 | class _SIK1M(data.Dataset): 13 | """ 14 | The Loader for joints so3 and quat 15 | """ 16 | 17 | def __init__( 18 | self, 19 | data_root="data", 20 | data_source=None 21 | ): 22 | print("Initialize _SIK1M instance") 23 | bone_len_path = os.path.join(data_root, 'data_bone.npy') 24 | shape_path = os.path.join(data_root, 'data_shape.npy') 25 | self.bone_len = np.load(bone_len_path) 26 | self.shape = np.load(shape_path) 27 | 28 | def __len__(self): 29 | return self.shape.shape[0] 30 | 31 | def __getitem__(self, index): 32 | temp_bone_len = self.bone_len[index] 33 | temp_shape = self.shape[index] 34 | 35 | metas = { 36 | 'rel_bone_len': temp_bone_len, 37 | 'shape': temp_shape 38 | } 39 | return metas 40 | 41 | 42 | class SIK1M(data.Dataset): 43 | def __init__( 44 | self, 45 | data_split="train", 46 | data_root="data", 47 | split_ratio=0.8 48 | ): 49 | global sik1m_inst 50 | if not sik1m_inst: 51 | sik1m_inst = _SIK1M(data_root=data_root) 52 | self.sik1m = sik1m_inst 53 | self.permu = list(range(len(self.sik1m))) 54 | self.alllen = len(self.sik1m) 55 | self.data_split = data_split 56 | # add the 0.1* the std of Relative bone length as noise, you can change it or not add 57 | self.noise = np.array([0.02906406, 0.02663224, 0.01769793, 0.0274501, 0.02573783, 0.0222863, 58 | 0., 0.02855567, 0.02330295, 0.0253288, 0.0266308, 0.02495683, 0.03685857, 0.02430637, 59 | 0.02349446]) 60 | self.noise = self.noise / 10.0 61 | if data_split == "train": 62 | self.vislen = int(len(self.sik1m) * split_ratio) 63 | self.sub_permu = self.permu[:self.vislen] 64 | elif data_split in ["val", "test"]: 65 | self.vislen = self.alllen - int(len(self.sik1m) * split_ratio) 66 | self.sub_permu = self.permu[(self.alllen - self.vislen):] 67 | else: 68 | self.vislen = len(self.sik1m) 69 | self.sub_permu = self.permu[:self.vislen] 70 | 71 | def __len__(self): 72 | return self.vislen 73 | 74 | def __getitem__(self, index): 75 | item = self.sik1m[self.sub_permu[index]] 76 | temp = np.random.randn(15, ) 77 | temp = np.multiply(self.noise, temp) 78 | item['rel_bone_len'] += temp 79 | return item 80 | 81 | 82 | def main(): 83 | sik1m_train = SIK1M( 84 | data_split="train", 85 | data_root="data" 86 | ) 87 | sik1m_test = SIK1M( 88 | data_split="test" 89 | ) 90 | 91 | metas = sik1m_train[2] 92 | print(metas) 93 | metas = sik1m_train[2] 94 | print(metas) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dexter_object.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/dexter_object.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/egodexter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/egodexter.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ganerated_hands.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/ganerated_hands.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/hand143_panopticdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/hand143_panopticdb.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/hand_labels.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/hand_labels.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/handataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/handataset.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/rhd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/rhd.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/stb.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/ganerated_hands.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | r""" 3 | GANeratedDataset 4 | GANerated Hands for Real-Time 3D Hand Tracking from Monocular RGB, CVPR 2018 5 | Link to dataset: https://handtracker.mpi-inf.mpg.de/projects/GANeratedHands/GANeratedDataset.htm 6 | """ 7 | 8 | import os 9 | import pickle 10 | from builtins import print 11 | 12 | import PIL 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | from termcolor import colored 18 | from tqdm import tqdm 19 | 20 | import config as cfg 21 | import utils.handutils as handutils 22 | 23 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR) 24 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)} 25 | 26 | 27 | class GANeratedDataset(torch.utils.data.Dataset): 28 | 29 | def __init__(self, 30 | data_root, 31 | data_split='train', 32 | hand_side='right', 33 | njoints=21, 34 | use_cache=True, 35 | vis=False): 36 | if not os.path.exists(data_root): 37 | raise ValueError("data_root: %s not exist" % data_root) 38 | self.name = 'GANeratedHands Dataset' 39 | self.data_split = data_split 40 | self.hand_side = hand_side 41 | self.clr_paths = [] 42 | self.kp2ds = [] 43 | self.joints = [] 44 | self.centers = [] 45 | self.my_scales = [] 46 | self.njoints = njoints 47 | self.reslu = [256, 256] 48 | 49 | self.vis = vis 50 | 51 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0 52 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9 53 | 54 | self.intr = np.array([ 55 | [617.173, 0, 315.453], 56 | [0, 617.173, 242.259], 57 | [0, 0, 1]]) 58 | 59 | # [train|test|val|train_val|all] 60 | if data_split == 'train': 61 | self.sequence = ['training', ] 62 | else: 63 | print("GANeratedDataset only has train set!") 64 | return None 65 | 66 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "GANeratedHands") 67 | os.makedirs(self.cache_folder, exist_ok=True) 68 | cache_path = os.path.join( 69 | self.cache_folder, "{}.pkl".format(self.data_split) 70 | ) 71 | 72 | if os.path.exists(cache_path) and use_cache: 73 | with open(cache_path, "rb") as fid: 74 | annotations = pickle.load(fid) 75 | self.clr_paths = annotations["clr_paths"] 76 | self.kp2ds = annotations["kp2ds"] 77 | self.joints = annotations["joints"] 78 | self.centers = annotations["centers"] 79 | self.my_scales = annotations["my_scales"] 80 | print("GANeratedHands {} gt loaded from {}".format(self.data_split, cache_path)) 81 | return 82 | 83 | print("init GANeratedHands {}, It will take a while at first time".format(data_split)) 84 | 85 | for img_type in ['noObject/', 'withObject/']: 86 | folders = os.listdir(data_root + img_type) 87 | folders = sorted(folders) 88 | folders = [img_type + x + '/' for x in folders if len(x) == 4] 89 | 90 | for folder in folders: 91 | images = os.listdir(os.path.join(data_root + folder)) 92 | images = [data_root + folder + x for x in images if x.find('.png') > 0] 93 | images = sorted(images) 94 | 95 | self.clr_paths.extend(images) 96 | 97 | for idx in tqdm(range(len(self.clr_paths))): 98 | img_name = self.clr_paths[idx] 99 | 100 | fn_2d_keypoints = img_name.replace('color_composed.png', 'joint2D.txt') 101 | arr_2d_keypoints = np.loadtxt(fn_2d_keypoints, delimiter=',') 102 | arr_2d_keypoints = arr_2d_keypoints.reshape([-1, 2]) 103 | 104 | center = handutils.get_annot_center(arr_2d_keypoints) 105 | self.centers.append(center[np.newaxis, :]) 106 | 107 | my_scale = handutils.get_ori_crop_scale(mask=None, mask_flag=False, side=None, kp2d=arr_2d_keypoints, 108 | ) 109 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :] 110 | self.my_scales.append(my_scale) 111 | 112 | arr_2d_keypoints = arr_2d_keypoints[np.newaxis, :, :] 113 | self.kp2ds.append(arr_2d_keypoints) 114 | 115 | fn_3d_keypoints = img_name.replace('color_composed.png', 'joint_pos_global.txt') 116 | arr_3d_keypoints = np.loadtxt(fn_3d_keypoints, delimiter=',') 117 | arr_3d_keypoints = arr_3d_keypoints.reshape([-1, 3]) 118 | arr_3d_keypoints = arr_3d_keypoints[np.newaxis, :, :] 119 | self.joints.append(arr_3d_keypoints) 120 | 121 | self.joints = np.concatenate(self.joints, axis=0).astype(np.float32) 122 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2) 123 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 2) 124 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) 125 | 126 | if use_cache: 127 | full_info = { 128 | "clr_paths": self.clr_paths, 129 | "joints": self.joints, 130 | "kp2ds": self.kp2ds, 131 | "centers": self.centers, 132 | "my_scales": self.my_scales, 133 | } 134 | with open(cache_path, "wb") as fid: 135 | pickle.dump(full_info, fid) 136 | print("Wrote cache for dataset GANeratedDataset {} to {}".format( 137 | self.data_split, cache_path 138 | )) 139 | return 140 | 141 | def __len__(self): 142 | """for GANeratedHands Dataset total (1,500 * 2) * 2 * 6 = 36,000 samples 143 | """ 144 | return len(self.clr_paths) 145 | 146 | def __str__(self): 147 | info = "GANeratedHands {} set. lenth {}".format( 148 | self.data_split, len(self.clr_paths) 149 | ) 150 | return colored(info, 'blue', attrs=['bold']) 151 | 152 | def _is_valid(self, clr, index): 153 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image)) 154 | 155 | if not valid_data: 156 | raise Exception("Encountered error processing GAN[{}]".format(index)) 157 | return valid_data 158 | 159 | def get_sample(self, index): 160 | flip = True if self.hand_side != "left" else False 161 | 162 | intr = self.intr 163 | 164 | # prepare color image 165 | clr = Image.open(self.clr_paths[index]).convert("RGB") 166 | self._is_valid(clr, index) 167 | 168 | # prepare kp2d 169 | kp2d = self.kp2ds[index].copy() 170 | 171 | # prepare joint 172 | joint = self.joints[index].copy() # (21, 3) 173 | center = self.centers[index].copy() 174 | my_scale = self.my_scales[index].copy() 175 | if flip: 176 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT) 177 | center[0] = clr.size[0] - center[0] 178 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0] 179 | joint[:, 0] = -joint[:, 0] 180 | 181 | sample = { 182 | 'index': index, 183 | 'clr': clr, 184 | 'kp2d': kp2d, 185 | 'center': center, 186 | 'my_scale': my_scale, 187 | 'joint': joint, 188 | 'intr': intr, 189 | } 190 | 191 | # visualization 192 | if self.vis: 193 | fig = plt.figure(figsize=(20, 20)) 194 | clr_ = np.array(clr) 195 | 196 | plt.subplot(1, 3, 1) 197 | clr1 = clr_.copy() 198 | plt.imshow(clr1) 199 | 200 | plt.subplot(1, 3, 2) 201 | clr2 = clr_.copy() 202 | plt.imshow(clr2) 203 | 204 | for p in range(kp2d.shape[0]): 205 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.') 206 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5) 207 | 208 | ax = fig.add_subplot(133, projection='3d') 209 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint') 210 | plt.plot(joint[:5, 0], joint[:5, 1], 211 | joint[:5, 2], 212 | 'r', 213 | label='thumb') 214 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1], 215 | joint[[0, 5, 6, 7, 8, ], 2], 216 | 'b', 217 | label='index') 218 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1], 219 | joint[[0, 9, 10, 11, 12], 2], 220 | 'b', 221 | label='middle') 222 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1], 223 | joint[[0, 13, 14, 15, 16], 2], 224 | 'b', 225 | label='ring') 226 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1], 227 | joint[[0, 17, 18, 19, 20], 2], 228 | 'b', 229 | label='pinky') 230 | # snap convention 231 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb') 232 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'ro', label='index') 233 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'ro', label='middle') 234 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring') 235 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'ro', label='pinky') 236 | 237 | plt.title('3D annotations') 238 | ax.set_xlabel('x') 239 | ax.set_ylabel('y') 240 | ax.set_zlabel('z') 241 | plt.legend() 242 | ax.view_init(-90, -90) 243 | plt.show() 244 | 245 | return sample 246 | 247 | 248 | if __name__ == '__main__': 249 | data_split = 'train' 250 | gan = GANeratedDataset( 251 | data_root='/home/chen/datasets/GANeratedHands_Release/data/', 252 | data_split=data_split, 253 | hand_side='right', 254 | njoints=21, 255 | use_cache=False, 256 | vis=True) 257 | print("len(gan)=", len(gan)) 258 | for i in range(len(gan)): 259 | print("i=", i) 260 | data = gan.get_sample(i) 261 | -------------------------------------------------------------------------------- /datasets/hand143_panopticdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | r""" 3 | Hands from Panoptic Studio by Multiview Bootstrapping (14817 annotations) 4 | Hand Keypoint Detection in Single Images using Multiview Bootstrapping, CVPR 2017 5 | Link to dataset: http://domedb.perception.cs.cmu.edu/handdb.html 6 | Download:http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand143_panopticdb.tar 7 | """ 8 | 9 | import json 10 | import os 11 | import pickle 12 | 13 | import PIL 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import torch.utils.data 17 | from PIL import Image 18 | from termcolor import colored 19 | from tqdm import tqdm 20 | 21 | import config as cfg 22 | import utils.handutils as handutils 23 | 24 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR) 25 | 26 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)} 27 | 28 | 29 | class Hand143_panopticdb(torch.utils.data.Dataset): 30 | def __init__( 31 | self, 32 | data_root="/home/chen/datasets/CMU/hand143_panopticdb", 33 | data_split='train', 34 | hand_side='right', 35 | njoints=21, 36 | use_cache=True, 37 | vis=False 38 | ): 39 | 40 | if not os.path.exists(data_root): 41 | raise ValueError("data_root: %s not exist" % data_root) 42 | 43 | self.name = 'hand143_panopticdb' 44 | self.data_split = data_split 45 | self.hand_side = hand_side 46 | self.clr_paths = [] 47 | self.kp2ds = [] 48 | self.centers = [] 49 | self.my_scales = [] 50 | self.njoints = njoints 51 | self.reslu = [1920, 1080] 52 | self.vis = vis 53 | 54 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0 55 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9 56 | 57 | # [train|test|val|train_val|all] 58 | if data_split == 'train': 59 | self.sequence = ['training', ] 60 | else: 61 | print("hand143_panopticdb only has train_set!") 62 | return 63 | 64 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "hand143_panopticdb") 65 | os.makedirs(self.cache_folder, exist_ok=True) 66 | cache_path = os.path.join( 67 | self.cache_folder, "{}.pkl".format(self.data_split) 68 | ) 69 | 70 | if os.path.exists(cache_path) and use_cache: 71 | with open(cache_path, "rb") as fid: 72 | annotations = pickle.load(fid) 73 | self.clr_paths = annotations["clr_paths"] 74 | self.kp2ds = annotations["kp2ds"] 75 | self.centers = annotations["centers"] 76 | self.my_scales = annotations["my_scales"] 77 | print("hand143_panopticdb {} gt loaded from {}".format(self.data_split, cache_path)) 78 | return 79 | 80 | self.clr_root_list = [ 81 | os.path.join(data_root, "imgs") 82 | ] 83 | 84 | self.ann_list = [ 85 | os.path.join( 86 | data_root, 87 | "hands_v143_14817.json" 88 | ) 89 | ] 90 | 91 | for clr_root, ann in zip(self.clr_root_list, self.ann_list): 92 | 93 | jsonPath = os.path.join(ann) 94 | with open(jsonPath, 'r') as fid: 95 | dat_all = json.load(fid) 96 | dat_all = dat_all['root'] 97 | 98 | for i in tqdm(range(len(dat_all))): 99 | clrpth = os.path.join(clr_root, '%.8d.jpg' % i) 100 | self.clr_paths.append(clrpth) 101 | 102 | dat = dat_all[i] 103 | kp2d = np.array(dat['joint_self'])[:, : 2] # kp 2d left & right hand 104 | center = handutils.get_annot_center(kp2d) 105 | my_scale = handutils.get_ori_crop_scale(mask=None, side=None, mask_flag=False, kp2d=kp2d) 106 | 107 | kp2d = kp2d[np.newaxis, :, :] 108 | self.kp2ds.append(kp2d) 109 | 110 | center = center[np.newaxis, :] 111 | self.centers.append(center) 112 | 113 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :] 114 | self.my_scales.append(my_scale) 115 | 116 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2) 117 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1) 118 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1) 119 | 120 | if use_cache: 121 | full_info = { 122 | "clr_paths": self.clr_paths, 123 | "kp2ds": self.kp2ds, 124 | "centers": self.centers, 125 | "my_scales": self.my_scales, 126 | } 127 | with open(cache_path, "wb") as fid: 128 | pickle.dump(full_info, fid) 129 | print("Wrote cache for dataset hand143_panopticdb {} to {}".format( 130 | self.data_split, cache_path 131 | )) 132 | return 133 | 134 | def _is_valid(self, clr, index): 135 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image)) 136 | 137 | if not valid_data: 138 | raise Exception("Encountered error processing cmu_1_[{}]".format(index)) 139 | return valid_data 140 | 141 | def __len__(self): 142 | return len(self.clr_paths) 143 | 144 | def __str__(self): 145 | info = "Hand143_panopticdb {} set. lenth {}".format( 146 | self.data_split, len(self.clr_paths) 147 | ) 148 | return colored(info, 'blue', attrs=['bold']) 149 | 150 | def get_sample(self, index): 151 | flip = True if self.hand_side != 'right' else False 152 | 153 | # prepare color image 154 | clr = Image.open(self.clr_paths[index]).convert("RGB") 155 | self._is_valid(clr, index) 156 | 157 | # prepare kp2d 158 | kp2d = self.kp2ds[index].copy() 159 | center = self.centers[index].copy() 160 | my_scale = self.my_scales[index].copy() 161 | if flip: 162 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT) 163 | center[0] = clr.size[0] - center[0] 164 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0] 165 | 166 | sample = { 167 | 'index': index, 168 | 'clr': clr, 169 | 'kp2d': kp2d, 170 | 'center': center, 171 | 'my_scale': my_scale, 172 | } 173 | 174 | # visualization 175 | if self.vis: 176 | plt.figure(figsize=(20, 20)) 177 | clr_ = np.array(clr) 178 | 179 | plt.subplot(1, 2, 1) 180 | clr1 = clr_.copy() 181 | plt.imshow(clr1) 182 | plt.title('color image') 183 | 184 | plt.subplot(1, 2, 2) 185 | clr2 = clr_.copy() 186 | plt.imshow(clr2) 187 | plt.plot(200, 100, 'r.', linewidth=10) # opencv convention 188 | for p in range(kp2d.shape[0]): 189 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.') 190 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5) 191 | plt.title('2D annotations') 192 | 193 | plt.show() 194 | 195 | return sample 196 | 197 | 198 | def main(): 199 | hand143_panopticdb = Hand143_panopticdb( 200 | data_root="/home/chen/datasets/CMU/hand143_panopticdb", 201 | data_split='train', 202 | hand_side='right', 203 | njoints=21, 204 | use_cache=True, 205 | vis=True 206 | ) 207 | print("len(hand143_panopticdb)=", len(hand143_panopticdb)) 208 | 209 | for i in tqdm(range(len(hand143_panopticdb))): 210 | print("i=", i) 211 | hand143_panopticdb.get_sample(i) 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /datasets/hand_labels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | r""" 3 | Hands with Manual Keypoint Annotations (Training: 1912 annotations, Testing: 846 annotations) 4 | Hand Keypoint Detection in Single Images using Multiview Bootstrapping, CVPR 2017 5 | Link to dataset: http://domedb.perception.cs.cmu.edu/handdb.html 6 | Download:http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip 7 | """ 8 | 9 | import json 10 | import os 11 | import pickle 12 | 13 | import PIL 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import torch.utils.data 17 | from PIL import Image 18 | from termcolor import colored 19 | from tqdm import tqdm 20 | 21 | import config as cfg 22 | import utils.handutils as handutils 23 | 24 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR) 25 | 26 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)} 27 | 28 | 29 | class Hand_labels(torch.utils.data.Dataset): 30 | def __init__( 31 | self, 32 | data_root="/home/chen/datasets/CMU/hand_labels", 33 | data_split='train', 34 | hand_side='right', 35 | njoints=21, 36 | use_cache=True, 37 | vis=False 38 | ): 39 | 40 | if not os.path.exists(data_root): 41 | raise ValueError("data_root: %s not exist" % data_root) 42 | 43 | self.vis = vis 44 | self.name = 'CMU:hand_labels' 45 | self.data_split = data_split 46 | self.hand_side = hand_side 47 | self.clr_paths = [] 48 | self.kp2ds = [] 49 | self.centers = [] 50 | self.sides = [] 51 | self.my_scales = [] 52 | self.njoints = njoints 53 | self.reslu = [1920, 1080] 54 | 55 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0 56 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9 57 | 58 | # [train|test|val|train_val|all] 59 | if data_split == 'train': 60 | self.sequence = ['manual_train', ] 61 | elif data_split == 'test': 62 | self.sequence = ['manual_test', ] 63 | elif data_split == 'val': 64 | self.sequence = ['manual_test', ] 65 | elif data_split == 'train_val': 66 | self.sequence = ['manual_train', ] 67 | elif data_split == 'all': 68 | self.sequence = ['manual_train', 'manual_test'] 69 | else: 70 | raise ValueError("hand_labels only has train_set!") 71 | 72 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "hand_labels") 73 | os.makedirs(self.cache_folder, exist_ok=True) 74 | cache_path = os.path.join( 75 | self.cache_folder, "{}.pkl".format(self.data_split) 76 | ) 77 | 78 | if os.path.exists(cache_path) and use_cache: 79 | with open(cache_path, "rb") as fid: 80 | annotations = pickle.load(fid) 81 | self.sides = annotations["sides"] 82 | self.clr_paths = annotations["clr_paths"] 83 | self.kp2ds = annotations["kp2ds"] 84 | self.centers = annotations["centers"] 85 | self.my_scales = annotations["my_scales"] 86 | print("hand_labels {} gt loaded from {}".format(self.data_split, cache_path)) 87 | return 88 | 89 | datapath_list = [ 90 | os.path.join(data_root, seq) for seq in self.sequence 91 | ] 92 | 93 | for datapath in datapath_list: 94 | files = sorted([f for f in os.listdir(datapath) if f.endswith('.json')]) 95 | 96 | for idx in tqdm(range(len(files))): 97 | f = files[idx] 98 | with open(os.path.join(datapath, f), 'r') as fid: 99 | dat = json.load(fid) 100 | 101 | kp2d = np.array(dat['hand_pts'])[:, : 2] 102 | is_left = dat['is_left'] 103 | self.sides.append("left" if is_left else "right") 104 | 105 | clr_pth = os.path.join(datapath, f[0:-5] + '.jpg') 106 | self.clr_paths.append(clr_pth) 107 | center = handutils.get_annot_center(kp2d) 108 | my_scale = handutils.get_ori_crop_scale(mask=False, mask_flag=False, side=None, kp2d=kp2d) 109 | 110 | kp2d = kp2d[np.newaxis, :, :] 111 | self.kp2ds.append(kp2d) 112 | 113 | center = center[np.newaxis, :] 114 | self.centers.append(center) 115 | 116 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :] 117 | self.my_scales.append(my_scale) 118 | 119 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2) 120 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1) 121 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1) 122 | 123 | if use_cache: 124 | full_info = { 125 | "sides": self.sides, 126 | "clr_paths": self.clr_paths, 127 | "kp2ds": self.kp2ds, 128 | "centers": self.centers, 129 | "my_scales": self.my_scales, 130 | } 131 | with open(cache_path, "wb") as fid: 132 | pickle.dump(full_info, fid) 133 | print("Wrote cache for dataset hand_labels {} to {}".format( 134 | self.data_split, cache_path 135 | )) 136 | return 137 | 138 | def _is_valid(self, clr, index): 139 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image)) 140 | 141 | if not valid_data: 142 | raise Exception("Encountered error processing CMU_2_[{}]".format(index)) 143 | return valid_data 144 | 145 | def __len__(self): 146 | return len(self.clr_paths) 147 | 148 | def __str__(self): 149 | info = "hand_labels {} set. lenth {}".format( 150 | self.data_split, len(self.clr_paths) 151 | ) 152 | return colored(info, 'blue', attrs=['bold']) 153 | 154 | def get_sample(self, index): 155 | flip = True if self.hand_side != self.sides[index] else False 156 | 157 | # prepare color image 158 | clr = Image.open(self.clr_paths[index]).convert("RGB") 159 | self._is_valid(clr, index) 160 | 161 | # prepare kp2d 162 | kp2d = self.kp2ds[index].copy() 163 | center = self.centers[index].copy() 164 | my_scale = self.my_scales[index].copy() 165 | if flip: 166 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT) 167 | center[0] = clr.size[0] - center[0] 168 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0] 169 | 170 | sample = { 171 | 'index': index, 172 | 'clr': clr, 173 | 'kp2d': kp2d, 174 | 'center': center, 175 | 'my_scale': my_scale, 176 | } 177 | 178 | # visualization 179 | if self.vis: 180 | clr_ = np.array(clr) 181 | plt.figure(figsize=(20, 20)) 182 | plt.subplot(1, 2, 1) 183 | clr1 = clr_.copy() 184 | plt.imshow(clr1) 185 | plt.title('color image') 186 | 187 | plt.subplot(1, 2, 2) 188 | clr2 = clr_.copy() 189 | plt.imshow(clr2) 190 | plt.plot(200, 100, 'r.', linewidth=10) # opencv convention 191 | for p in range(kp2d.shape[0]): 192 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.') 193 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5) 194 | plt.title('2D annotations') 195 | 196 | plt.show() 197 | 198 | return sample 199 | 200 | 201 | def main(): 202 | hand_labels = Hand_labels( 203 | data_root="/home/chen/datasets/CMU/hand_labels", 204 | data_split='train', 205 | hand_side='right', 206 | njoints=21, 207 | use_cache=True, 208 | vis=True 209 | ) 210 | print("len(hand_labels)=", len(hand_labels)) 211 | 212 | for i in tqdm(range(len(hand_labels))): 213 | print("i=", i) 214 | data = hand_labels.get_sample(i) 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /datasets/handataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | r""" 3 | Hand dataset controll all sub dataset 4 | """ 5 | 6 | import os 7 | import random 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch.utils.data 12 | from PIL import Image, ImageFilter 13 | from termcolor import colored 14 | from tqdm import tqdm 15 | 16 | import config as cfg 17 | import utils.func as func 18 | import utils.handutils as handutils 19 | import utils.heatmaputils as hmutils 20 | import utils.imgutils as imutils 21 | from datasets.dexter_object import DexterObjectDataset 22 | from datasets.ganerated_hands import GANeratedDataset 23 | from datasets.hand143_panopticdb import Hand143_panopticdb 24 | from datasets.hand_labels import Hand_labels 25 | from datasets.rhd import RHDDataset 26 | from datasets.stb import STBDataset 27 | 28 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)} 29 | 30 | 31 | class HandDataset(torch.utils.data.Dataset): 32 | def __init__( 33 | self, 34 | data_split='train', 35 | data_root="/disk1/data", 36 | subset_name=['rhd', 'stb'], 37 | hand_side='right', 38 | sigma=1.0, 39 | inp_res=128, 40 | hm_res=32, 41 | njoints=21, 42 | train=True, 43 | scale_jittering=0.1, 44 | center_jettering=0.1, 45 | max_rot=np.pi, 46 | hue=0.15, 47 | saturation=0.5, 48 | contrast=0.5, 49 | brightness=0.5, 50 | blur_radius=0.5, vis=False 51 | ): 52 | 53 | self.inp_res = inp_res # 128 # network input resolution 54 | self.hm_res = hm_res # 32 # out_testset hm resolution 55 | self.njoints = njoints 56 | self.sigma = sigma 57 | self.max_rot = max_rot 58 | 59 | # Training attributes 60 | self.train = train 61 | self.scale_jittering = scale_jittering 62 | self.center_jittering = center_jettering 63 | 64 | # Color jitter attributes 65 | self.hue = hue 66 | self.contrast = contrast 67 | self.brightness = brightness 68 | self.saturation = saturation 69 | self.blur_radius = blur_radius 70 | 71 | self.datasets = [] 72 | self.ref_bone_link = (0, 9) # mid mcp 73 | self.joint_root_idx = 9 # root 74 | 75 | self.vis = vis 76 | 77 | if 'stb' in subset_name: 78 | self.stb = STBDataset( 79 | data_root=os.path.join(data_root, 'STB'), 80 | data_split=data_split, 81 | hand_side=hand_side, 82 | njoints=njoints, 83 | ) 84 | print(self.stb) 85 | self.datasets.append(self.stb) 86 | 87 | if 'rhd' in subset_name: 88 | self.rhd = RHDDataset( 89 | data_root=os.path.join(data_root, 'RHD/RHD_published_v2'), 90 | data_split=data_split, 91 | hand_side=hand_side, 92 | njoints=njoints, 93 | ) 94 | print(self.rhd) 95 | self.datasets.append(self.rhd) 96 | 97 | if 'cmu' in subset_name: 98 | self.hand143_panopticdb = Hand143_panopticdb( 99 | data_root=os.path.join(data_root, 'CMU/hand143_panopticdb'), 100 | data_split=data_split, 101 | hand_side=hand_side, 102 | njoints=njoints, 103 | ) 104 | print(self.hand143_panopticdb) 105 | self.datasets.append(self.hand143_panopticdb) 106 | 107 | self.hand_labels = Hand_labels( 108 | data_root=os.path.join(data_root, 'CMU/hand_labels'), 109 | data_split=data_split, 110 | hand_side=hand_side, 111 | njoints=njoints, 112 | ) 113 | print(self.hand_labels) 114 | self.datasets.append(self.hand_labels) 115 | 116 | info = "CMU {} set. lenth {}".format( 117 | data_split, len(self.hand_labels) + len(self.hand143_panopticdb) 118 | ) 119 | print(colored(info, 'yellow', attrs=['bold'])) 120 | 121 | if 'gan' in subset_name: 122 | self.gan = GANeratedDataset( 123 | data_root=os.path.join(data_root, 'GANeratedHands_Release/data/'), 124 | data_split=data_split, 125 | hand_side=hand_side, 126 | njoints=njoints, 127 | ) 128 | print(self.gan) 129 | self.datasets.append(self.gan) 130 | 131 | if 'do' in subset_name: 132 | self.do = DexterObjectDataset( 133 | data_root=os.path.join(data_root, 'dexter+object'), 134 | data_split=data_split, 135 | hand_side=hand_side, 136 | njoints=njoints, 137 | ) 138 | print(self.do) 139 | self.datasets.append(self.do) 140 | 141 | self.total_data = 0 142 | for ds in self.datasets: 143 | self.total_data += len(ds) 144 | 145 | def __getitem__(self, index): 146 | rng = np.random.RandomState(seed=random.randint(0, 1024)) 147 | try: 148 | sample, ds = self._get_sample(index) 149 | except Exception: 150 | index = np.random.randint(0, len(self)) 151 | sample, ds = self._get_sample(index) 152 | 153 | clr = sample['clr'] 154 | my_clr1 = clr.copy() 155 | center = sample['center'] 156 | scale = sample['my_scale'] 157 | if 'intr' in sample.keys(): 158 | intr = sample['intr'] 159 | 160 | # Data augmentation 161 | if self.train: 162 | center_offsets = ( 163 | self.center_jittering 164 | * scale 165 | * rng.uniform(low=-1, high=1, size=2) 166 | ) 167 | center = center + center_offsets.astype(int) 168 | 169 | # Scale jittering 170 | scale_jittering = self.scale_jittering * rng.randn() + 1 171 | scale_jittering = np.clip( 172 | scale_jittering, 173 | 1 - self.scale_jittering, 174 | 1 + self.scale_jittering, 175 | ) 176 | scale = scale * scale_jittering 177 | rot = rng.uniform(low=-self.max_rot, high=self.max_rot) 178 | else: 179 | rot = 0 180 | 181 | rot_mat = np.array([ 182 | [np.cos(rot), -np.sin(rot), 0], 183 | [np.sin(rot), np.cos(rot), 0], 184 | [0, 0, 1], 185 | ]).astype(np.float32) 186 | 187 | if 'intr' in sample.keys(): 188 | affinetrans, post_rot_trans = handutils.get_affine_transform( 189 | center=center, 190 | scale=scale, 191 | optical_center=[intr[0, 2], intr[1, 2]], 192 | out_res=[self.inp_res, self.inp_res], 193 | rot=rot 194 | ) 195 | else: 196 | affinetrans, post_rot_trans = handutils.get_affine_transform_test( 197 | center, scale, [self.inp_res, self.inp_res], rot=rot 198 | ) 199 | 200 | ''' prepare kp2d ''' 201 | kp2d = sample['kp2d'] 202 | kp2d_ori = kp2d.copy() 203 | kp2d = handutils.transform_coords(kp2d, affinetrans) 204 | 205 | ''' Generate GT Gussian hm and hm veil ''' 206 | hm = np.zeros( 207 | (self.njoints, self.hm_res, self.hm_res), 208 | dtype='float32' 209 | ) # (CHW) 210 | hm_veil = np.ones(self.njoints, dtype='float32') 211 | for i in range(self.njoints): 212 | kp = ( 213 | (kp2d[i] / self.inp_res) * self.hm_res 214 | ).astype(np.int32) # kp uv: [0~256] -> [0~64] 215 | hm[i], aval = hmutils.gen_heatmap(hm[i], kp, self.sigma) 216 | hm_veil[i] *= aval 217 | 218 | joint = np.zeros([21, 3]) 219 | delta_map = np.zeros([21, 3, 32, 32]) 220 | location_map = np.zeros([21, 3, 32, 32]) 221 | flag = 0 222 | 223 | if 'joint' in sample.keys(): 224 | 225 | flag = 1 226 | ''' prepare joint ''' 227 | joint = sample['joint'] 228 | if self.train: 229 | joint = rot_mat.dot( 230 | joint.transpose(1, 0) 231 | ).transpose() 232 | 233 | joint_bone = 0 234 | for jid, nextjid in zip(self.ref_bone_link[:-1], self.ref_bone_link[1:]): 235 | joint_bone += np.linalg.norm(joint[nextjid] - joint[jid]) 236 | joint_root = joint[self.joint_root_idx] 237 | joint_bone = np.atleast_1d(joint_bone) 238 | 239 | '''prepare location maps L''' 240 | jointR = joint - joint_root[np.newaxis, :] # root relative 241 | jointRS = jointR / joint_bone # scale invariant 242 | # '''jointRS.shape= (21, 3) to locationmap(21,3,32,32)''' 243 | location_map = jointRS[:, :, np.newaxis, np.newaxis].repeat(32, axis=-2).repeat(32, axis=-1) 244 | 245 | '''prepare delta maps D''' 246 | kin_chain = [ 247 | jointRS[i] - jointRS[cfg.SNAP_PARENT[i]] 248 | for i in range(21) 249 | ] 250 | kin_chain = np.array(kin_chain) # id 0's parent is itself #21*3 251 | kin_len = np.linalg.norm( 252 | kin_chain, ord=2, axis=-1, keepdims=True # 21*1 253 | ) 254 | kin_chain[1:] = kin_chain[1:] / kin_len[1:] 255 | # '''kin_chain(21, 3) to delta_map(21,3,32,32)''' 256 | delta_map = kin_chain[:, :, np.newaxis, np.newaxis].repeat(32, axis=-2).repeat(32, axis=-1) 257 | 258 | if 'tip' in sample.keys(): 259 | joint = sample['tip'] 260 | if self.train: 261 | joint = rot_mat.dot( 262 | joint.transpose(1, 0) 263 | ).transpose() 264 | 265 | ''' prepare clr image ''' 266 | if self.train: 267 | blur_radius = random.random() * self.blur_radius 268 | clr = clr.filter(ImageFilter.GaussianBlur(blur_radius)) 269 | clr = imutils.color_jitter( 270 | clr, 271 | brightness=self.brightness, 272 | saturation=self.saturation, 273 | hue=self.hue, 274 | contrast=self.contrast, 275 | ) 276 | 277 | # Transform and crop 278 | clr = handutils.transform_img( 279 | clr, affinetrans, [self.inp_res, self.inp_res] 280 | ) 281 | clr = clr.crop((0, 0, self.inp_res, self.inp_res)) 282 | my_clr2 = clr.copy() 283 | 284 | ''' implicit HWC -> CHW, 255 -> 1 ''' 285 | clr = func.to_tensor(clr).float() 286 | ''' 0-mean, 1 std, [0,1] -> [-0.5, 0.5] ''' 287 | clr = func.normalize(clr, [0.5, 0.5, 0.5], [1, 1, 1]) 288 | 289 | # visualization 290 | if self.vis: 291 | 292 | clr1 = my_clr1.copy() 293 | 294 | fig = plt.figure(figsize=(20, 10)) 295 | plt.subplot(1, 4, 1) 296 | plt.imshow(np.asarray(clr1)) 297 | plt.title('ori_Color+2D annotations') 298 | plt.plot(kp2d_ori[0, 0], kp2d_ori[0, 1], 'ro', markersize=5) 299 | plt.text(kp2d_ori[0][0], kp2d_ori[0][1], '0', color="w", fontsize=7.5) 300 | for p in range(1, kp2d_ori.shape[0]): 301 | plt.plot(kp2d_ori[p][0], kp2d_ori[p][1], 'bo', markersize=5) 302 | plt.text(kp2d_ori[p][0], kp2d_ori[p][1], '{0}'.format(p), color="w", fontsize=5) 303 | 304 | plt.subplot(1, 4, 2) 305 | clr2 = np.array(my_clr2.copy()) 306 | plt.imshow(clr2) 307 | plt.plot(kp2d[0, 0], kp2d[0, 1], 'ro', markersize=5) 308 | plt.text(kp2d[0][0], kp2d[0][1], '0', color="w", fontsize=7.5) 309 | for p in range(1, kp2d.shape[0]): 310 | plt.plot(kp2d[p][0], kp2d[p][1], 'bo', markersize=5) 311 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), color="w", fontsize=5) 312 | plt.title('cropped_Color+2D annotations') 313 | 314 | plt.subplot(1, 4, 3) 315 | clr3 = my_clr2.copy().resize((self.hm_res, self.hm_res), Image.ANTIALIAS) 316 | tmp = clr3.convert('L') 317 | tmp = np.array(tmp) 318 | for k in range(hm.shape[0]): 319 | tmp = tmp + hm[k] * 64 320 | plt.imshow(tmp) 321 | plt.title('heatmap') 322 | 323 | if 'joint' in sample.keys(): 324 | ax = fig.add_subplot(144, projection='3d') 325 | 326 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint') 327 | 328 | plt.plot(joint[:5, 0], joint[:5, 1], 329 | joint[:5, 2], 330 | 'r', 331 | label='thumb') 332 | 333 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1], 334 | joint[[0, 5, 6, 7, 8, ], 2], 335 | 'b', 336 | label='index') 337 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1], 338 | joint[[0, 9, 10, 11, 12], 2], 339 | 'b', 340 | label='middle') 341 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1], 342 | joint[[0, 13, 14, 15, 16], 2], 343 | 'b', 344 | label='ring') 345 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1], 346 | joint[[0, 17, 18, 19, 20], 2], 347 | 'b', 348 | label='pinky') 349 | # snap convention 350 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb') 351 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'r*', label='index') 352 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'rs', label='middle') 353 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring') 354 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'rv', label='pinky') 355 | 356 | plt.title('3D annotations') 357 | ax.set_xlabel('x') 358 | ax.set_ylabel('y') 359 | ax.set_zlabel('z') 360 | plt.legend() 361 | ax.view_init(-90, -90) 362 | 363 | plt.show() 364 | 365 | ## to torch tensor 366 | clr = clr 367 | hm = torch.from_numpy(hm).float() 368 | hm_veil = torch.from_numpy(hm_veil).float() 369 | joint = torch.from_numpy(joint).float() 370 | location_map = torch.from_numpy(location_map).float() 371 | delta_map = torch.from_numpy(delta_map).float() 372 | 373 | metas = { 374 | 'index': index, 375 | 'clr': clr, 376 | 'hm': hm, 377 | 'hm_veil': hm_veil, 378 | 'location_map': location_map, 379 | 'delta_map': delta_map, 380 | 'flag_3d': flag, 381 | "joint": joint 382 | } 383 | 384 | return metas 385 | 386 | def _get_sample(self, index): 387 | base = 0 388 | dataset = None 389 | for ds in self.datasets: 390 | if index < base + len(ds): 391 | sample = ds.get_sample(index - base) 392 | dataset = ds 393 | break 394 | else: 395 | base += len(ds) 396 | return sample, dataset 397 | 398 | def __len__(self): 399 | return self.total_data 400 | 401 | 402 | if __name__ == '__main__': 403 | test_set = HandDataset( 404 | data_split='test', 405 | train=False, 406 | scale_jittering=0.1, 407 | center_jettering=0.1, 408 | max_rot=0.5 * np.pi, 409 | subset_name=["rhd", "stb", "do", "eo"], 410 | data_root="/home/chen/datasets/", vis=True 411 | ) 412 | 413 | for id in tqdm(range(0, len(test_set), 10)): 414 | print("id=", id) 415 | data = test_set[id] 416 | -------------------------------------------------------------------------------- /datasets/rhd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Lixin YANG. All Rights Reserved. 2 | r""" 3 | Randered dataset 4 | Learning to Estimate 3D Hand joint from Single RGB Images, ICCV 2017 5 | """ 6 | 7 | import os 8 | import pickle 9 | 10 | import PIL 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch.utils.data 14 | from PIL import Image 15 | from progress.bar import Bar 16 | from termcolor import colored 17 | from tqdm import tqdm 18 | 19 | import config as cfg 20 | import utils.handutils as handutils 21 | 22 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR) 23 | 24 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)} 25 | rhd_joint_name2id = {w: i for i, w in enumerate(cfg.rhd_joints)} 26 | rhd_to_snap_id = [snap_joint_name2id[joint_name] for joint_name in cfg.rhd_joints] 27 | 28 | 29 | class RHDDataset(torch.utils.data.Dataset): 30 | def __init__( 31 | self, 32 | data_root="/disk1/data/RHD/RHD_published_v2", 33 | data_split='train', 34 | hand_side='right', 35 | njoints=21, 36 | use_cache=True, 37 | visual=False 38 | ): 39 | 40 | if not os.path.exists(data_root): 41 | raise ValueError("data_root: %s not exist" % data_root) 42 | 43 | self.name = 'rhd' 44 | self.data_split = data_split 45 | self.hand_side = hand_side 46 | self.clr_paths = [] 47 | self.mask_paths = [] 48 | self.joints = [] 49 | self.kp2ds = [] 50 | self.centers = [] 51 | self.my_scales = [] 52 | self.sides = [] 53 | self.intrs = [] 54 | self.njoints = njoints # total 21 hand parts 55 | self.reslu = [320, 320] 56 | 57 | self.visual = visual 58 | 59 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0 60 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9 61 | 62 | # [train|test|val|train_val|all] 63 | if data_split == 'train': 64 | self.sequence = ['training', ] 65 | elif data_split == 'test': 66 | self.sequence = ['evaluation', ] 67 | elif data_split == 'val': 68 | self.sequence = ['evaluation', ] 69 | elif data_split == 'train_val': 70 | self.sequence = ['training', ] 71 | elif data_split == 'all': 72 | self.sequence = ['training', 'evaluation'] 73 | else: 74 | raise ValueError("split {} not in [train|test|val|train_val|all]".format(data_split)) 75 | 76 | self.cache_folder = os.path.join(CACHE_HOME, "my-{}".format(data_split), "rhd") 77 | os.makedirs(self.cache_folder, exist_ok=True) 78 | cache_path = os.path.join( 79 | self.cache_folder, "{}.pkl".format(self.data_split) 80 | ) 81 | if os.path.exists(cache_path) and use_cache: 82 | with open(cache_path, "rb") as fid: 83 | annotations = pickle.load(fid) 84 | self.sides = annotations["sides"] 85 | self.clr_paths = annotations["clr_paths"] 86 | self.mask_paths = annotations["mask_paths"] 87 | self.joints = annotations["joints"] 88 | self.kp2ds = annotations["kp2ds"] 89 | self.intrs = annotations["intrs"] 90 | self.centers = annotations["centers"] 91 | self.my_scales = annotations["my_scales"] 92 | print("rhd {} gt loaded from {}".format(self.data_split, cache_path)) 93 | return 94 | 95 | datapath_list = [ 96 | os.path.join(data_root, seq) for seq in self.sequence 97 | ] 98 | annoname_list = [ 99 | "anno_{}.pickle".format(seq) for seq in self.sequence 100 | ] 101 | anno_list = [ 102 | os.path.join(datapath, annoname) \ 103 | for datapath, annoname in zip(datapath_list, annoname_list) 104 | ] 105 | clr_root_list = [ 106 | os.path.join(datapath, "color") for datapath in datapath_list 107 | ] 108 | dep_root_list = [ 109 | os.path.join(datapath, "depth") for datapath in datapath_list 110 | ] 111 | mask_root_list = [ 112 | os.path.join(datapath, "mask") for datapath in datapath_list 113 | ] 114 | 115 | print("init RHD {}, It will take a while at first time".format(data_split)) 116 | for anno, clr_root, dep_root, mask_root \ 117 | in zip( 118 | anno_list, 119 | clr_root_list, 120 | dep_root_list, 121 | mask_root_list 122 | ): 123 | 124 | with open(anno, 'rb') as fi: 125 | rawdatas = pickle.load(fi) 126 | fi.close() 127 | 128 | bar = Bar('RHD', max=len(rawdatas)) 129 | for i in tqdm(range(len(rawdatas))): 130 | 131 | raw = rawdatas[i] 132 | rawkp2d = raw['uv_vis'][:, : 2] # kp 2d left & right hand 133 | rawvis = raw['uv_vis'][:, 2] 134 | 135 | rawjoint = raw['xyz'] # x, y, z coordinates of the keypoints, in meters 136 | rawintr = raw['K'] 137 | 138 | ''' "both" means left, right''' 139 | kp2dboth = [ 140 | rawkp2d[:21][rhd_to_snap_id, :], 141 | rawkp2d[21:][rhd_to_snap_id, :] 142 | ] 143 | visboth = [ 144 | rawvis[:21][rhd_to_snap_id], 145 | rawvis[21:][rhd_to_snap_id] 146 | ] 147 | jointboth = [ 148 | rawjoint[:21][rhd_to_snap_id, :], 149 | rawjoint[21:][rhd_to_snap_id, :] 150 | ] 151 | 152 | intrboth = [rawintr, rawintr] 153 | sideboth = ['l', 'r'] 154 | 155 | l_kp_count = np.sum(raw['uv_vis'][:21, 2] == 1) 156 | r_kp_count = np.sum(raw['uv_vis'][21:, 2] == 1) 157 | vis_side = 'l' if l_kp_count > r_kp_count else 'r' 158 | 159 | for kp2d, vis, joint, side, intr \ 160 | in zip(kp2dboth, visboth, jointboth, sideboth, intrboth): 161 | if side != vis_side: 162 | continue 163 | 164 | clrpth = os.path.join(clr_root, '%.5d.png' % i) 165 | maskpth = os.path.join(mask_root, '%.5d.png' % i) 166 | self.clr_paths.append(clrpth) 167 | self.mask_paths.append(maskpth) 168 | self.sides.append(side) 169 | 170 | joint = joint[np.newaxis, :, :] 171 | self.joints.append(joint) 172 | 173 | center = handutils.get_annot_center(kp2d) 174 | kp2d = kp2d[np.newaxis, :, :] 175 | self.kp2ds.append(kp2d) 176 | 177 | center = center[np.newaxis, :] 178 | self.centers.append(center) 179 | 180 | mask = Image.open(maskpth).convert("RGB") 181 | mask = np.array(mask)[:, :, 2:] 182 | my_scale = handutils.get_ori_crop_scale(mask, side, kp2d.squeeze(0)) 183 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :] 184 | self.my_scales.append(my_scale) 185 | 186 | intr = intr[np.newaxis, :] 187 | self.intrs.append(intr) 188 | 189 | bar.suffix = ('({n}/{all}), total:{t:}s, eta:{eta:}s').format( 190 | n=i + 1, all=len(rawdatas), t=bar.elapsed_td, eta=bar.eta_td) 191 | bar.next() 192 | 193 | bar.finish() 194 | self.joints = np.concatenate(self.joints, axis=0).astype(np.float32) # (N, 21, 3) 195 | 196 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2) 197 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1) 198 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1) 199 | self.intrs = np.concatenate(self.intrs, axis=0).astype(np.float32) # (N, 3,3) 200 | 201 | if use_cache: 202 | full_info = { 203 | "sides": self.sides, 204 | "clr_paths": self.clr_paths, 205 | "mask_paths": self.mask_paths, 206 | "joints": self.joints, 207 | "kp2ds": self.kp2ds, 208 | "intrs": self.intrs, 209 | "centers": self.centers, 210 | "my_scales": self.my_scales, 211 | } 212 | with open(cache_path, "wb") as fid: 213 | pickle.dump(full_info, fid) 214 | print("Wrote cache for dataset rhd {} to {}".format( 215 | self.data_split, cache_path 216 | )) 217 | return 218 | 219 | def get_sample(self, index): 220 | side = self.sides[index] 221 | """ 'r' in 'left' / 'l' in 'right' """ 222 | flip = True if (side not in self.hand_side) else False 223 | 224 | clr = Image.open(self.clr_paths[index]).convert("RGB") 225 | self._is_valid(clr, index) 226 | mask = Image.open(self.mask_paths[index]).convert("RGB") 227 | self._is_valid(mask, index) 228 | 229 | # prepare jont 230 | joint = self.joints[index].copy() 231 | 232 | # prepare kp2d 233 | kp2d = self.kp2ds[index].copy() 234 | 235 | center = self.centers[index].copy() 236 | # scale = self.scales[index].copy() 237 | 238 | my_scale = self.my_scales[index].copy() 239 | 240 | if flip: 241 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT) 242 | center[0] = clr.size[0] - center[0] # clr.size[0] represents width of image 243 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0] 244 | joint[:, 0] = -joint[:, 0] 245 | 246 | sample = { 247 | 'index': index, 248 | 'clr': clr, 249 | 'kp2d': kp2d, 250 | 'center': center, 251 | 'my_scale': my_scale, 252 | 'joint': joint, 253 | 'intr': self.intrs[index], 254 | } 255 | 256 | if self.visual: 257 | fig = plt.figure(figsize=(20, 20)) 258 | plt.subplot(1, 3, 1) 259 | plt.imshow(clr.copy()) 260 | plt.title('Color') 261 | 262 | plt.subplot(1, 3, 2) 263 | plt.imshow(clr.copy()) 264 | plt.plot(kp2d[:, :1], kp2d[:, 1:], 'ro') 265 | plt.title('Color+2D annotations') 266 | 267 | ax = fig.add_subplot(133, projection='3d') 268 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint') 269 | plt.plot(joint[:5, 0], joint[:5, 1], 270 | joint[:5, 2], 271 | 'r', 272 | label='thumb') 273 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1], 274 | joint[[0, 5, 6, 7, 8, ], 2], 275 | 'b', 276 | label='index') 277 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1], 278 | joint[[0, 9, 10, 11, 12], 2], 279 | 'b', 280 | label='middle') 281 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1], 282 | joint[[0, 13, 14, 15, 16], 2], 283 | 'b', 284 | label='ring') 285 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1], 286 | joint[[0, 17, 18, 19, 20], 2], 287 | 'b', 288 | label='pinky') 289 | # snap convention 290 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb') 291 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'ro', label='index') 292 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'ro', label='middle') 293 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring') 294 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'ro', label='pinky') 295 | # plt.plot(joint [1:, 0], joint [1:, 1], joint [1:, 2], 'o') 296 | 297 | plt.title('3D annotations') 298 | ax.set_xlabel('x') 299 | ax.set_ylabel('y') 300 | ax.set_zlabel('z') 301 | plt.legend() 302 | ax.view_init(-90, -90) 303 | plt.show() 304 | 305 | return sample 306 | 307 | def _apply_mask(self, dep, mask, side): 308 | ''' follow the label rules in RHD datasets ''' 309 | if side is 'l': 310 | valid_mask_id = [i for i in range(2, 18)] 311 | else: 312 | valid_mask_id = [i for i in range(18, 34)] 313 | 314 | mask = np.array(mask)[:, :, 2:] 315 | dep = np.array(dep) 316 | ll = valid_mask_id[0] 317 | uu = valid_mask_id[-1] 318 | mask[mask < ll] = 0 319 | mask[mask > uu] = 0 320 | mask[mask > 0] = 1 321 | if mask.dtype != np.uint8: 322 | mask = mask.astype(np.uint8) 323 | dep = np.multiply(dep, mask) 324 | dep = Image.fromarray(dep, mode="RGB") 325 | return dep 326 | 327 | def __len__(self): 328 | return len(self.clr_paths) 329 | 330 | def __str__(self): 331 | info = "RHD {} set. lenth {}".format( 332 | self.data_split, len(self.clr_paths) 333 | ) 334 | return colored(info, 'yellow', attrs=['bold']) 335 | 336 | def norm_dep_img(self, dep_): 337 | """RHD depthmap to depth image 338 | """ 339 | if isinstance(dep_, PIL.Image.Image): 340 | dep_ = np.array(dep_) 341 | assert (dep_.shape[-1] == 3) # used to be "RGB" 342 | 343 | ''' Converts a RGB-coded depth into float valued depth. ''' 344 | dep = (dep_[:, :, 0] * 2 ** 8 + dep_[:, :, 1]).astype('float32') 345 | dep /= float(2 ** 16 - 1) 346 | dep *= 5.0 ## depth in meter ! 347 | 348 | return dep 349 | 350 | def _is_valid(self, img, index): 351 | valid_data = isinstance(img, (np.ndarray, PIL.Image.Image)) 352 | if not valid_data: 353 | raise Exception("Encountered error processing rhd[{}]".format(index)) 354 | return valid_data 355 | 356 | 357 | def main(): 358 | data_split = 'test' 359 | rhd = RHDDataset( 360 | data_root="/home/chen/datasets/RHD/RHD_published_v2", 361 | data_split=data_split, 362 | hand_side='right', 363 | njoints=21, 364 | use_cache=False, 365 | visual=True 366 | ) 367 | print("len(rhd)=", len(rhd)) 368 | 369 | for i in tqdm(range(len(rhd))): 370 | print("id=", id) 371 | data = rhd.get_sample(i) 372 | 373 | 374 | if __name__ == "__main__": 375 | main() 376 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from manopth import manolayer 4 | from model.detnet import detnet 5 | from utils import func, bone, AIK, smoother 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from utils import vis 9 | from op_pso import PSO 10 | import open3d 11 | 12 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 13 | _mano_root = 'mano/models' 14 | 15 | module = detnet().to(device) 16 | print('load model start') 17 | check_point = torch.load('new_check_point/ckp_detnet_83.pth', map_location=device) 18 | model_state = module.state_dict() 19 | state = {} 20 | for k, v in check_point.items(): 21 | if k in model_state: 22 | state[k] = v 23 | else: 24 | print(k, ' is NOT in current model') 25 | model_state.update(state) 26 | module.load_state_dict(model_state) 27 | print('load model finished') 28 | pose, shape = func.initiate("zero") 29 | pre_useful_bone_len = np.zeros((1, 15)) 30 | pose0 = torch.eye(3).repeat(1, 16, 1, 1) 31 | 32 | mano = manolayer.ManoLayer(flat_hand_mean=True, 33 | side="right", 34 | mano_root=_mano_root, 35 | use_pca=False, 36 | root_rot_mode='rotmat', 37 | joint_rot_mode='rotmat') 38 | print('start opencv') 39 | point_fliter = smoother.OneEuroFilter(4.0, 0.0) 40 | mesh_fliter = smoother.OneEuroFilter(4.0, 0.0) 41 | shape_fliter = smoother.OneEuroFilter(4.0, 0.0) 42 | cap = cv2.VideoCapture(0) 43 | print('opencv finished') 44 | flag = 1 45 | plt.ion() 46 | f = plt.figure() 47 | 48 | fliter_ax = f.add_subplot(111, projection='3d') 49 | plt.show() 50 | view_mat = np.array([[1.0, 0.0, 0.0], 51 | [0.0, -1.0, 0], 52 | [0.0, 0, -1.0]]) 53 | mesh = open3d.geometry.TriangleMesh() 54 | hand_verts, j3d_recon = mano(pose0, shape.float()) 55 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces) 56 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0] 57 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts) 58 | viewer = open3d.visualization.Visualizer() 59 | viewer.create_window(width=480, height=480, window_name='mesh') 60 | viewer.add_geometry(mesh) 61 | viewer.update_renderer() 62 | 63 | print('start pose estimate') 64 | 65 | pre_uv = None 66 | shape_time = 0 67 | opt_shape = None 68 | shape_flag = True 69 | while (cap.isOpened()): 70 | ret_flag, img = cap.read() 71 | input = np.flip(img.copy(), -1) 72 | k = cv2.waitKey(1) & 0xFF 73 | if input.shape[0] > input.shape[1]: 74 | margin = (input.shape[0] - input.shape[1]) // 2 75 | input = input[margin:-margin] 76 | else: 77 | margin = (input.shape[1] - input.shape[0]) // 2 78 | input = input[:, margin:-margin] 79 | img = input.copy() 80 | img = np.flip(img, -1) 81 | cv2.imshow("Capture_Test", img) 82 | input = cv2.resize(input, (128, 128)) 83 | input = torch.tensor(input.transpose([2, 0, 1]), dtype=torch.float, device=device) # hwc -> chw 84 | input = func.normalize(input, [0.5, 0.5, 0.5], [1, 1, 1]) 85 | result = module(input.unsqueeze(0)) 86 | 87 | pre_joints = result['xyz'].squeeze(0) 88 | now_uv = result['uv'].clone().detach().cpu().numpy()[0, 0] 89 | now_uv = now_uv.astype(np.float) 90 | trans = np.zeros((1, 3)) 91 | trans[0, 0:2] = now_uv - 16.0 92 | trans = trans / 16.0 93 | new_tran = np.array([[trans[0, 1], trans[0, 0], trans[0, 2]]]) 94 | pre_joints = pre_joints.clone().detach().cpu().numpy() 95 | 96 | flited_joints = point_fliter.process(pre_joints) 97 | 98 | fliter_ax.cla() 99 | 100 | filted_ax = vis.plot3d(flited_joints + new_tran, fliter_ax) 101 | pre_useful_bone_len = bone.caculate_length(pre_joints, label="useful") 102 | 103 | NGEN = 100 104 | popsize = 100 105 | low = np.zeros((1, 10)) - 3.0 106 | up = np.zeros((1, 10)) + 3.0 107 | parameters = [NGEN, popsize, low, up] 108 | pso = PSO(parameters, pre_useful_bone_len.reshape((1, 15)),_mano_root) 109 | pso.main() 110 | opt_shape = pso.ng_best 111 | opt_shape = shape_fliter.process(opt_shape) 112 | 113 | opt_tensor_shape = torch.tensor(opt_shape, dtype=torch.float) 114 | _, j3d_p0_ops = mano(pose0, opt_tensor_shape) 115 | template = j3d_p0_ops.cpu().numpy().squeeze(0) / 1000.0 # template, m 21*3 116 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(pre_joints[9] - pre_joints[0]) 117 | j3d_pre_process = pre_joints * ratio # template, m 118 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0] 119 | pose_R = AIK.adaptive_IK(template, j3d_pre_process) 120 | pose_R = torch.from_numpy(pose_R).float() 121 | # reconstruction 122 | hand_verts, j3d_recon = mano(pose_R, opt_tensor_shape.float()) 123 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces) 124 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0] 125 | hand_verts = mesh_fliter.process(hand_verts) 126 | hand_verts = np.matmul(view_mat, hand_verts.T).T 127 | hand_verts[:, 0] = hand_verts[:, 0] - 50 128 | hand_verts[:, 1] = hand_verts[:, 1] - 50 129 | mesh_tran = np.array([[-new_tran[0, 0], new_tran[0, 1], new_tran[0, 2]]]) 130 | hand_verts = hand_verts - 100 * mesh_tran 131 | 132 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts) 133 | mesh.paint_uniform_color([228 / 255, 178 / 255, 148 / 255]) 134 | mesh.compute_triangle_normals() 135 | mesh.compute_vertex_normals() 136 | viewer.update_geometry(mesh) 137 | viewer.poll_events() 138 | if k == ord('q'): 139 | break 140 | cap.release() 141 | cv2.destroyAllWindows() 142 | -------------------------------------------------------------------------------- /demo_dl.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from manopth import manolayer 4 | from model.detnet import detnet 5 | from utils import func, bone, AIK, smoother 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from utils import vis 9 | from op_pso import PSO 10 | import open3d 11 | from model import shape_net 12 | import os 13 | 14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | _mano_root = 'mano/models' 16 | 17 | module = detnet().to(device) 18 | print('load model start') 19 | check_point = torch.load('new_check_point/ckp_detnet_83.pth', map_location=device) 20 | model_state = module.state_dict() 21 | state = {} 22 | for k, v in check_point.items(): 23 | if k in model_state: 24 | state[k] = v 25 | else: 26 | print(k, ' is NOT in current model') 27 | model_state.update(state) 28 | module.load_state_dict(model_state) 29 | print('load model finished') 30 | 31 | shape_model = shape_net.ShapeNet() 32 | shape_net.load_checkpoint( 33 | shape_model, os.path.join('checkpoints', 'ckp_siknet_synth_41.pth.tar') 34 | ) 35 | for params in shape_model.parameters(): 36 | params.requires_grad = False 37 | 38 | pose, shape = func.initiate("zero") 39 | pre_useful_bone_len = np.zeros((1, 15)) 40 | pose0 = torch.eye(3).repeat(1, 16, 1, 1) 41 | 42 | mano = manolayer.ManoLayer(flat_hand_mean=True, 43 | side="right", 44 | mano_root=_mano_root, 45 | use_pca=False, 46 | root_rot_mode='rotmat', 47 | joint_rot_mode='rotmat') 48 | print('start opencv') 49 | point_fliter = smoother.OneEuroFilter(4.0, 0.0) 50 | mesh_fliter = smoother.OneEuroFilter(4.0, 0.0) 51 | shape_fliter = smoother.OneEuroFilter(4.0, 0.0) 52 | cap = cv2.VideoCapture(0) 53 | print('opencv finished') 54 | flag = 1 55 | plt.ion() 56 | f = plt.figure() 57 | 58 | fliter_ax = f.add_subplot(111, projection='3d') 59 | plt.show() 60 | view_mat = np.array([[1.0, 0.0, 0.0], 61 | [0.0, -1.0, 0], 62 | [0.0, 0, -1.0]]) 63 | mesh = open3d.geometry.TriangleMesh() 64 | hand_verts, j3d_recon = mano(pose0, shape.float()) 65 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces) 66 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0] 67 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts) 68 | viewer = open3d.visualization.Visualizer() 69 | viewer.create_window(width=480, height=480, window_name='mesh') 70 | viewer.add_geometry(mesh) 71 | viewer.update_renderer() 72 | 73 | print('start pose estimate') 74 | 75 | pre_uv = None 76 | shape_time = 0 77 | opt_shape = None 78 | shape_flag = True 79 | while (cap.isOpened()): 80 | ret_flag, img = cap.read() 81 | input = np.flip(img.copy(), -1) 82 | k = cv2.waitKey(1) & 0xFF 83 | if input.shape[0] > input.shape[1]: 84 | margin = (input.shape[0] - input.shape[1]) // 2 85 | input = input[margin:-margin] 86 | else: 87 | margin = (input.shape[1] - input.shape[0]) // 2 88 | input = input[:, margin:-margin] 89 | img = input.copy() 90 | img = np.flip(img, -1) 91 | cv2.imshow("Capture_Test", img) 92 | input = cv2.resize(input, (128, 128)) 93 | input = torch.tensor(input.transpose([2, 0, 1]), dtype=torch.float, device=device) # hwc -> chw 94 | input = func.normalize(input, [0.5, 0.5, 0.5], [1, 1, 1]) 95 | result = module(input.unsqueeze(0)) 96 | 97 | pre_joints = result['xyz'].squeeze(0) 98 | now_uv = result['uv'].clone().detach().cpu().numpy()[0, 0] 99 | now_uv = now_uv.astype(np.float) 100 | trans = np.zeros((1, 3)) 101 | trans[0, 0:2] = now_uv - 16.0 102 | trans = trans / 16.0 103 | new_tran = np.array([[trans[0, 1], trans[0, 0], trans[0, 2]]]) 104 | pre_joints = pre_joints.clone().detach().cpu().numpy() 105 | 106 | flited_joints = point_fliter.process(pre_joints) 107 | 108 | fliter_ax.cla() 109 | 110 | filted_ax = vis.plot3d(flited_joints + new_tran, fliter_ax) 111 | pre_useful_bone_len = bone.caculate_length(pre_joints, label="useful") 112 | 113 | shape_model_input = torch.tensor(pre_useful_bone_len, dtype=torch.float) 114 | shape_model_input = shape_model_input.reshape((1, 15)) 115 | dl_shape = shape_model(shape_model_input) 116 | dl_shape = dl_shape['beta'].numpy() 117 | dl_shape = shape_fliter.process(dl_shape) 118 | opt_tensor_shape = torch.tensor(dl_shape, dtype=torch.float) 119 | _, j3d_p0_ops = mano(pose0, opt_tensor_shape) 120 | template = j3d_p0_ops.cpu().numpy().squeeze(0) / 1000.0 # template, m 21*3 121 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(pre_joints[9] - pre_joints[0]) 122 | j3d_pre_process = pre_joints * ratio # template, m 123 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0] 124 | pose_R = AIK.adaptive_IK(template, j3d_pre_process) 125 | pose_R = torch.from_numpy(pose_R).float() 126 | # reconstruction 127 | hand_verts, j3d_recon = mano(pose_R, opt_tensor_shape.float()) 128 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces) 129 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0] 130 | hand_verts = mesh_fliter.process(hand_verts) 131 | hand_verts = np.matmul(view_mat, hand_verts.T).T 132 | hand_verts[:, 0] = hand_verts[:, 0] - 50 133 | hand_verts[:, 1] = hand_verts[:, 1] - 50 134 | mesh_tran = np.array([[-new_tran[0, 0], new_tran[0, 1], new_tran[0, 2]]]) 135 | hand_verts = hand_verts - 100 * mesh_tran 136 | 137 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts) 138 | mesh.paint_uniform_color([228 / 255, 178 / 255, 148 / 255]) 139 | mesh.compute_triangle_normals() 140 | mesh.compute_vertex_normals() 141 | viewer.update_geometry(mesh) 142 | viewer.poll_events() 143 | if k == ord('q'): 144 | break 145 | cap.release() 146 | cv2.destroyAllWindows() 147 | -------------------------------------------------------------------------------- /dl_shape_estimate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import create_data 6 | from model import shape_net 7 | 8 | import numpy as np 9 | 10 | 11 | 12 | def align_bone_len(opt_, pre_): 13 | opt = opt_.copy() 14 | pre = pre_.copy() 15 | 16 | opt_align = opt.copy() 17 | for i in range(opt.shape[0]): 18 | ratio = pre[i][6] / opt[i][6] 19 | opt_align[i] = ratio * opt_align[i] 20 | 21 | err = np.abs(opt_align - pre).mean(0) 22 | 23 | return err 24 | 25 | def fun(_shape, _label, data_loader): 26 | # 计算相对骨骼长度 27 | shape = _shape.clone().detach() 28 | label = _label.detach().clone() 29 | # 根据shape计算相对骨骼长度 30 | X = data_loader.new_cal_ref_bone(shape) 31 | err = align_bone_len(X.cpu().numpy(), label.cpu().numpy()) 32 | return err.sum() 33 | 34 | checkpoint = 'checkpoints' 35 | 36 | model = shape_net.ShapeNet() 37 | shape_net.load_checkpoint( 38 | model, os.path.join(checkpoint, 'ckp_siknet_synth_41.pth.tar') 39 | ) 40 | for params in model.parameters(): 41 | params.requires_grad = False 42 | 43 | data_set = ['rhd', 'stb', 'do', 'eo'] 44 | temp_data = create_data.DataSet(_mano_root='mano/models') 45 | for data in data_set: 46 | print('*' * 20) 47 | print('加载' + data + '数据集') 48 | print('*' * 20) 49 | # 加载预测 50 | pre_path = os.path.join('out_testset/', data + '_pre_joints.npy') 51 | temp = np.load(pre_path) 52 | temp = torch.Tensor(temp) 53 | _x = temp_data.cal_ref_bone(temp) 54 | # 模型回归shape 55 | Y = model(_x) 56 | Y = Y['beta'] 57 | np.save('out_testset/' + data + '_dl.npy', Y.clone().detach().cpu().numpy()) 58 | dl_err = fun(Y, _x, temp_data) 59 | print('回归误差:{}'.format(dl_err)) 60 | 61 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: minimal-hand-torch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2021.1.19=h06a4308_0 10 | - cairo=1.14.12=h8948797_3 11 | - certifi=2020.12.5=py37h06a4308_0 12 | - cffi=1.14.5=py37h261ae71_0 13 | - cudatoolkit=10.0.130=0 14 | - cycler=0.10.0=py37_0 15 | - dbus=1.13.18=hb2f20db_0 16 | - expat=2.2.10=he6710b0_2 17 | - ffmpeg=4.0=hcdf2ecd_0 18 | - fontconfig=2.13.1=h6c09931_0 19 | - freeglut=3.0.0=hf484d3e_5 20 | - freetype=2.10.4=h5ab3b9f_0 21 | - glib=2.67.4=h36276a3_1 22 | - graphite2=1.3.14=h23475e2_0 23 | - gst-plugins-base=1.14.0=h8213a91_2 24 | - gstreamer=1.14.0=h28cd5cc_2 25 | - harfbuzz=1.8.8=hffaf4a1_0 26 | - hdf5=1.10.2=hba1933b_1 27 | - icu=58.2=he6710b0_3 28 | - intel-openmp=2020.2=254 29 | - jasper=2.0.14=h07fcdf6_1 30 | - jpeg=9b=h024ee3a_2 31 | - kiwisolver=1.3.1=py37h2531618_0 32 | - lcms2=2.11=h396b838_0 33 | - ld_impl_linux-64=2.33.1=h53a641e_7 34 | - libedit=3.1.20191231=h14c3975_1 35 | - libffi=3.3=he6710b0_2 36 | - libgcc-ng=9.1.0=hdf63c60_0 37 | - libgfortran-ng=7.3.0=hdf63c60_0 38 | - libglu=9.0.0=hf484d3e_1 39 | - libopencv=3.4.2=hb342d67_1 40 | - libopus=1.3.1=h7b6447c_0 41 | - libpng=1.6.37=hbc83047_0 42 | - libstdcxx-ng=9.1.0=hdf63c60_0 43 | - libtiff=4.1.0=h2733197_1 44 | - libuuid=1.0.3=h1bed415_2 45 | - libuv=1.40.0=h7b6447c_0 46 | - libvpx=1.7.0=h439df22_0 47 | - libxcb=1.14=h7b6447c_0 48 | - libxml2=2.9.10=hb55368b_3 49 | - lz4-c=1.9.3=h2531618_0 50 | - matplotlib=3.3.4=py37h06a4308_0 51 | - matplotlib-base=3.3.4=py37h62a2d02_0 52 | - mkl=2020.2=256 53 | - mkl-service=2.3.0=py37he8ac12f_0 54 | - mkl_fft=1.3.0=py37h54f3939_0 55 | - mkl_random=1.1.1=py37h0573a6f_0 56 | - ncurses=6.2=he6710b0_1 57 | - ninja=1.10.2=py37hff7bd54_0 58 | - numpy=1.19.2=py37h54aff64_0 59 | - numpy-base=1.19.2=py37hfa32c7d_0 60 | - olefile=0.46=py_0 61 | - opencv=3.4.2=py37h6fd60c2_1 62 | - openssl=1.1.1j=h27cfd23_0 63 | - pandas=1.2.2=py37ha9443f7_0 64 | - pcre=8.44=he6710b0_0 65 | - pillow=8.1.0=py37he98fc37_0 66 | - pip=21.0.1=py37h06a4308_0 67 | - pixman=0.40.0=h7b6447c_0 68 | - py-opencv=3.4.2=py37hb342d67_1 69 | - pycparser=2.20=py_2 70 | - pyparsing=2.4.7=pyhd3eb1b0_0 71 | - pyqt=5.9.2=py37h05f1152_2 72 | - python=3.7.10=hdb3f193_0 73 | - python-dateutil=2.8.1=pyhd3eb1b0_0 74 | - pytz=2021.1=pyhd3eb1b0_0 75 | - qt=5.9.7=h5867ecd_1 76 | - readline=8.1=h27cfd23_0 77 | - scipy=1.6.1=py37h91f5cce_0 78 | - setuptools=52.0.0=py37h06a4308_0 79 | - sip=4.19.8=py37hf484d3e_0 80 | - six=1.15.0=pyhd3eb1b0_0 81 | - sqlite=3.33.0=h62c20be_0 82 | - termcolor=1.1.0=py37h06a4308_1 83 | - tk=8.6.10=hbc83047_0 84 | - tornado=6.1=py37h27cfd23_0 85 | - tqdm=4.56.0=pyhd3eb1b0_0 86 | - typing_extensions=3.7.4.3=pyha847dfd_0 87 | - wheel=0.36.2=pyhd3eb1b0_0 88 | - xz=5.2.5=h7b6447c_0 89 | - zlib=1.2.11=h7b6447c_3 90 | - zstd=1.4.5=h9ceee32_0 91 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 92 | - torchvision=0.4.0=py37_cu100 93 | - pip: 94 | - art==3.7 95 | - chumpy==0.70 96 | - coverage==4.5.3 97 | - einops==0.3.0 98 | - manopth==0.0.1 99 | - mxnet==1.6.0 100 | - progress==1.5 101 | - torch==1.2.0 102 | - transforms3d==0.3.1 103 | - git+https://github.com/hassony2/manopth.git 104 | 105 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .detloss import * 2 | -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/detloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/losses/__pycache__/detloss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/detloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as torch_f 3 | 4 | 5 | class DetLoss: 6 | def __init__( 7 | self, 8 | lambda_hm=100, 9 | lambda_dm=1.0, 10 | lambda_lm=1.0, 11 | 12 | ): 13 | self.lambda_hm = lambda_hm 14 | self.lambda_dm = lambda_dm 15 | self.lambda_lm = lambda_lm 16 | 17 | def compute_loss(self, preds, targs, infos): 18 | 19 | hm_veil = infos['hm_veil'] 20 | batch_size = infos['batch_size'] 21 | flag = targs['flag_3d'] 22 | batch_3d_size = flag.sum() 23 | 24 | flag = flag.bool() 25 | 26 | final_loss = torch.Tensor([0]).cuda() 27 | det_losses = {} 28 | 29 | pred_hm = preds['h_map'] 30 | pred_dm = preds['d_map'][flag] 31 | pred_lm = preds['l_map'][flag] 32 | 33 | targ_hm = targs['hm'] # B*21*32*32 34 | 35 | targ_hm_tile = \ 36 | targ_hm.unsqueeze(2).expand(targ_hm.size(0), targ_hm.size(1), 3, targ_hm.size(2), targ_hm.size(3), 37 | )[flag] # B'*21*3*32*32 38 | targ_dm = targs['dm'][flag] 39 | targ_lm = targs['lm'][flag] 40 | 41 | # compute hmloss anyway 42 | hm_loss = torch.Tensor([0]).cuda() 43 | if self.lambda_hm: 44 | hm_veil = hm_veil.unsqueeze(-1) 45 | njoints = pred_hm.size(1) 46 | pred_hm = pred_hm.reshape((batch_size, njoints, -1)).split(1, 1) 47 | targ_hm = targ_hm.reshape((batch_size, njoints, -1)).split(1, 1) 48 | for idx in range(njoints): 49 | pred_hmapi = pred_hm[idx].squeeze() # (B, 1, 1024)->(B, 1024) 50 | targ_hmi = targ_hm[idx].squeeze() 51 | hm_loss += 0.5 * torch_f.mse_loss( 52 | pred_hmapi.mul(hm_veil[:, idx]), # (B, 1024) mul (B, 1) 53 | targ_hmi.mul(hm_veil[:, idx]) 54 | ) # mse calculate the loss of every sample (in fact it calculate minbacth_loss/32*32 ) 55 | final_loss += self.lambda_hm * hm_loss 56 | det_losses["det_hm"] = hm_loss 57 | 58 | # compute dm loss 59 | loss_dm = torch.Tensor([0]).cuda() 60 | if self.lambda_dm: 61 | loss_dm = torch.norm( 62 | (pred_dm - targ_dm) * targ_hm_tile) / batch_3d_size # loss of every sample 63 | final_loss += self.lambda_dm * loss_dm 64 | det_losses["det_dm"] = loss_dm 65 | 66 | # compute lm loss 67 | loss_lm = torch.Tensor([0]).cuda() 68 | if self.lambda_lm: 69 | loss_lm = torch.norm( 70 | (pred_lm - targ_lm) * targ_hm_tile) / batch_3d_size # loss of every sample 71 | final_loss += self.lambda_lm * loss_lm 72 | det_losses["det_lm"] = loss_lm 73 | 74 | det_losses["det_total"] = final_loss 75 | 76 | return final_loss, det_losses, batch_3d_size 77 | -------------------------------------------------------------------------------- /losses/shape_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as torch_f 5 | 6 | 7 | 8 | class SIKLoss: 9 | def __init__( 10 | self, 11 | lambda_joint=1.0, 12 | lambda_shape=1.0 13 | ): 14 | self.lambda_joint = lambda_joint 15 | self.lambda_shape = lambda_shape 16 | 17 | def compute_loss(self, preds, targs): 18 | batch_size = targs['batch_size'] 19 | final_loss = torch.Tensor([0]).cuda() 20 | invk_losses = {} 21 | 22 | if self.lambda_joint: 23 | joint_loss = torch_f.mse_loss( 24 | 1000 * preds['jointRS'] * targs['joint_bone'].unsqueeze(1), 25 | 1000 * targs['jointRS'] * targs['joint_bone'].unsqueeze(1) 26 | ) 27 | final_loss += self.lambda_joint * joint_loss 28 | else: 29 | joint_loss = None 30 | invk_losses["joint"] = joint_loss 31 | 32 | if self.lambda_shape: 33 | # shape_reg_loss = 10.0 * torch_f.mse_loss( 34 | # preds["beta"], 35 | # torch.zeros_like(preds["beta"]) 36 | # ) 37 | shape_reg_loss = torch.norm(preds['beta'], dim=-1, keepdim=True) 38 | shape_reg_loss = torch.pow(shape_reg_loss, 2.0) 39 | shape_reg_loss = torch.mean(shape_reg_loss) 40 | 41 | pred_rel_len = preds['bone_len_hat'] 42 | 43 | # kin_len_loss = torch_f.mse_loss( 44 | # pred_rel_len.reshape(batch_size, -1), 45 | # targs['rel_bone_len'].reshape(batch_size, -1) 46 | # ) 47 | kin_len_loss = torch.norm(pred_rel_len - 48 | targs['rel_bone_len'].reshape(batch_size, -1), 49 | dim=-1, keepdim=True) 50 | kin_len_loss = torch.pow(kin_len_loss, 2.0) 51 | kin_len_loss = torch.mean(kin_len_loss) 52 | shape_total_loss = kin_len_loss + 1e-3 * shape_reg_loss 53 | final_loss += self.lambda_shape * shape_total_loss 54 | else: 55 | shape_reg_loss, kin_len_loss = None, None 56 | invk_losses['shape_reg'] = shape_reg_loss 57 | invk_losses['bone_len'] = kin_len_loss 58 | 59 | return final_loss, invk_losses 60 | -------------------------------------------------------------------------------- /manopth/rotproj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_rotprojs(batches_rotmats): 5 | proj_rotmats = [] 6 | for batch_idx, batch_rotmats in enumerate(batches_rotmats): 7 | proj_batch_rotmats = [] 8 | for rot_idx, rotmat in enumerate(batch_rotmats): 9 | # GPU implementation of svd is VERY slow 10 | # ~ 2 10^-3 per hit vs 5 10^-5 on cpu 11 | _device = rotmat.device 12 | U, S, V = rotmat.cpu().svd() 13 | rotmat = torch.matmul(U, V.transpose(0, 1)) 14 | orth_det = rotmat.det() 15 | # Remove reflection 16 | if orth_det < 0: 17 | rotmat[:, 2] = -1 * rotmat[:, 2] 18 | 19 | rotmat = rotmat.to(_device) 20 | proj_batch_rotmats.append(rotmat) 21 | proj_rotmats.append(torch.stack(proj_batch_rotmats)) 22 | return torch.stack(proj_rotmats) 23 | -------------------------------------------------------------------------------- /model/detnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .detnet import detnet 2 | __all__ =['detnet'] -------------------------------------------------------------------------------- /model/detnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/detnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/detnet/__pycache__/detnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/detnet/__pycache__/detnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/detnet/detnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | detnet based on PyTorch 3 | this is modified from https://github.com/lingtengqiu/Minimal-Hand 4 | ''' 5 | import sys 6 | 7 | import torch 8 | 9 | sys.path.append("./") 10 | from torch import nn 11 | from einops import rearrange, repeat 12 | from model.helper import resnet50, conv3x3 13 | import numpy as np 14 | 15 | 16 | # my modification 17 | def get_pose_tile_torch(N): 18 | pos_tile = np.expand_dims( 19 | np.stack( 20 | [ 21 | np.tile(np.linspace(-1, 1, 32).reshape([1, 32]), [32, 1]), 22 | np.tile(np.linspace(-1, 1, 32).reshape([32, 1]), [1, 32]) 23 | ], -1 24 | ), 0 25 | ) 26 | pos_tile = np.tile(pos_tile, (N, 1, 1, 1)) 27 | retv = torch.from_numpy(pos_tile).float() 28 | return rearrange(retv, 'b h w c -> b c h w') 29 | 30 | 31 | class net_2d(nn.Module): 32 | def __init__(self, input_features, output_features, stride, joints=21): 33 | super().__init__() 34 | self.project = nn.Sequential(conv3x3(input_features, output_features, stride), nn.BatchNorm2d(output_features), 35 | nn.ReLU()) 36 | 37 | self.prediction = nn.Conv2d(output_features, joints, 1, 1, 0) 38 | 39 | def forward(self, x): 40 | x = self.project(x) 41 | x = self.prediction(x).sigmoid() 42 | return x 43 | 44 | 45 | class net_3d(nn.Module): 46 | def __init__(self, input_features, output_features, stride, joints=21, need_norm=False): 47 | super().__init__() 48 | self.need_norm = need_norm 49 | self.project = nn.Sequential(conv3x3(input_features, output_features, stride), nn.BatchNorm2d(output_features), 50 | nn.ReLU()) 51 | self.prediction = nn.Conv2d(output_features, joints * 3, 1, 1, 0) 52 | 53 | def forward(self, x): 54 | x = self.prediction(self.project(x)) 55 | 56 | dmap = rearrange(x, 'b (j l) h w -> b j l h w', l=3) 57 | 58 | return dmap 59 | 60 | 61 | class detnet(nn.Module): 62 | def __init__(self, stacks=1): 63 | super().__init__() 64 | self.resnet50 = resnet50() 65 | 66 | self.hmap_0 = net_2d(258, 256, 1) 67 | self.dmap_0 = net_3d(279, 256, 1) 68 | self.lmap_0 = net_3d(342, 256, 1) 69 | self.stacks = stacks 70 | 71 | def forward(self, x): 72 | features = self.resnet50(x) 73 | 74 | device = x.device 75 | pos_tile = get_pose_tile_torch(features.shape[0]).to(device) 76 | 77 | x = torch.cat([features, pos_tile], dim=1) 78 | 79 | hmaps = [] 80 | dmaps = [] 81 | lmaps = [] 82 | 83 | for _ in range(self.stacks): 84 | heat_map = self.hmap_0(x) 85 | hmaps.append(heat_map) 86 | x = torch.cat([x, heat_map], dim=1) 87 | 88 | dmap = self.dmap_0(x) 89 | dmaps.append(dmap) 90 | 91 | x = torch.cat([x, rearrange(dmap, 'b j l h w -> b (j l) h w')], dim=1) 92 | 93 | lmap = self.lmap_0(x) 94 | lmaps.append(lmap) 95 | hmap, dmap, lmap = hmaps[-1], dmaps[-1], lmaps[-1] 96 | 97 | uv, argmax = self.map_to_uv(hmap) 98 | 99 | delta = self.dmap_to_delta(dmap, argmax) 100 | xyz = self.lmap_to_xyz(lmap, argmax) 101 | 102 | det_result = { 103 | "h_map": hmap, 104 | "d_map": dmap, 105 | "l_map": lmap, 106 | "delta": delta, 107 | "xyz": xyz, 108 | "uv": uv 109 | } 110 | 111 | return det_result 112 | 113 | @property 114 | def pos(self): 115 | return self.__pos_tile 116 | 117 | @staticmethod 118 | def map_to_uv(hmap): 119 | b, j, h, w = hmap.shape 120 | hmap = rearrange(hmap, 'b j h w -> b j (h w)') 121 | argmax = torch.argmax(hmap, -1, keepdim=True) 122 | u = argmax // w 123 | v = argmax % w 124 | uv = torch.cat([u, v], dim=-1) 125 | 126 | return uv, argmax 127 | 128 | @staticmethod 129 | def dmap_to_delta(dmap, argmax): 130 | return detnet.lmap_to_xyz(dmap, argmax) 131 | 132 | @staticmethod 133 | def lmap_to_xyz(lmap, argmax): 134 | lmap = rearrange(lmap, 'b j l h w -> b j (h w) l') 135 | index = repeat(argmax, 'b j i -> b j i c', c=3) 136 | xyz = torch.gather(lmap, dim=2, index=index).squeeze(2) 137 | return xyz 138 | 139 | 140 | if __name__ == '__main__': 141 | mydet = detnet() 142 | img_crop = torch.randn(10, 3, 128, 128) 143 | res = mydet(img_crop) 144 | 145 | hmap = res["h_map"] 146 | dmap = res["d_map"] 147 | lmap = res["l_map"] 148 | delta = res["delta"] 149 | xyz = res["xyz"] 150 | uv = res["uv"] 151 | 152 | print("hmap.shape=", hmap.shape) 153 | print("dmap.shape=", dmap.shape) 154 | print("lmap.shape=", lmap.shape) 155 | print("delta.shape=", delta.shape) 156 | print("xyz.shape=", xyz.shape) 157 | print("uv.shape=", uv.shape) 158 | -------------------------------------------------------------------------------- /model/helper/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_helper import * 2 | __all__=['resnet50','conv3x3','conv1x1'] -------------------------------------------------------------------------------- /model/helper/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/helper/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/helper/__pycache__/resnet_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/helper/__pycache__/resnet_helper.cpython-37.pyc -------------------------------------------------------------------------------- /model/shape_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Lixin YANG. All Rights Reserved. 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as torch_f 9 | from manopth.manolayer import ManoLayer 10 | 11 | 12 | class ShapeNet(nn.Module): 13 | def __init__( 14 | self, 15 | dropout=0, 16 | _mano_root='mano/models' 17 | ): 18 | super(ShapeNet, self).__init__() 19 | 20 | ''' shape ''' 21 | hidden_neurons = [128, 256, 512, 256, 128] 22 | in_neurons = 15 23 | out_neurons = 10 24 | neurons = [in_neurons] + hidden_neurons 25 | 26 | shapereg_layers = [] 27 | for layer_idx, (inps, outs) in enumerate( 28 | zip(neurons[:-1], neurons[1:]) 29 | ): 30 | if dropout: 31 | shapereg_layers.append(nn.Dropout(p=dropout)) 32 | shapereg_layers.append(nn.Linear(inps, outs)) 33 | shapereg_layers.append(nn.ReLU()) 34 | 35 | shapereg_layers.append(nn.Linear(neurons[-1], out_neurons)) 36 | self.shapereg_layers = nn.Sequential(*shapereg_layers) 37 | args = {'flat_hand_mean': True, 'root_rot_mode': 'axisang', 38 | 'ncomps': 45, 'mano_root': _mano_root, 39 | 'no_pca': True, 'joint_rot_mode': 'axisang', 'side': 'right'} 40 | self.mano_layer = ManoLayer(flat_hand_mean=args['flat_hand_mean'], 41 | side=args['side'], 42 | mano_root=args['mano_root'], 43 | ncomps=args['ncomps'], 44 | use_pca=not args['no_pca'], 45 | root_rot_mode=args['root_rot_mode'], 46 | joint_rot_mode=args['joint_rot_mode'] 47 | ) 48 | 49 | def new_cal_ref_bone(self, _shape): 50 | parent_index = [0, 51 | 0, 1, 2, 52 | 0, 4, 5, 53 | 0, 7, 8, 54 | 0, 10, 11, 55 | 0, 13, 14 56 | ] 57 | index = [0, 58 | 1, 2, 3, # index 59 | 4, 5, 6, # middle 60 | 7, 8, 9, # pinky 61 | 10, 11, 12, # ring 62 | 13, 14, 15] # thumb 63 | reoder_index = [ 64 | 13, 14, 15, 65 | 1, 2, 3, 66 | 4, 5, 6, 67 | 10, 11, 12, 68 | 7, 8, 9] 69 | shape = _shape 70 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs, 71 | shape.transpose(1, 0)).permute(2, 0, 1) \ 72 | + self.mano_layer.th_v_template 73 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped) 74 | temp1 = th_j 75 | temp2 = th_j[:, parent_index, :] 76 | result = temp1 - temp2 77 | ref_len = th_j[:, [4], :] - th_j[:, [0], :] 78 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True) 79 | result = torch.norm(result, dim=-1, keepdim=True) 80 | result = result / ref_len 81 | return torch.squeeze(result, dim=-1)[:, reoder_index] 82 | 83 | def forward(self, bone_len): 84 | beta = self.shapereg_layers(bone_len) 85 | beta = torch.tanh(beta) 86 | bone_len_hat = self.new_cal_ref_bone(beta) 87 | 88 | results = { 89 | 'beta': beta, 90 | 'bone_len_hat': bone_len_hat 91 | } 92 | return results 93 | 94 | def save_checkpoint( 95 | state, 96 | checkpoint='checkpoint', 97 | filename='checkpoint.pth.tar', 98 | snapshot=None, 99 | is_best=False 100 | ): 101 | # preds = to_numpy(preds) 102 | filepath = os.path.join(checkpoint, filename) 103 | fileprefix = filename.split('.')[0] 104 | torch.save(state, filepath) 105 | 106 | if snapshot and state['epoch'] % snapshot == 0: 107 | shutil.copyfile( 108 | filepath, 109 | os.path.join( 110 | checkpoint, 111 | '{}_{}.pth.tar'.format(fileprefix, state['epoch']) 112 | ) 113 | ) 114 | 115 | if is_best: 116 | shutil.copyfile( 117 | filepath, 118 | os.path.join( 119 | checkpoint, 120 | '{}_best.pth.tar'.format(fileprefix) 121 | ) 122 | ) 123 | 124 | def load_checkpoint(model, checkpoint): 125 | name = checkpoint 126 | checkpoint = torch.load(name) 127 | pretrain_dict = clean_state_dict(checkpoint['state_dict']) 128 | model_state = model.state_dict() 129 | state = {} 130 | for k, v in pretrain_dict.items(): 131 | if k in model_state: 132 | state[k] = v 133 | else: 134 | print(k, ' is NOT in current model') 135 | model_state.update(state) 136 | model.load_state_dict(model_state) 137 | 138 | def clean_state_dict(state_dict): 139 | """save a cleaned version of model without dict and DataParallel 140 | 141 | Arguments: 142 | state_dict {collections.OrderedDict} -- [description] 143 | 144 | Returns: 145 | clean_model {collections.OrderedDict} -- [description] 146 | """ 147 | 148 | clean_model = state_dict 149 | # create new OrderedDict that does not contain `module.` 150 | from collections import OrderedDict 151 | clean_model = OrderedDict() 152 | if any(key.startswith('module') for key in state_dict): 153 | for k, v in state_dict.items(): 154 | name = k[7:] # remove `module.` 155 | clean_model[name] = v 156 | else: 157 | return state_dict 158 | 159 | return clean_model 160 | 161 | if __name__ == '__main__': 162 | input = torch.rand((10, 15)) 163 | model = ShapeNet() 164 | out_put = model(input) 165 | loss = torch.mean(out_put) 166 | loss.backward() 167 | -------------------------------------------------------------------------------- /op_pso.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from manopth.manolayer import ManoLayer 6 | from tqdm import tqdm 7 | 8 | from optimize_shape import align_bone_len 9 | from utils import bone 10 | import matplotlib.pyplot as plt 11 | 12 | from utils.LM_new import LM_Solver 13 | 14 | 15 | class PSO: 16 | def __init__(self, parameters, target, _mano_root='mano/models'): 17 | """ 18 | particle swarm optimization 19 | parameter: a list type, like [NGEN, pop_size, var_num_min, var_num_max] 20 | """ 21 | self.mano_layer = ManoLayer(side="right", 22 | mano_root=_mano_root, use_pca=False, flat_hand_mean=True) 23 | 24 | # 初始化 25 | self.NGEN = parameters[0] 26 | self.pop_size = parameters[1] 27 | self.var_num = parameters[2].shape[1] 28 | self.bound = [] 29 | self.bound.append(parameters[2]) 30 | self.bound.append(parameters[3]) 31 | self.set_target(target) 32 | 33 | def set_target(self, target): 34 | self.target = target.copy() 35 | self.pop_x = np.random.randn(self.pop_size, self.var_num) 36 | self.pop_v = np.random.random((self.pop_size, self.var_num)) 37 | self.p_best = self.pop_x.copy() 38 | self.p_best_fit = self.batch_new_get_loss(self.pop_x) 39 | g_best_index = np.argmin(self.p_best_fit, axis=0) 40 | if g_best_index.shape[0] > 1: 41 | g_best_index = g_best_index[[0]] 42 | self.g_best = self.p_best[g_best_index].copy() 43 | 44 | def new_cal_ref_bone(self, _shape): 45 | parent_index = [0, 46 | 0, 1, 2, 47 | 0, 4, 5, 48 | 0, 7, 8, 49 | 0, 10, 11, 50 | 0, 13, 14 51 | ] 52 | index = [0, 53 | 1, 2, 3, # index 54 | 4, 5, 6, # middle 55 | 7, 8, 9, # pinky 56 | 10, 11, 12, # ring 57 | 13, 14, 15] # thumb 58 | reoder_index = [ 59 | 13, 14, 15, 60 | 1, 2, 3, 61 | 4, 5, 6, 62 | 10, 11, 12, 63 | 7, 8, 9] 64 | shape = torch.Tensor(_shape.reshape((-1, 10))) 65 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs, 66 | shape.transpose(1, 0)).permute(2, 0, 1) \ 67 | + self.mano_layer.th_v_template 68 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped) 69 | temp1 = th_j.clone().detach() 70 | temp2 = th_j.clone().detach()[:, parent_index, :] 71 | result = temp1 - temp2 72 | result = torch.norm(result, dim=-1, keepdim=True) 73 | ref_len = result[:, [4]] 74 | result = result / ref_len 75 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy() 76 | 77 | def batch_new_get_loss(self, beta_): 78 | weight = 1e-3 79 | beta = beta_.copy() 80 | temp = self.new_cal_ref_bone(beta) 81 | loss = np.linalg.norm(temp - self.target, axis=-1, keepdims=True) ** 2 + \ 82 | weight * np.linalg.norm(beta, axis=-1, keepdims=True) 83 | return loss 84 | 85 | def update_operator(self, pop_size): 86 | 87 | c1 = 2 88 | c2 = 2 89 | w = 0.4 90 | 91 | self.pop_v = w * self.pop_v \ 92 | + c1 * np.multiply(np.random.rand(pop_size, 1), (self.p_best - self.pop_x)) \ 93 | + c2 * np.multiply(np.random.rand(pop_size, 1), (self.g_best - self.pop_x)) 94 | self.pop_x = self.pop_x + self.pop_v 95 | low_flag = self.pop_x < self.bound[0] 96 | up_flag = self.pop_x > self.bound[1] 97 | self.pop_x[low_flag] = -3.0 98 | self.pop_x[up_flag] = 3.0 99 | temp = self.batch_new_get_loss(self.pop_x) 100 | p_best_flag = temp < self.p_best_fit 101 | p_best_flag = p_best_flag.reshape((pop_size,)) 102 | self.p_best[p_best_flag] = self.pop_x[p_best_flag] 103 | self.p_best_fit[p_best_flag] = temp[p_best_flag] 104 | g_best_index = np.argmin(self.p_best_fit, axis=0) 105 | if g_best_index.shape[0] > 1: 106 | g_best_index = g_best_index[[0]] 107 | self.g_best = self.pop_x[g_best_index] 108 | self.g_best_fit = self.p_best_fit[g_best_index][0][0] 109 | 110 | def main(self, slover=None, return_err=False): 111 | best_fit = [] 112 | self.ng_best = np.zeros((1, self.var_num)) 113 | self.ng_best_fit = self.batch_new_get_loss(self.ng_best)[0][0] 114 | for gen in range(self.NGEN): 115 | self.update_operator(self.pop_size) 116 | # print('############ Generation {} ############'.format(str(gen + 1))) 117 | dis = self.g_best_fit - self.ng_best_fit 118 | if self.g_best_fit < self.ng_best_fit: 119 | self.ng_best = self.g_best.copy() 120 | self.ng_best_fit = self.g_best_fit 121 | if abs(dis) < 1e-6: 122 | break 123 | # print(':{}'.format(self.ng_best)) 124 | # print(':{}'.format(self.ng_best_fit)) 125 | # best_fit.append(self.ng_best_fit) 126 | # print("---- End of (successful) Searching ----") 127 | # 128 | # plt.figure() 129 | # plt.title("Figure1") 130 | # plt.xlabel("iterators", size=14) 131 | # plt.ylabel("fitness", size=14) 132 | # t = [t for t in range(self.NGEN)] 133 | # plt.plot(t, best_fit, color='b', linewidth=2) 134 | # plt.show() 135 | if return_err: 136 | err = solver.new_cal_ref_bone(self.ng_best) 137 | err = align_bone_len(err, self.target) 138 | return err 139 | 140 | 141 | if __name__ == '__main__': 142 | import time 143 | 144 | data_set = ['rhd', 'stb', 'do', 'eo'] 145 | for data in data_set: 146 | solver = LM_Solver(num_Iter=500, th_beta=torch.zeros((1, 10)), th_pose=torch.zeros((1, 48)), 147 | lb_target=np.zeros((15, 1)), 148 | weight=1e-5) 149 | 150 | NGEN = 100 151 | popsize = 100 152 | low = np.zeros((1, 10)) - 3.0 153 | up = np.zeros((1, 10)) + 3.0 154 | parameters = [NGEN, popsize, low, up] 155 | err = np.zeros((1, 15)) 156 | path = 'out_testset/' + data + '_pre_joints.npy' 157 | print('load:{}'.format(path)) 158 | target = np.load(path) 159 | pso_shape = np.zeros((target.shape[0], 10)) 160 | for i in tqdm(range(target.shape[0])): 161 | _target = target[[0]] 162 | _target = bone.caculate_length(_target, label='useful') 163 | _target = _target.reshape((1, 15)) 164 | pso = PSO(parameters, _target) 165 | err += pso.main(slover=solver, return_err=True) 166 | pso_shape[[i]] = pso.ng_best 167 | print(err.sum() / target.shape[0]) 168 | save_path = 'out_testset/' + data + '_pso.npy' 169 | print('save:{}'.format(save_path)) 170 | np.save(save_path, pso_shape) 171 | -------------------------------------------------------------------------------- /optimize_shape.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from utils import func, bone 7 | from utils.LM import LM_Solver 8 | 9 | 10 | def align_bone_len(opt_, pre_): 11 | opt = opt_.copy() 12 | pre = pre_.copy() 13 | 14 | opt_align = opt.copy() 15 | for i in range(opt.shape[0]): 16 | ratio = pre[i][6] / opt[i][6] 17 | opt_align[i] = ratio * opt_align[i] 18 | 19 | err = np.abs(opt_align - pre).mean(0) 20 | 21 | return err 22 | 23 | 24 | def main(args): 25 | path=args.path 26 | for dataset in args.dataset: 27 | # load predictions (N*21*3) 28 | print("load {}'s joint 3D".format(dataset)) 29 | pred_j3d = np.load("{}/{}_pre_joints.npy".format(path, dataset),allow_pickle=True) 30 | 31 | opt_shapes = [] 32 | opt_bone_lens = [] 33 | pre_useful_bone_lens = [] 34 | 35 | # loop 36 | for pred in tqdm(pred_j3d): 37 | # 0 initialization 38 | pose, shape = func.initiate("zero") 39 | 40 | pre_useful_bone_len = bone.caculate_length(pred, label="useful") 41 | pre_useful_bone_lens.append(pre_useful_bone_len) 42 | 43 | # optimize here! 44 | solver = LM_Solver(num_Iter=500, th_beta=shape, th_pose=pose, lb_target=pre_useful_bone_len, 45 | weight=args.weight) 46 | opt_shape = solver.LM() 47 | opt_shapes.append(opt_shape) 48 | 49 | opt_bone_len = solver.get_bones(opt_shape) 50 | opt_bone_lens.append(opt_bone_len) 51 | 52 | # plt.plot(solver.get_result(), 'r') 53 | # plt.show() 54 | 55 | # break 56 | 57 | opt_shapes = np.array(opt_shapes).reshape(-1, 10) 58 | opt_bone_lens = np.array(opt_bone_lens).reshape(-1, 15) 59 | pre_useful_bone_lens = np.array(pre_useful_bone_lens).reshape(-1, 15) 60 | 61 | np.save("{}/{}_shapes.npy".format(path, dataset, args.weight), opt_shapes) 62 | 63 | error = align_bone_len(opt_bone_lens, pre_useful_bone_lens) 64 | 65 | print("dataset:{} weight:{} ERR sum: {}".format(dataset, args.weight, error.sum())) 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser( 70 | description='optimize shape params. of mano model ') 71 | 72 | # Dataset setting 73 | parser.add_argument( 74 | '-ds', 75 | "--dataset", 76 | nargs="+", 77 | default=['rhd', 'stb', 'do', 'eo'], 78 | type=str, 79 | help="sub datasets, should be listed in: [stb|rhd|do|eo]" 80 | ) 81 | parser.add_argument( 82 | '-wt', '--weight', 83 | default=1e-5, 84 | type=float, 85 | metavar='weight', 86 | help='weight of L2 regularizer ' 87 | ) 88 | 89 | parser.add_argument( 90 | '-p', 91 | '--path', 92 | default='out_testset', 93 | type=str, 94 | metavar='data_root', 95 | help='directory') 96 | 97 | main(parser.parse_args()) 98 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def main(args): 8 | path=args.out_path 9 | 10 | # losses in train 11 | lossD = np.load(os.path.join(path, "lossD.npy")) 12 | lossH = np.load(os.path.join(path, "lossH.npy")) 13 | lossL = np.load(os.path.join(path, "lossL.npy")) 14 | 15 | auc_all = np.load(os.path.join(path, "auc_all.npy"), allow_pickle=True).item() 16 | acc_hm_all = np.load(os.path.join(path, "acc_hm_all.npy"), allow_pickle=True).item() 17 | 18 | # rhd 19 | auc_all_rhd = np.array(auc_all['rhd']) 20 | acc_hm_rhd = np.array(acc_hm_all["rhd"]) 21 | 22 | # stb 23 | auc_all_stb = np.array(auc_all['stb']) 24 | acc_hm_stb = np.array(acc_hm_all["stb"]) 25 | 26 | # do 27 | auc_all_do = np.array(auc_all['do']) 28 | 29 | # eo 30 | auc_all_eo = np.array(auc_all['eo']) 31 | 32 | plt.figure(figsize=[50, 50]) 33 | 34 | plt.subplot(2, 4, 1) 35 | plt.plot(lossH[:, :1], lossH[:, 1:], marker='o', label='lossH') 36 | plt.plot(lossD[:, :1], lossD[:, 1:], marker='*', label='lossD') 37 | plt.plot(lossL[:, :1], lossL[:, 1:], marker='h', label='lossL') 38 | plt.title("LOSSES") 39 | plt.legend(title='Losses Category:') 40 | 41 | # rhd 42 | plt.subplot(2, 4, 2) 43 | plt.plot(auc_all_rhd[:, :1], auc_all_rhd[:, 1:], marker='d') 44 | plt.title( 45 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("RHD", np.argmax(auc_all_rhd[:, 1:]) + 1, 46 | np.max(auc_all_rhd[:, 1:]))) 47 | 48 | plt.subplot(2, 4, 3) 49 | plt.plot(acc_hm_rhd[:, :1], acc_hm_rhd[:, 1:], marker='d') 50 | plt.title( 51 | "{}_test || (EPOCH={} , ACC_HM={:0.4f})".format("RHD", np.argmax(acc_hm_rhd[:, 1:]) + 1, 52 | np.max(acc_hm_rhd[:, 1:]))) 53 | 54 | # stb 55 | plt.subplot(2, 4, 4) 56 | plt.plot(auc_all_stb[:, :1], auc_all_stb[:, 1:], marker='d') 57 | plt.title( 58 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("STB", np.argmax(auc_all_stb[:, 1:]) + 1, 59 | np.max(auc_all_stb[:, 1:]))) 60 | 61 | plt.subplot(2, 4, 5) 62 | plt.plot(acc_hm_stb[:, :1], acc_hm_stb[:, 1:], marker='d') 63 | plt.title( 64 | "{}_test || (EPOCH={} , ACC_HM={:0.4f})".format("STB", np.argmax(acc_hm_stb[:, 1:]) + 1, 65 | np.max(acc_hm_stb[:, 1:]))) 66 | 67 | # do 68 | plt.subplot(2, 4, 6) 69 | plt.plot(auc_all_do[:, :1], auc_all_do[:, 1:], marker='d') 70 | plt.title( 71 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("DO", np.argmax(auc_all_do[:, 1:] + 1), np.max(auc_all_do[:, 1:]))) 72 | 73 | # eo 74 | plt.subplot(2, 4, 7) 75 | plt.plot(auc_all_eo[:, :1], auc_all_eo[:, 1:], marker='d') 76 | plt.title( 77 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("EO", np.argmax(auc_all_eo[:, 1:]) + 1, np.max(auc_all_eo[:, 1:]))) 78 | 79 | # plt.savefig("vis_train.png") 80 | plt.show() 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser( 85 | description='Result') 86 | parser.add_argument( 87 | '-p', 88 | '--out_path', 89 | type=str, 90 | default="out_loss_auc", 91 | help='ouput path' 92 | ) 93 | main(parser.parse_args()) 94 | 95 | -------------------------------------------------------------------------------- /train_shape_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data 10 | from progress.bar import Bar 11 | from tensorboardX.writer import SummaryWriter 12 | from termcolor import cprint 13 | 14 | from model import shape_net 15 | from datasets import SIK1M 16 | from losses import shape_loss 17 | # select proper device to run 18 | from utils import misc 19 | from utils.eval.evalutils import AverageMeter 20 | import numpy as np 21 | 22 | writer = SummaryWriter('log') 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | cudnn.benchmark = True 25 | steps = 0 26 | 27 | 28 | 29 | def print_args(args): 30 | opts = vars(args) 31 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow') 32 | for k, v in sorted(opts.items()): 33 | print("{:>30} : {}".format(k, v)) 34 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow') 35 | 36 | 37 | def main(args): 38 | if not os.path.isdir(args.checkpoint): 39 | os.makedirs(args.checkpoint) 40 | print_args(args) 41 | print("\nCREATE NETWORK") 42 | model = shape_net.ShapeNet(_mano_root='mano/models') 43 | model = model.to(device) 44 | 45 | criterion = shape_loss.SIKLoss( 46 | lambda_joint=0.0, 47 | lambda_shape=1.0 48 | ) 49 | 50 | optimizer = torch.optim.Adam( 51 | [ 52 | { 53 | 'params': model.shapereg_layers.parameters(), 54 | 'initial_lr': args.learning_rate 55 | }, 56 | 57 | ], 58 | lr=args.learning_rate, 59 | ) 60 | 61 | train_dataset = SIK1M.SIK1M( 62 | data_root=args.data_root, 63 | data_split="train" 64 | ) 65 | 66 | val_dataset = SIK1M.SIK1M( 67 | data_root=args.data_root, 68 | data_split="test" 69 | ) 70 | 71 | print("Total train dataset size: {}".format(len(train_dataset))) 72 | print("Total val dataset size: {}".format(len(val_dataset))) 73 | 74 | train_loader = torch.utils.data.DataLoader( 75 | train_dataset, 76 | batch_size=args.train_batch, 77 | shuffle=True, 78 | num_workers=args.workers, 79 | pin_memory=True 80 | ) 81 | val_loader = torch.utils.data.DataLoader( 82 | val_dataset, 83 | batch_size=args.test_batch, 84 | shuffle=False, 85 | num_workers=args.workers, 86 | pin_memory=True 87 | ) 88 | 89 | if args.evaluate or args.resume: 90 | shape_net.load_checkpoint( 91 | model, os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar') 92 | ) 93 | if args.evaluate: 94 | for params in model.invk_layers.parameters(): 95 | params.requires_grad = False 96 | 97 | if args.evaluate: 98 | validate(val_loader, model, args=args) 99 | cprint('Eval All Done', 'yellow', attrs=['bold']) 100 | return 0 101 | 102 | model = torch.nn.DataParallel(model) 103 | print("\nUSING {} GPUs".format(torch.cuda.device_count())) 104 | scheduler = torch.optim.lr_scheduler.StepLR( 105 | optimizer, args.lr_decay_step, gamma=args.gamma, 106 | last_epoch=args.start_epoch 107 | ) 108 | train_bone_len = [] 109 | train_shape_l2 = [] 110 | test_bone_len = [] 111 | test_shape_l2 = [] 112 | for epoch in range(args.start_epoch, args.epochs + 1): 113 | print('\nEpoch: %d' % (epoch + 1)) 114 | for i in range(len(optimizer.param_groups)): 115 | print('group %d lr:' % i, optimizer.param_groups[i]['lr']) 116 | ############# trian for on epoch ############### 117 | t1, t2 = train( 118 | train_loader, 119 | model, 120 | criterion, 121 | optimizer, 122 | args=args, 123 | ) 124 | ################################################## 125 | shape_net.save_checkpoint( 126 | { 127 | 'epoch': epoch + 1, 128 | 'state_dict': model.module.state_dict(), 129 | }, 130 | checkpoint=args.checkpoint, 131 | filename='{}.pth.tar'.format(args.saved_prefix), 132 | snapshot=args.snapshot, 133 | is_best=False 134 | ) 135 | t3, t4 = validate(val_loader, model, criterion, args) 136 | train_bone_len.append(t1) 137 | train_shape_l2.append(t2) 138 | test_bone_len.append(t3) 139 | test_shape_l2.append(t4) 140 | np.save('log/train_bone_len.npy', train_bone_len) 141 | np.save('log/test_bone_len.npy', test_bone_len) 142 | np.save('log/train_shape_l2.npy', train_shape_l2) 143 | np.save('log/test_shape_l2.npy', test_shape_l2) 144 | scheduler.step() 145 | cprint('All Done', 'yellow', attrs=['bold']) 146 | return 0 # end of main 147 | 148 | 149 | def validate(val_loader, model, criterion, args): 150 | am_shape_l2 = AverageMeter() 151 | am_bone_len = AverageMeter() 152 | model.eval() 153 | bar = Bar('\033[33m Eval \033[0m', max=len(val_loader)) 154 | with torch.no_grad(): 155 | for i, metas in enumerate(val_loader): 156 | results, targets, total_loss, losses = one_forward_pass( 157 | metas, model, criterion, args, train=True 158 | ) 159 | am_shape_l2.update(losses['shape_reg'].item(), targets['batch_size']) 160 | am_bone_len.update(losses['bone_len'].item(), targets['batch_size']) 161 | bar.suffix = ( 162 | '({batch}/{size}) ' 163 | 't: {total:}s | ' 164 | 'eta:{eta:}s | ' 165 | 'lN: {lossLen:.5f} | ' 166 | 'lL2: {lossL2:.5f} | ' 167 | ).format( 168 | batch=i + 1, 169 | size=len(val_loader), 170 | total=bar.elapsed_td, 171 | eta=bar.eta_td, 172 | lossLen=am_bone_len.avg, 173 | lossL2=am_shape_l2.avg, 174 | ) 175 | bar.next() 176 | bar.finish() 177 | return (am_bone_len.avg, am_shape_l2.avg) 178 | 179 | 180 | def train(train_loader, model, criterion, optimizer, args): 181 | batch_time = AverageMeter() 182 | data_time = AverageMeter() 183 | am_shape_l2 = AverageMeter() 184 | am_bone_len = AverageMeter() 185 | 186 | last = time.time() 187 | # switch to trian 188 | model.train() 189 | bar = Bar('\033[31m Train \033[0m', max=len(train_loader)) 190 | for i, metas in enumerate(train_loader): 191 | data_time.update(time.time() - last) 192 | results, targets, total_loss, losses = one_forward_pass( 193 | metas, model, criterion, args, train=True 194 | ) 195 | global steps 196 | steps += 1 197 | writer.add_scalar('loss', total_loss.item(), steps) 198 | am_shape_l2.update(losses['shape_reg'].item(), targets['batch_size']) 199 | am_bone_len.update(losses['bone_len'].item(), targets['batch_size']) 200 | ''' backward and step ''' 201 | optimizer.zero_grad() 202 | total_loss.backward() 203 | optimizer.step() 204 | 205 | ''' progress ''' 206 | batch_time.update(time.time() - last) 207 | last = time.time() 208 | bar.suffix = ( 209 | '({batch}/{size}) ' 210 | 'd: {data:.2f}s | ' 211 | 'b: {bt:.2f}s | ' 212 | 't: {total:}s | ' 213 | 'eta:{eta:}s | ' 214 | 'lN: {lossLen:.5f} | ' 215 | 'lL2: {lossL2:.5f} | ' 216 | ).format( 217 | batch=i + 1, 218 | size=len(train_loader), 219 | data=data_time.avg, 220 | bt=batch_time.avg, 221 | total=bar.elapsed_td, 222 | eta=bar.eta_td, 223 | lossLen=am_bone_len.avg, 224 | lossL2=am_shape_l2.avg, 225 | ) 226 | bar.next() 227 | bar.finish() 228 | return (am_bone_len.avg, am_shape_l2.avg) 229 | 230 | 231 | def one_forward_pass(metas, model, criterion, args, train=True): 232 | ''' prepare targets ''' 233 | rel_bone_len = metas['rel_bone_len'].float().to(device, non_blocking=True) 234 | targets = { 235 | 'batch_size': rel_bone_len.shape[0], 236 | 'rel_bone_len': rel_bone_len 237 | } 238 | ''' ---------------- Forward Pass ---------------- ''' 239 | results = model(rel_bone_len) 240 | ''' ---------------- Forward End ---------------- ''' 241 | 242 | total_loss = torch.Tensor([0]).cuda() 243 | losses = {} 244 | if not train: 245 | return results, targets, total_loss, losses 246 | 247 | ''' conpute losses ''' 248 | total_loss, losses = criterion.compute_loss(results, targets) 249 | return results, targets, total_loss, losses 250 | 251 | 252 | if __name__ == '__main__': 253 | parser = argparse.ArgumentParser( 254 | description='PyTorch Train dl shape') 255 | # Miscs 256 | parser.add_argument( 257 | '-ckp', 258 | '--checkpoint', 259 | default='checkpoints', 260 | type=str, 261 | metavar='PATH', 262 | help='path to save checkpoint (default: checkpoint)' 263 | ) 264 | 265 | parser.add_argument( 266 | '-dr', 267 | '--data_root', 268 | type=str, 269 | default='data', 270 | help='dataset root directory' 271 | ) 272 | 273 | parser.add_argument( 274 | '-sp', 275 | '--saved_prefix', 276 | default='ckp_siknet_synth', 277 | type=str, 278 | metavar='PATH', 279 | help='path to save checkpoint (default: checkpoint)' 280 | ) 281 | parser.add_argument( 282 | '--snapshot', 283 | default=1, type=int, 284 | help='save models for every #snapshot epochs (default: 1)' 285 | ) 286 | 287 | parser.add_argument( 288 | '-e', '--evaluate', 289 | dest='evaluate', 290 | action='store_true', 291 | help='evaluate model on validation set' 292 | ) 293 | 294 | parser.add_argument( 295 | '-r', '--resume', 296 | dest='resume', 297 | action='store_true', 298 | help='resume model on validation set' 299 | ) 300 | 301 | # Training Parameters 302 | parser.add_argument( 303 | '-j', '--workers', 304 | default=8, 305 | type=int, 306 | metavar='N', 307 | help='number of data loading workers (default: 8)' 308 | ) 309 | parser.add_argument( 310 | '--epochs', 311 | default=150, 312 | type=int, 313 | metavar='N', 314 | help='number of total epochs to run' 315 | ) 316 | parser.add_argument( 317 | '-se', '--start_epoch', 318 | default=0, 319 | type=int, 320 | metavar='N', 321 | help='manual epoch number (useful on restarts)' 322 | ) 323 | parser.add_argument( 324 | '-b', '--train_batch', 325 | default=1024, 326 | type=int, 327 | metavar='N', 328 | help='train batchsize' 329 | ) 330 | parser.add_argument( 331 | '-tb', '--test_batch', 332 | default=512, 333 | type=int, 334 | metavar='N', 335 | help='test batchsize' 336 | ) 337 | 338 | parser.add_argument( 339 | '-lr', '--learning-rate', 340 | default=1.0e-4, 341 | type=float, 342 | metavar='LR', 343 | help='initial learning rate' 344 | ) 345 | parser.add_argument( 346 | "--lr_decay_step", 347 | default=40, 348 | type=int, 349 | help="Epochs after which to decay learning rate", 350 | ) 351 | parser.add_argument( 352 | '--gamma', 353 | type=float, 354 | default=0.1, 355 | help='LR is multiplied by gamma on schedule.' 356 | ) 357 | main(parser.parse_args()) 358 | -------------------------------------------------------------------------------- /utils/AIK.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | import numpy as np 3 | import transforms3d 4 | 5 | import config as cfg 6 | 7 | angels0 = np.zeros((1, 21)) 8 | 9 | 10 | def to_dict(joints): 11 | temp_dict = dict() 12 | for i in range(21): 13 | temp_dict[i] = joints[:, [i]] 14 | return temp_dict 15 | 16 | 17 | def adaptive_IK(T_, P_): 18 | ''' 19 | Computes pose parameters given template and predictions. 20 | We think the twist of hand bone could be omitted. 21 | 22 | :param T: template ,21*3 23 | :param P: target, 21*3 24 | :return: pose params. 25 | ''' 26 | 27 | T = T_.copy().astype(np.float64) 28 | P = P_.copy().astype(np.float64) 29 | 30 | P = P.transpose(1, 0) 31 | T = T.transpose(1, 0) 32 | 33 | # to dict 34 | P = to_dict(P) 35 | T = to_dict(T) 36 | 37 | # some globals 38 | R = {} 39 | R_pa_k = {} 40 | q = {} 41 | 42 | q[0] = T[0] # in fact, q[0] = P[0] = T[0]. 43 | 44 | # compute R0, here we think R0 is not only a Orthogonal matrix, but also a Rotation matrix. 45 | # you can refer to paper "Least-Squares Fitting of Two 3-D Point Sets. K. S. Arun; T. S. Huang; S. D. Blostein" 46 | # It is slightly different from https://github.com/Jeff-sjtu/HybrIK/blob/main/hybrik/utils/pose_utils.py#L4, in which R0 is regard as orthogonal matrix only. 47 | # Using their method might further boost accuracy. 48 | P_0 = np.concatenate([P[1] - P[0], P[5] - P[0], 49 | P[9] - P[0], P[13] - P[0], 50 | P[17] - P[0]], axis=-1) 51 | T_0 = np.concatenate([T[1] - T[0], T[5] - T[0], 52 | T[9] - T[0], T[13] - T[0], 53 | T[17] - T[0]], axis=-1) 54 | H = np.matmul(T_0, P_0.T) 55 | 56 | U, S, V_T = np.linalg.svd(H) 57 | V = V_T.T 58 | R0 = np.matmul(V, U.T) 59 | 60 | det0 = np.linalg.det(R0) 61 | 62 | if abs(det0 + 1) < 1e-6: 63 | V_ = V.copy() 64 | 65 | if (abs(S) < 1e-4).sum(): 66 | V_[:, 2] = -V_[:, 2] 67 | R0 = np.matmul(V_, U.T) 68 | 69 | R[0] = R0 70 | 71 | # the bone from 1,5,9,13,17 to 0 has same rotations 72 | R[1] = R[0].copy() 73 | R[5] = R[0].copy() 74 | R[9] = R[0].copy() 75 | R[13] = R[0].copy() 76 | R[17] = R[0].copy() 77 | 78 | # compute rotation along kinematics 79 | for k in cfg.kinematic_tree: 80 | pa = cfg.SNAP_PARENT[k] 81 | pa_pa = cfg.SNAP_PARENT[pa] 82 | q[pa] = np.matmul(R[pa], (T[pa] - T[pa_pa])) + q[pa_pa] 83 | delta_p_k = np.matmul(np.linalg.inv(R[pa]), P[k] - q[pa]) 84 | delta_p_k = delta_p_k.reshape((3,)) 85 | delta_t_k = T[k] - T[pa] 86 | delta_t_k = delta_t_k.reshape((3,)) 87 | temp_axis = np.cross(delta_t_k, delta_p_k) 88 | axis = temp_axis / (np.linalg.norm(temp_axis, axis=-1) + 1e-8) 89 | temp = (np.linalg.norm(delta_t_k, axis=0) + 1e-8) * (np.linalg.norm(delta_p_k, axis=0) + 1e-8) 90 | cos_alpha = np.dot(delta_t_k, delta_p_k) / temp 91 | 92 | alpha = np.arccos(cos_alpha) 93 | 94 | twist = delta_t_k 95 | D_sw = transforms3d.axangles.axangle2mat(axis=axis, angle=alpha, is_normalized=False) 96 | D_tw = transforms3d.axangles.axangle2mat(axis=twist, angle=angels0[:, k], is_normalized=False) 97 | R_pa_k[k] = np.matmul(D_sw, D_tw) 98 | R[k] = np.matmul(R[pa], R_pa_k[k]) 99 | 100 | pose_R = np.zeros((1, 16, 3, 3)) 101 | pose_R[0, 0] = R[0] 102 | for key in cfg.ID2ROT.keys(): 103 | value = cfg.ID2ROT[key] 104 | pose_R[0, value] = R_pa_k[key] 105 | 106 | return pose_R 107 | -------------------------------------------------------------------------------- /utils/LM.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | # import time 3 | 4 | import numpy as np 5 | import torch 6 | from manopth.manolayer import ManoLayer 7 | 8 | from utils import bone 9 | 10 | 11 | class LM_Solver(): 12 | def __init__(self, num_Iter=500, th_beta=None, th_pose=None, lb_target=None, 13 | weight=0.01): 14 | self.count = 0 15 | # self.time_start = time.time() 16 | # self.time_in_mano = 0 17 | self.minimal_loss = 9999 18 | self.best_beta = np.zeros([10, 1]) 19 | self.num_Iter = num_Iter 20 | 21 | self.th_beta = th_beta 22 | self.th_pose = th_pose 23 | 24 | self.beta = th_beta.numpy() 25 | self.pose = th_pose.numpy() 26 | 27 | self.mano_layer = ManoLayer(side="right", 28 | mano_root='mano/models', use_pca=False, flat_hand_mean=True) 29 | 30 | self.threshold_stop = 10 ** -13 31 | self.weight = weight 32 | self.residual_memory = [] 33 | 34 | self.lb = np.zeros(21) 35 | 36 | _, self.joints = self.mano_layer(self.th_pose, self.th_beta) 37 | self.joints = self.joints.numpy().reshape(21, 3) 38 | 39 | self.lb_target = lb_target.reshape(15, 1) 40 | # self.test_time = 0 41 | 42 | def update(self, beta_): 43 | beta = beta_.copy() 44 | self.count += 1 45 | # now = time.time() 46 | my_th_beta = torch.from_numpy(beta).float().reshape(1, 10) 47 | _, joints = self.mano_layer(self.th_pose, my_th_beta) 48 | # self.time_in_mano = time.time() - now 49 | 50 | useful_lb = bone.caculate_length(joints, label="useful") 51 | lb_ref = useful_lb[6] 52 | return useful_lb, lb_ref 53 | 54 | def new_cal_ref_bone(self, _shape): 55 | # now = time.time() 56 | parent_index = [0, 57 | 0, 1, 2, 58 | 0, 4, 5, 59 | 0, 7, 8, 60 | 0, 10, 11, 61 | 0, 13, 14 62 | ] 63 | # index = [0, 64 | # 1, 2, 3, # index 65 | # 4, 5, 6, # middle 66 | # 7, 8, 9, # pinky 67 | # 10, 11, 12, # ring 68 | # 13, 14, 15] # thumb 69 | reoder_index = [ 70 | 13, 14, 15, 71 | 1, 2, 3, 72 | 4, 5, 6, 73 | 10, 11, 12, 74 | 7, 8, 9] 75 | shape = torch.Tensor(_shape.reshape((-1, 10))) 76 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs, 77 | shape.transpose(1, 0)).permute(2, 0, 1) \ 78 | + self.mano_layer.th_v_template 79 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped) 80 | temp1 = th_j.clone().detach() 81 | temp2 = th_j.clone().detach()[:, parent_index, :] 82 | result = temp1 - temp2 83 | result = torch.norm(result, dim=-1, keepdim=True) 84 | ref_len = result[:, [4]] 85 | result = result / ref_len 86 | # self.time_in_mano = time.time() - now 87 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy() 88 | 89 | def get_residual(self, beta_): 90 | beta = beta_.copy() 91 | lb, lb_ref = self.update(beta) 92 | lb = lb.reshape(45, 1) 93 | return lb / lb_ref - self.lb_target 94 | 95 | def get_count(self): 96 | return self.count 97 | 98 | def get_bones(self, beta_): 99 | beta = beta_.copy() 100 | lb, _ = self.update(beta) 101 | lb = lb.reshape(15, 1) 102 | 103 | return lb 104 | 105 | # Vectorization implementation 106 | def batch_get_l2_loss(self, beta_): 107 | weight = 1e-5 108 | beta = beta_.copy() 109 | temp = self.new_cal_ref_bone(beta) 110 | loss = np.transpose(temp) 111 | loss = np.linalg.norm(loss - self.lb_target, axis=0) ** 2 + \ 112 | weight * np.linalg.norm(beta, axis=-1) 113 | return loss 114 | 115 | def new_get_derivative(self, beta_): 116 | # params: beta_ 10*1 117 | # return: 1*10 118 | beta = beta_.copy().reshape((1, 10)) 119 | temp_shape = np.zeros((20, beta.shape[1])) # 20*10 120 | step = 0.01 121 | for t2 in range(10): # 位置 122 | t3 = 10 + t2 123 | temp_shape[t2] = beta.copy() 124 | temp_shape[t3] = beta.copy() 125 | temp_shape[t2, t2] += step 126 | temp_shape[t3, t2] -= step 127 | 128 | res = self.batch_get_l2_loss(temp_shape) 129 | d = res[0:10] - res[10:20] # 10*1 130 | d = d.reshape((1, 10)) / (2 * step) 131 | return d 132 | 133 | # LM algorithm 134 | def LM(self): 135 | u = 1e-2 136 | v = 1.5 137 | beta = self.beta.reshape(10, 1) 138 | 139 | out_n = 1 140 | # num_beta = np.shape(beta)[0] # the number of beta 141 | # calculating the init Jocobian matrix 142 | Jacobian = np.zeros([out_n, beta.shape[0]]) 143 | 144 | last_update = 0 145 | last_loss = 0 146 | # self.test_time = 0 147 | for i in range(self.num_Iter): 148 | # loss = self.new_get_loss(beta) 149 | loss = self.batch_get_l2_loss(beta) 150 | loss = loss[0] 151 | if loss < self.minimal_loss: 152 | self.minimal_loss = loss 153 | self.best_beta = beta 154 | 155 | if abs(loss - last_loss) < self.threshold_stop: 156 | # self.time_total = time.time() - self.time_start 157 | return beta 158 | 159 | # for k in range(num_beta): 160 | # Jacobian[:, k] = self.get_derivative(beta, k) 161 | Jacobian = self.new_get_derivative(beta) 162 | jtj = np.matmul(Jacobian.T, Jacobian) 163 | jtj = jtj + u * np.eye(jtj.shape[0]) 164 | 165 | update = last_loss - loss 166 | delta = (np.matmul(np.linalg.inv(jtj), Jacobian.T) * loss) 167 | 168 | beta -= delta 169 | 170 | if update > last_update and update > 0: 171 | u /= v 172 | else: 173 | u *= v 174 | 175 | last_update = update 176 | last_loss = loss 177 | self.residual_memory.append(loss) 178 | 179 | return beta 180 | 181 | def get_result(self): 182 | return self.residual_memory 183 | -------------------------------------------------------------------------------- /utils/LM_new.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Hao Meng. All Rights Reserved. 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from manopth.manolayer import ManoLayer 7 | 8 | from utils import bone 9 | 10 | 11 | class LM_Solver(): 12 | def __init__(self, num_Iter=500, th_beta=None, th_pose=None, lb_target=None, 13 | weight=0.01,_mano_root='mano/models'): 14 | self.count = 0 15 | self.time_start = time.time() 16 | self.time_in_mano = 0 17 | self.minimal_loss = 9999 18 | self.best_beta = np.zeros([10, 1]) 19 | self.num_Iter = num_Iter 20 | 21 | self.th_beta = th_beta 22 | self.th_pose = th_pose 23 | 24 | self.beta = th_beta.numpy() 25 | self.pose = th_pose.numpy() 26 | 27 | self.mano_layer = ManoLayer(side="right", 28 | mano_root='mano/models', use_pca=False, flat_hand_mean=True) 29 | 30 | self.threshold_stop = 10 ** -13 31 | self.weight = weight 32 | self.residual_memory = [] 33 | 34 | self.lb = np.zeros(21) 35 | 36 | _, self.joints = self.mano_layer(self.th_pose, self.th_beta) 37 | self.joints = self.joints.numpy().reshape(21, 3) 38 | 39 | self.lb_target = lb_target.reshape(15, 1) 40 | self.test_time = 0 41 | 42 | def update_target(self, target): 43 | self.lb_target = target.copy().reshape(15, 1) 44 | 45 | def update(self, beta_): 46 | beta = beta_.copy() 47 | self.count += 1 48 | now = time.time() 49 | my_th_beta = torch.from_numpy(beta).float().reshape(1, 10) 50 | _, joints = self.mano_layer(self.th_pose, my_th_beta) 51 | self.time_in_mano = time.time() - now 52 | 53 | useful_lb = bone.caculate_length(joints, label="useful") 54 | lb_ref = useful_lb[6] 55 | return useful_lb, lb_ref 56 | 57 | def new_cal_ref_bone(self, _shape): 58 | now = time.time() 59 | parent_index = [0, 60 | 0, 1, 2, 61 | 0, 4, 5, 62 | 0, 7, 8, 63 | 0, 10, 11, 64 | 0, 13, 14 65 | ] 66 | index = [0, 67 | 1, 2, 3, # index 68 | 4, 5, 6, # middle 69 | 7, 8, 9, # pinky 70 | 10, 11, 12, # ring 71 | 13, 14, 15] # thumb 72 | reoder_index = [ 73 | 13, 14, 15, 74 | 1, 2, 3, 75 | 4, 5, 6, 76 | 10, 11, 12, 77 | 7, 8, 9] 78 | shape = torch.Tensor(_shape.reshape((-1, 10))) 79 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs, 80 | shape.transpose(1, 0)).permute(2, 0, 1) \ 81 | + self.mano_layer.th_v_template 82 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped) 83 | temp1 = th_j.clone().detach() 84 | temp2 = th_j.clone().detach()[:, parent_index, :] 85 | result = temp1 - temp2 86 | result = torch.norm(result, dim=-1, keepdim=True) 87 | ref_len = result[:, [4]] 88 | result = result / ref_len 89 | self.time_in_mano = time.time() - now 90 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy() 91 | 92 | def get_residual(self, beta_): 93 | beta = beta_.copy() 94 | lb, lb_ref = self.update(beta) 95 | lb = lb.reshape(45, 1) 96 | return lb / lb_ref - self.lb_target 97 | 98 | def get_count(self): 99 | return self.count 100 | 101 | def get_bones(self, beta_): 102 | beta = beta_.copy() 103 | lb, _ = self.update(beta) 104 | lb = lb.reshape(15, 1) 105 | return lb 106 | 107 | def get_loss(self, beta_): 108 | 109 | beta = beta_.copy() 110 | 111 | lb, lb_ref = self.update(beta) 112 | lb = lb.reshape(15, 1) 113 | 114 | loss = np.linalg.norm(lb / lb_ref - self.lb_target) ** 2 + \ 115 | self.weight * np.linalg.norm(beta) ** 2 116 | 117 | return loss 118 | 119 | def new_get_loss(self, beta_): 120 | beta = beta_.copy() 121 | temp = self.new_cal_ref_bone(beta_) 122 | loss = temp.reshape((15, 1)) 123 | loss = np.linalg.norm(loss - self.lb_target) ** 2 + \ 124 | self.weight * np.linalg.norm(beta_) 125 | return loss 126 | 127 | def get_derivative(self, beta_, n): 128 | 129 | beta = beta_.copy() 130 | params1 = np.array(beta) 131 | params2 = np.array(beta) 132 | step = 0.01 133 | params1[n] += step 134 | params2[n] -= step 135 | 136 | res1 = self.new_get_loss(params1) 137 | res2 = self.new_get_loss(params2) 138 | 139 | d = (res1 - res2) / (2 * step) 140 | 141 | return d.ravel() 142 | 143 | def batch_new_get_loss(self, beta_): 144 | weight = 1e-5 145 | beta = beta_.copy() 146 | temp = self.new_cal_ref_bone(beta) 147 | loss = np.transpose(temp) 148 | loss = np.linalg.norm(loss - self.lb_target, axis=0) ** 2 + \ 149 | weight * np.linalg.norm(beta, axis=-1) 150 | return loss 151 | 152 | def new_get_derivative(self, beta_): 153 | # params: beta_ 10*1 154 | # return: 1*10 155 | beta = beta_.copy().reshape((1, 10)) 156 | temp_shape = np.zeros((20, beta.shape[1])) # 20*10 157 | step = 0.01 158 | for t2 in range(10): # 位置 159 | t3 = 10 + t2 160 | temp_shape[t2] = beta.copy() 161 | temp_shape[t3] = beta.copy() 162 | temp_shape[t2, t2] += step 163 | temp_shape[t3, t2] -= step 164 | 165 | res = self.batch_new_get_loss(temp_shape) 166 | d = res[0:10] - res[10:20] # 10*1 167 | d = d.reshape((1, 10)) / (2 * step) 168 | return d 169 | 170 | # LM algorithm 171 | def LM(self): 172 | u = 1e-2 173 | v = 1.5 174 | beta = self.beta.reshape(10, 1) 175 | 176 | out_n = 1 177 | num_beta = np.shape(beta)[0] # the number of beta 178 | # calculating the init Jocobian matrix 179 | Jacobian = np.zeros([out_n, beta.shape[0]]) 180 | 181 | last_update = 0 182 | last_loss = 0 183 | self.test_time = 0 184 | for i in range(self.num_Iter): 185 | loss = self.new_get_loss(beta) 186 | if loss < self.minimal_loss: 187 | self.minimal_loss = loss 188 | self.best_beta = beta 189 | 190 | if abs(loss - last_loss) < self.threshold_stop: 191 | self.time_total = time.time() - self.time_start 192 | return beta 193 | 194 | # for k in range(num_beta): 195 | # Jacobian[:, k] = self.get_derivative(beta, k) 196 | Jacobian = self.new_get_derivative(beta) 197 | jtj = np.matmul(Jacobian.T, Jacobian) 198 | jtj = jtj + u * np.eye(jtj.shape[0]) 199 | 200 | update = last_loss - loss 201 | delta = (np.matmul(np.linalg.inv(jtj), Jacobian.T) * loss) 202 | 203 | beta -= delta 204 | 205 | if update > last_update and update > 0: 206 | u /= v 207 | else: 208 | u *= v 209 | 210 | last_update = update 211 | last_loss = loss 212 | self.residual_memory.append(loss) 213 | 214 | return beta 215 | 216 | def get_result(self): 217 | return self.residual_memory 218 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/AIK.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/AIK.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/LM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/LM.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/align.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/align.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/bone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/bone.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/func.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/func.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/handutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/handutils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/heatmaputils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/heatmaputils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/imgutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/imgutils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/vis.cpython-37.pyc -------------------------------------------------------------------------------- /utils/align.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def global_align(gtj0, prj0, key): 3 | gtj = gtj0.copy() 4 | prj = prj0.copy() 5 | 6 | if key in ["stb", "rhd"]: 7 | # gtj :B*21*3 8 | # prj :B*21*3 9 | root_idx = 9 # root 10 | ref_bone_link = [0, 9] # mid mcp 11 | pred_align = prj.copy() 12 | for i in range(prj.shape[0]): 13 | 14 | pred_ref_bone_len = np.linalg.norm(prj[i][ref_bone_link[0]] - prj[i][ref_bone_link[1]]) 15 | gt_ref_bone_len = np.linalg.norm(gtj[i][ref_bone_link[0]] - gtj[i][ref_bone_link[1]]) 16 | scale = gt_ref_bone_len / pred_ref_bone_len 17 | 18 | for j in range(21): 19 | pred_align[i][j] = gtj[i][root_idx] + scale * (prj[i][j] - prj[i][root_idx]) 20 | 21 | return gtj, pred_align 22 | 23 | if key in ["do", "eo"]: 24 | # gtj :B*5*3 25 | # prj :B*5*3 26 | 27 | prj_ = prj.copy()[:, [4, 8, 12, 16, 20], :] # B*5*3 28 | 29 | gtj_valid = [] 30 | prj_valid_align = [] 31 | 32 | for i in range(prj_.shape[0]): 33 | # 5*3 34 | mask = ~(np.isnan(gtj[i][:, 0])) 35 | if mask.sum() < 2: 36 | continue 37 | 38 | prj_mask = prj_[i][mask] # m*3 39 | gtj_mask = gtj[i][mask] # m*3 40 | 41 | gtj_valid_center = np.mean(gtj_mask, 0) 42 | prj_valid_center = np.mean(prj_mask, 0) 43 | 44 | gtj_center_length = np.linalg.norm(gtj_mask - gtj_valid_center, axis=1).mean() 45 | prj_center_length = np.linalg.norm(prj_mask - prj_valid_center, axis=1).mean() 46 | scale = gtj_center_length / prj_center_length 47 | 48 | prj_valid_align_i = gtj_valid_center + scale * (prj_[i][mask] - prj_valid_center) 49 | 50 | gtj_valid.append(gtj_mask) 51 | prj_valid_align.append(prj_valid_align_i) 52 | 53 | return np.array(gtj_valid), np.array(prj_valid_align) -------------------------------------------------------------------------------- /utils/bone.py: -------------------------------------------------------------------------------- 1 | import config as cfg 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def caculate_length(j3d_, label=None): 7 | if isinstance(j3d_, torch.Tensor): 8 | j3d = j3d_.clone() 9 | j3d = j3d.detach().cpu() 10 | j3d = j3d.numpy() 11 | else: 12 | j3d = j3d_.copy() 13 | 14 | if len(j3d.shape) != 2: 15 | j3d = j3d.squeeze() 16 | 17 | bone = [ 18 | j3d[i] - j3d[cfg.SNAP_PARENT[i]] 19 | for i in range(21) 20 | ] 21 | bone_len = np.linalg.norm( 22 | bone, ord=2, axis=-1, keepdims=True # 21*1 23 | ) 24 | 25 | if label == "full": 26 | return bone_len 27 | elif label == "useful": 28 | return bone_len[cfg.USEFUL_BONE] 29 | else: 30 | raise ValueError("{} not in ['full'|'useful']".format(label)) 31 | -------------------------------------------------------------------------------- /utils/eval/__pycache__/evalutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/eval/__pycache__/evalutils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/eval/__pycache__/zimeval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/eval/__pycache__/zimeval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/eval/evalutils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from random import randint 5 | 6 | from utils.misc import * 7 | from utils.heatmaputils import get_heatmap_pred 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0. 18 | self.avg = 0. 19 | self.sum = 0. 20 | self.count = 0. 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def calc_dists(preds, target, normalize, mask): 30 | preds = preds.float() 31 | target = target.float() 32 | dists = torch.zeros(preds.size(1), preds.size(0)) # (njoint, B) 33 | for b in range(preds.size(0)): 34 | for j in range(preds.size(1)): 35 | if mask[b][j] == 0: 36 | dists[j, b] = -1 37 | elif target[b, j, 0] < 1 or target[b, j, 1] < 1: 38 | dists[j, b] = -1 39 | else: 40 | dists[j, b] = torch.dist(preds[b, j, :], target[b, j, :]) / normalize[b] 41 | 42 | return dists 43 | 44 | 45 | def dist_acc(dist, thr=0.5): 46 | """ Return percentage below threshold while ignoring values with a -1 """ 47 | dist = dist[dist != -1] 48 | if len(dist) > 0: 49 | return 1.0 * (dist < thr).sum().item() / len(dist) 50 | else: 51 | return -1 52 | 53 | 54 | def accuracy_heatmap(output, target, mask, thr=0.5): 55 | """ Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations 56 | First to be returned is average accuracy across 'idxs', Second is individual accuracies 57 | """ 58 | preds = get_heatmap_pred(output).float() # (B, njoint, 2) 59 | gts = get_heatmap_pred(target).float() 60 | norm = torch.ones(preds.size(0)) * output.size(3) / 10.0 # (B, ), all 6.4:(1/10 of heatmap side) 61 | dists = calc_dists(preds, gts, norm, mask) # (njoint, B) 62 | 63 | acc = torch.zeros(mask.size(1)) 64 | avg_acc = 0 65 | cnt = 0 66 | 67 | for i in range(mask.size(1)): # njoint 68 | acc[i] = dist_acc(dists[i], thr) 69 | if acc[i] >= 0: 70 | avg_acc += acc[i] 71 | cnt += 1 72 | 73 | if cnt != 0: 74 | avg_acc /= cnt 75 | 76 | return avg_acc, acc 77 | -------------------------------------------------------------------------------- /utils/eval/zimeval.py: -------------------------------------------------------------------------------- 1 | # ColorHandPose3DNetwork - Network for estimating 3D Hand Pose from a single RGB Image 2 | # Copyright (C) 2017 Christian Zimmermann 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 2 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | import numpy as np 18 | import torch 19 | 20 | 21 | class EvalUtil: 22 | """ Util class for evaluation networks. 23 | """ 24 | 25 | def __init__(self, num_kp=21): 26 | # init empty data storage 27 | self.data = list() 28 | self.num_kp = num_kp 29 | for _ in range(num_kp): 30 | self.data.append(list()) 31 | 32 | def feed(self, keypoint_gt, keypoint_pred, keypoint_vis=None): 33 | """ 34 | Used to feed data to the class. 35 | Stores the euclidean distance between gt and pred, when it is visible. 36 | """ 37 | if isinstance(keypoint_gt, torch.Tensor): 38 | keypoint_gt = keypoint_gt.detach().cpu() 39 | keypoint_gt = keypoint_gt.numpy() 40 | if isinstance(keypoint_pred, torch.Tensor): 41 | keypoint_pred = keypoint_pred.detach().cpu() 42 | keypoint_pred = keypoint_pred.numpy() 43 | keypoint_gt = np.squeeze(keypoint_gt) 44 | keypoint_pred = np.squeeze(keypoint_pred) 45 | 46 | if keypoint_vis is None: 47 | keypoint_vis = np.ones_like(keypoint_gt[:, 0]) 48 | keypoint_vis = np.squeeze(keypoint_vis).astype("bool") 49 | 50 | assert len(keypoint_gt.shape) == 2 51 | assert len(keypoint_pred.shape) == 2 52 | assert len(keypoint_vis.shape) == 1 53 | 54 | # calc euclidean distance 55 | diff = keypoint_gt - keypoint_pred 56 | euclidean_dist = np.sqrt(np.sum(np.square(diff), axis=1)) 57 | 58 | num_kp = keypoint_gt.shape[0] 59 | for i in range(num_kp): 60 | if keypoint_vis[i]: 61 | self.data[i].append(euclidean_dist[i]) 62 | 63 | def _get_pck(self, kp_id, threshold): 64 | """ Returns pck for one keypoint for the given threshold. """ 65 | if len(self.data[kp_id]) == 0: 66 | return None 67 | 68 | data = np.array(self.data[kp_id]) 69 | pck = np.mean((data <= threshold).astype("float")) 70 | return pck 71 | 72 | def get_pck_all(self, threshold): 73 | pckall = [] 74 | for kp_id in range(self.num_kp): 75 | pck = self._get_pck(kp_id, threshold) 76 | pckall.append(pck) 77 | pckall = np.mean(np.array(pckall)) 78 | return pckall 79 | 80 | def _get_epe(self, kp_id): 81 | """ Returns end point error for one keypoint. """ 82 | if len(self.data[kp_id]) == 0: 83 | return None, None 84 | 85 | data = np.array(self.data[kp_id]) 86 | epe_mean = np.mean(data) 87 | epe_median = np.median(data) 88 | return epe_mean, epe_median 89 | 90 | def get_measures(self, val_min, val_max, steps): 91 | """ Outputs the average mean and median error as well as the pck score. """ 92 | thresholds = np.linspace(val_min, val_max, steps) 93 | thresholds = np.array(thresholds) 94 | norm_factor = np.trapz(np.ones_like(thresholds), thresholds) 95 | 96 | # init mean measures 97 | epe_mean_all = list() 98 | epe_median_all = list() 99 | auc_all = list() 100 | pck_curve_all = list() 101 | 102 | # Create one plot for each part 103 | for part_id in range(self.num_kp): 104 | # mean/median error 105 | mean, median = self._get_epe(part_id) 106 | 107 | if mean is None: 108 | # there was no valid measurement for this keypoint 109 | continue 110 | 111 | epe_mean_all.append(mean) 112 | epe_median_all.append(median) 113 | 114 | # pck/auc 115 | pck_curve = list() 116 | for t in thresholds: 117 | pck = self._get_pck(part_id, t) 118 | pck_curve.append(pck) 119 | 120 | pck_curve = np.array(pck_curve) 121 | pck_curve_all.append(pck_curve) 122 | auc = np.trapz(pck_curve, thresholds) 123 | auc /= norm_factor 124 | auc_all.append(auc) 125 | # Display error per keypoint 126 | epe_mean_joint = epe_mean_all 127 | epe_mean_all = np.mean(np.array(epe_mean_all)) 128 | epe_median_all = np.mean(np.array(epe_median_all)) 129 | auc_all = np.mean(np.array(auc_all)) 130 | pck_curve_all = np.mean(np.array(pck_curve_all), axis=0) # mean only over keypoints 131 | 132 | return ( 133 | epe_mean_all, 134 | epe_mean_joint, 135 | epe_median_all, 136 | auc_all, 137 | pck_curve_all, 138 | thresholds, 139 | ) 140 | 141 | # return { 142 | # 'epe_mean_all': epe_mean_all, 143 | # 'epe_mean_joint': epe_mean_joint, 144 | # "epe_median_all": epe_median_all, 145 | # "auc_all": auc_all, 146 | # "pck_curve_all": pck_curve_all, 147 | # "thresholds": thresholds 148 | # } 149 | -------------------------------------------------------------------------------- /utils/func.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms.functional import * 2 | 3 | 4 | def batch_denormalize(tensor, mean, std, inplace=False): 5 | """Normalize a tensor image with mean and standard deviation. 6 | 7 | .. note:: 8 | This transform acts out_testset of place by default, i.e., it does not mutates the input tensor. 9 | 10 | See :class:`~torchvision.transforms.Normalize` for more details. 11 | 12 | Args: 13 | tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized. 14 | mean (sequence): Sequence of means for each channel. 15 | std (sequence): Sequence of standard deviations for each channel. 16 | inplace(bool,optional): Bool to make this operation inplace. 17 | 18 | Returns: 19 | Tensor: Normalized Tensor image. 20 | """ 21 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4: 22 | raise TypeError('invalid tensor or tensor channel is not BCHW') 23 | 24 | if not inplace: 25 | tensor = tensor.clone() 26 | 27 | dtype = tensor.dtype 28 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) 29 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device) 30 | tensor.mul_(std[None, :, None, None]).sub_(-1 * mean[None, :, None, None]) 31 | return tensor 32 | 33 | 34 | def to_numpy(tensor): 35 | if torch.is_tensor(tensor): 36 | return tensor.detach().cpu().numpy() 37 | elif type(tensor).__module__ != 'numpy': 38 | raise ValueError("Cannot convert {} to numpy array" 39 | .format(type(tensor))) 40 | else: 41 | return tensor 42 | 43 | 44 | def bhwc_2_bchw(tensor): 45 | """ 46 | :param x: torch tensor, B x H x W x C 47 | :return: torch tensor, B x C x H x W 48 | """ 49 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4: 50 | raise TypeError('invalid tensor or tensor channel is not BCHW') 51 | return tensor.unsqueeze(1).transpose(1, -1).squeeze(-1) 52 | 53 | 54 | def bchw_2_bhwc(tensor): 55 | """ 56 | :param x: torch tensor, B x C x H x W 57 | :return: torch tensor, B x H x W x C 58 | """ 59 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4: 60 | raise TypeError('invalid tensor or tensor channel is not BCHW') 61 | return tensor.unsqueeze(-1).transpose(1, -1).squeeze(1) 62 | 63 | def initiate(label=None): 64 | if label == "zero": 65 | shape = torch.zeros(10).unsqueeze(0) 66 | pose = torch.zeros(48).unsqueeze(0) 67 | elif label == "uniform": 68 | shape = torch.from_numpy(np.random.normal(size=[1, 10])).float() 69 | pose = torch.from_numpy(np.random.normal(size=[1, 48])).float() 70 | elif label == "01": 71 | shape = torch.rand(1, 10) 72 | pose = torch.rand(1, 48) 73 | else: 74 | raise ValueError("{} not in ['zero'|'uniform'|'01']".format(label)) 75 | return pose, shape 76 | -------------------------------------------------------------------------------- /utils/heatmaputils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Lixin YANG, Jiasen Li. All Rights Reserved. 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def gen_heatmap(img, pt, sigma): 7 | """generate heatmap based on pt coord. 8 | 9 | :param img: original heatmap, zeros 10 | :type img: np (H,W) float32 11 | :param pt: keypoint coord. 12 | :type pt: np (2,) int32 13 | :param sigma: guassian sigma 14 | :type sigma: float 15 | :return 16 | - generated heatmap, np (H, W) each pixel values id a probability 17 | - flag 0 or 1: indicate wheather this heatmap is valid(1) 18 | 19 | """ 20 | 21 | pt = pt.astype(np.int32) 22 | # Check that any part of the gaussian is in-bounds 23 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 24 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 25 | if ( 26 | ul[0] >= img.shape[1] 27 | or ul[1] >= img.shape[0] 28 | or br[0] < 0 29 | or br[1] < 0 30 | ): 31 | # If not, just return the image as is 32 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 33 | return img, 0 34 | 35 | # Generate gaussian 36 | size = 6 * sigma + 1 37 | x = np.arange(0, size, 1, float) 38 | y = x[:, np.newaxis] 39 | x0 = y0 = size // 2 40 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 41 | # Usable gaussian range 42 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 43 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 44 | # Image range 45 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 46 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 47 | 48 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 49 | return img, 1 50 | 51 | 52 | def get_heatmap_pred(heatmaps): 53 | """ get predictions from heatmaps in torch Tensor 54 | return type: torch.LongTensor 55 | """ 56 | assert heatmaps.dim() == 4, 'Score maps should be 4-dim (B, nJoints, H, W)' 57 | maxval, idx = torch.max(heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1), 2) 58 | 59 | maxval = maxval.view(heatmaps.size(0), heatmaps.size(1), 1) 60 | idx = idx.view(heatmaps.size(0), heatmaps.size(1), 1) 61 | 62 | preds = idx.repeat(1, 1, 2).float() # (B, njoint, 2) 63 | 64 | preds[:, :, 0] = (preds[:, :, 0]) % heatmaps.size(3) # + 1 65 | preds[:, :, 1] = torch.floor((preds[:, :, 1]) / heatmaps.size(3)) # + 1 66 | 67 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float() 68 | preds *= pred_mask 69 | return preds 70 | -------------------------------------------------------------------------------- /utils/imgutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import random 5 | import torchvision 6 | import utils.func as func 7 | import config as cfg 8 | 9 | 10 | def get_color_params(brightness=0, contrast=0, saturation=0, hue=0): 11 | if brightness > 0: 12 | brightness_factor = random.uniform( 13 | max(0, 1 - brightness), 1 + brightness) 14 | else: 15 | brightness_factor = None 16 | 17 | if contrast > 0: 18 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) 19 | else: 20 | contrast_factor = None 21 | 22 | if saturation > 0: 23 | saturation_factor = random.uniform( 24 | max(0, 1 - saturation), 1 + saturation) 25 | else: 26 | saturation_factor = None 27 | 28 | if hue > 0: 29 | hue_factor = random.uniform(-hue, hue) 30 | else: 31 | hue_factor = None 32 | return brightness_factor, contrast_factor, saturation_factor, hue_factor 33 | 34 | 35 | def color_jitter(img, brightness=0, contrast=0, saturation=0, hue=0): 36 | brightness, contrast, saturation, hue = get_color_params( 37 | brightness=brightness, 38 | contrast=contrast, 39 | saturation=saturation, 40 | hue=hue) 41 | 42 | # Create img transform function sequence 43 | img_transforms = [] 44 | if brightness is not None: 45 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 46 | if saturation is not None: 47 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 48 | if hue is not None: 49 | img_transforms.append( 50 | lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 51 | if contrast is not None: 52 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 53 | random.shuffle(img_transforms) 54 | 55 | jittered_img = img 56 | for func in img_transforms: 57 | jittered_img = func(jittered_img) 58 | return jittered_img 59 | 60 | 61 | def batch_with_dep(clrs, deps): 62 | clrs = func.to_numpy(clrs) 63 | if clrs.dtype is not np.uint8: 64 | clrs = (clrs * 255).astype(np.uint8) 65 | assert len(deps.shape) == 4, "deps should have shape (B, 1, H, W)" 66 | deps = func.to_numpy(deps) 67 | deps = deps.swapaxes(1, 2).swapaxes(2, 3) 68 | deps = deps.repeat(3, axis=3) 69 | if deps.dtype is not np.uint8: 70 | deps = (deps * 255).astype(np.uint8) 71 | 72 | batch_size = clrs.shape[0] 73 | 74 | alpha = 0.6 75 | beta = 0.9 76 | gamma = 0 77 | 78 | batch = [] 79 | for i in range(16): 80 | if i >= batch_size: 81 | batch.append(np.zeros((64, 64, 3)).astype(np.uint8)) 82 | continue 83 | clr = clrs[i] 84 | clr = cv2.resize(clr, (64, 64)) 85 | dep = deps[i] 86 | dep_img = cv2.addWeighted(clr, alpha, dep, beta, gamma) 87 | batch.append(dep_img) 88 | 89 | resu = [] 90 | for i in range(4): 91 | resu.append(np.concatenate(batch[i * 4: i * 4 + 4], axis=1)) 92 | resu = np.concatenate(resu) 93 | return resu 94 | 95 | 96 | def batch_with_joint(clrs, uvds): 97 | clrs = func.to_numpy(clrs) 98 | if clrs.dtype is not np.uint8: 99 | clrs = (clrs * 255).astype(np.uint8) 100 | uvds = func.to_numpy(uvds) 101 | 102 | batch_size = clrs.shape[0] 103 | 104 | batch = [] 105 | for i in range(16): 106 | if i >= batch_size: 107 | batch.append(np.zeros((256, 256, 3)).astype(np.uint8)) 108 | continue 109 | clr = clrs[i] 110 | uv = (np.array(uvds[i][:, :2]) * clr.shape[0]).astype(np.uint8) # (256) 111 | clr = draw_hand_skeloten(clr, uv, cfg.SNAP_BONES) 112 | batch.append(clr) 113 | 114 | resu = [] 115 | for i in range(4): 116 | resu.append(np.concatenate(batch[i * 4: i * 4 + 4], axis=1)) 117 | resu = np.concatenate(resu) 118 | return resu 119 | 120 | 121 | def draw_hand_skeloten(clr, uv, bone_links, colors=cfg.JOINT_COLORS): 122 | for i in range(len(bone_links)): 123 | bone = bone_links[i] 124 | for j in bone: 125 | cv2.circle(clr, tuple(uv[j]), 4, colors[i], -1) 126 | for j, nj in zip(bone[:-1], bone[1:]): 127 | cv2.line(clr, tuple(uv[j]), tuple(uv[nj]), colors[i], 2) 128 | return clr 129 | 130 | 131 | def batch_with_heatmap( 132 | inputs, 133 | heatmaps, 134 | num_rows=2, 135 | parts_to_show=None, 136 | n_in_batch=1, 137 | ): 138 | # inputs = func.to_numpy(inputs * 255) # 0~1 -> 0 ~255 139 | heatmaps = func.to_numpy(heatmaps) 140 | batch_img = [] 141 | for n in range(min(inputs.shape[0], n_in_batch)): 142 | inp = inputs[n] 143 | batch_img.append( 144 | sample_with_heatmap( 145 | inp, 146 | heatmaps[n], 147 | num_rows=num_rows, 148 | parts_to_show=parts_to_show 149 | ) 150 | ) 151 | resu = np.concatenate(batch_img) 152 | return resu 153 | 154 | 155 | def sample_with_heatmap(img, heatmap, num_rows=2, parts_to_show=None): 156 | if parts_to_show is None: 157 | parts_to_show = np.arange(heatmap.shape[0]) # 21 158 | 159 | # Generate a single image to display input/output pair 160 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) 161 | size = img.shape[0] // num_rows 162 | 163 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) 164 | full_img[:img.shape[0], :img.shape[1]] = img 165 | 166 | inp_small = cv2.resize(img, (size, size)) 167 | 168 | # Set up heatmap display for each part 169 | for i, part in enumerate(parts_to_show): 170 | part_idx = part 171 | out_resized = cv2.resize(heatmap[part_idx], (size, size)) 172 | out_resized = out_resized.astype(float) 173 | out_img = inp_small.copy() * .4 174 | color_hm = color_heatmap(out_resized) 175 | out_img += color_hm * .6 176 | 177 | col_offset = (i % num_cols + num_rows) * size 178 | row_offset = (i // num_cols) * size 179 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img 180 | 181 | return full_img 182 | 183 | 184 | def color_heatmap(x): 185 | color = np.zeros((x.shape[0], x.shape[1], 3)) 186 | color[:, :, 0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 187 | color[:, :, 1] = gauss(x, 1, .5, .3) 188 | color[:, :, 2] = gauss(x, 1, .2, .3) 189 | color[color > 1] = 1 190 | color = (color * 255).astype(np.uint8) 191 | return color 192 | 193 | 194 | def gauss(x, a, b, c, d=0): 195 | return a * np.exp(-(x - b) ** 2 / (2 * c ** 2)) + d 196 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | import scipy.io 6 | import torch 7 | from termcolor import colored, cprint 8 | 9 | import utils.func as func 10 | import copy 11 | 12 | 13 | def print_args(args): 14 | opts = vars(args) 15 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow') 16 | for k, v in sorted(opts.items()): 17 | print("{:>30} : {}".format(k, v)) 18 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow') 19 | 20 | 21 | def param_count(net): 22 | return sum(p.numel() for p in net.parameters()) / 1e6 23 | 24 | 25 | 26 | 27 | def out_loss_auc( 28 | loss_all_, auc_all_, acc_hm_all_, outpath 29 | ): 30 | loss_all = copy.deepcopy(loss_all_) 31 | acc_hm_all = copy.deepcopy(acc_hm_all_) 32 | auc_all = copy.deepcopy(auc_all_) 33 | 34 | for k, l in zip(loss_all.keys(), loss_all.values()): 35 | np.save(os.path.join(outpath, "{}.npy".format(k)), np.vstack((np.arange(1, len(l) + 1), np.array(l))).T) 36 | 37 | if len(acc_hm_all): 38 | for key ,value in acc_hm_all.items(): 39 | acc_hm_all[key]=np.array(value) 40 | np.save(os.path.join(outpath, "acc_hm_all.npy"), acc_hm_all) 41 | 42 | 43 | if len(auc_all): 44 | for key ,value in auc_all.items(): 45 | auc_all[key]=np.array(value) 46 | np.save(os.path.join(outpath, "auc_all.npy"), np.array(auc_all)) 47 | 48 | 49 | def saveloss(d): 50 | for k, v in zip(d.keys(), d.values()): 51 | mat = np.array(v) 52 | np.save(os.path.join("losses", "{}.npy".format(k)), mat) 53 | 54 | 55 | def save_checkpoint( 56 | state, 57 | checkpoint='checkpoint', 58 | filename='checkpoint.pth', 59 | snapshot=None, 60 | # is_best=False 61 | is_best=None 62 | ): 63 | # preds = to_numpy(preds) 64 | filepath = os.path.join(checkpoint, filename) 65 | fileprefix = filename.split('.')[0] 66 | # torch.save(state, filepath) 67 | torch.save(state['model'].state_dict(), filepath) 68 | 69 | if snapshot and state['epoch'] % snapshot == 0: 70 | shutil.copyfile( 71 | filepath, 72 | os.path.join( 73 | checkpoint, 74 | '{}_{}.pth'.format(fileprefix, state['epoch']) 75 | ) 76 | ) 77 | 78 | [auc, best_acc] = is_best 79 | 80 | for key in auc.keys(): 81 | if auc[key] > best_acc[key]: 82 | shutil.copyfile( 83 | filepath, 84 | os.path.join( 85 | checkpoint, 86 | '{}_{}best.pth'.format(fileprefix, key) 87 | ) 88 | ) 89 | 90 | 91 | # def load_checkpoint(model, checkpoint): 92 | # name = checkpoint 93 | # checkpoint = torch.load(name) 94 | # pretrain_dict = clean_state_dict(checkpoint['state_dict']) 95 | # model_state = model.state_dict() 96 | # state = {} 97 | # for k, v in pretrain_dict.items(): 98 | # if k in model_state: 99 | # state[k] = v 100 | # else: 101 | # print(k, ' is NOT in current model') 102 | # model_state.update(state) 103 | # model.load_state_dict(model_state) 104 | # print(colored('loaded {}'.format(name), 'cyan')) 105 | 106 | def load_checkpoint(model, checkpoint): 107 | name = checkpoint 108 | checkpoint = torch.load(name) 109 | pretrain_dict = clean_state_dict(checkpoint['state_dict']) 110 | model_state = model.state_dict() 111 | state = {} 112 | for k, v in pretrain_dict.items(): 113 | if k in model_state: 114 | state[k] = v 115 | else: 116 | print(k, ' is NOT in current model') 117 | model_state.update(state) 118 | model.load_state_dict(model_state) 119 | print(colored('loaded {}'.format(name), 'cyan')) 120 | 121 | 122 | def clean_state_dict(state_dict): 123 | """save a cleaned version of model without dict and DataParallel 124 | 125 | Arguments: 126 | state_dict {collections.OrderedDict} -- [description] 127 | 128 | Returns: 129 | clean_model {collections.OrderedDict} -- [description] 130 | """ 131 | 132 | clean_model = state_dict 133 | # create new OrderedDict that does not contain `module.` 134 | from collections import OrderedDict 135 | clean_model = OrderedDict() 136 | if any(key.startswith('module') for key in state_dict): 137 | for k, v in state_dict.items(): 138 | name = k[7:] # remove `module.` 139 | clean_model[name] = v 140 | else: 141 | return state_dict 142 | 143 | return clean_model 144 | 145 | 146 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): 147 | preds = func.to_numpy(preds) 148 | filepath = os.path.join(checkpoint, filename) 149 | scipy.io.savemat(filepath, mdict={'preds': preds}) 150 | 151 | 152 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): 153 | """Sets the learning rate to the initial LR decayed by schedule""" 154 | if epoch in schedule: 155 | lr *= gamma 156 | print("adjust learning rate to: %.3e" % lr) 157 | for param_group in optimizer.param_groups: 158 | param_group['lr'] = lr 159 | return lr 160 | 161 | 162 | def adjust_learning_rate_in_group(optimizer, group_id, epoch, lr, schedule, gamma): 163 | """Sets the learning rate to the initial LR decayed by schedule""" 164 | if epoch in schedule: 165 | lr *= gamma 166 | print("adjust learning rate of group %d to: %.3e" % (group_id, lr)) 167 | optimizer.param_groups[group_id]['lr'] = lr 168 | return lr 169 | 170 | 171 | def resume_learning_rate(optimizer, epoch, lr, schedule, gamma): 172 | for decay_id in schedule: 173 | if epoch > decay_id: 174 | lr *= gamma 175 | print("adjust learning rate to: %.3e" % lr) 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = lr 178 | return lr 179 | 180 | 181 | def resume_learning_rate_in_group(optimizer, group_id, epoch, lr, schedule, gamma): 182 | for decay_id in schedule: 183 | if epoch > decay_id: 184 | lr *= gamma 185 | print("adjust learning rate of group %d to: %.3e" % (group_id, lr)) 186 | optimizer.param_groups[group_id]['lr'] = lr 187 | return lr 188 | -------------------------------------------------------------------------------- /utils/smoother.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LowPassFilter: 5 | def __init__(self): 6 | self.prev_raw_value = None 7 | self.prev_filtered_value = None 8 | 9 | def process(self, value, alpha): 10 | if self.prev_raw_value is None: 11 | s = value 12 | else: 13 | s = alpha * value + (1.0 - alpha) * self.prev_filtered_value 14 | self.prev_raw_value = value 15 | self.prev_filtered_value = s 16 | return s 17 | 18 | 19 | class OneEuroFilter: 20 | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30): 21 | self.freq = freq 22 | self.mincutoff = mincutoff 23 | self.beta = beta 24 | self.dcutoff = dcutoff 25 | self.x_filter = LowPassFilter() 26 | self.dx_filter = LowPassFilter() 27 | 28 | def compute_alpha(self, cutoff): 29 | te = 1.0 / self.freq 30 | tau = 1.0 / (2 * np.pi * cutoff) 31 | return 1.0 / (1.0 + tau / te) 32 | 33 | def process(self, x): 34 | prev_x = self.x_filter.prev_raw_value 35 | dx = 0.0 if prev_x is None else (x - prev_x) * self.freq 36 | edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff)) 37 | cutoff = self.mincutoff + self.beta * np.abs(edx) 38 | return self.x_filter.process(x, self.compute_alpha(cutoff)) 39 | 40 | 41 | if __name__ == '__main__': 42 | fliter = OneEuroFilter(4.0, 0.0) 43 | noise = 0.01 * np.random.rand(1000) 44 | x = np.linspace(0, 1, 1000) 45 | X = x + noise 46 | import matplotlib.pyplot as plt 47 | 48 | plt.plot(x) 49 | plt.plot(X) 50 | y = np.zeros((1000,)) 51 | for i in range(1000): 52 | y[i] = fliter.process(x[i]) 53 | plt.plot(y) 54 | plt.draw() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot3d(joints_,ax, title=None): 5 | joints = joints_.copy() 6 | ax.plot(joints[:, 0], joints[:, 1], joints[:, 2], 'yo', label='keypoint') 7 | 8 | ax.plot(joints[:5, 0], joints[:5, 1], 9 | joints[:5, 2], 10 | 'r', 11 | label='thumb') 12 | 13 | ax.plot(joints[[0, 5, 6, 7, 8, ], 0], joints[[0, 5, 6, 7, 8, ], 1], 14 | joints[[0, 5, 6, 7, 8, ], 2], 15 | 'b', 16 | label='index') 17 | ax.plot(joints[[0, 9, 10, 11, 12, ], 0], joints[[0, 9, 10, 11, 12], 1], 18 | joints[[0, 9, 10, 11, 12], 2], 19 | 'b', 20 | label='middle') 21 | ax.plot(joints[[0, 13, 14, 15, 16], 0], joints[[0, 13, 14, 15, 16], 1], 22 | joints[[0, 13, 14, 15, 16], 2], 23 | 'b', 24 | label='ring') 25 | ax.plot(joints[[0, 17, 18, 19, 20], 0], joints[[0, 17, 18, 19, 20], 1], 26 | joints[[0, 17, 18, 19, 20], 2], 27 | 'b', 28 | label='pinky') 29 | # snap convention 30 | ax.plot(joints[4][0], joints[4][1], joints[4][2], 'rD', label='thumb') 31 | ax.plot(joints[8][0], joints[8][1], joints[8][2], 'ro', label='index') 32 | ax.plot(joints[12][0], joints[12][1], joints[12][2], 'ro', label='middle') 33 | ax.plot(joints[16][0], joints[16][1], joints[16][2], 'ro', label='ring') 34 | ax.plot(joints[20][0], joints[20][1], joints[20][2], 'ro', label='pinky') 35 | # plt.plot(joints [1:, 0], joints [1:, 1], joints [1:, 2], 'o') 36 | ax.set_xlabel('x') 37 | ax.set_ylabel('y') 38 | ax.set_zlabel('z') 39 | ax.set_xlim(xmin=-1.0,xmax=1.0) 40 | ax.set_ylim(ymin=-1.0,ymax=1.0) 41 | ax.set_zlim(zmin=-1.0,zmax=1.0) 42 | # plt.legend() 43 | # ax.view_init(330, 110) 44 | ax.view_init(-90, -90) 45 | return ax 46 | 47 | 48 | def multi_plot3d(jointss_, title=None): 49 | jointss = jointss_.copy() 50 | fig = plt.figure(figsize=[50, 50]) 51 | 52 | ax = fig.add_subplot(111, projection='3d') 53 | 54 | colors = ['b', 'r', "g"] 55 | 56 | for i in range(len(jointss)): 57 | joints = jointss[i] 58 | 59 | plt.plot(joints[:, 0], joints[:, 1], joints[:, 2], 'yo') 60 | 61 | plt.plot(joints[:5, 0], joints[:5, 1], 62 | joints[:5, 2], 63 | colors[i], 64 | ) 65 | 66 | plt.plot(joints[[0, 5, 6, 7, 8, ], 0], joints[[0, 5, 6, 7, 8, ], 1], 67 | joints[[0, 5, 6, 7, 8, ], 2], 68 | colors[i], 69 | ) 70 | plt.plot(joints[[0, 9, 10, 11, 12, ], 0], joints[[0, 9, 10, 11, 12], 1], 71 | joints[[0, 9, 10, 11, 12], 2], 72 | colors[i], 73 | ) 74 | plt.plot(joints[[0, 13, 14, 15, 16], 0], joints[[0, 13, 14, 15, 16], 1], 75 | joints[[0, 13, 14, 15, 16], 2], 76 | colors[i], 77 | ) 78 | plt.plot(joints[[0, 17, 18, 19, 20], 0], joints[[0, 17, 18, 19, 20], 1], 79 | joints[[0, 17, 18, 19, 20], 2], 80 | colors[i], 81 | ) 82 | 83 | ####### 84 | # plt.plot(joints[:1, 0], joints[:1, 1], 85 | # joints[:1, 2], 86 | # colors[i], 87 | # ) 88 | # 89 | # plt.plot(joints[[0, 5, ], 0], joints[[0, 5, ], 1], 90 | # joints[[0, 5, ], 2], 91 | # colors[i], 92 | # ) 93 | # plt.plot(joints[[0, 9, ], 0], joints[[0, 9, ], 1], 94 | # joints[[0, 9,], 2], 95 | # colors[i], 96 | # ) 97 | # plt.plot(joints[[0, 13, ], 0], joints[[0, 13, ], 1], 98 | # joints[[0, 13, ], 2], 99 | # colors[i], 100 | # ) 101 | # plt.plot(joints[[0, 17, ], 0], joints[[0, 17, ], 1], 102 | # joints[[0, 17, ], 2], 103 | # colors[i], 104 | # ) 105 | 106 | # snap convention 107 | plt.plot(joints[4][0], joints[4][1], joints[4][2], 'rD') 108 | plt.plot(joints[8][0], joints[8][1], joints[8][2], 'ro', ) 109 | plt.plot(joints[12][0], joints[12][1], joints[12][2], 'ro', ) 110 | plt.plot(joints[16][0], joints[16][1], joints[16][2], 'ro', ) 111 | plt.plot(joints[20][0], joints[20][1], joints[20][2], 'ro', ) 112 | # plt.plot(joints [1:, 0], joints [1:, 1], joints [1:, 2], 'o') 113 | 114 | plt.title(title) 115 | ax.set_xlabel('x') 116 | ax.set_ylabel('y') 117 | ax.set_zlabel('z') 118 | plt.legend() 119 | # ax.view_init(330, 110) 120 | ax.view_init(-90, -90) 121 | 122 | if title: 123 | title_ = "" 124 | for i in range(len(title)): 125 | title_ += "{}: {} ".format(colors[i], title[i]) 126 | 127 | ax.set_title(title_, fontsize=12, color='black') 128 | else: 129 | ax.set_title("None", fontsize=12, color='black') 130 | plt.show() 131 | --------------------------------------------------------------------------------