├── .idea
├── .gitignore
├── Minimal-Hand-pytorch.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
└── vcs.xml
├── LICENSE
├── README.md
├── ShapeNet.md
├── __pycache__
└── config.cpython-37.pyc
├── aik_pose.py
├── assets
├── DEMO2.gif
├── demo.gif
└── results.png
├── config.py
├── create_data.py
├── datasets
├── SIK1M.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-37.pyc
│ ├── dexter_object.cpython-37.pyc
│ ├── egodexter.cpython-37.pyc
│ ├── ganerated_hands.cpython-37.pyc
│ ├── hand143_panopticdb.cpython-37.pyc
│ ├── hand_labels.cpython-37.pyc
│ ├── handataset.cpython-37.pyc
│ ├── rhd.cpython-37.pyc
│ └── stb.cpython-37.pyc
├── dexter_object.py
├── egodexter.py
├── ganerated_hands.py
├── hand143_panopticdb.py
├── hand_labels.py
├── handataset.py
├── rhd.py
└── stb.py
├── demo.py
├── demo_dl.py
├── dl_shape_estimate.py
├── environment.yml
├── losses
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── detloss.cpython-37.pyc
├── detloss.py
└── shape_loss.py
├── manopth
└── rotproj.py
├── model
├── detnet
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── detnet.cpython-37.pyc
│ └── detnet.py
├── helper
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── resnet_helper.cpython-37.pyc
│ └── resnet_helper.py
└── shape_net.py
├── op_pso.py
├── optimize_shape.py
├── plot.py
├── train_detnet.py
├── train_shape_net.py
└── utils
├── AIK.py
├── LM.py
├── LM_new.py
├── __init__.py
├── __pycache__
├── AIK.cpython-37.pyc
├── LM.cpython-37.pyc
├── __init__.cpython-37.pyc
├── align.cpython-37.pyc
├── bone.cpython-37.pyc
├── func.cpython-37.pyc
├── handutils.cpython-37.pyc
├── heatmaputils.cpython-37.pyc
├── imgutils.cpython-37.pyc
├── misc.cpython-37.pyc
└── vis.cpython-37.pyc
├── align.py
├── bone.py
├── eval
├── __pycache__
│ ├── evalutils.cpython-37.pyc
│ └── zimeval.cpython-37.pyc
├── evalutils.py
└── zimeval.py
├── func.py
├── handutils.py
├── heatmaputils.py
├── imgutils.py
├── misc.py
├── smoother.py
└── vis.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/Minimal-Hand-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Hao Meng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Minimal Hand Pytorch
2 |
3 | **Unofficial** PyTorch reimplementation of [minimal-hand](https://calciferzh.github.io/files/zhou2020monocular.pdf) (CVPR2020).
4 |
5 | 
6 | 
7 |
8 | you can also find in youtube or bilibili
9 |
10 |
14 |
15 |
16 |
17 |
18 | This project reimplement following components :
19 | 1. Training (DetNet) and Evaluation Code
20 | 1. Shape Estimation
21 | 1. Pose Estimation: Instead of IKNet in original paper, an [analytical inverse kinematics method](https://arxiv.org/abs/2011.14672) is used.
22 |
23 |
24 |
25 | Offical project link:
26 | [\[minimal-hand\]](https://github.com/CalciferZh/minimal-hand)
27 |
28 | ## Update
29 | ###
30 | * 2021/08/22 many guys may get errors when creating environment from .yaml file, u may refer to [here](https://github.com/MengHao666/Minimal-Hand-pytorch/issues/29#issue-976328196)
31 | * 2021/03/09 update about `utils/LM.py`, **time cost drop from 12s/item to 1.57s/item**
32 |
33 | * 2021/03/12 update about `utils/LM.py`, **time cost drop from 1.57s/item to 0.27s/item**
34 |
35 | * 2021/03/17 realtime perfomance is achieved when using PSO to estimate shape, coming soon
36 |
37 | * 2021/03/20 Add PSO to estimate shape. ~~AUC is decreased by about 0.01 on STB and RHD datasets, and increased a little on EO and do datasets.~~ Modifiy utlis/vis.py to improve realtime perfomance
38 |
39 | * 2021/03/24 Fixed some errors in calculating AUC. Update the 3D PCK AUC Diffenence.
40 |
41 | * 2021/06/14 A new method to estimate shape parameters by using fully connected neural network is added. This is finished by @maitetsu as part of his undergraduate graduation project. Please refer to [ShapeNet.md](./ShapeNet.md) for details. Thanks to @kishan1823 and @EEWenbinWu for pointing out the mistake. There are a little differences between the manopth I used and the official manopth. More details see [issues 11](https://github.com/MengHao666/Minimal-Hand-pytorch/issues/11). manopth/rotproj.py is the modified rotproj.py. **This could achieve much faster real-time performance!**
42 |
43 |
44 | ## Usage
45 |
46 | - Retrieve the code
47 | ```sh
48 | git clone https://github.com/MengHao666/Minimal-Hand-pytorch
49 | cd Minimal-Hand-pytorch
50 | ```
51 |
52 | - Create and activate the virtual environment with python dependencies
53 | ```
54 | conda env create --file=environment.yml
55 | conda activate minimal-hand-torch
56 | ```
57 |
58 | ### Prepare MANO hand model
59 | 1. Download MANO model from [here](https://mano.is.tue.mpg.de/) and unzip it.
60 | 1. Create an account by clicking *Sign Up* and provide your information
61 | 1. Download Models and Code (the downloaded file should have the format mano_v*_*.zip). Note that all code and data from this download falls under the [MANO license](http://mano.is.tue.mpg.de/license).
62 | 1. unzip and copy the content of the *models* folder into the `mano` folder
63 |
64 | 1. Your structure should look like this:
65 |
66 | ```
67 | Minimal-Hand-pytorch/
68 | mano/
69 | models/
70 | webuser/
71 | ```
72 |
73 | ## Download and Prepare datasets
74 |
75 | ### Training dataset
76 | * CMU HandDB [part1](http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand143_panopticdb.tar) ; [part2](http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip)
77 | * [Rendered Handpose Dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html)
78 | * [GANerated Hands Dataset](https://handtracker.mpi-inf.mpg.de/projects/GANeratedHands/GANeratedDataset.htm)
79 |
80 | ### Evaluation dataset
81 | * [STB Dataset](https://github.com/zhjwustc/icip17_stereo_hand_pose_dataset),or u can find it[here](https://bhpan.buaa.edu.cn:443/link/55321872BA66E9205C91BA30D9FADC8F):
82 |
83 | STB_supp: for license reason, download link could be found in [bihand](https://github.com/lixiny/bihand )
84 |
85 | * [DexterObjectDataset](https://handtracker.mpi-inf.mpg.de/projects/RealtimeHO/dexter+object.htm) ;
86 |
87 | DO_supp: [Google Drive](https://drive.google.com/file/d/1uhqJGfVJs_-Yviqj9Y2Ngo7NPt5hUihl/view?usp=sharing) or
88 | [Baidu Pan](https://pan.baidu.com/s/1ckfLnaBQUfZJG3IDvMo09Q) (`s892`)
89 | * [EgoDexterDataset](http://handtracker.mpi-inf.mpg.de/projects/OccludedHands/EgoDexter.htm)
90 |
91 | EO_supp: [Google Drive](https://drive.google.com/file/d/1WRHVTp7ZmryE41xN2Yhp-qet0ddeOim4/view?usp=sharing) or
92 | [Baidu Pan](https://pan.baidu.com/s/1sK4Nfvs6og-eXJGwDQCHlQ) (`axkm`)
93 |
94 |
95 |
96 | ### Processing
97 | - Create a data directory, extract all above datasets or additional materials in it
98 |
99 | Now your `data` folder structure should like this:
100 | ```
101 | data/
102 |
103 | CMU/
104 | hand143_panopticdb/
105 | datasets/
106 | ...
107 | hand_labels/
108 | datasets/
109 | ...
110 |
111 | RHD/
112 | RHD_published_v2/
113 | evaluation/
114 | training/
115 | view_sample.py
116 | ...
117 |
118 | GANeratedHands_Release/
119 | data/
120 | ...
121 |
122 | STB/
123 | images/
124 | B1Counting/
125 | SK_color_0.png
126 | SK_depth_0.png
127 | SK_depth_seg_0.png <-- merged from STB_supp
128 | ...
129 | ...
130 | labels/
131 | B1Counting_BB.mat
132 | ...
133 |
134 | dexter+object/
135 | calibration/
136 | bbox_dexter+object.csv
137 | DO_pred_2d.npy
138 | data/
139 | Grasp1/
140 | annotations/
141 | Grasp13D.txt
142 | my_Grasp13D.txt
143 | ...
144 | ...
145 | Grasp2/
146 | annotations/
147 | Grasp23D.txt
148 | my_Grasp23D.txt
149 | ...
150 | ...
151 | Occlusion/
152 | annotations/
153 | Occlusion3D.txt
154 | my_Occlusion3D.txt
155 | ...
156 | ...
157 | Pinch/
158 | annotations/
159 | Pinch3D.txt
160 | my_Pinch3D.txt
161 | ...
162 | ...
163 | Rigid/
164 | annotations/
165 | Rigid3D.txt
166 | my_Rigid3D.txt
167 | ...
168 | ...
169 | Rotate/
170 | annotations/
171 | Rotate3D.txt
172 | my_Rotate3D.txt
173 | ...
174 | ...
175 |
176 |
177 | EgoDexter/
178 | preview/
179 | data/
180 | Desk/
181 | annotation.txt_3D.txt
182 | my_annotation.txt_3D.txt
183 | ...
184 | Fruits/
185 | annotation.txt_3D.txt
186 | my_annotation.txt_3D.txt
187 | ...
188 | Kitchen/
189 | annotation.txt_3D.txt
190 | my_annotation.txt_3D.txt
191 | ...
192 | Rotunda/
193 | annotation.txt_3D.txt
194 | my_annotation.txt_3D.txt
195 | ...
196 |
197 | ```
198 |
199 | ### Note
200 | - **All code and data from these download falls under their own licenses.**
201 | - DO represents "dexter+object" dataset; EO represents "EgoDexter" dataset
202 | - `DO_supp` and `EO_supp` are modified from original ones.
203 | - DO_pred_2d.npy are 2D predictions from 2D part of DetNet.
204 | - some labels of DO and EO is obviously wrong (u could find some examples with original labels from [dexter_object.py](datasets/dexter_object.py) or [egodexter.py](datasets/egodexter.py)), when projected into image plane, thus should be omitted.
205 | Here come `my_{}3D.txt` and `my_annotation.txt_3D.txt`.
206 |
207 | ## Download my Results
208 |
209 | - my_results: [Google Drive](https://drive.google.com/file/d/1e6aG4ZSOB6Ri_1TjXI9N-1r7MtwmjA6w/view?usp=sharing) or
210 | [Baidu Pan](https://pan.baidu.com/s/1Hh0ZU8p04prFVSp9bQm_IA) (`2rv7`)
211 | - extract it in project folder
212 | - **The parameters used in the real-time demo can be found [google_drive](https://drive.google.com/file/d/1fug29PBMo1Cb2DwAtX7f2E_yLHjDBmiM/view?usp=sharing) or [baidu](https://pan.baidu.com/s/1gr3xSkLuvsveSQ7nW1taSA) (un06). It is trained with loss of [Hand-BMC-pytorch](https://github.com/MengHao666/Hand-BMC-pytorch) together!!!**
213 |
214 |
215 | realtime demo with PSO-based shape estimation
216 |
217 | ```
218 | python demo.py
219 | ```
220 |
221 | realtime demo with learing-based shape estimation
222 |
223 | ```
224 | python demo_dl.py
225 | ```
226 |
227 | ## DetNet Training and Evaluation
228 |
229 | Run the training code
230 | ```
231 | python train_detnet.py --data_root data/
232 | ```
233 |
234 |
235 | Run the evaluation code
236 | ```
237 | python train_detnet.py --data_root data/ --datasets_test testset_name_to_test --evaluate --evaluate_id checkpoints_id_to_load
238 | ```
239 | or use my results
240 | ```
241 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "rhd" --evaluate --evaluate_id 106
242 |
243 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "stb" --evaluate --evaluate_id 71
244 |
245 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "do" --evaluate --evaluate_id 68
246 |
247 | python train_detnet.py --checkpoint my_results/checkpoints --datasets_test "eo" --evaluate --evaluate_id 101
248 | ```
249 |
250 | ## Shape Estimation with LM algorithm
251 |
252 | Run the shape optimization code. This can be very time consuming when the weight parameter is quite small.
253 | ```
254 | python optimize_shape.py --weight 1e-5
255 | ```
256 | or use my results
257 | ```
258 | python optimize_shape.py --path my_results/out_testset/
259 | ```
260 |
261 | ## Pose Estimation
262 |
263 | Run the following code which uses a analytical inverse kinematics method.
264 | ```
265 | python aik_pose.py
266 | ```
267 | or use my results
268 | ```
269 | python aik_pose.py --path my_results/out_testset/
270 | ```
271 |
272 |
273 | ### Detnet training and evaluation curve
274 | Run the following code to see my results
275 | ```
276 | python plot.py --path my_results/out_loss_auc
277 | ```
278 |
279 | (AUC means 3D PCK, and ACC_HM means 2D PCK)
280 | 
281 |
282 | ### 3D PCK AUC Diffenence
283 |
284 | \* means this project
285 |
286 | | Dataset | DetNet(paper) | DetNet(*) | DetNet+IKNet(paper) | DetNet+LM+AIK(*) | DetNet+PSO+AIK(*) | DetNet+DL+AIK(*) |
287 | | :-----: | :-----------: | :-------: | :-----------------: | :--------------: | :-------------: | :-------------: |
288 | | **RHD** | - | 0.9339 | 0.856 | 0.9301 | 0.9310 | 0.9272 |
289 | | **STB** | 0.891 | 0.8744 | 0.898 | 0.8647 | 0.8671 | 0.8624 |
290 | | **DO** | 0.923 | 0.9378 | 0.948 | 0.9392 | 0.9342 | 0.9400 |
291 | | **EO** | 0.804 | 0.9270 | 0.811 | 0.9288 | 0.9277 | 0.9365 |
292 |
293 |
294 |
295 | ### Note
296 |
297 | - Adjusting training parameters carefully, longer training time, more complicated network or **[Biomechanical Constraint Losses](https://github.com/MengHao666/Hand-BMC-pytorch)** could further boost accuracy.
298 | - As there is no official open source of original paper, above comparison is a little rough.
299 |
300 | ## Citation
301 |
302 | This is the **unofficial** pytorch reimplementation of the paper "Monocular Real-time Hand Shape and Motion Capture using Multi-modal Data" (CVPR 2020).
303 |
304 |
305 |
306 | If you find the project helpful, please star this project and cite them:
307 | ```
308 | @inproceedings{zhou2020monocular,
309 | title={Monocular Real-time Hand Shape and Motion Capture using Multi-modal Data},
310 | author={Zhou, Yuxiao and Habermann, Marc and Xu, Weipeng and Habibie, Ikhsanul and Theobalt, Christian and Xu, Feng},
311 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
312 | pages={0--0},
313 | year={2020}
314 | }
315 | ```
316 |
317 | ## Acknowledgement
318 |
319 | - Code of Mano Pytorch Layer was adapted from [manopth](https://github.com/hassony2/manopth).
320 |
321 | - Code for evaluating the hand PCK and AUC in `utils/eval/zimeval.py` was adapted from [hand3d](https://github.com/lmb-freiburg/hand3d).
322 |
323 | - Part code of data augmentation, dataset parsing and utils were adapted from [bihand](https://github.com/lixiny/bihand) and [3D-Hand-Pose-Estimation](https://github.com/OlgaChernytska/3D-Hand-Pose-Estimation).
324 |
325 | - Code of network model was adapted from [Minimal-Hand](https://github.com/lingtengqiu/Minimal-Hand).
326 |
327 | - @Mrsirovo for the starter code of the `utils/LM.py` , @maitetsu update it later.
328 |
329 | - @maitetsu for the starter code of the `utils/AIK.py`,the implementation of PSO and deep-learing method for shape estimation.
330 |
--------------------------------------------------------------------------------
/ShapeNet.md:
--------------------------------------------------------------------------------
1 | ## ShapeNet
2 |
3 | #### train ShapeNet
4 |
5 | ShapeNet is a model that uses fully connected neural network to estimate shape parameters. It's faster than pso. Code of ShapeNet are adapted from [bihand](https://github.com/lixiny/bihand). If you want train your ShapeNet, just run these.
6 |
7 | ```python
8 | # create training set
9 | python create_data.py
10 | # train the model
11 | python train_shape_net.py
12 | ```
13 | As for the pre-trained model or the training set generated by me, you can download from [google drive](https://drive.google.com/drive/folders/1OEM8Q8SIJjXkembz3wlp6UiXFfpzAsLu?usp=sharing) or [bhpan](https://bhpan.buaa.edu.cn:443/link/88CD866C9EB3F30906D57571678FFE6D). Put the data_bone.npy and data_shape.npy in ROOT_DIR/data. Put the ckp_siknet_synth_41.pth.tar in ROOT_DIR/checkpoints.
14 |
15 | ```
16 | ROOT_DIR
17 | ...
18 | |--data
19 | |--data_bone.npy
20 | |--data_shape.npy
21 | ...
22 | |--checkpoints
23 | |--ckp_siknet_synth_41.pth.tar
24 | ...
25 | ```
26 |
27 |
28 |
29 | #### training set
30 |
31 | The training set is generated by MANO. More details see create_data.py.
32 |
33 | 1. First sample shape parameters from normal distribution N(0,3)
34 |
35 | 2. Calculate the relative bone length corresponding to the shape parameter.
36 |
37 | #### loss
38 |
39 | The loss of ShapeNet consists of two parts, one is the error of relative bone length, the other is the regularization loss of shape parameters.
40 |
41 | $$
42 | L = \lambda_1 ||\hat{x} - x||_2^2 + \lambda_2 || \hat{y}||_2^2 \\
43 | $$
44 |
45 | $x$ is the input relative bone length.
46 | $\hat{y}$ is the output shape parameters of ShapeNet.
47 | $\hat{x}$ is the relative bone length corresponding to the shape parameter.
48 |
49 | In my opinion, shape parameters include not only relative bone length information, but also absolute bone length information. It is impossible to guarantee that a relative bone length only corresponds to one shape parameter, which is necessary for neural networks. Therefore, the loss function does not directly calculate the error between the shape parameter and the label of it.
50 |
51 | #### AUC
52 |
53 | AUC of ShapeNet can refer [README.md](./README.md). You can get higher AUC, if you change " beta = torch.tanh(beta) " which is the line 85 of model/shape_net.py to "beta = 3*torch.tanh(beta) ". This will make the output range of ShapeNet bigger and get higher AUC. According to the experiment, the finger will be thinner.
54 |
55 | I didn't adjust the parameters carefully. Maybe you can get better results if you adjust parameters.
56 |
57 |
--------------------------------------------------------------------------------
/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/aik_pose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from manopth import demo
6 | from manopth import manolayer
7 | from tqdm import tqdm
8 |
9 | from utils import AIK, align, vis
10 | from utils.eval.zimeval import EvalUtil
11 |
12 |
13 | def recon_eval(op_shapes, pre_j3ds, gt_j3ds, visual, key):
14 | pose0 = torch.eye(3).repeat(1, 16, 1, 1)
15 | mano = manolayer.ManoLayer(flat_hand_mean=True,
16 | side="right",
17 | mano_root='mano/models',
18 | use_pca=False,
19 | root_rot_mode='rotmat',
20 | joint_rot_mode='rotmat')
21 |
22 | j3d_recons = []
23 | evaluator = EvalUtil()
24 | for i in tqdm(range(pre_j3ds.shape[0])):
25 | j3d_pre = pre_j3ds[i]
26 |
27 | op_shape = torch.tensor(op_shapes[i]).float().unsqueeze(0)
28 | _, j3d_p0_ops = mano(pose0, op_shape)
29 | template = j3d_p0_ops.cpu().numpy().squeeze() / 1000.0 # template, m
30 |
31 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(j3d_pre[9] - j3d_pre[0])
32 | j3d_pre_process = j3d_pre * ratio # template, m
33 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0]
34 |
35 | pose_R = AIK.adaptive_IK(template, j3d_pre_process)
36 | pose_R = torch.from_numpy(pose_R).float()
37 |
38 | # reconstruction
39 | hand_verts, j3d_recon = mano(pose_R, op_shape.float())
40 |
41 | # visualization
42 | if visual:
43 | demo.display_hand({
44 | 'verts': hand_verts.cpu(),
45 | 'joints': j3d_recon.cpu()
46 | },
47 | mano_faces=mano.th_faces)
48 |
49 | j3d_recon = j3d_recon.cpu().numpy().squeeze() / 1000.
50 | j3d_recons.append(j3d_recon)
51 |
52 | # visualization
53 | if visual:
54 | vis.multi_plot3d([j3d_recon, j3d_pre_process], title=["recon", "pre"])
55 | j3d_recons = np.array(j3d_recons)
56 | gt_joint, j3d_recon_align_gt = align.global_align(gt_j3ds, j3d_recons, key=key)
57 |
58 | for targj, predj_a in zip(gt_joint, j3d_recon_align_gt):
59 | evaluator.feed(targj * 1000.0, predj_a * 1000.0)
60 |
61 | (
62 | _1, _2, _3,
63 | auc_all,
64 | pck_curve_all,
65 | thresholds
66 | ) = evaluator.get_measures(
67 | 20, 50, 15
68 | )
69 | print("Reconstruction AUC all of {}_test_set is : {}".format(key, auc_all))
70 |
71 |
72 | def main(args):
73 | path = args.path
74 | for key_i in args.dataset:
75 | print("load {}'s joint 3D".format(key_i))
76 | _path = "{}/{}_dl.npy".format(path, key_i)
77 | print('load {}'.format(_path))
78 | op_shapes = np.load(_path)
79 | pre_j3ds = np.load("{}/{}_pre_joints.npy".format(path, key_i))
80 | gt_j3ds = np.load("{}/{}_gt_joints.npy".format(path, key_i))
81 | recon_eval(op_shapes, pre_j3ds, gt_j3ds, args.visualize, key_i)
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser(
86 | description=' get pose params. of mano model ')
87 |
88 | parser.add_argument(
89 | '-ds',
90 | "--dataset",
91 | nargs="+",
92 | default=['rhd', 'stb', 'do', 'eo'],
93 | type=list,
94 | help="sub datasets, should be listed in: [stb|rhd|do|eo]"
95 | )
96 |
97 | parser.add_argument(
98 | '-p',
99 | "--path",
100 | default="out_testset",
101 | type=str,
102 | help="path"
103 | )
104 |
105 | parser.add_argument(
106 | '-vis',
107 | '--visualize',
108 | action='store_true',
109 | help='visualize reconstruction result',
110 | default=False
111 | )
112 |
113 | main(parser.parse_args())
114 |
--------------------------------------------------------------------------------
/assets/DEMO2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/DEMO2.gif
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/demo.gif
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/assets/results.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | DEPTH_RANGE = 3.0
2 | DEPTH_MIN = -1.5
3 |
4 | stb_joints = [
5 | 'loc_bn_palm_L',
6 | 'loc_bn_pinky_L_01',
7 | 'loc_bn_pinky_L_02',
8 | 'loc_bn_pinky_L_03',
9 | 'loc_bn_pinky_L_04',
10 | 'loc_bn_ring_L_01',
11 | 'loc_bn_ring_L_02',
12 | 'loc_bn_ring_L_03',
13 | 'loc_bn_ring_L_04',
14 | 'loc_bn_mid_L_01',
15 | 'loc_bn_mid_L_02',
16 | 'loc_bn_mid_L_03',
17 | 'loc_bn_mid_L_04',
18 | 'loc_bn_index_L_01',
19 | 'loc_bn_index_L_02',
20 | 'loc_bn_index_L_03',
21 | 'loc_bn_index_L_04',
22 | 'loc_bn_thumb_L_01',
23 | 'loc_bn_thumb_L_02',
24 | 'loc_bn_thumb_L_03',
25 | 'loc_bn_thumb_L_04',
26 | ]
27 |
28 | rhd_joints = [
29 | 'loc_bn_palm_L',
30 | 'loc_bn_thumb_L_04',
31 | 'loc_bn_thumb_L_03',
32 | 'loc_bn_thumb_L_02',
33 | 'loc_bn_thumb_L_01',
34 | 'loc_bn_index_L_04',
35 | 'loc_bn_index_L_03',
36 | 'loc_bn_index_L_02',
37 | 'loc_bn_index_L_01',
38 | 'loc_bn_mid_L_04',
39 | 'loc_bn_mid_L_03',
40 | 'loc_bn_mid_L_02',
41 | 'loc_bn_mid_L_01',
42 | 'loc_bn_ring_L_04',
43 | 'loc_bn_ring_L_03',
44 | 'loc_bn_ring_L_02',
45 | 'loc_bn_ring_L_01',
46 | 'loc_bn_pinky_L_04',
47 | 'loc_bn_pinky_L_03',
48 | 'loc_bn_pinky_L_02',
49 | 'loc_bn_pinky_L_01'
50 | ]
51 |
52 | snap_joint_names = [
53 | 'loc_bn_palm_L',
54 | 'loc_bn_thumb_L_01',
55 | 'loc_bn_thumb_L_02',
56 | 'loc_bn_thumb_L_03',
57 | 'loc_bn_thumb_L_04',
58 | 'loc_bn_index_L_01',
59 | 'loc_bn_index_L_02',
60 | 'loc_bn_index_L_03',
61 | 'loc_bn_index_L_04',
62 | 'loc_bn_mid_L_01',
63 | 'loc_bn_mid_L_02',
64 | 'loc_bn_mid_L_03',
65 | 'loc_bn_mid_L_04',
66 | 'loc_bn_ring_L_01',
67 | 'loc_bn_ring_L_02',
68 | 'loc_bn_ring_L_03',
69 | 'loc_bn_ring_L_04',
70 | 'loc_bn_pinky_L_01',
71 | 'loc_bn_pinky_L_02',
72 | 'loc_bn_pinky_L_03',
73 | 'loc_bn_pinky_L_04'
74 | ]
75 |
76 | SNAP_BONES = [
77 | (0, 1, 2, 3, 4),
78 | (0, 5, 6, 7, 8),
79 | (0, 9, 10, 11, 12),
80 | (0, 13, 14, 15, 16),
81 | (0, 17, 18, 19, 20)
82 | ]
83 |
84 | SNAP_PARENT = [
85 | 0, # 0's parent
86 | 0, # 1's parent
87 | 1,
88 | 2,
89 | 3,
90 | 0, # 5's parent
91 | 5,
92 | 6,
93 | 7,
94 | 0, # 9's parent
95 | 9,
96 | 10,
97 | 11,
98 | 0, # 13's parent
99 | 13,
100 | 14,
101 | 15,
102 | 0, # 17's parent
103 | 17,
104 | 18,
105 | 19,
106 | ]
107 |
108 | JOINT_COLORS = (
109 | (216, 31, 53),
110 | (214, 208, 0),
111 | (136, 72, 152),
112 | (126, 199, 216),
113 | (0, 0, 230),
114 | )
115 |
116 | DEFAULT_CACHE_DIR = 'datasets/data/.cache'
117 |
118 | USEFUL_BONE = [1, 2, 3,
119 | 5, 6, 7,
120 | 9, 10, 11,
121 | 13, 14, 15,
122 | 17, 18, 19]
123 |
124 | kinematic_tree = [2, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 19, 20]
125 |
126 | ID2ROT = {
127 | 2: 13, 3: 14, 4: 15,
128 | 6: 1, 7: 2, 8: 3,
129 | 10: 4, 11: 5, 12: 6,
130 | 14: 10, 15: 11, 16: 12,
131 | 18: 7, 19: 8, 20: 9,
132 | }
--------------------------------------------------------------------------------
/create_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from manopth import manolayer
3 | import os
4 |
5 |
6 | class DataSet:
7 | def __init__(self, device=torch.device('cpu'), _mano_root='mano/models'):
8 | args = {'flat_hand_mean': True, 'root_rot_mode': 'axisang',
9 | 'ncomps': 45, 'mano_root': _mano_root,
10 | 'no_pca': True, 'joint_rot_mode': 'axisang', 'side': 'right'}
11 | self.mano = manolayer.ManoLayer(flat_hand_mean=args['flat_hand_mean'],
12 | side=args['side'],
13 | mano_root=args['mano_root'],
14 | ncomps=args['ncomps'],
15 | use_pca=not args['no_pca'],
16 | root_rot_mode=args['root_rot_mode'],
17 | joint_rot_mode=args['joint_rot_mode']
18 | ).to(device)
19 | self.device = device
20 |
21 | def new_cal_ref_bone(self, _shape):
22 | parent_index = [0,
23 | 0, 1, 2,
24 | 0, 4, 5,
25 | 0, 7, 8,
26 | 0, 10, 11,
27 | 0, 13, 14
28 | ]
29 | index = [0,
30 | 1, 2, 3, # index
31 | 4, 5, 6, # middle
32 | 7, 8, 9, # pinky
33 | 10, 11, 12, # ring
34 | 13, 14, 15] # thumb
35 | reoder_index = [
36 | 13, 14, 15,
37 | 1, 2, 3,
38 | 4, 5, 6,
39 | 10, 11, 12,
40 | 7, 8, 9]
41 | shape = _shape.clone().detach()
42 | th_v_shaped = torch.matmul(self.mano.th_shapedirs,
43 | shape.transpose(1, 0)).permute(2, 0, 1) \
44 | + self.mano.th_v_template
45 | th_j = torch.matmul(self.mano.th_J_regressor, th_v_shaped)
46 | temp1 = th_j.clone().detach()
47 | temp2 = th_j.clone().detach()[:, parent_index, :]
48 | result = temp1 - temp2
49 | ref_len = th_j[:, [4], :] - th_j[:, [0], :]
50 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True)
51 | result = torch.norm(result, dim=-1, keepdim=True)
52 | result = result / ref_len
53 | return torch.squeeze(result, dim=-1)[:, reoder_index]
54 |
55 | def sample(self):
56 | shape = 3 * torch.randn((1, 10))
57 | result = self.new_cal_ref_bone(shape)
58 | return (result, shape)
59 |
60 | def batch_sample(self, batch_size):
61 | shape = 3 * torch.randn((batch_size, 10))
62 | result = self.new_cal_ref_bone(shape)
63 | return (result, shape)
64 |
65 | @staticmethod
66 | def cal_ref_bone(_Jtr):
67 | parent_index = [0,
68 | 0, 1, 2, 3,
69 | 0, 5, 6, 7,
70 | 0, 9, 10, 8,
71 | 0, 13, 14, 15,
72 | 0, 17, 18, 19
73 | ]
74 | index = [1, 2, 3,
75 | 5, 6, 7,
76 | 9, 10, 11,
77 | 13, 14, 15,
78 | 17, 18, 19]
79 | temp1 = _Jtr.clone().detach()
80 | temp2 = _Jtr.clone().detach()[:, parent_index, :]
81 | result = temp1 - temp2
82 | result = result[:, index, :]
83 | ref_len = _Jtr[:, [9], :] - _Jtr[:, [0], :]
84 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True)
85 | result = torch.norm(result, dim=-1, keepdim=True)
86 | # result = result / ref_len
87 | return torch.squeeze(result, dim=-1)
88 |
89 |
90 | if __name__ == '__main__':
91 | dataset = DataSet()
92 | import numpy as np
93 | import tqdm
94 |
95 | Total_Num = 1000000
96 | NUM = 10000
97 | data_bone = np.zeros((Total_Num, 15))
98 | data_shape = np.zeros((Total_Num, 10))
99 | for i in tqdm.tqdm(range(Total_Num // NUM)):
100 | t1 = i * NUM
101 | t2 = t1 + NUM
102 | temp_1, temp_2 = dataset.batch_sample(NUM)
103 | data_bone[t1:t2] = temp_1
104 | data_shape[t1:t2] = temp_2
105 | print(t1, t2)
106 |
107 | save_dir = 'data'
108 | if os.path.exists(save_dir):
109 | pass
110 | else:
111 | os.mkdir(save_dir)
112 | np.save(os.path.join(save_dir, 'data_bone.npy'), data_bone)
113 | np.save(os.path.join(save_dir, 'data_shape.npy'), data_shape)
114 | print('*' * 10, 'test', '*' * 10)
115 | data_bone = np.load(os.path.join(save_dir, 'data_bone.npy'))
116 | data_shape = np.load(os.path.join(save_dir, 'data_shape.npy'))
117 | test_flag = 1
118 | for i in tqdm.tqdm(range(Total_Num // NUM)):
119 | t1 = i * NUM
120 | t2 = t1 + NUM
121 | test_shape = data_shape[t1:t2]
122 | test_shape = torch.tensor(test_shape, dtype=torch.float)
123 | test_bone = data_bone[t1:t2]
124 | temp_1 = dataset.new_cal_ref_bone(test_shape)
125 | flag = np.allclose(temp_1, test_bone)
126 | flag = int(flag)
127 | test_flag = test_flag * flag
128 | print(test_flag)
129 |
--------------------------------------------------------------------------------
/datasets/SIK1M.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle
3 | import torch
4 | import os
5 | from torch.utils import data
6 | from termcolor import colored, cprint
7 | import numpy as np
8 |
9 | sik1m_inst = 0
10 |
11 |
12 | class _SIK1M(data.Dataset):
13 | """
14 | The Loader for joints so3 and quat
15 | """
16 |
17 | def __init__(
18 | self,
19 | data_root="data",
20 | data_source=None
21 | ):
22 | print("Initialize _SIK1M instance")
23 | bone_len_path = os.path.join(data_root, 'data_bone.npy')
24 | shape_path = os.path.join(data_root, 'data_shape.npy')
25 | self.bone_len = np.load(bone_len_path)
26 | self.shape = np.load(shape_path)
27 |
28 | def __len__(self):
29 | return self.shape.shape[0]
30 |
31 | def __getitem__(self, index):
32 | temp_bone_len = self.bone_len[index]
33 | temp_shape = self.shape[index]
34 |
35 | metas = {
36 | 'rel_bone_len': temp_bone_len,
37 | 'shape': temp_shape
38 | }
39 | return metas
40 |
41 |
42 | class SIK1M(data.Dataset):
43 | def __init__(
44 | self,
45 | data_split="train",
46 | data_root="data",
47 | split_ratio=0.8
48 | ):
49 | global sik1m_inst
50 | if not sik1m_inst:
51 | sik1m_inst = _SIK1M(data_root=data_root)
52 | self.sik1m = sik1m_inst
53 | self.permu = list(range(len(self.sik1m)))
54 | self.alllen = len(self.sik1m)
55 | self.data_split = data_split
56 | # add the 0.1* the std of Relative bone length as noise, you can change it or not add
57 | self.noise = np.array([0.02906406, 0.02663224, 0.01769793, 0.0274501, 0.02573783, 0.0222863,
58 | 0., 0.02855567, 0.02330295, 0.0253288, 0.0266308, 0.02495683, 0.03685857, 0.02430637,
59 | 0.02349446])
60 | self.noise = self.noise / 10.0
61 | if data_split == "train":
62 | self.vislen = int(len(self.sik1m) * split_ratio)
63 | self.sub_permu = self.permu[:self.vislen]
64 | elif data_split in ["val", "test"]:
65 | self.vislen = self.alllen - int(len(self.sik1m) * split_ratio)
66 | self.sub_permu = self.permu[(self.alllen - self.vislen):]
67 | else:
68 | self.vislen = len(self.sik1m)
69 | self.sub_permu = self.permu[:self.vislen]
70 |
71 | def __len__(self):
72 | return self.vislen
73 |
74 | def __getitem__(self, index):
75 | item = self.sik1m[self.sub_permu[index]]
76 | temp = np.random.randn(15, )
77 | temp = np.multiply(self.noise, temp)
78 | item['rel_bone_len'] += temp
79 | return item
80 |
81 |
82 | def main():
83 | sik1m_train = SIK1M(
84 | data_split="train",
85 | data_root="data"
86 | )
87 | sik1m_test = SIK1M(
88 | data_split="test"
89 | )
90 |
91 | metas = sik1m_train[2]
92 | print(metas)
93 | metas = sik1m_train[2]
94 | print(metas)
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dexter_object.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/dexter_object.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/egodexter.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/egodexter.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ganerated_hands.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/ganerated_hands.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/hand143_panopticdb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/hand143_panopticdb.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/hand_labels.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/hand_labels.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/handataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/handataset.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/rhd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/rhd.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/datasets/__pycache__/stb.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/ganerated_hands.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | r"""
3 | GANeratedDataset
4 | GANerated Hands for Real-Time 3D Hand Tracking from Monocular RGB, CVPR 2018
5 | Link to dataset: https://handtracker.mpi-inf.mpg.de/projects/GANeratedHands/GANeratedDataset.htm
6 | """
7 |
8 | import os
9 | import pickle
10 | from builtins import print
11 |
12 | import PIL
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import torch
16 | from PIL import Image
17 | from termcolor import colored
18 | from tqdm import tqdm
19 |
20 | import config as cfg
21 | import utils.handutils as handutils
22 |
23 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR)
24 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)}
25 |
26 |
27 | class GANeratedDataset(torch.utils.data.Dataset):
28 |
29 | def __init__(self,
30 | data_root,
31 | data_split='train',
32 | hand_side='right',
33 | njoints=21,
34 | use_cache=True,
35 | vis=False):
36 | if not os.path.exists(data_root):
37 | raise ValueError("data_root: %s not exist" % data_root)
38 | self.name = 'GANeratedHands Dataset'
39 | self.data_split = data_split
40 | self.hand_side = hand_side
41 | self.clr_paths = []
42 | self.kp2ds = []
43 | self.joints = []
44 | self.centers = []
45 | self.my_scales = []
46 | self.njoints = njoints
47 | self.reslu = [256, 256]
48 |
49 | self.vis = vis
50 |
51 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0
52 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9
53 |
54 | self.intr = np.array([
55 | [617.173, 0, 315.453],
56 | [0, 617.173, 242.259],
57 | [0, 0, 1]])
58 |
59 | # [train|test|val|train_val|all]
60 | if data_split == 'train':
61 | self.sequence = ['training', ]
62 | else:
63 | print("GANeratedDataset only has train set!")
64 | return None
65 |
66 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "GANeratedHands")
67 | os.makedirs(self.cache_folder, exist_ok=True)
68 | cache_path = os.path.join(
69 | self.cache_folder, "{}.pkl".format(self.data_split)
70 | )
71 |
72 | if os.path.exists(cache_path) and use_cache:
73 | with open(cache_path, "rb") as fid:
74 | annotations = pickle.load(fid)
75 | self.clr_paths = annotations["clr_paths"]
76 | self.kp2ds = annotations["kp2ds"]
77 | self.joints = annotations["joints"]
78 | self.centers = annotations["centers"]
79 | self.my_scales = annotations["my_scales"]
80 | print("GANeratedHands {} gt loaded from {}".format(self.data_split, cache_path))
81 | return
82 |
83 | print("init GANeratedHands {}, It will take a while at first time".format(data_split))
84 |
85 | for img_type in ['noObject/', 'withObject/']:
86 | folders = os.listdir(data_root + img_type)
87 | folders = sorted(folders)
88 | folders = [img_type + x + '/' for x in folders if len(x) == 4]
89 |
90 | for folder in folders:
91 | images = os.listdir(os.path.join(data_root + folder))
92 | images = [data_root + folder + x for x in images if x.find('.png') > 0]
93 | images = sorted(images)
94 |
95 | self.clr_paths.extend(images)
96 |
97 | for idx in tqdm(range(len(self.clr_paths))):
98 | img_name = self.clr_paths[idx]
99 |
100 | fn_2d_keypoints = img_name.replace('color_composed.png', 'joint2D.txt')
101 | arr_2d_keypoints = np.loadtxt(fn_2d_keypoints, delimiter=',')
102 | arr_2d_keypoints = arr_2d_keypoints.reshape([-1, 2])
103 |
104 | center = handutils.get_annot_center(arr_2d_keypoints)
105 | self.centers.append(center[np.newaxis, :])
106 |
107 | my_scale = handutils.get_ori_crop_scale(mask=None, mask_flag=False, side=None, kp2d=arr_2d_keypoints,
108 | )
109 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :]
110 | self.my_scales.append(my_scale)
111 |
112 | arr_2d_keypoints = arr_2d_keypoints[np.newaxis, :, :]
113 | self.kp2ds.append(arr_2d_keypoints)
114 |
115 | fn_3d_keypoints = img_name.replace('color_composed.png', 'joint_pos_global.txt')
116 | arr_3d_keypoints = np.loadtxt(fn_3d_keypoints, delimiter=',')
117 | arr_3d_keypoints = arr_3d_keypoints.reshape([-1, 3])
118 | arr_3d_keypoints = arr_3d_keypoints[np.newaxis, :, :]
119 | self.joints.append(arr_3d_keypoints)
120 |
121 | self.joints = np.concatenate(self.joints, axis=0).astype(np.float32)
122 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2)
123 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 2)
124 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32)
125 |
126 | if use_cache:
127 | full_info = {
128 | "clr_paths": self.clr_paths,
129 | "joints": self.joints,
130 | "kp2ds": self.kp2ds,
131 | "centers": self.centers,
132 | "my_scales": self.my_scales,
133 | }
134 | with open(cache_path, "wb") as fid:
135 | pickle.dump(full_info, fid)
136 | print("Wrote cache for dataset GANeratedDataset {} to {}".format(
137 | self.data_split, cache_path
138 | ))
139 | return
140 |
141 | def __len__(self):
142 | """for GANeratedHands Dataset total (1,500 * 2) * 2 * 6 = 36,000 samples
143 | """
144 | return len(self.clr_paths)
145 |
146 | def __str__(self):
147 | info = "GANeratedHands {} set. lenth {}".format(
148 | self.data_split, len(self.clr_paths)
149 | )
150 | return colored(info, 'blue', attrs=['bold'])
151 |
152 | def _is_valid(self, clr, index):
153 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image))
154 |
155 | if not valid_data:
156 | raise Exception("Encountered error processing GAN[{}]".format(index))
157 | return valid_data
158 |
159 | def get_sample(self, index):
160 | flip = True if self.hand_side != "left" else False
161 |
162 | intr = self.intr
163 |
164 | # prepare color image
165 | clr = Image.open(self.clr_paths[index]).convert("RGB")
166 | self._is_valid(clr, index)
167 |
168 | # prepare kp2d
169 | kp2d = self.kp2ds[index].copy()
170 |
171 | # prepare joint
172 | joint = self.joints[index].copy() # (21, 3)
173 | center = self.centers[index].copy()
174 | my_scale = self.my_scales[index].copy()
175 | if flip:
176 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT)
177 | center[0] = clr.size[0] - center[0]
178 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0]
179 | joint[:, 0] = -joint[:, 0]
180 |
181 | sample = {
182 | 'index': index,
183 | 'clr': clr,
184 | 'kp2d': kp2d,
185 | 'center': center,
186 | 'my_scale': my_scale,
187 | 'joint': joint,
188 | 'intr': intr,
189 | }
190 |
191 | # visualization
192 | if self.vis:
193 | fig = plt.figure(figsize=(20, 20))
194 | clr_ = np.array(clr)
195 |
196 | plt.subplot(1, 3, 1)
197 | clr1 = clr_.copy()
198 | plt.imshow(clr1)
199 |
200 | plt.subplot(1, 3, 2)
201 | clr2 = clr_.copy()
202 | plt.imshow(clr2)
203 |
204 | for p in range(kp2d.shape[0]):
205 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.')
206 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5)
207 |
208 | ax = fig.add_subplot(133, projection='3d')
209 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint')
210 | plt.plot(joint[:5, 0], joint[:5, 1],
211 | joint[:5, 2],
212 | 'r',
213 | label='thumb')
214 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1],
215 | joint[[0, 5, 6, 7, 8, ], 2],
216 | 'b',
217 | label='index')
218 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1],
219 | joint[[0, 9, 10, 11, 12], 2],
220 | 'b',
221 | label='middle')
222 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1],
223 | joint[[0, 13, 14, 15, 16], 2],
224 | 'b',
225 | label='ring')
226 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1],
227 | joint[[0, 17, 18, 19, 20], 2],
228 | 'b',
229 | label='pinky')
230 | # snap convention
231 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb')
232 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'ro', label='index')
233 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'ro', label='middle')
234 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring')
235 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'ro', label='pinky')
236 |
237 | plt.title('3D annotations')
238 | ax.set_xlabel('x')
239 | ax.set_ylabel('y')
240 | ax.set_zlabel('z')
241 | plt.legend()
242 | ax.view_init(-90, -90)
243 | plt.show()
244 |
245 | return sample
246 |
247 |
248 | if __name__ == '__main__':
249 | data_split = 'train'
250 | gan = GANeratedDataset(
251 | data_root='/home/chen/datasets/GANeratedHands_Release/data/',
252 | data_split=data_split,
253 | hand_side='right',
254 | njoints=21,
255 | use_cache=False,
256 | vis=True)
257 | print("len(gan)=", len(gan))
258 | for i in range(len(gan)):
259 | print("i=", i)
260 | data = gan.get_sample(i)
261 |
--------------------------------------------------------------------------------
/datasets/hand143_panopticdb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | r"""
3 | Hands from Panoptic Studio by Multiview Bootstrapping (14817 annotations)
4 | Hand Keypoint Detection in Single Images using Multiview Bootstrapping, CVPR 2017
5 | Link to dataset: http://domedb.perception.cs.cmu.edu/handdb.html
6 | Download:http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand143_panopticdb.tar
7 | """
8 |
9 | import json
10 | import os
11 | import pickle
12 |
13 | import PIL
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | import torch.utils.data
17 | from PIL import Image
18 | from termcolor import colored
19 | from tqdm import tqdm
20 |
21 | import config as cfg
22 | import utils.handutils as handutils
23 |
24 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR)
25 |
26 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)}
27 |
28 |
29 | class Hand143_panopticdb(torch.utils.data.Dataset):
30 | def __init__(
31 | self,
32 | data_root="/home/chen/datasets/CMU/hand143_panopticdb",
33 | data_split='train',
34 | hand_side='right',
35 | njoints=21,
36 | use_cache=True,
37 | vis=False
38 | ):
39 |
40 | if not os.path.exists(data_root):
41 | raise ValueError("data_root: %s not exist" % data_root)
42 |
43 | self.name = 'hand143_panopticdb'
44 | self.data_split = data_split
45 | self.hand_side = hand_side
46 | self.clr_paths = []
47 | self.kp2ds = []
48 | self.centers = []
49 | self.my_scales = []
50 | self.njoints = njoints
51 | self.reslu = [1920, 1080]
52 | self.vis = vis
53 |
54 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0
55 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9
56 |
57 | # [train|test|val|train_val|all]
58 | if data_split == 'train':
59 | self.sequence = ['training', ]
60 | else:
61 | print("hand143_panopticdb only has train_set!")
62 | return
63 |
64 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "hand143_panopticdb")
65 | os.makedirs(self.cache_folder, exist_ok=True)
66 | cache_path = os.path.join(
67 | self.cache_folder, "{}.pkl".format(self.data_split)
68 | )
69 |
70 | if os.path.exists(cache_path) and use_cache:
71 | with open(cache_path, "rb") as fid:
72 | annotations = pickle.load(fid)
73 | self.clr_paths = annotations["clr_paths"]
74 | self.kp2ds = annotations["kp2ds"]
75 | self.centers = annotations["centers"]
76 | self.my_scales = annotations["my_scales"]
77 | print("hand143_panopticdb {} gt loaded from {}".format(self.data_split, cache_path))
78 | return
79 |
80 | self.clr_root_list = [
81 | os.path.join(data_root, "imgs")
82 | ]
83 |
84 | self.ann_list = [
85 | os.path.join(
86 | data_root,
87 | "hands_v143_14817.json"
88 | )
89 | ]
90 |
91 | for clr_root, ann in zip(self.clr_root_list, self.ann_list):
92 |
93 | jsonPath = os.path.join(ann)
94 | with open(jsonPath, 'r') as fid:
95 | dat_all = json.load(fid)
96 | dat_all = dat_all['root']
97 |
98 | for i in tqdm(range(len(dat_all))):
99 | clrpth = os.path.join(clr_root, '%.8d.jpg' % i)
100 | self.clr_paths.append(clrpth)
101 |
102 | dat = dat_all[i]
103 | kp2d = np.array(dat['joint_self'])[:, : 2] # kp 2d left & right hand
104 | center = handutils.get_annot_center(kp2d)
105 | my_scale = handutils.get_ori_crop_scale(mask=None, side=None, mask_flag=False, kp2d=kp2d)
106 |
107 | kp2d = kp2d[np.newaxis, :, :]
108 | self.kp2ds.append(kp2d)
109 |
110 | center = center[np.newaxis, :]
111 | self.centers.append(center)
112 |
113 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :]
114 | self.my_scales.append(my_scale)
115 |
116 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2)
117 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1)
118 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1)
119 |
120 | if use_cache:
121 | full_info = {
122 | "clr_paths": self.clr_paths,
123 | "kp2ds": self.kp2ds,
124 | "centers": self.centers,
125 | "my_scales": self.my_scales,
126 | }
127 | with open(cache_path, "wb") as fid:
128 | pickle.dump(full_info, fid)
129 | print("Wrote cache for dataset hand143_panopticdb {} to {}".format(
130 | self.data_split, cache_path
131 | ))
132 | return
133 |
134 | def _is_valid(self, clr, index):
135 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image))
136 |
137 | if not valid_data:
138 | raise Exception("Encountered error processing cmu_1_[{}]".format(index))
139 | return valid_data
140 |
141 | def __len__(self):
142 | return len(self.clr_paths)
143 |
144 | def __str__(self):
145 | info = "Hand143_panopticdb {} set. lenth {}".format(
146 | self.data_split, len(self.clr_paths)
147 | )
148 | return colored(info, 'blue', attrs=['bold'])
149 |
150 | def get_sample(self, index):
151 | flip = True if self.hand_side != 'right' else False
152 |
153 | # prepare color image
154 | clr = Image.open(self.clr_paths[index]).convert("RGB")
155 | self._is_valid(clr, index)
156 |
157 | # prepare kp2d
158 | kp2d = self.kp2ds[index].copy()
159 | center = self.centers[index].copy()
160 | my_scale = self.my_scales[index].copy()
161 | if flip:
162 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT)
163 | center[0] = clr.size[0] - center[0]
164 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0]
165 |
166 | sample = {
167 | 'index': index,
168 | 'clr': clr,
169 | 'kp2d': kp2d,
170 | 'center': center,
171 | 'my_scale': my_scale,
172 | }
173 |
174 | # visualization
175 | if self.vis:
176 | plt.figure(figsize=(20, 20))
177 | clr_ = np.array(clr)
178 |
179 | plt.subplot(1, 2, 1)
180 | clr1 = clr_.copy()
181 | plt.imshow(clr1)
182 | plt.title('color image')
183 |
184 | plt.subplot(1, 2, 2)
185 | clr2 = clr_.copy()
186 | plt.imshow(clr2)
187 | plt.plot(200, 100, 'r.', linewidth=10) # opencv convention
188 | for p in range(kp2d.shape[0]):
189 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.')
190 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5)
191 | plt.title('2D annotations')
192 |
193 | plt.show()
194 |
195 | return sample
196 |
197 |
198 | def main():
199 | hand143_panopticdb = Hand143_panopticdb(
200 | data_root="/home/chen/datasets/CMU/hand143_panopticdb",
201 | data_split='train',
202 | hand_side='right',
203 | njoints=21,
204 | use_cache=True,
205 | vis=True
206 | )
207 | print("len(hand143_panopticdb)=", len(hand143_panopticdb))
208 |
209 | for i in tqdm(range(len(hand143_panopticdb))):
210 | print("i=", i)
211 | hand143_panopticdb.get_sample(i)
212 |
213 |
214 | if __name__ == "__main__":
215 | main()
216 |
--------------------------------------------------------------------------------
/datasets/hand_labels.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | r"""
3 | Hands with Manual Keypoint Annotations (Training: 1912 annotations, Testing: 846 annotations)
4 | Hand Keypoint Detection in Single Images using Multiview Bootstrapping, CVPR 2017
5 | Link to dataset: http://domedb.perception.cs.cmu.edu/handdb.html
6 | Download:http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip
7 | """
8 |
9 | import json
10 | import os
11 | import pickle
12 |
13 | import PIL
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | import torch.utils.data
17 | from PIL import Image
18 | from termcolor import colored
19 | from tqdm import tqdm
20 |
21 | import config as cfg
22 | import utils.handutils as handutils
23 |
24 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR)
25 |
26 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)}
27 |
28 |
29 | class Hand_labels(torch.utils.data.Dataset):
30 | def __init__(
31 | self,
32 | data_root="/home/chen/datasets/CMU/hand_labels",
33 | data_split='train',
34 | hand_side='right',
35 | njoints=21,
36 | use_cache=True,
37 | vis=False
38 | ):
39 |
40 | if not os.path.exists(data_root):
41 | raise ValueError("data_root: %s not exist" % data_root)
42 |
43 | self.vis = vis
44 | self.name = 'CMU:hand_labels'
45 | self.data_split = data_split
46 | self.hand_side = hand_side
47 | self.clr_paths = []
48 | self.kp2ds = []
49 | self.centers = []
50 | self.sides = []
51 | self.my_scales = []
52 | self.njoints = njoints
53 | self.reslu = [1920, 1080]
54 |
55 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0
56 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9
57 |
58 | # [train|test|val|train_val|all]
59 | if data_split == 'train':
60 | self.sequence = ['manual_train', ]
61 | elif data_split == 'test':
62 | self.sequence = ['manual_test', ]
63 | elif data_split == 'val':
64 | self.sequence = ['manual_test', ]
65 | elif data_split == 'train_val':
66 | self.sequence = ['manual_train', ]
67 | elif data_split == 'all':
68 | self.sequence = ['manual_train', 'manual_test']
69 | else:
70 | raise ValueError("hand_labels only has train_set!")
71 |
72 | self.cache_folder = os.path.join(CACHE_HOME, "my-train", "hand_labels")
73 | os.makedirs(self.cache_folder, exist_ok=True)
74 | cache_path = os.path.join(
75 | self.cache_folder, "{}.pkl".format(self.data_split)
76 | )
77 |
78 | if os.path.exists(cache_path) and use_cache:
79 | with open(cache_path, "rb") as fid:
80 | annotations = pickle.load(fid)
81 | self.sides = annotations["sides"]
82 | self.clr_paths = annotations["clr_paths"]
83 | self.kp2ds = annotations["kp2ds"]
84 | self.centers = annotations["centers"]
85 | self.my_scales = annotations["my_scales"]
86 | print("hand_labels {} gt loaded from {}".format(self.data_split, cache_path))
87 | return
88 |
89 | datapath_list = [
90 | os.path.join(data_root, seq) for seq in self.sequence
91 | ]
92 |
93 | for datapath in datapath_list:
94 | files = sorted([f for f in os.listdir(datapath) if f.endswith('.json')])
95 |
96 | for idx in tqdm(range(len(files))):
97 | f = files[idx]
98 | with open(os.path.join(datapath, f), 'r') as fid:
99 | dat = json.load(fid)
100 |
101 | kp2d = np.array(dat['hand_pts'])[:, : 2]
102 | is_left = dat['is_left']
103 | self.sides.append("left" if is_left else "right")
104 |
105 | clr_pth = os.path.join(datapath, f[0:-5] + '.jpg')
106 | self.clr_paths.append(clr_pth)
107 | center = handutils.get_annot_center(kp2d)
108 | my_scale = handutils.get_ori_crop_scale(mask=False, mask_flag=False, side=None, kp2d=kp2d)
109 |
110 | kp2d = kp2d[np.newaxis, :, :]
111 | self.kp2ds.append(kp2d)
112 |
113 | center = center[np.newaxis, :]
114 | self.centers.append(center)
115 |
116 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :]
117 | self.my_scales.append(my_scale)
118 |
119 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2)
120 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1)
121 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1)
122 |
123 | if use_cache:
124 | full_info = {
125 | "sides": self.sides,
126 | "clr_paths": self.clr_paths,
127 | "kp2ds": self.kp2ds,
128 | "centers": self.centers,
129 | "my_scales": self.my_scales,
130 | }
131 | with open(cache_path, "wb") as fid:
132 | pickle.dump(full_info, fid)
133 | print("Wrote cache for dataset hand_labels {} to {}".format(
134 | self.data_split, cache_path
135 | ))
136 | return
137 |
138 | def _is_valid(self, clr, index):
139 | valid_data = isinstance(clr, (np.ndarray, PIL.Image.Image))
140 |
141 | if not valid_data:
142 | raise Exception("Encountered error processing CMU_2_[{}]".format(index))
143 | return valid_data
144 |
145 | def __len__(self):
146 | return len(self.clr_paths)
147 |
148 | def __str__(self):
149 | info = "hand_labels {} set. lenth {}".format(
150 | self.data_split, len(self.clr_paths)
151 | )
152 | return colored(info, 'blue', attrs=['bold'])
153 |
154 | def get_sample(self, index):
155 | flip = True if self.hand_side != self.sides[index] else False
156 |
157 | # prepare color image
158 | clr = Image.open(self.clr_paths[index]).convert("RGB")
159 | self._is_valid(clr, index)
160 |
161 | # prepare kp2d
162 | kp2d = self.kp2ds[index].copy()
163 | center = self.centers[index].copy()
164 | my_scale = self.my_scales[index].copy()
165 | if flip:
166 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT)
167 | center[0] = clr.size[0] - center[0]
168 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0]
169 |
170 | sample = {
171 | 'index': index,
172 | 'clr': clr,
173 | 'kp2d': kp2d,
174 | 'center': center,
175 | 'my_scale': my_scale,
176 | }
177 |
178 | # visualization
179 | if self.vis:
180 | clr_ = np.array(clr)
181 | plt.figure(figsize=(20, 20))
182 | plt.subplot(1, 2, 1)
183 | clr1 = clr_.copy()
184 | plt.imshow(clr1)
185 | plt.title('color image')
186 |
187 | plt.subplot(1, 2, 2)
188 | clr2 = clr_.copy()
189 | plt.imshow(clr2)
190 | plt.plot(200, 100, 'r.', linewidth=10) # opencv convention
191 | for p in range(kp2d.shape[0]):
192 | plt.plot(kp2d[p][0], kp2d[p][1], 'r.')
193 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), fontsize=5)
194 | plt.title('2D annotations')
195 |
196 | plt.show()
197 |
198 | return sample
199 |
200 |
201 | def main():
202 | hand_labels = Hand_labels(
203 | data_root="/home/chen/datasets/CMU/hand_labels",
204 | data_split='train',
205 | hand_side='right',
206 | njoints=21,
207 | use_cache=True,
208 | vis=True
209 | )
210 | print("len(hand_labels)=", len(hand_labels))
211 |
212 | for i in tqdm(range(len(hand_labels))):
213 | print("i=", i)
214 | data = hand_labels.get_sample(i)
215 |
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------
/datasets/handataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | r"""
3 | Hand dataset controll all sub dataset
4 | """
5 |
6 | import os
7 | import random
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch.utils.data
12 | from PIL import Image, ImageFilter
13 | from termcolor import colored
14 | from tqdm import tqdm
15 |
16 | import config as cfg
17 | import utils.func as func
18 | import utils.handutils as handutils
19 | import utils.heatmaputils as hmutils
20 | import utils.imgutils as imutils
21 | from datasets.dexter_object import DexterObjectDataset
22 | from datasets.ganerated_hands import GANeratedDataset
23 | from datasets.hand143_panopticdb import Hand143_panopticdb
24 | from datasets.hand_labels import Hand_labels
25 | from datasets.rhd import RHDDataset
26 | from datasets.stb import STBDataset
27 |
28 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)}
29 |
30 |
31 | class HandDataset(torch.utils.data.Dataset):
32 | def __init__(
33 | self,
34 | data_split='train',
35 | data_root="/disk1/data",
36 | subset_name=['rhd', 'stb'],
37 | hand_side='right',
38 | sigma=1.0,
39 | inp_res=128,
40 | hm_res=32,
41 | njoints=21,
42 | train=True,
43 | scale_jittering=0.1,
44 | center_jettering=0.1,
45 | max_rot=np.pi,
46 | hue=0.15,
47 | saturation=0.5,
48 | contrast=0.5,
49 | brightness=0.5,
50 | blur_radius=0.5, vis=False
51 | ):
52 |
53 | self.inp_res = inp_res # 128 # network input resolution
54 | self.hm_res = hm_res # 32 # out_testset hm resolution
55 | self.njoints = njoints
56 | self.sigma = sigma
57 | self.max_rot = max_rot
58 |
59 | # Training attributes
60 | self.train = train
61 | self.scale_jittering = scale_jittering
62 | self.center_jittering = center_jettering
63 |
64 | # Color jitter attributes
65 | self.hue = hue
66 | self.contrast = contrast
67 | self.brightness = brightness
68 | self.saturation = saturation
69 | self.blur_radius = blur_radius
70 |
71 | self.datasets = []
72 | self.ref_bone_link = (0, 9) # mid mcp
73 | self.joint_root_idx = 9 # root
74 |
75 | self.vis = vis
76 |
77 | if 'stb' in subset_name:
78 | self.stb = STBDataset(
79 | data_root=os.path.join(data_root, 'STB'),
80 | data_split=data_split,
81 | hand_side=hand_side,
82 | njoints=njoints,
83 | )
84 | print(self.stb)
85 | self.datasets.append(self.stb)
86 |
87 | if 'rhd' in subset_name:
88 | self.rhd = RHDDataset(
89 | data_root=os.path.join(data_root, 'RHD/RHD_published_v2'),
90 | data_split=data_split,
91 | hand_side=hand_side,
92 | njoints=njoints,
93 | )
94 | print(self.rhd)
95 | self.datasets.append(self.rhd)
96 |
97 | if 'cmu' in subset_name:
98 | self.hand143_panopticdb = Hand143_panopticdb(
99 | data_root=os.path.join(data_root, 'CMU/hand143_panopticdb'),
100 | data_split=data_split,
101 | hand_side=hand_side,
102 | njoints=njoints,
103 | )
104 | print(self.hand143_panopticdb)
105 | self.datasets.append(self.hand143_panopticdb)
106 |
107 | self.hand_labels = Hand_labels(
108 | data_root=os.path.join(data_root, 'CMU/hand_labels'),
109 | data_split=data_split,
110 | hand_side=hand_side,
111 | njoints=njoints,
112 | )
113 | print(self.hand_labels)
114 | self.datasets.append(self.hand_labels)
115 |
116 | info = "CMU {} set. lenth {}".format(
117 | data_split, len(self.hand_labels) + len(self.hand143_panopticdb)
118 | )
119 | print(colored(info, 'yellow', attrs=['bold']))
120 |
121 | if 'gan' in subset_name:
122 | self.gan = GANeratedDataset(
123 | data_root=os.path.join(data_root, 'GANeratedHands_Release/data/'),
124 | data_split=data_split,
125 | hand_side=hand_side,
126 | njoints=njoints,
127 | )
128 | print(self.gan)
129 | self.datasets.append(self.gan)
130 |
131 | if 'do' in subset_name:
132 | self.do = DexterObjectDataset(
133 | data_root=os.path.join(data_root, 'dexter+object'),
134 | data_split=data_split,
135 | hand_side=hand_side,
136 | njoints=njoints,
137 | )
138 | print(self.do)
139 | self.datasets.append(self.do)
140 |
141 | self.total_data = 0
142 | for ds in self.datasets:
143 | self.total_data += len(ds)
144 |
145 | def __getitem__(self, index):
146 | rng = np.random.RandomState(seed=random.randint(0, 1024))
147 | try:
148 | sample, ds = self._get_sample(index)
149 | except Exception:
150 | index = np.random.randint(0, len(self))
151 | sample, ds = self._get_sample(index)
152 |
153 | clr = sample['clr']
154 | my_clr1 = clr.copy()
155 | center = sample['center']
156 | scale = sample['my_scale']
157 | if 'intr' in sample.keys():
158 | intr = sample['intr']
159 |
160 | # Data augmentation
161 | if self.train:
162 | center_offsets = (
163 | self.center_jittering
164 | * scale
165 | * rng.uniform(low=-1, high=1, size=2)
166 | )
167 | center = center + center_offsets.astype(int)
168 |
169 | # Scale jittering
170 | scale_jittering = self.scale_jittering * rng.randn() + 1
171 | scale_jittering = np.clip(
172 | scale_jittering,
173 | 1 - self.scale_jittering,
174 | 1 + self.scale_jittering,
175 | )
176 | scale = scale * scale_jittering
177 | rot = rng.uniform(low=-self.max_rot, high=self.max_rot)
178 | else:
179 | rot = 0
180 |
181 | rot_mat = np.array([
182 | [np.cos(rot), -np.sin(rot), 0],
183 | [np.sin(rot), np.cos(rot), 0],
184 | [0, 0, 1],
185 | ]).astype(np.float32)
186 |
187 | if 'intr' in sample.keys():
188 | affinetrans, post_rot_trans = handutils.get_affine_transform(
189 | center=center,
190 | scale=scale,
191 | optical_center=[intr[0, 2], intr[1, 2]],
192 | out_res=[self.inp_res, self.inp_res],
193 | rot=rot
194 | )
195 | else:
196 | affinetrans, post_rot_trans = handutils.get_affine_transform_test(
197 | center, scale, [self.inp_res, self.inp_res], rot=rot
198 | )
199 |
200 | ''' prepare kp2d '''
201 | kp2d = sample['kp2d']
202 | kp2d_ori = kp2d.copy()
203 | kp2d = handutils.transform_coords(kp2d, affinetrans)
204 |
205 | ''' Generate GT Gussian hm and hm veil '''
206 | hm = np.zeros(
207 | (self.njoints, self.hm_res, self.hm_res),
208 | dtype='float32'
209 | ) # (CHW)
210 | hm_veil = np.ones(self.njoints, dtype='float32')
211 | for i in range(self.njoints):
212 | kp = (
213 | (kp2d[i] / self.inp_res) * self.hm_res
214 | ).astype(np.int32) # kp uv: [0~256] -> [0~64]
215 | hm[i], aval = hmutils.gen_heatmap(hm[i], kp, self.sigma)
216 | hm_veil[i] *= aval
217 |
218 | joint = np.zeros([21, 3])
219 | delta_map = np.zeros([21, 3, 32, 32])
220 | location_map = np.zeros([21, 3, 32, 32])
221 | flag = 0
222 |
223 | if 'joint' in sample.keys():
224 |
225 | flag = 1
226 | ''' prepare joint '''
227 | joint = sample['joint']
228 | if self.train:
229 | joint = rot_mat.dot(
230 | joint.transpose(1, 0)
231 | ).transpose()
232 |
233 | joint_bone = 0
234 | for jid, nextjid in zip(self.ref_bone_link[:-1], self.ref_bone_link[1:]):
235 | joint_bone += np.linalg.norm(joint[nextjid] - joint[jid])
236 | joint_root = joint[self.joint_root_idx]
237 | joint_bone = np.atleast_1d(joint_bone)
238 |
239 | '''prepare location maps L'''
240 | jointR = joint - joint_root[np.newaxis, :] # root relative
241 | jointRS = jointR / joint_bone # scale invariant
242 | # '''jointRS.shape= (21, 3) to locationmap(21,3,32,32)'''
243 | location_map = jointRS[:, :, np.newaxis, np.newaxis].repeat(32, axis=-2).repeat(32, axis=-1)
244 |
245 | '''prepare delta maps D'''
246 | kin_chain = [
247 | jointRS[i] - jointRS[cfg.SNAP_PARENT[i]]
248 | for i in range(21)
249 | ]
250 | kin_chain = np.array(kin_chain) # id 0's parent is itself #21*3
251 | kin_len = np.linalg.norm(
252 | kin_chain, ord=2, axis=-1, keepdims=True # 21*1
253 | )
254 | kin_chain[1:] = kin_chain[1:] / kin_len[1:]
255 | # '''kin_chain(21, 3) to delta_map(21,3,32,32)'''
256 | delta_map = kin_chain[:, :, np.newaxis, np.newaxis].repeat(32, axis=-2).repeat(32, axis=-1)
257 |
258 | if 'tip' in sample.keys():
259 | joint = sample['tip']
260 | if self.train:
261 | joint = rot_mat.dot(
262 | joint.transpose(1, 0)
263 | ).transpose()
264 |
265 | ''' prepare clr image '''
266 | if self.train:
267 | blur_radius = random.random() * self.blur_radius
268 | clr = clr.filter(ImageFilter.GaussianBlur(blur_radius))
269 | clr = imutils.color_jitter(
270 | clr,
271 | brightness=self.brightness,
272 | saturation=self.saturation,
273 | hue=self.hue,
274 | contrast=self.contrast,
275 | )
276 |
277 | # Transform and crop
278 | clr = handutils.transform_img(
279 | clr, affinetrans, [self.inp_res, self.inp_res]
280 | )
281 | clr = clr.crop((0, 0, self.inp_res, self.inp_res))
282 | my_clr2 = clr.copy()
283 |
284 | ''' implicit HWC -> CHW, 255 -> 1 '''
285 | clr = func.to_tensor(clr).float()
286 | ''' 0-mean, 1 std, [0,1] -> [-0.5, 0.5] '''
287 | clr = func.normalize(clr, [0.5, 0.5, 0.5], [1, 1, 1])
288 |
289 | # visualization
290 | if self.vis:
291 |
292 | clr1 = my_clr1.copy()
293 |
294 | fig = plt.figure(figsize=(20, 10))
295 | plt.subplot(1, 4, 1)
296 | plt.imshow(np.asarray(clr1))
297 | plt.title('ori_Color+2D annotations')
298 | plt.plot(kp2d_ori[0, 0], kp2d_ori[0, 1], 'ro', markersize=5)
299 | plt.text(kp2d_ori[0][0], kp2d_ori[0][1], '0', color="w", fontsize=7.5)
300 | for p in range(1, kp2d_ori.shape[0]):
301 | plt.plot(kp2d_ori[p][0], kp2d_ori[p][1], 'bo', markersize=5)
302 | plt.text(kp2d_ori[p][0], kp2d_ori[p][1], '{0}'.format(p), color="w", fontsize=5)
303 |
304 | plt.subplot(1, 4, 2)
305 | clr2 = np.array(my_clr2.copy())
306 | plt.imshow(clr2)
307 | plt.plot(kp2d[0, 0], kp2d[0, 1], 'ro', markersize=5)
308 | plt.text(kp2d[0][0], kp2d[0][1], '0', color="w", fontsize=7.5)
309 | for p in range(1, kp2d.shape[0]):
310 | plt.plot(kp2d[p][0], kp2d[p][1], 'bo', markersize=5)
311 | plt.text(kp2d[p][0], kp2d[p][1], '{0}'.format(p), color="w", fontsize=5)
312 | plt.title('cropped_Color+2D annotations')
313 |
314 | plt.subplot(1, 4, 3)
315 | clr3 = my_clr2.copy().resize((self.hm_res, self.hm_res), Image.ANTIALIAS)
316 | tmp = clr3.convert('L')
317 | tmp = np.array(tmp)
318 | for k in range(hm.shape[0]):
319 | tmp = tmp + hm[k] * 64
320 | plt.imshow(tmp)
321 | plt.title('heatmap')
322 |
323 | if 'joint' in sample.keys():
324 | ax = fig.add_subplot(144, projection='3d')
325 |
326 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint')
327 |
328 | plt.plot(joint[:5, 0], joint[:5, 1],
329 | joint[:5, 2],
330 | 'r',
331 | label='thumb')
332 |
333 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1],
334 | joint[[0, 5, 6, 7, 8, ], 2],
335 | 'b',
336 | label='index')
337 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1],
338 | joint[[0, 9, 10, 11, 12], 2],
339 | 'b',
340 | label='middle')
341 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1],
342 | joint[[0, 13, 14, 15, 16], 2],
343 | 'b',
344 | label='ring')
345 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1],
346 | joint[[0, 17, 18, 19, 20], 2],
347 | 'b',
348 | label='pinky')
349 | # snap convention
350 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb')
351 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'r*', label='index')
352 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'rs', label='middle')
353 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring')
354 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'rv', label='pinky')
355 |
356 | plt.title('3D annotations')
357 | ax.set_xlabel('x')
358 | ax.set_ylabel('y')
359 | ax.set_zlabel('z')
360 | plt.legend()
361 | ax.view_init(-90, -90)
362 |
363 | plt.show()
364 |
365 | ## to torch tensor
366 | clr = clr
367 | hm = torch.from_numpy(hm).float()
368 | hm_veil = torch.from_numpy(hm_veil).float()
369 | joint = torch.from_numpy(joint).float()
370 | location_map = torch.from_numpy(location_map).float()
371 | delta_map = torch.from_numpy(delta_map).float()
372 |
373 | metas = {
374 | 'index': index,
375 | 'clr': clr,
376 | 'hm': hm,
377 | 'hm_veil': hm_veil,
378 | 'location_map': location_map,
379 | 'delta_map': delta_map,
380 | 'flag_3d': flag,
381 | "joint": joint
382 | }
383 |
384 | return metas
385 |
386 | def _get_sample(self, index):
387 | base = 0
388 | dataset = None
389 | for ds in self.datasets:
390 | if index < base + len(ds):
391 | sample = ds.get_sample(index - base)
392 | dataset = ds
393 | break
394 | else:
395 | base += len(ds)
396 | return sample, dataset
397 |
398 | def __len__(self):
399 | return self.total_data
400 |
401 |
402 | if __name__ == '__main__':
403 | test_set = HandDataset(
404 | data_split='test',
405 | train=False,
406 | scale_jittering=0.1,
407 | center_jettering=0.1,
408 | max_rot=0.5 * np.pi,
409 | subset_name=["rhd", "stb", "do", "eo"],
410 | data_root="/home/chen/datasets/", vis=True
411 | )
412 |
413 | for id in tqdm(range(0, len(test_set), 10)):
414 | print("id=", id)
415 | data = test_set[id]
416 |
--------------------------------------------------------------------------------
/datasets/rhd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Lixin YANG. All Rights Reserved.
2 | r"""
3 | Randered dataset
4 | Learning to Estimate 3D Hand joint from Single RGB Images, ICCV 2017
5 | """
6 |
7 | import os
8 | import pickle
9 |
10 | import PIL
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch.utils.data
14 | from PIL import Image
15 | from progress.bar import Bar
16 | from termcolor import colored
17 | from tqdm import tqdm
18 |
19 | import config as cfg
20 | import utils.handutils as handutils
21 |
22 | CACHE_HOME = os.path.expanduser(cfg.DEFAULT_CACHE_DIR)
23 |
24 | snap_joint_name2id = {w: i for i, w in enumerate(cfg.snap_joint_names)}
25 | rhd_joint_name2id = {w: i for i, w in enumerate(cfg.rhd_joints)}
26 | rhd_to_snap_id = [snap_joint_name2id[joint_name] for joint_name in cfg.rhd_joints]
27 |
28 |
29 | class RHDDataset(torch.utils.data.Dataset):
30 | def __init__(
31 | self,
32 | data_root="/disk1/data/RHD/RHD_published_v2",
33 | data_split='train',
34 | hand_side='right',
35 | njoints=21,
36 | use_cache=True,
37 | visual=False
38 | ):
39 |
40 | if not os.path.exists(data_root):
41 | raise ValueError("data_root: %s not exist" % data_root)
42 |
43 | self.name = 'rhd'
44 | self.data_split = data_split
45 | self.hand_side = hand_side
46 | self.clr_paths = []
47 | self.mask_paths = []
48 | self.joints = []
49 | self.kp2ds = []
50 | self.centers = []
51 | self.my_scales = []
52 | self.sides = []
53 | self.intrs = []
54 | self.njoints = njoints # total 21 hand parts
55 | self.reslu = [320, 320]
56 |
57 | self.visual = visual
58 |
59 | self.root_id = snap_joint_name2id['loc_bn_palm_L'] # 0
60 | self.mid_mcp_id = snap_joint_name2id['loc_bn_mid_L_01'] # 9
61 |
62 | # [train|test|val|train_val|all]
63 | if data_split == 'train':
64 | self.sequence = ['training', ]
65 | elif data_split == 'test':
66 | self.sequence = ['evaluation', ]
67 | elif data_split == 'val':
68 | self.sequence = ['evaluation', ]
69 | elif data_split == 'train_val':
70 | self.sequence = ['training', ]
71 | elif data_split == 'all':
72 | self.sequence = ['training', 'evaluation']
73 | else:
74 | raise ValueError("split {} not in [train|test|val|train_val|all]".format(data_split))
75 |
76 | self.cache_folder = os.path.join(CACHE_HOME, "my-{}".format(data_split), "rhd")
77 | os.makedirs(self.cache_folder, exist_ok=True)
78 | cache_path = os.path.join(
79 | self.cache_folder, "{}.pkl".format(self.data_split)
80 | )
81 | if os.path.exists(cache_path) and use_cache:
82 | with open(cache_path, "rb") as fid:
83 | annotations = pickle.load(fid)
84 | self.sides = annotations["sides"]
85 | self.clr_paths = annotations["clr_paths"]
86 | self.mask_paths = annotations["mask_paths"]
87 | self.joints = annotations["joints"]
88 | self.kp2ds = annotations["kp2ds"]
89 | self.intrs = annotations["intrs"]
90 | self.centers = annotations["centers"]
91 | self.my_scales = annotations["my_scales"]
92 | print("rhd {} gt loaded from {}".format(self.data_split, cache_path))
93 | return
94 |
95 | datapath_list = [
96 | os.path.join(data_root, seq) for seq in self.sequence
97 | ]
98 | annoname_list = [
99 | "anno_{}.pickle".format(seq) for seq in self.sequence
100 | ]
101 | anno_list = [
102 | os.path.join(datapath, annoname) \
103 | for datapath, annoname in zip(datapath_list, annoname_list)
104 | ]
105 | clr_root_list = [
106 | os.path.join(datapath, "color") for datapath in datapath_list
107 | ]
108 | dep_root_list = [
109 | os.path.join(datapath, "depth") for datapath in datapath_list
110 | ]
111 | mask_root_list = [
112 | os.path.join(datapath, "mask") for datapath in datapath_list
113 | ]
114 |
115 | print("init RHD {}, It will take a while at first time".format(data_split))
116 | for anno, clr_root, dep_root, mask_root \
117 | in zip(
118 | anno_list,
119 | clr_root_list,
120 | dep_root_list,
121 | mask_root_list
122 | ):
123 |
124 | with open(anno, 'rb') as fi:
125 | rawdatas = pickle.load(fi)
126 | fi.close()
127 |
128 | bar = Bar('RHD', max=len(rawdatas))
129 | for i in tqdm(range(len(rawdatas))):
130 |
131 | raw = rawdatas[i]
132 | rawkp2d = raw['uv_vis'][:, : 2] # kp 2d left & right hand
133 | rawvis = raw['uv_vis'][:, 2]
134 |
135 | rawjoint = raw['xyz'] # x, y, z coordinates of the keypoints, in meters
136 | rawintr = raw['K']
137 |
138 | ''' "both" means left, right'''
139 | kp2dboth = [
140 | rawkp2d[:21][rhd_to_snap_id, :],
141 | rawkp2d[21:][rhd_to_snap_id, :]
142 | ]
143 | visboth = [
144 | rawvis[:21][rhd_to_snap_id],
145 | rawvis[21:][rhd_to_snap_id]
146 | ]
147 | jointboth = [
148 | rawjoint[:21][rhd_to_snap_id, :],
149 | rawjoint[21:][rhd_to_snap_id, :]
150 | ]
151 |
152 | intrboth = [rawintr, rawintr]
153 | sideboth = ['l', 'r']
154 |
155 | l_kp_count = np.sum(raw['uv_vis'][:21, 2] == 1)
156 | r_kp_count = np.sum(raw['uv_vis'][21:, 2] == 1)
157 | vis_side = 'l' if l_kp_count > r_kp_count else 'r'
158 |
159 | for kp2d, vis, joint, side, intr \
160 | in zip(kp2dboth, visboth, jointboth, sideboth, intrboth):
161 | if side != vis_side:
162 | continue
163 |
164 | clrpth = os.path.join(clr_root, '%.5d.png' % i)
165 | maskpth = os.path.join(mask_root, '%.5d.png' % i)
166 | self.clr_paths.append(clrpth)
167 | self.mask_paths.append(maskpth)
168 | self.sides.append(side)
169 |
170 | joint = joint[np.newaxis, :, :]
171 | self.joints.append(joint)
172 |
173 | center = handutils.get_annot_center(kp2d)
174 | kp2d = kp2d[np.newaxis, :, :]
175 | self.kp2ds.append(kp2d)
176 |
177 | center = center[np.newaxis, :]
178 | self.centers.append(center)
179 |
180 | mask = Image.open(maskpth).convert("RGB")
181 | mask = np.array(mask)[:, :, 2:]
182 | my_scale = handutils.get_ori_crop_scale(mask, side, kp2d.squeeze(0))
183 | my_scale = (np.atleast_1d(my_scale))[np.newaxis, :]
184 | self.my_scales.append(my_scale)
185 |
186 | intr = intr[np.newaxis, :]
187 | self.intrs.append(intr)
188 |
189 | bar.suffix = ('({n}/{all}), total:{t:}s, eta:{eta:}s').format(
190 | n=i + 1, all=len(rawdatas), t=bar.elapsed_td, eta=bar.eta_td)
191 | bar.next()
192 |
193 | bar.finish()
194 | self.joints = np.concatenate(self.joints, axis=0).astype(np.float32) # (N, 21, 3)
195 |
196 | self.kp2ds = np.concatenate(self.kp2ds, axis=0).astype(np.float32) # (N, 21, 2)
197 | self.centers = np.concatenate(self.centers, axis=0).astype(np.float32) # (N, 1)
198 | self.my_scales = np.concatenate(self.my_scales, axis=0).astype(np.float32) # (N, 1)
199 | self.intrs = np.concatenate(self.intrs, axis=0).astype(np.float32) # (N, 3,3)
200 |
201 | if use_cache:
202 | full_info = {
203 | "sides": self.sides,
204 | "clr_paths": self.clr_paths,
205 | "mask_paths": self.mask_paths,
206 | "joints": self.joints,
207 | "kp2ds": self.kp2ds,
208 | "intrs": self.intrs,
209 | "centers": self.centers,
210 | "my_scales": self.my_scales,
211 | }
212 | with open(cache_path, "wb") as fid:
213 | pickle.dump(full_info, fid)
214 | print("Wrote cache for dataset rhd {} to {}".format(
215 | self.data_split, cache_path
216 | ))
217 | return
218 |
219 | def get_sample(self, index):
220 | side = self.sides[index]
221 | """ 'r' in 'left' / 'l' in 'right' """
222 | flip = True if (side not in self.hand_side) else False
223 |
224 | clr = Image.open(self.clr_paths[index]).convert("RGB")
225 | self._is_valid(clr, index)
226 | mask = Image.open(self.mask_paths[index]).convert("RGB")
227 | self._is_valid(mask, index)
228 |
229 | # prepare jont
230 | joint = self.joints[index].copy()
231 |
232 | # prepare kp2d
233 | kp2d = self.kp2ds[index].copy()
234 |
235 | center = self.centers[index].copy()
236 | # scale = self.scales[index].copy()
237 |
238 | my_scale = self.my_scales[index].copy()
239 |
240 | if flip:
241 | clr = clr.transpose(Image.FLIP_LEFT_RIGHT)
242 | center[0] = clr.size[0] - center[0] # clr.size[0] represents width of image
243 | kp2d[:, 0] = clr.size[0] - kp2d[:, 0]
244 | joint[:, 0] = -joint[:, 0]
245 |
246 | sample = {
247 | 'index': index,
248 | 'clr': clr,
249 | 'kp2d': kp2d,
250 | 'center': center,
251 | 'my_scale': my_scale,
252 | 'joint': joint,
253 | 'intr': self.intrs[index],
254 | }
255 |
256 | if self.visual:
257 | fig = plt.figure(figsize=(20, 20))
258 | plt.subplot(1, 3, 1)
259 | plt.imshow(clr.copy())
260 | plt.title('Color')
261 |
262 | plt.subplot(1, 3, 2)
263 | plt.imshow(clr.copy())
264 | plt.plot(kp2d[:, :1], kp2d[:, 1:], 'ro')
265 | plt.title('Color+2D annotations')
266 |
267 | ax = fig.add_subplot(133, projection='3d')
268 | plt.plot(joint[:, 0], joint[:, 1], joint[:, 2], 'yo', label='keypoint')
269 | plt.plot(joint[:5, 0], joint[:5, 1],
270 | joint[:5, 2],
271 | 'r',
272 | label='thumb')
273 | plt.plot(joint[[0, 5, 6, 7, 8, ], 0], joint[[0, 5, 6, 7, 8, ], 1],
274 | joint[[0, 5, 6, 7, 8, ], 2],
275 | 'b',
276 | label='index')
277 | plt.plot(joint[[0, 9, 10, 11, 12, ], 0], joint[[0, 9, 10, 11, 12], 1],
278 | joint[[0, 9, 10, 11, 12], 2],
279 | 'b',
280 | label='middle')
281 | plt.plot(joint[[0, 13, 14, 15, 16], 0], joint[[0, 13, 14, 15, 16], 1],
282 | joint[[0, 13, 14, 15, 16], 2],
283 | 'b',
284 | label='ring')
285 | plt.plot(joint[[0, 17, 18, 19, 20], 0], joint[[0, 17, 18, 19, 20], 1],
286 | joint[[0, 17, 18, 19, 20], 2],
287 | 'b',
288 | label='pinky')
289 | # snap convention
290 | plt.plot(joint[4][0], joint[4][1], joint[4][2], 'rD', label='thumb')
291 | plt.plot(joint[8][0], joint[8][1], joint[8][2], 'ro', label='index')
292 | plt.plot(joint[12][0], joint[12][1], joint[12][2], 'ro', label='middle')
293 | plt.plot(joint[16][0], joint[16][1], joint[16][2], 'ro', label='ring')
294 | plt.plot(joint[20][0], joint[20][1], joint[20][2], 'ro', label='pinky')
295 | # plt.plot(joint [1:, 0], joint [1:, 1], joint [1:, 2], 'o')
296 |
297 | plt.title('3D annotations')
298 | ax.set_xlabel('x')
299 | ax.set_ylabel('y')
300 | ax.set_zlabel('z')
301 | plt.legend()
302 | ax.view_init(-90, -90)
303 | plt.show()
304 |
305 | return sample
306 |
307 | def _apply_mask(self, dep, mask, side):
308 | ''' follow the label rules in RHD datasets '''
309 | if side is 'l':
310 | valid_mask_id = [i for i in range(2, 18)]
311 | else:
312 | valid_mask_id = [i for i in range(18, 34)]
313 |
314 | mask = np.array(mask)[:, :, 2:]
315 | dep = np.array(dep)
316 | ll = valid_mask_id[0]
317 | uu = valid_mask_id[-1]
318 | mask[mask < ll] = 0
319 | mask[mask > uu] = 0
320 | mask[mask > 0] = 1
321 | if mask.dtype != np.uint8:
322 | mask = mask.astype(np.uint8)
323 | dep = np.multiply(dep, mask)
324 | dep = Image.fromarray(dep, mode="RGB")
325 | return dep
326 |
327 | def __len__(self):
328 | return len(self.clr_paths)
329 |
330 | def __str__(self):
331 | info = "RHD {} set. lenth {}".format(
332 | self.data_split, len(self.clr_paths)
333 | )
334 | return colored(info, 'yellow', attrs=['bold'])
335 |
336 | def norm_dep_img(self, dep_):
337 | """RHD depthmap to depth image
338 | """
339 | if isinstance(dep_, PIL.Image.Image):
340 | dep_ = np.array(dep_)
341 | assert (dep_.shape[-1] == 3) # used to be "RGB"
342 |
343 | ''' Converts a RGB-coded depth into float valued depth. '''
344 | dep = (dep_[:, :, 0] * 2 ** 8 + dep_[:, :, 1]).astype('float32')
345 | dep /= float(2 ** 16 - 1)
346 | dep *= 5.0 ## depth in meter !
347 |
348 | return dep
349 |
350 | def _is_valid(self, img, index):
351 | valid_data = isinstance(img, (np.ndarray, PIL.Image.Image))
352 | if not valid_data:
353 | raise Exception("Encountered error processing rhd[{}]".format(index))
354 | return valid_data
355 |
356 |
357 | def main():
358 | data_split = 'test'
359 | rhd = RHDDataset(
360 | data_root="/home/chen/datasets/RHD/RHD_published_v2",
361 | data_split=data_split,
362 | hand_side='right',
363 | njoints=21,
364 | use_cache=False,
365 | visual=True
366 | )
367 | print("len(rhd)=", len(rhd))
368 |
369 | for i in tqdm(range(len(rhd))):
370 | print("id=", id)
371 | data = rhd.get_sample(i)
372 |
373 |
374 | if __name__ == "__main__":
375 | main()
376 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | from manopth import manolayer
4 | from model.detnet import detnet
5 | from utils import func, bone, AIK, smoother
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from utils import vis
9 | from op_pso import PSO
10 | import open3d
11 |
12 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
13 | _mano_root = 'mano/models'
14 |
15 | module = detnet().to(device)
16 | print('load model start')
17 | check_point = torch.load('new_check_point/ckp_detnet_83.pth', map_location=device)
18 | model_state = module.state_dict()
19 | state = {}
20 | for k, v in check_point.items():
21 | if k in model_state:
22 | state[k] = v
23 | else:
24 | print(k, ' is NOT in current model')
25 | model_state.update(state)
26 | module.load_state_dict(model_state)
27 | print('load model finished')
28 | pose, shape = func.initiate("zero")
29 | pre_useful_bone_len = np.zeros((1, 15))
30 | pose0 = torch.eye(3).repeat(1, 16, 1, 1)
31 |
32 | mano = manolayer.ManoLayer(flat_hand_mean=True,
33 | side="right",
34 | mano_root=_mano_root,
35 | use_pca=False,
36 | root_rot_mode='rotmat',
37 | joint_rot_mode='rotmat')
38 | print('start opencv')
39 | point_fliter = smoother.OneEuroFilter(4.0, 0.0)
40 | mesh_fliter = smoother.OneEuroFilter(4.0, 0.0)
41 | shape_fliter = smoother.OneEuroFilter(4.0, 0.0)
42 | cap = cv2.VideoCapture(0)
43 | print('opencv finished')
44 | flag = 1
45 | plt.ion()
46 | f = plt.figure()
47 |
48 | fliter_ax = f.add_subplot(111, projection='3d')
49 | plt.show()
50 | view_mat = np.array([[1.0, 0.0, 0.0],
51 | [0.0, -1.0, 0],
52 | [0.0, 0, -1.0]])
53 | mesh = open3d.geometry.TriangleMesh()
54 | hand_verts, j3d_recon = mano(pose0, shape.float())
55 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces)
56 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0]
57 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts)
58 | viewer = open3d.visualization.Visualizer()
59 | viewer.create_window(width=480, height=480, window_name='mesh')
60 | viewer.add_geometry(mesh)
61 | viewer.update_renderer()
62 |
63 | print('start pose estimate')
64 |
65 | pre_uv = None
66 | shape_time = 0
67 | opt_shape = None
68 | shape_flag = True
69 | while (cap.isOpened()):
70 | ret_flag, img = cap.read()
71 | input = np.flip(img.copy(), -1)
72 | k = cv2.waitKey(1) & 0xFF
73 | if input.shape[0] > input.shape[1]:
74 | margin = (input.shape[0] - input.shape[1]) // 2
75 | input = input[margin:-margin]
76 | else:
77 | margin = (input.shape[1] - input.shape[0]) // 2
78 | input = input[:, margin:-margin]
79 | img = input.copy()
80 | img = np.flip(img, -1)
81 | cv2.imshow("Capture_Test", img)
82 | input = cv2.resize(input, (128, 128))
83 | input = torch.tensor(input.transpose([2, 0, 1]), dtype=torch.float, device=device) # hwc -> chw
84 | input = func.normalize(input, [0.5, 0.5, 0.5], [1, 1, 1])
85 | result = module(input.unsqueeze(0))
86 |
87 | pre_joints = result['xyz'].squeeze(0)
88 | now_uv = result['uv'].clone().detach().cpu().numpy()[0, 0]
89 | now_uv = now_uv.astype(np.float)
90 | trans = np.zeros((1, 3))
91 | trans[0, 0:2] = now_uv - 16.0
92 | trans = trans / 16.0
93 | new_tran = np.array([[trans[0, 1], trans[0, 0], trans[0, 2]]])
94 | pre_joints = pre_joints.clone().detach().cpu().numpy()
95 |
96 | flited_joints = point_fliter.process(pre_joints)
97 |
98 | fliter_ax.cla()
99 |
100 | filted_ax = vis.plot3d(flited_joints + new_tran, fliter_ax)
101 | pre_useful_bone_len = bone.caculate_length(pre_joints, label="useful")
102 |
103 | NGEN = 100
104 | popsize = 100
105 | low = np.zeros((1, 10)) - 3.0
106 | up = np.zeros((1, 10)) + 3.0
107 | parameters = [NGEN, popsize, low, up]
108 | pso = PSO(parameters, pre_useful_bone_len.reshape((1, 15)),_mano_root)
109 | pso.main()
110 | opt_shape = pso.ng_best
111 | opt_shape = shape_fliter.process(opt_shape)
112 |
113 | opt_tensor_shape = torch.tensor(opt_shape, dtype=torch.float)
114 | _, j3d_p0_ops = mano(pose0, opt_tensor_shape)
115 | template = j3d_p0_ops.cpu().numpy().squeeze(0) / 1000.0 # template, m 21*3
116 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(pre_joints[9] - pre_joints[0])
117 | j3d_pre_process = pre_joints * ratio # template, m
118 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0]
119 | pose_R = AIK.adaptive_IK(template, j3d_pre_process)
120 | pose_R = torch.from_numpy(pose_R).float()
121 | # reconstruction
122 | hand_verts, j3d_recon = mano(pose_R, opt_tensor_shape.float())
123 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces)
124 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0]
125 | hand_verts = mesh_fliter.process(hand_verts)
126 | hand_verts = np.matmul(view_mat, hand_verts.T).T
127 | hand_verts[:, 0] = hand_verts[:, 0] - 50
128 | hand_verts[:, 1] = hand_verts[:, 1] - 50
129 | mesh_tran = np.array([[-new_tran[0, 0], new_tran[0, 1], new_tran[0, 2]]])
130 | hand_verts = hand_verts - 100 * mesh_tran
131 |
132 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts)
133 | mesh.paint_uniform_color([228 / 255, 178 / 255, 148 / 255])
134 | mesh.compute_triangle_normals()
135 | mesh.compute_vertex_normals()
136 | viewer.update_geometry(mesh)
137 | viewer.poll_events()
138 | if k == ord('q'):
139 | break
140 | cap.release()
141 | cv2.destroyAllWindows()
142 |
--------------------------------------------------------------------------------
/demo_dl.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | from manopth import manolayer
4 | from model.detnet import detnet
5 | from utils import func, bone, AIK, smoother
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from utils import vis
9 | from op_pso import PSO
10 | import open3d
11 | from model import shape_net
12 | import os
13 |
14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15 | _mano_root = 'mano/models'
16 |
17 | module = detnet().to(device)
18 | print('load model start')
19 | check_point = torch.load('new_check_point/ckp_detnet_83.pth', map_location=device)
20 | model_state = module.state_dict()
21 | state = {}
22 | for k, v in check_point.items():
23 | if k in model_state:
24 | state[k] = v
25 | else:
26 | print(k, ' is NOT in current model')
27 | model_state.update(state)
28 | module.load_state_dict(model_state)
29 | print('load model finished')
30 |
31 | shape_model = shape_net.ShapeNet()
32 | shape_net.load_checkpoint(
33 | shape_model, os.path.join('checkpoints', 'ckp_siknet_synth_41.pth.tar')
34 | )
35 | for params in shape_model.parameters():
36 | params.requires_grad = False
37 |
38 | pose, shape = func.initiate("zero")
39 | pre_useful_bone_len = np.zeros((1, 15))
40 | pose0 = torch.eye(3).repeat(1, 16, 1, 1)
41 |
42 | mano = manolayer.ManoLayer(flat_hand_mean=True,
43 | side="right",
44 | mano_root=_mano_root,
45 | use_pca=False,
46 | root_rot_mode='rotmat',
47 | joint_rot_mode='rotmat')
48 | print('start opencv')
49 | point_fliter = smoother.OneEuroFilter(4.0, 0.0)
50 | mesh_fliter = smoother.OneEuroFilter(4.0, 0.0)
51 | shape_fliter = smoother.OneEuroFilter(4.0, 0.0)
52 | cap = cv2.VideoCapture(0)
53 | print('opencv finished')
54 | flag = 1
55 | plt.ion()
56 | f = plt.figure()
57 |
58 | fliter_ax = f.add_subplot(111, projection='3d')
59 | plt.show()
60 | view_mat = np.array([[1.0, 0.0, 0.0],
61 | [0.0, -1.0, 0],
62 | [0.0, 0, -1.0]])
63 | mesh = open3d.geometry.TriangleMesh()
64 | hand_verts, j3d_recon = mano(pose0, shape.float())
65 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces)
66 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0]
67 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts)
68 | viewer = open3d.visualization.Visualizer()
69 | viewer.create_window(width=480, height=480, window_name='mesh')
70 | viewer.add_geometry(mesh)
71 | viewer.update_renderer()
72 |
73 | print('start pose estimate')
74 |
75 | pre_uv = None
76 | shape_time = 0
77 | opt_shape = None
78 | shape_flag = True
79 | while (cap.isOpened()):
80 | ret_flag, img = cap.read()
81 | input = np.flip(img.copy(), -1)
82 | k = cv2.waitKey(1) & 0xFF
83 | if input.shape[0] > input.shape[1]:
84 | margin = (input.shape[0] - input.shape[1]) // 2
85 | input = input[margin:-margin]
86 | else:
87 | margin = (input.shape[1] - input.shape[0]) // 2
88 | input = input[:, margin:-margin]
89 | img = input.copy()
90 | img = np.flip(img, -1)
91 | cv2.imshow("Capture_Test", img)
92 | input = cv2.resize(input, (128, 128))
93 | input = torch.tensor(input.transpose([2, 0, 1]), dtype=torch.float, device=device) # hwc -> chw
94 | input = func.normalize(input, [0.5, 0.5, 0.5], [1, 1, 1])
95 | result = module(input.unsqueeze(0))
96 |
97 | pre_joints = result['xyz'].squeeze(0)
98 | now_uv = result['uv'].clone().detach().cpu().numpy()[0, 0]
99 | now_uv = now_uv.astype(np.float)
100 | trans = np.zeros((1, 3))
101 | trans[0, 0:2] = now_uv - 16.0
102 | trans = trans / 16.0
103 | new_tran = np.array([[trans[0, 1], trans[0, 0], trans[0, 2]]])
104 | pre_joints = pre_joints.clone().detach().cpu().numpy()
105 |
106 | flited_joints = point_fliter.process(pre_joints)
107 |
108 | fliter_ax.cla()
109 |
110 | filted_ax = vis.plot3d(flited_joints + new_tran, fliter_ax)
111 | pre_useful_bone_len = bone.caculate_length(pre_joints, label="useful")
112 |
113 | shape_model_input = torch.tensor(pre_useful_bone_len, dtype=torch.float)
114 | shape_model_input = shape_model_input.reshape((1, 15))
115 | dl_shape = shape_model(shape_model_input)
116 | dl_shape = dl_shape['beta'].numpy()
117 | dl_shape = shape_fliter.process(dl_shape)
118 | opt_tensor_shape = torch.tensor(dl_shape, dtype=torch.float)
119 | _, j3d_p0_ops = mano(pose0, opt_tensor_shape)
120 | template = j3d_p0_ops.cpu().numpy().squeeze(0) / 1000.0 # template, m 21*3
121 | ratio = np.linalg.norm(template[9] - template[0]) / np.linalg.norm(pre_joints[9] - pre_joints[0])
122 | j3d_pre_process = pre_joints * ratio # template, m
123 | j3d_pre_process = j3d_pre_process - j3d_pre_process[0] + template[0]
124 | pose_R = AIK.adaptive_IK(template, j3d_pre_process)
125 | pose_R = torch.from_numpy(pose_R).float()
126 | # reconstruction
127 | hand_verts, j3d_recon = mano(pose_R, opt_tensor_shape.float())
128 | mesh.triangles = open3d.utility.Vector3iVector(mano.th_faces)
129 | hand_verts = hand_verts.clone().detach().cpu().numpy()[0]
130 | hand_verts = mesh_fliter.process(hand_verts)
131 | hand_verts = np.matmul(view_mat, hand_verts.T).T
132 | hand_verts[:, 0] = hand_verts[:, 0] - 50
133 | hand_verts[:, 1] = hand_verts[:, 1] - 50
134 | mesh_tran = np.array([[-new_tran[0, 0], new_tran[0, 1], new_tran[0, 2]]])
135 | hand_verts = hand_verts - 100 * mesh_tran
136 |
137 | mesh.vertices = open3d.utility.Vector3dVector(hand_verts)
138 | mesh.paint_uniform_color([228 / 255, 178 / 255, 148 / 255])
139 | mesh.compute_triangle_normals()
140 | mesh.compute_vertex_normals()
141 | viewer.update_geometry(mesh)
142 | viewer.poll_events()
143 | if k == ord('q'):
144 | break
145 | cap.release()
146 | cv2.destroyAllWindows()
147 |
--------------------------------------------------------------------------------
/dl_shape_estimate.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | import create_data
6 | from model import shape_net
7 |
8 | import numpy as np
9 |
10 |
11 |
12 | def align_bone_len(opt_, pre_):
13 | opt = opt_.copy()
14 | pre = pre_.copy()
15 |
16 | opt_align = opt.copy()
17 | for i in range(opt.shape[0]):
18 | ratio = pre[i][6] / opt[i][6]
19 | opt_align[i] = ratio * opt_align[i]
20 |
21 | err = np.abs(opt_align - pre).mean(0)
22 |
23 | return err
24 |
25 | def fun(_shape, _label, data_loader):
26 | # 计算相对骨骼长度
27 | shape = _shape.clone().detach()
28 | label = _label.detach().clone()
29 | # 根据shape计算相对骨骼长度
30 | X = data_loader.new_cal_ref_bone(shape)
31 | err = align_bone_len(X.cpu().numpy(), label.cpu().numpy())
32 | return err.sum()
33 |
34 | checkpoint = 'checkpoints'
35 |
36 | model = shape_net.ShapeNet()
37 | shape_net.load_checkpoint(
38 | model, os.path.join(checkpoint, 'ckp_siknet_synth_41.pth.tar')
39 | )
40 | for params in model.parameters():
41 | params.requires_grad = False
42 |
43 | data_set = ['rhd', 'stb', 'do', 'eo']
44 | temp_data = create_data.DataSet(_mano_root='mano/models')
45 | for data in data_set:
46 | print('*' * 20)
47 | print('加载' + data + '数据集')
48 | print('*' * 20)
49 | # 加载预测
50 | pre_path = os.path.join('out_testset/', data + '_pre_joints.npy')
51 | temp = np.load(pre_path)
52 | temp = torch.Tensor(temp)
53 | _x = temp_data.cal_ref_bone(temp)
54 | # 模型回归shape
55 | Y = model(_x)
56 | Y = Y['beta']
57 | np.save('out_testset/' + data + '_dl.npy', Y.clone().detach().cpu().numpy())
58 | dl_err = fun(Y, _x, temp_data)
59 | print('回归误差:{}'.format(dl_err))
60 |
61 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: minimal-hand-torch
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - blas=1.0=mkl
8 | - bzip2=1.0.8=h7b6447c_0
9 | - ca-certificates=2021.1.19=h06a4308_0
10 | - cairo=1.14.12=h8948797_3
11 | - certifi=2020.12.5=py37h06a4308_0
12 | - cffi=1.14.5=py37h261ae71_0
13 | - cudatoolkit=10.0.130=0
14 | - cycler=0.10.0=py37_0
15 | - dbus=1.13.18=hb2f20db_0
16 | - expat=2.2.10=he6710b0_2
17 | - ffmpeg=4.0=hcdf2ecd_0
18 | - fontconfig=2.13.1=h6c09931_0
19 | - freeglut=3.0.0=hf484d3e_5
20 | - freetype=2.10.4=h5ab3b9f_0
21 | - glib=2.67.4=h36276a3_1
22 | - graphite2=1.3.14=h23475e2_0
23 | - gst-plugins-base=1.14.0=h8213a91_2
24 | - gstreamer=1.14.0=h28cd5cc_2
25 | - harfbuzz=1.8.8=hffaf4a1_0
26 | - hdf5=1.10.2=hba1933b_1
27 | - icu=58.2=he6710b0_3
28 | - intel-openmp=2020.2=254
29 | - jasper=2.0.14=h07fcdf6_1
30 | - jpeg=9b=h024ee3a_2
31 | - kiwisolver=1.3.1=py37h2531618_0
32 | - lcms2=2.11=h396b838_0
33 | - ld_impl_linux-64=2.33.1=h53a641e_7
34 | - libedit=3.1.20191231=h14c3975_1
35 | - libffi=3.3=he6710b0_2
36 | - libgcc-ng=9.1.0=hdf63c60_0
37 | - libgfortran-ng=7.3.0=hdf63c60_0
38 | - libglu=9.0.0=hf484d3e_1
39 | - libopencv=3.4.2=hb342d67_1
40 | - libopus=1.3.1=h7b6447c_0
41 | - libpng=1.6.37=hbc83047_0
42 | - libstdcxx-ng=9.1.0=hdf63c60_0
43 | - libtiff=4.1.0=h2733197_1
44 | - libuuid=1.0.3=h1bed415_2
45 | - libuv=1.40.0=h7b6447c_0
46 | - libvpx=1.7.0=h439df22_0
47 | - libxcb=1.14=h7b6447c_0
48 | - libxml2=2.9.10=hb55368b_3
49 | - lz4-c=1.9.3=h2531618_0
50 | - matplotlib=3.3.4=py37h06a4308_0
51 | - matplotlib-base=3.3.4=py37h62a2d02_0
52 | - mkl=2020.2=256
53 | - mkl-service=2.3.0=py37he8ac12f_0
54 | - mkl_fft=1.3.0=py37h54f3939_0
55 | - mkl_random=1.1.1=py37h0573a6f_0
56 | - ncurses=6.2=he6710b0_1
57 | - ninja=1.10.2=py37hff7bd54_0
58 | - numpy=1.19.2=py37h54aff64_0
59 | - numpy-base=1.19.2=py37hfa32c7d_0
60 | - olefile=0.46=py_0
61 | - opencv=3.4.2=py37h6fd60c2_1
62 | - openssl=1.1.1j=h27cfd23_0
63 | - pandas=1.2.2=py37ha9443f7_0
64 | - pcre=8.44=he6710b0_0
65 | - pillow=8.1.0=py37he98fc37_0
66 | - pip=21.0.1=py37h06a4308_0
67 | - pixman=0.40.0=h7b6447c_0
68 | - py-opencv=3.4.2=py37hb342d67_1
69 | - pycparser=2.20=py_2
70 | - pyparsing=2.4.7=pyhd3eb1b0_0
71 | - pyqt=5.9.2=py37h05f1152_2
72 | - python=3.7.10=hdb3f193_0
73 | - python-dateutil=2.8.1=pyhd3eb1b0_0
74 | - pytz=2021.1=pyhd3eb1b0_0
75 | - qt=5.9.7=h5867ecd_1
76 | - readline=8.1=h27cfd23_0
77 | - scipy=1.6.1=py37h91f5cce_0
78 | - setuptools=52.0.0=py37h06a4308_0
79 | - sip=4.19.8=py37hf484d3e_0
80 | - six=1.15.0=pyhd3eb1b0_0
81 | - sqlite=3.33.0=h62c20be_0
82 | - termcolor=1.1.0=py37h06a4308_1
83 | - tk=8.6.10=hbc83047_0
84 | - tornado=6.1=py37h27cfd23_0
85 | - tqdm=4.56.0=pyhd3eb1b0_0
86 | - typing_extensions=3.7.4.3=pyha847dfd_0
87 | - wheel=0.36.2=pyhd3eb1b0_0
88 | - xz=5.2.5=h7b6447c_0
89 | - zlib=1.2.11=h7b6447c_3
90 | - zstd=1.4.5=h9ceee32_0
91 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0
92 | - torchvision=0.4.0=py37_cu100
93 | - pip:
94 | - art==3.7
95 | - chumpy==0.70
96 | - coverage==4.5.3
97 | - einops==0.3.0
98 | - manopth==0.0.1
99 | - mxnet==1.6.0
100 | - progress==1.5
101 | - torch==1.2.0
102 | - transforms3d==0.3.1
103 | - git+https://github.com/hassony2/manopth.git
104 |
105 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .detloss import *
2 |
--------------------------------------------------------------------------------
/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/detloss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/losses/__pycache__/detloss.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/detloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as torch_f
3 |
4 |
5 | class DetLoss:
6 | def __init__(
7 | self,
8 | lambda_hm=100,
9 | lambda_dm=1.0,
10 | lambda_lm=1.0,
11 |
12 | ):
13 | self.lambda_hm = lambda_hm
14 | self.lambda_dm = lambda_dm
15 | self.lambda_lm = lambda_lm
16 |
17 | def compute_loss(self, preds, targs, infos):
18 |
19 | hm_veil = infos['hm_veil']
20 | batch_size = infos['batch_size']
21 | flag = targs['flag_3d']
22 | batch_3d_size = flag.sum()
23 |
24 | flag = flag.bool()
25 |
26 | final_loss = torch.Tensor([0]).cuda()
27 | det_losses = {}
28 |
29 | pred_hm = preds['h_map']
30 | pred_dm = preds['d_map'][flag]
31 | pred_lm = preds['l_map'][flag]
32 |
33 | targ_hm = targs['hm'] # B*21*32*32
34 |
35 | targ_hm_tile = \
36 | targ_hm.unsqueeze(2).expand(targ_hm.size(0), targ_hm.size(1), 3, targ_hm.size(2), targ_hm.size(3),
37 | )[flag] # B'*21*3*32*32
38 | targ_dm = targs['dm'][flag]
39 | targ_lm = targs['lm'][flag]
40 |
41 | # compute hmloss anyway
42 | hm_loss = torch.Tensor([0]).cuda()
43 | if self.lambda_hm:
44 | hm_veil = hm_veil.unsqueeze(-1)
45 | njoints = pred_hm.size(1)
46 | pred_hm = pred_hm.reshape((batch_size, njoints, -1)).split(1, 1)
47 | targ_hm = targ_hm.reshape((batch_size, njoints, -1)).split(1, 1)
48 | for idx in range(njoints):
49 | pred_hmapi = pred_hm[idx].squeeze() # (B, 1, 1024)->(B, 1024)
50 | targ_hmi = targ_hm[idx].squeeze()
51 | hm_loss += 0.5 * torch_f.mse_loss(
52 | pred_hmapi.mul(hm_veil[:, idx]), # (B, 1024) mul (B, 1)
53 | targ_hmi.mul(hm_veil[:, idx])
54 | ) # mse calculate the loss of every sample (in fact it calculate minbacth_loss/32*32 )
55 | final_loss += self.lambda_hm * hm_loss
56 | det_losses["det_hm"] = hm_loss
57 |
58 | # compute dm loss
59 | loss_dm = torch.Tensor([0]).cuda()
60 | if self.lambda_dm:
61 | loss_dm = torch.norm(
62 | (pred_dm - targ_dm) * targ_hm_tile) / batch_3d_size # loss of every sample
63 | final_loss += self.lambda_dm * loss_dm
64 | det_losses["det_dm"] = loss_dm
65 |
66 | # compute lm loss
67 | loss_lm = torch.Tensor([0]).cuda()
68 | if self.lambda_lm:
69 | loss_lm = torch.norm(
70 | (pred_lm - targ_lm) * targ_hm_tile) / batch_3d_size # loss of every sample
71 | final_loss += self.lambda_lm * loss_lm
72 | det_losses["det_lm"] = loss_lm
73 |
74 | det_losses["det_total"] = final_loss
75 |
76 | return final_loss, det_losses, batch_3d_size
77 |
--------------------------------------------------------------------------------
/losses/shape_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as torch_f
5 |
6 |
7 |
8 | class SIKLoss:
9 | def __init__(
10 | self,
11 | lambda_joint=1.0,
12 | lambda_shape=1.0
13 | ):
14 | self.lambda_joint = lambda_joint
15 | self.lambda_shape = lambda_shape
16 |
17 | def compute_loss(self, preds, targs):
18 | batch_size = targs['batch_size']
19 | final_loss = torch.Tensor([0]).cuda()
20 | invk_losses = {}
21 |
22 | if self.lambda_joint:
23 | joint_loss = torch_f.mse_loss(
24 | 1000 * preds['jointRS'] * targs['joint_bone'].unsqueeze(1),
25 | 1000 * targs['jointRS'] * targs['joint_bone'].unsqueeze(1)
26 | )
27 | final_loss += self.lambda_joint * joint_loss
28 | else:
29 | joint_loss = None
30 | invk_losses["joint"] = joint_loss
31 |
32 | if self.lambda_shape:
33 | # shape_reg_loss = 10.0 * torch_f.mse_loss(
34 | # preds["beta"],
35 | # torch.zeros_like(preds["beta"])
36 | # )
37 | shape_reg_loss = torch.norm(preds['beta'], dim=-1, keepdim=True)
38 | shape_reg_loss = torch.pow(shape_reg_loss, 2.0)
39 | shape_reg_loss = torch.mean(shape_reg_loss)
40 |
41 | pred_rel_len = preds['bone_len_hat']
42 |
43 | # kin_len_loss = torch_f.mse_loss(
44 | # pred_rel_len.reshape(batch_size, -1),
45 | # targs['rel_bone_len'].reshape(batch_size, -1)
46 | # )
47 | kin_len_loss = torch.norm(pred_rel_len -
48 | targs['rel_bone_len'].reshape(batch_size, -1),
49 | dim=-1, keepdim=True)
50 | kin_len_loss = torch.pow(kin_len_loss, 2.0)
51 | kin_len_loss = torch.mean(kin_len_loss)
52 | shape_total_loss = kin_len_loss + 1e-3 * shape_reg_loss
53 | final_loss += self.lambda_shape * shape_total_loss
54 | else:
55 | shape_reg_loss, kin_len_loss = None, None
56 | invk_losses['shape_reg'] = shape_reg_loss
57 | invk_losses['bone_len'] = kin_len_loss
58 |
59 | return final_loss, invk_losses
60 |
--------------------------------------------------------------------------------
/manopth/rotproj.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def batch_rotprojs(batches_rotmats):
5 | proj_rotmats = []
6 | for batch_idx, batch_rotmats in enumerate(batches_rotmats):
7 | proj_batch_rotmats = []
8 | for rot_idx, rotmat in enumerate(batch_rotmats):
9 | # GPU implementation of svd is VERY slow
10 | # ~ 2 10^-3 per hit vs 5 10^-5 on cpu
11 | _device = rotmat.device
12 | U, S, V = rotmat.cpu().svd()
13 | rotmat = torch.matmul(U, V.transpose(0, 1))
14 | orth_det = rotmat.det()
15 | # Remove reflection
16 | if orth_det < 0:
17 | rotmat[:, 2] = -1 * rotmat[:, 2]
18 |
19 | rotmat = rotmat.to(_device)
20 | proj_batch_rotmats.append(rotmat)
21 | proj_rotmats.append(torch.stack(proj_batch_rotmats))
22 | return torch.stack(proj_rotmats)
23 |
--------------------------------------------------------------------------------
/model/detnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .detnet import detnet
2 | __all__ =['detnet']
--------------------------------------------------------------------------------
/model/detnet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/detnet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/model/detnet/__pycache__/detnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/detnet/__pycache__/detnet.cpython-37.pyc
--------------------------------------------------------------------------------
/model/detnet/detnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | detnet based on PyTorch
3 | this is modified from https://github.com/lingtengqiu/Minimal-Hand
4 | '''
5 | import sys
6 |
7 | import torch
8 |
9 | sys.path.append("./")
10 | from torch import nn
11 | from einops import rearrange, repeat
12 | from model.helper import resnet50, conv3x3
13 | import numpy as np
14 |
15 |
16 | # my modification
17 | def get_pose_tile_torch(N):
18 | pos_tile = np.expand_dims(
19 | np.stack(
20 | [
21 | np.tile(np.linspace(-1, 1, 32).reshape([1, 32]), [32, 1]),
22 | np.tile(np.linspace(-1, 1, 32).reshape([32, 1]), [1, 32])
23 | ], -1
24 | ), 0
25 | )
26 | pos_tile = np.tile(pos_tile, (N, 1, 1, 1))
27 | retv = torch.from_numpy(pos_tile).float()
28 | return rearrange(retv, 'b h w c -> b c h w')
29 |
30 |
31 | class net_2d(nn.Module):
32 | def __init__(self, input_features, output_features, stride, joints=21):
33 | super().__init__()
34 | self.project = nn.Sequential(conv3x3(input_features, output_features, stride), nn.BatchNorm2d(output_features),
35 | nn.ReLU())
36 |
37 | self.prediction = nn.Conv2d(output_features, joints, 1, 1, 0)
38 |
39 | def forward(self, x):
40 | x = self.project(x)
41 | x = self.prediction(x).sigmoid()
42 | return x
43 |
44 |
45 | class net_3d(nn.Module):
46 | def __init__(self, input_features, output_features, stride, joints=21, need_norm=False):
47 | super().__init__()
48 | self.need_norm = need_norm
49 | self.project = nn.Sequential(conv3x3(input_features, output_features, stride), nn.BatchNorm2d(output_features),
50 | nn.ReLU())
51 | self.prediction = nn.Conv2d(output_features, joints * 3, 1, 1, 0)
52 |
53 | def forward(self, x):
54 | x = self.prediction(self.project(x))
55 |
56 | dmap = rearrange(x, 'b (j l) h w -> b j l h w', l=3)
57 |
58 | return dmap
59 |
60 |
61 | class detnet(nn.Module):
62 | def __init__(self, stacks=1):
63 | super().__init__()
64 | self.resnet50 = resnet50()
65 |
66 | self.hmap_0 = net_2d(258, 256, 1)
67 | self.dmap_0 = net_3d(279, 256, 1)
68 | self.lmap_0 = net_3d(342, 256, 1)
69 | self.stacks = stacks
70 |
71 | def forward(self, x):
72 | features = self.resnet50(x)
73 |
74 | device = x.device
75 | pos_tile = get_pose_tile_torch(features.shape[0]).to(device)
76 |
77 | x = torch.cat([features, pos_tile], dim=1)
78 |
79 | hmaps = []
80 | dmaps = []
81 | lmaps = []
82 |
83 | for _ in range(self.stacks):
84 | heat_map = self.hmap_0(x)
85 | hmaps.append(heat_map)
86 | x = torch.cat([x, heat_map], dim=1)
87 |
88 | dmap = self.dmap_0(x)
89 | dmaps.append(dmap)
90 |
91 | x = torch.cat([x, rearrange(dmap, 'b j l h w -> b (j l) h w')], dim=1)
92 |
93 | lmap = self.lmap_0(x)
94 | lmaps.append(lmap)
95 | hmap, dmap, lmap = hmaps[-1], dmaps[-1], lmaps[-1]
96 |
97 | uv, argmax = self.map_to_uv(hmap)
98 |
99 | delta = self.dmap_to_delta(dmap, argmax)
100 | xyz = self.lmap_to_xyz(lmap, argmax)
101 |
102 | det_result = {
103 | "h_map": hmap,
104 | "d_map": dmap,
105 | "l_map": lmap,
106 | "delta": delta,
107 | "xyz": xyz,
108 | "uv": uv
109 | }
110 |
111 | return det_result
112 |
113 | @property
114 | def pos(self):
115 | return self.__pos_tile
116 |
117 | @staticmethod
118 | def map_to_uv(hmap):
119 | b, j, h, w = hmap.shape
120 | hmap = rearrange(hmap, 'b j h w -> b j (h w)')
121 | argmax = torch.argmax(hmap, -1, keepdim=True)
122 | u = argmax // w
123 | v = argmax % w
124 | uv = torch.cat([u, v], dim=-1)
125 |
126 | return uv, argmax
127 |
128 | @staticmethod
129 | def dmap_to_delta(dmap, argmax):
130 | return detnet.lmap_to_xyz(dmap, argmax)
131 |
132 | @staticmethod
133 | def lmap_to_xyz(lmap, argmax):
134 | lmap = rearrange(lmap, 'b j l h w -> b j (h w) l')
135 | index = repeat(argmax, 'b j i -> b j i c', c=3)
136 | xyz = torch.gather(lmap, dim=2, index=index).squeeze(2)
137 | return xyz
138 |
139 |
140 | if __name__ == '__main__':
141 | mydet = detnet()
142 | img_crop = torch.randn(10, 3, 128, 128)
143 | res = mydet(img_crop)
144 |
145 | hmap = res["h_map"]
146 | dmap = res["d_map"]
147 | lmap = res["l_map"]
148 | delta = res["delta"]
149 | xyz = res["xyz"]
150 | uv = res["uv"]
151 |
152 | print("hmap.shape=", hmap.shape)
153 | print("dmap.shape=", dmap.shape)
154 | print("lmap.shape=", lmap.shape)
155 | print("delta.shape=", delta.shape)
156 | print("xyz.shape=", xyz.shape)
157 | print("uv.shape=", uv.shape)
158 |
--------------------------------------------------------------------------------
/model/helper/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_helper import *
2 | __all__=['resnet50','conv3x3','conv1x1']
--------------------------------------------------------------------------------
/model/helper/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/helper/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/model/helper/__pycache__/resnet_helper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/model/helper/__pycache__/resnet_helper.cpython-37.pyc
--------------------------------------------------------------------------------
/model/shape_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Lixin YANG. All Rights Reserved.
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as torch_f
9 | from manopth.manolayer import ManoLayer
10 |
11 |
12 | class ShapeNet(nn.Module):
13 | def __init__(
14 | self,
15 | dropout=0,
16 | _mano_root='mano/models'
17 | ):
18 | super(ShapeNet, self).__init__()
19 |
20 | ''' shape '''
21 | hidden_neurons = [128, 256, 512, 256, 128]
22 | in_neurons = 15
23 | out_neurons = 10
24 | neurons = [in_neurons] + hidden_neurons
25 |
26 | shapereg_layers = []
27 | for layer_idx, (inps, outs) in enumerate(
28 | zip(neurons[:-1], neurons[1:])
29 | ):
30 | if dropout:
31 | shapereg_layers.append(nn.Dropout(p=dropout))
32 | shapereg_layers.append(nn.Linear(inps, outs))
33 | shapereg_layers.append(nn.ReLU())
34 |
35 | shapereg_layers.append(nn.Linear(neurons[-1], out_neurons))
36 | self.shapereg_layers = nn.Sequential(*shapereg_layers)
37 | args = {'flat_hand_mean': True, 'root_rot_mode': 'axisang',
38 | 'ncomps': 45, 'mano_root': _mano_root,
39 | 'no_pca': True, 'joint_rot_mode': 'axisang', 'side': 'right'}
40 | self.mano_layer = ManoLayer(flat_hand_mean=args['flat_hand_mean'],
41 | side=args['side'],
42 | mano_root=args['mano_root'],
43 | ncomps=args['ncomps'],
44 | use_pca=not args['no_pca'],
45 | root_rot_mode=args['root_rot_mode'],
46 | joint_rot_mode=args['joint_rot_mode']
47 | )
48 |
49 | def new_cal_ref_bone(self, _shape):
50 | parent_index = [0,
51 | 0, 1, 2,
52 | 0, 4, 5,
53 | 0, 7, 8,
54 | 0, 10, 11,
55 | 0, 13, 14
56 | ]
57 | index = [0,
58 | 1, 2, 3, # index
59 | 4, 5, 6, # middle
60 | 7, 8, 9, # pinky
61 | 10, 11, 12, # ring
62 | 13, 14, 15] # thumb
63 | reoder_index = [
64 | 13, 14, 15,
65 | 1, 2, 3,
66 | 4, 5, 6,
67 | 10, 11, 12,
68 | 7, 8, 9]
69 | shape = _shape
70 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs,
71 | shape.transpose(1, 0)).permute(2, 0, 1) \
72 | + self.mano_layer.th_v_template
73 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped)
74 | temp1 = th_j
75 | temp2 = th_j[:, parent_index, :]
76 | result = temp1 - temp2
77 | ref_len = th_j[:, [4], :] - th_j[:, [0], :]
78 | ref_len = torch.norm(ref_len, dim=-1, keepdim=True)
79 | result = torch.norm(result, dim=-1, keepdim=True)
80 | result = result / ref_len
81 | return torch.squeeze(result, dim=-1)[:, reoder_index]
82 |
83 | def forward(self, bone_len):
84 | beta = self.shapereg_layers(bone_len)
85 | beta = torch.tanh(beta)
86 | bone_len_hat = self.new_cal_ref_bone(beta)
87 |
88 | results = {
89 | 'beta': beta,
90 | 'bone_len_hat': bone_len_hat
91 | }
92 | return results
93 |
94 | def save_checkpoint(
95 | state,
96 | checkpoint='checkpoint',
97 | filename='checkpoint.pth.tar',
98 | snapshot=None,
99 | is_best=False
100 | ):
101 | # preds = to_numpy(preds)
102 | filepath = os.path.join(checkpoint, filename)
103 | fileprefix = filename.split('.')[0]
104 | torch.save(state, filepath)
105 |
106 | if snapshot and state['epoch'] % snapshot == 0:
107 | shutil.copyfile(
108 | filepath,
109 | os.path.join(
110 | checkpoint,
111 | '{}_{}.pth.tar'.format(fileprefix, state['epoch'])
112 | )
113 | )
114 |
115 | if is_best:
116 | shutil.copyfile(
117 | filepath,
118 | os.path.join(
119 | checkpoint,
120 | '{}_best.pth.tar'.format(fileprefix)
121 | )
122 | )
123 |
124 | def load_checkpoint(model, checkpoint):
125 | name = checkpoint
126 | checkpoint = torch.load(name)
127 | pretrain_dict = clean_state_dict(checkpoint['state_dict'])
128 | model_state = model.state_dict()
129 | state = {}
130 | for k, v in pretrain_dict.items():
131 | if k in model_state:
132 | state[k] = v
133 | else:
134 | print(k, ' is NOT in current model')
135 | model_state.update(state)
136 | model.load_state_dict(model_state)
137 |
138 | def clean_state_dict(state_dict):
139 | """save a cleaned version of model without dict and DataParallel
140 |
141 | Arguments:
142 | state_dict {collections.OrderedDict} -- [description]
143 |
144 | Returns:
145 | clean_model {collections.OrderedDict} -- [description]
146 | """
147 |
148 | clean_model = state_dict
149 | # create new OrderedDict that does not contain `module.`
150 | from collections import OrderedDict
151 | clean_model = OrderedDict()
152 | if any(key.startswith('module') for key in state_dict):
153 | for k, v in state_dict.items():
154 | name = k[7:] # remove `module.`
155 | clean_model[name] = v
156 | else:
157 | return state_dict
158 |
159 | return clean_model
160 |
161 | if __name__ == '__main__':
162 | input = torch.rand((10, 15))
163 | model = ShapeNet()
164 | out_put = model(input)
165 | loss = torch.mean(out_put)
166 | loss.backward()
167 |
--------------------------------------------------------------------------------
/op_pso.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from manopth.manolayer import ManoLayer
6 | from tqdm import tqdm
7 |
8 | from optimize_shape import align_bone_len
9 | from utils import bone
10 | import matplotlib.pyplot as plt
11 |
12 | from utils.LM_new import LM_Solver
13 |
14 |
15 | class PSO:
16 | def __init__(self, parameters, target, _mano_root='mano/models'):
17 | """
18 | particle swarm optimization
19 | parameter: a list type, like [NGEN, pop_size, var_num_min, var_num_max]
20 | """
21 | self.mano_layer = ManoLayer(side="right",
22 | mano_root=_mano_root, use_pca=False, flat_hand_mean=True)
23 |
24 | # 初始化
25 | self.NGEN = parameters[0]
26 | self.pop_size = parameters[1]
27 | self.var_num = parameters[2].shape[1]
28 | self.bound = []
29 | self.bound.append(parameters[2])
30 | self.bound.append(parameters[3])
31 | self.set_target(target)
32 |
33 | def set_target(self, target):
34 | self.target = target.copy()
35 | self.pop_x = np.random.randn(self.pop_size, self.var_num)
36 | self.pop_v = np.random.random((self.pop_size, self.var_num))
37 | self.p_best = self.pop_x.copy()
38 | self.p_best_fit = self.batch_new_get_loss(self.pop_x)
39 | g_best_index = np.argmin(self.p_best_fit, axis=0)
40 | if g_best_index.shape[0] > 1:
41 | g_best_index = g_best_index[[0]]
42 | self.g_best = self.p_best[g_best_index].copy()
43 |
44 | def new_cal_ref_bone(self, _shape):
45 | parent_index = [0,
46 | 0, 1, 2,
47 | 0, 4, 5,
48 | 0, 7, 8,
49 | 0, 10, 11,
50 | 0, 13, 14
51 | ]
52 | index = [0,
53 | 1, 2, 3, # index
54 | 4, 5, 6, # middle
55 | 7, 8, 9, # pinky
56 | 10, 11, 12, # ring
57 | 13, 14, 15] # thumb
58 | reoder_index = [
59 | 13, 14, 15,
60 | 1, 2, 3,
61 | 4, 5, 6,
62 | 10, 11, 12,
63 | 7, 8, 9]
64 | shape = torch.Tensor(_shape.reshape((-1, 10)))
65 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs,
66 | shape.transpose(1, 0)).permute(2, 0, 1) \
67 | + self.mano_layer.th_v_template
68 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped)
69 | temp1 = th_j.clone().detach()
70 | temp2 = th_j.clone().detach()[:, parent_index, :]
71 | result = temp1 - temp2
72 | result = torch.norm(result, dim=-1, keepdim=True)
73 | ref_len = result[:, [4]]
74 | result = result / ref_len
75 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy()
76 |
77 | def batch_new_get_loss(self, beta_):
78 | weight = 1e-3
79 | beta = beta_.copy()
80 | temp = self.new_cal_ref_bone(beta)
81 | loss = np.linalg.norm(temp - self.target, axis=-1, keepdims=True) ** 2 + \
82 | weight * np.linalg.norm(beta, axis=-1, keepdims=True)
83 | return loss
84 |
85 | def update_operator(self, pop_size):
86 |
87 | c1 = 2
88 | c2 = 2
89 | w = 0.4
90 |
91 | self.pop_v = w * self.pop_v \
92 | + c1 * np.multiply(np.random.rand(pop_size, 1), (self.p_best - self.pop_x)) \
93 | + c2 * np.multiply(np.random.rand(pop_size, 1), (self.g_best - self.pop_x))
94 | self.pop_x = self.pop_x + self.pop_v
95 | low_flag = self.pop_x < self.bound[0]
96 | up_flag = self.pop_x > self.bound[1]
97 | self.pop_x[low_flag] = -3.0
98 | self.pop_x[up_flag] = 3.0
99 | temp = self.batch_new_get_loss(self.pop_x)
100 | p_best_flag = temp < self.p_best_fit
101 | p_best_flag = p_best_flag.reshape((pop_size,))
102 | self.p_best[p_best_flag] = self.pop_x[p_best_flag]
103 | self.p_best_fit[p_best_flag] = temp[p_best_flag]
104 | g_best_index = np.argmin(self.p_best_fit, axis=0)
105 | if g_best_index.shape[0] > 1:
106 | g_best_index = g_best_index[[0]]
107 | self.g_best = self.pop_x[g_best_index]
108 | self.g_best_fit = self.p_best_fit[g_best_index][0][0]
109 |
110 | def main(self, slover=None, return_err=False):
111 | best_fit = []
112 | self.ng_best = np.zeros((1, self.var_num))
113 | self.ng_best_fit = self.batch_new_get_loss(self.ng_best)[0][0]
114 | for gen in range(self.NGEN):
115 | self.update_operator(self.pop_size)
116 | # print('############ Generation {} ############'.format(str(gen + 1)))
117 | dis = self.g_best_fit - self.ng_best_fit
118 | if self.g_best_fit < self.ng_best_fit:
119 | self.ng_best = self.g_best.copy()
120 | self.ng_best_fit = self.g_best_fit
121 | if abs(dis) < 1e-6:
122 | break
123 | # print(':{}'.format(self.ng_best))
124 | # print(':{}'.format(self.ng_best_fit))
125 | # best_fit.append(self.ng_best_fit)
126 | # print("---- End of (successful) Searching ----")
127 | #
128 | # plt.figure()
129 | # plt.title("Figure1")
130 | # plt.xlabel("iterators", size=14)
131 | # plt.ylabel("fitness", size=14)
132 | # t = [t for t in range(self.NGEN)]
133 | # plt.plot(t, best_fit, color='b', linewidth=2)
134 | # plt.show()
135 | if return_err:
136 | err = solver.new_cal_ref_bone(self.ng_best)
137 | err = align_bone_len(err, self.target)
138 | return err
139 |
140 |
141 | if __name__ == '__main__':
142 | import time
143 |
144 | data_set = ['rhd', 'stb', 'do', 'eo']
145 | for data in data_set:
146 | solver = LM_Solver(num_Iter=500, th_beta=torch.zeros((1, 10)), th_pose=torch.zeros((1, 48)),
147 | lb_target=np.zeros((15, 1)),
148 | weight=1e-5)
149 |
150 | NGEN = 100
151 | popsize = 100
152 | low = np.zeros((1, 10)) - 3.0
153 | up = np.zeros((1, 10)) + 3.0
154 | parameters = [NGEN, popsize, low, up]
155 | err = np.zeros((1, 15))
156 | path = 'out_testset/' + data + '_pre_joints.npy'
157 | print('load:{}'.format(path))
158 | target = np.load(path)
159 | pso_shape = np.zeros((target.shape[0], 10))
160 | for i in tqdm(range(target.shape[0])):
161 | _target = target[[0]]
162 | _target = bone.caculate_length(_target, label='useful')
163 | _target = _target.reshape((1, 15))
164 | pso = PSO(parameters, _target)
165 | err += pso.main(slover=solver, return_err=True)
166 | pso_shape[[i]] = pso.ng_best
167 | print(err.sum() / target.shape[0])
168 | save_path = 'out_testset/' + data + '_pso.npy'
169 | print('save:{}'.format(save_path))
170 | np.save(save_path, pso_shape)
171 |
--------------------------------------------------------------------------------
/optimize_shape.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | from utils import func, bone
7 | from utils.LM import LM_Solver
8 |
9 |
10 | def align_bone_len(opt_, pre_):
11 | opt = opt_.copy()
12 | pre = pre_.copy()
13 |
14 | opt_align = opt.copy()
15 | for i in range(opt.shape[0]):
16 | ratio = pre[i][6] / opt[i][6]
17 | opt_align[i] = ratio * opt_align[i]
18 |
19 | err = np.abs(opt_align - pre).mean(0)
20 |
21 | return err
22 |
23 |
24 | def main(args):
25 | path=args.path
26 | for dataset in args.dataset:
27 | # load predictions (N*21*3)
28 | print("load {}'s joint 3D".format(dataset))
29 | pred_j3d = np.load("{}/{}_pre_joints.npy".format(path, dataset),allow_pickle=True)
30 |
31 | opt_shapes = []
32 | opt_bone_lens = []
33 | pre_useful_bone_lens = []
34 |
35 | # loop
36 | for pred in tqdm(pred_j3d):
37 | # 0 initialization
38 | pose, shape = func.initiate("zero")
39 |
40 | pre_useful_bone_len = bone.caculate_length(pred, label="useful")
41 | pre_useful_bone_lens.append(pre_useful_bone_len)
42 |
43 | # optimize here!
44 | solver = LM_Solver(num_Iter=500, th_beta=shape, th_pose=pose, lb_target=pre_useful_bone_len,
45 | weight=args.weight)
46 | opt_shape = solver.LM()
47 | opt_shapes.append(opt_shape)
48 |
49 | opt_bone_len = solver.get_bones(opt_shape)
50 | opt_bone_lens.append(opt_bone_len)
51 |
52 | # plt.plot(solver.get_result(), 'r')
53 | # plt.show()
54 |
55 | # break
56 |
57 | opt_shapes = np.array(opt_shapes).reshape(-1, 10)
58 | opt_bone_lens = np.array(opt_bone_lens).reshape(-1, 15)
59 | pre_useful_bone_lens = np.array(pre_useful_bone_lens).reshape(-1, 15)
60 |
61 | np.save("{}/{}_shapes.npy".format(path, dataset, args.weight), opt_shapes)
62 |
63 | error = align_bone_len(opt_bone_lens, pre_useful_bone_lens)
64 |
65 | print("dataset:{} weight:{} ERR sum: {}".format(dataset, args.weight, error.sum()))
66 |
67 |
68 | if __name__ == '__main__':
69 | parser = argparse.ArgumentParser(
70 | description='optimize shape params. of mano model ')
71 |
72 | # Dataset setting
73 | parser.add_argument(
74 | '-ds',
75 | "--dataset",
76 | nargs="+",
77 | default=['rhd', 'stb', 'do', 'eo'],
78 | type=str,
79 | help="sub datasets, should be listed in: [stb|rhd|do|eo]"
80 | )
81 | parser.add_argument(
82 | '-wt', '--weight',
83 | default=1e-5,
84 | type=float,
85 | metavar='weight',
86 | help='weight of L2 regularizer '
87 | )
88 |
89 | parser.add_argument(
90 | '-p',
91 | '--path',
92 | default='out_testset',
93 | type=str,
94 | metavar='data_root',
95 | help='directory')
96 |
97 | main(parser.parse_args())
98 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 |
7 | def main(args):
8 | path=args.out_path
9 |
10 | # losses in train
11 | lossD = np.load(os.path.join(path, "lossD.npy"))
12 | lossH = np.load(os.path.join(path, "lossH.npy"))
13 | lossL = np.load(os.path.join(path, "lossL.npy"))
14 |
15 | auc_all = np.load(os.path.join(path, "auc_all.npy"), allow_pickle=True).item()
16 | acc_hm_all = np.load(os.path.join(path, "acc_hm_all.npy"), allow_pickle=True).item()
17 |
18 | # rhd
19 | auc_all_rhd = np.array(auc_all['rhd'])
20 | acc_hm_rhd = np.array(acc_hm_all["rhd"])
21 |
22 | # stb
23 | auc_all_stb = np.array(auc_all['stb'])
24 | acc_hm_stb = np.array(acc_hm_all["stb"])
25 |
26 | # do
27 | auc_all_do = np.array(auc_all['do'])
28 |
29 | # eo
30 | auc_all_eo = np.array(auc_all['eo'])
31 |
32 | plt.figure(figsize=[50, 50])
33 |
34 | plt.subplot(2, 4, 1)
35 | plt.plot(lossH[:, :1], lossH[:, 1:], marker='o', label='lossH')
36 | plt.plot(lossD[:, :1], lossD[:, 1:], marker='*', label='lossD')
37 | plt.plot(lossL[:, :1], lossL[:, 1:], marker='h', label='lossL')
38 | plt.title("LOSSES")
39 | plt.legend(title='Losses Category:')
40 |
41 | # rhd
42 | plt.subplot(2, 4, 2)
43 | plt.plot(auc_all_rhd[:, :1], auc_all_rhd[:, 1:], marker='d')
44 | plt.title(
45 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("RHD", np.argmax(auc_all_rhd[:, 1:]) + 1,
46 | np.max(auc_all_rhd[:, 1:])))
47 |
48 | plt.subplot(2, 4, 3)
49 | plt.plot(acc_hm_rhd[:, :1], acc_hm_rhd[:, 1:], marker='d')
50 | plt.title(
51 | "{}_test || (EPOCH={} , ACC_HM={:0.4f})".format("RHD", np.argmax(acc_hm_rhd[:, 1:]) + 1,
52 | np.max(acc_hm_rhd[:, 1:])))
53 |
54 | # stb
55 | plt.subplot(2, 4, 4)
56 | plt.plot(auc_all_stb[:, :1], auc_all_stb[:, 1:], marker='d')
57 | plt.title(
58 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("STB", np.argmax(auc_all_stb[:, 1:]) + 1,
59 | np.max(auc_all_stb[:, 1:])))
60 |
61 | plt.subplot(2, 4, 5)
62 | plt.plot(acc_hm_stb[:, :1], acc_hm_stb[:, 1:], marker='d')
63 | plt.title(
64 | "{}_test || (EPOCH={} , ACC_HM={:0.4f})".format("STB", np.argmax(acc_hm_stb[:, 1:]) + 1,
65 | np.max(acc_hm_stb[:, 1:])))
66 |
67 | # do
68 | plt.subplot(2, 4, 6)
69 | plt.plot(auc_all_do[:, :1], auc_all_do[:, 1:], marker='d')
70 | plt.title(
71 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("DO", np.argmax(auc_all_do[:, 1:] + 1), np.max(auc_all_do[:, 1:])))
72 |
73 | # eo
74 | plt.subplot(2, 4, 7)
75 | plt.plot(auc_all_eo[:, :1], auc_all_eo[:, 1:], marker='d')
76 | plt.title(
77 | "{}_test || (EPOCH={} , AUC={:0.4f})".format("EO", np.argmax(auc_all_eo[:, 1:]) + 1, np.max(auc_all_eo[:, 1:])))
78 |
79 | # plt.savefig("vis_train.png")
80 | plt.show()
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser(
85 | description='Result')
86 | parser.add_argument(
87 | '-p',
88 | '--out_path',
89 | type=str,
90 | default="out_loss_auc",
91 | help='ouput path'
92 | )
93 | main(parser.parse_args())
94 |
95 |
--------------------------------------------------------------------------------
/train_shape_net.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import torch.nn.parallel
8 | import torch.optim
9 | import torch.utils.data
10 | from progress.bar import Bar
11 | from tensorboardX.writer import SummaryWriter
12 | from termcolor import cprint
13 |
14 | from model import shape_net
15 | from datasets import SIK1M
16 | from losses import shape_loss
17 | # select proper device to run
18 | from utils import misc
19 | from utils.eval.evalutils import AverageMeter
20 | import numpy as np
21 |
22 | writer = SummaryWriter('log')
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 | cudnn.benchmark = True
25 | steps = 0
26 |
27 |
28 |
29 | def print_args(args):
30 | opts = vars(args)
31 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow')
32 | for k, v in sorted(opts.items()):
33 | print("{:>30} : {}".format(k, v))
34 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow')
35 |
36 |
37 | def main(args):
38 | if not os.path.isdir(args.checkpoint):
39 | os.makedirs(args.checkpoint)
40 | print_args(args)
41 | print("\nCREATE NETWORK")
42 | model = shape_net.ShapeNet(_mano_root='mano/models')
43 | model = model.to(device)
44 |
45 | criterion = shape_loss.SIKLoss(
46 | lambda_joint=0.0,
47 | lambda_shape=1.0
48 | )
49 |
50 | optimizer = torch.optim.Adam(
51 | [
52 | {
53 | 'params': model.shapereg_layers.parameters(),
54 | 'initial_lr': args.learning_rate
55 | },
56 |
57 | ],
58 | lr=args.learning_rate,
59 | )
60 |
61 | train_dataset = SIK1M.SIK1M(
62 | data_root=args.data_root,
63 | data_split="train"
64 | )
65 |
66 | val_dataset = SIK1M.SIK1M(
67 | data_root=args.data_root,
68 | data_split="test"
69 | )
70 |
71 | print("Total train dataset size: {}".format(len(train_dataset)))
72 | print("Total val dataset size: {}".format(len(val_dataset)))
73 |
74 | train_loader = torch.utils.data.DataLoader(
75 | train_dataset,
76 | batch_size=args.train_batch,
77 | shuffle=True,
78 | num_workers=args.workers,
79 | pin_memory=True
80 | )
81 | val_loader = torch.utils.data.DataLoader(
82 | val_dataset,
83 | batch_size=args.test_batch,
84 | shuffle=False,
85 | num_workers=args.workers,
86 | pin_memory=True
87 | )
88 |
89 | if args.evaluate or args.resume:
90 | shape_net.load_checkpoint(
91 | model, os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar')
92 | )
93 | if args.evaluate:
94 | for params in model.invk_layers.parameters():
95 | params.requires_grad = False
96 |
97 | if args.evaluate:
98 | validate(val_loader, model, args=args)
99 | cprint('Eval All Done', 'yellow', attrs=['bold'])
100 | return 0
101 |
102 | model = torch.nn.DataParallel(model)
103 | print("\nUSING {} GPUs".format(torch.cuda.device_count()))
104 | scheduler = torch.optim.lr_scheduler.StepLR(
105 | optimizer, args.lr_decay_step, gamma=args.gamma,
106 | last_epoch=args.start_epoch
107 | )
108 | train_bone_len = []
109 | train_shape_l2 = []
110 | test_bone_len = []
111 | test_shape_l2 = []
112 | for epoch in range(args.start_epoch, args.epochs + 1):
113 | print('\nEpoch: %d' % (epoch + 1))
114 | for i in range(len(optimizer.param_groups)):
115 | print('group %d lr:' % i, optimizer.param_groups[i]['lr'])
116 | ############# trian for on epoch ###############
117 | t1, t2 = train(
118 | train_loader,
119 | model,
120 | criterion,
121 | optimizer,
122 | args=args,
123 | )
124 | ##################################################
125 | shape_net.save_checkpoint(
126 | {
127 | 'epoch': epoch + 1,
128 | 'state_dict': model.module.state_dict(),
129 | },
130 | checkpoint=args.checkpoint,
131 | filename='{}.pth.tar'.format(args.saved_prefix),
132 | snapshot=args.snapshot,
133 | is_best=False
134 | )
135 | t3, t4 = validate(val_loader, model, criterion, args)
136 | train_bone_len.append(t1)
137 | train_shape_l2.append(t2)
138 | test_bone_len.append(t3)
139 | test_shape_l2.append(t4)
140 | np.save('log/train_bone_len.npy', train_bone_len)
141 | np.save('log/test_bone_len.npy', test_bone_len)
142 | np.save('log/train_shape_l2.npy', train_shape_l2)
143 | np.save('log/test_shape_l2.npy', test_shape_l2)
144 | scheduler.step()
145 | cprint('All Done', 'yellow', attrs=['bold'])
146 | return 0 # end of main
147 |
148 |
149 | def validate(val_loader, model, criterion, args):
150 | am_shape_l2 = AverageMeter()
151 | am_bone_len = AverageMeter()
152 | model.eval()
153 | bar = Bar('\033[33m Eval \033[0m', max=len(val_loader))
154 | with torch.no_grad():
155 | for i, metas in enumerate(val_loader):
156 | results, targets, total_loss, losses = one_forward_pass(
157 | metas, model, criterion, args, train=True
158 | )
159 | am_shape_l2.update(losses['shape_reg'].item(), targets['batch_size'])
160 | am_bone_len.update(losses['bone_len'].item(), targets['batch_size'])
161 | bar.suffix = (
162 | '({batch}/{size}) '
163 | 't: {total:}s | '
164 | 'eta:{eta:}s | '
165 | 'lN: {lossLen:.5f} | '
166 | 'lL2: {lossL2:.5f} | '
167 | ).format(
168 | batch=i + 1,
169 | size=len(val_loader),
170 | total=bar.elapsed_td,
171 | eta=bar.eta_td,
172 | lossLen=am_bone_len.avg,
173 | lossL2=am_shape_l2.avg,
174 | )
175 | bar.next()
176 | bar.finish()
177 | return (am_bone_len.avg, am_shape_l2.avg)
178 |
179 |
180 | def train(train_loader, model, criterion, optimizer, args):
181 | batch_time = AverageMeter()
182 | data_time = AverageMeter()
183 | am_shape_l2 = AverageMeter()
184 | am_bone_len = AverageMeter()
185 |
186 | last = time.time()
187 | # switch to trian
188 | model.train()
189 | bar = Bar('\033[31m Train \033[0m', max=len(train_loader))
190 | for i, metas in enumerate(train_loader):
191 | data_time.update(time.time() - last)
192 | results, targets, total_loss, losses = one_forward_pass(
193 | metas, model, criterion, args, train=True
194 | )
195 | global steps
196 | steps += 1
197 | writer.add_scalar('loss', total_loss.item(), steps)
198 | am_shape_l2.update(losses['shape_reg'].item(), targets['batch_size'])
199 | am_bone_len.update(losses['bone_len'].item(), targets['batch_size'])
200 | ''' backward and step '''
201 | optimizer.zero_grad()
202 | total_loss.backward()
203 | optimizer.step()
204 |
205 | ''' progress '''
206 | batch_time.update(time.time() - last)
207 | last = time.time()
208 | bar.suffix = (
209 | '({batch}/{size}) '
210 | 'd: {data:.2f}s | '
211 | 'b: {bt:.2f}s | '
212 | 't: {total:}s | '
213 | 'eta:{eta:}s | '
214 | 'lN: {lossLen:.5f} | '
215 | 'lL2: {lossL2:.5f} | '
216 | ).format(
217 | batch=i + 1,
218 | size=len(train_loader),
219 | data=data_time.avg,
220 | bt=batch_time.avg,
221 | total=bar.elapsed_td,
222 | eta=bar.eta_td,
223 | lossLen=am_bone_len.avg,
224 | lossL2=am_shape_l2.avg,
225 | )
226 | bar.next()
227 | bar.finish()
228 | return (am_bone_len.avg, am_shape_l2.avg)
229 |
230 |
231 | def one_forward_pass(metas, model, criterion, args, train=True):
232 | ''' prepare targets '''
233 | rel_bone_len = metas['rel_bone_len'].float().to(device, non_blocking=True)
234 | targets = {
235 | 'batch_size': rel_bone_len.shape[0],
236 | 'rel_bone_len': rel_bone_len
237 | }
238 | ''' ---------------- Forward Pass ---------------- '''
239 | results = model(rel_bone_len)
240 | ''' ---------------- Forward End ---------------- '''
241 |
242 | total_loss = torch.Tensor([0]).cuda()
243 | losses = {}
244 | if not train:
245 | return results, targets, total_loss, losses
246 |
247 | ''' conpute losses '''
248 | total_loss, losses = criterion.compute_loss(results, targets)
249 | return results, targets, total_loss, losses
250 |
251 |
252 | if __name__ == '__main__':
253 | parser = argparse.ArgumentParser(
254 | description='PyTorch Train dl shape')
255 | # Miscs
256 | parser.add_argument(
257 | '-ckp',
258 | '--checkpoint',
259 | default='checkpoints',
260 | type=str,
261 | metavar='PATH',
262 | help='path to save checkpoint (default: checkpoint)'
263 | )
264 |
265 | parser.add_argument(
266 | '-dr',
267 | '--data_root',
268 | type=str,
269 | default='data',
270 | help='dataset root directory'
271 | )
272 |
273 | parser.add_argument(
274 | '-sp',
275 | '--saved_prefix',
276 | default='ckp_siknet_synth',
277 | type=str,
278 | metavar='PATH',
279 | help='path to save checkpoint (default: checkpoint)'
280 | )
281 | parser.add_argument(
282 | '--snapshot',
283 | default=1, type=int,
284 | help='save models for every #snapshot epochs (default: 1)'
285 | )
286 |
287 | parser.add_argument(
288 | '-e', '--evaluate',
289 | dest='evaluate',
290 | action='store_true',
291 | help='evaluate model on validation set'
292 | )
293 |
294 | parser.add_argument(
295 | '-r', '--resume',
296 | dest='resume',
297 | action='store_true',
298 | help='resume model on validation set'
299 | )
300 |
301 | # Training Parameters
302 | parser.add_argument(
303 | '-j', '--workers',
304 | default=8,
305 | type=int,
306 | metavar='N',
307 | help='number of data loading workers (default: 8)'
308 | )
309 | parser.add_argument(
310 | '--epochs',
311 | default=150,
312 | type=int,
313 | metavar='N',
314 | help='number of total epochs to run'
315 | )
316 | parser.add_argument(
317 | '-se', '--start_epoch',
318 | default=0,
319 | type=int,
320 | metavar='N',
321 | help='manual epoch number (useful on restarts)'
322 | )
323 | parser.add_argument(
324 | '-b', '--train_batch',
325 | default=1024,
326 | type=int,
327 | metavar='N',
328 | help='train batchsize'
329 | )
330 | parser.add_argument(
331 | '-tb', '--test_batch',
332 | default=512,
333 | type=int,
334 | metavar='N',
335 | help='test batchsize'
336 | )
337 |
338 | parser.add_argument(
339 | '-lr', '--learning-rate',
340 | default=1.0e-4,
341 | type=float,
342 | metavar='LR',
343 | help='initial learning rate'
344 | )
345 | parser.add_argument(
346 | "--lr_decay_step",
347 | default=40,
348 | type=int,
349 | help="Epochs after which to decay learning rate",
350 | )
351 | parser.add_argument(
352 | '--gamma',
353 | type=float,
354 | default=0.1,
355 | help='LR is multiplied by gamma on schedule.'
356 | )
357 | main(parser.parse_args())
358 |
--------------------------------------------------------------------------------
/utils/AIK.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | import numpy as np
3 | import transforms3d
4 |
5 | import config as cfg
6 |
7 | angels0 = np.zeros((1, 21))
8 |
9 |
10 | def to_dict(joints):
11 | temp_dict = dict()
12 | for i in range(21):
13 | temp_dict[i] = joints[:, [i]]
14 | return temp_dict
15 |
16 |
17 | def adaptive_IK(T_, P_):
18 | '''
19 | Computes pose parameters given template and predictions.
20 | We think the twist of hand bone could be omitted.
21 |
22 | :param T: template ,21*3
23 | :param P: target, 21*3
24 | :return: pose params.
25 | '''
26 |
27 | T = T_.copy().astype(np.float64)
28 | P = P_.copy().astype(np.float64)
29 |
30 | P = P.transpose(1, 0)
31 | T = T.transpose(1, 0)
32 |
33 | # to dict
34 | P = to_dict(P)
35 | T = to_dict(T)
36 |
37 | # some globals
38 | R = {}
39 | R_pa_k = {}
40 | q = {}
41 |
42 | q[0] = T[0] # in fact, q[0] = P[0] = T[0].
43 |
44 | # compute R0, here we think R0 is not only a Orthogonal matrix, but also a Rotation matrix.
45 | # you can refer to paper "Least-Squares Fitting of Two 3-D Point Sets. K. S. Arun; T. S. Huang; S. D. Blostein"
46 | # It is slightly different from https://github.com/Jeff-sjtu/HybrIK/blob/main/hybrik/utils/pose_utils.py#L4, in which R0 is regard as orthogonal matrix only.
47 | # Using their method might further boost accuracy.
48 | P_0 = np.concatenate([P[1] - P[0], P[5] - P[0],
49 | P[9] - P[0], P[13] - P[0],
50 | P[17] - P[0]], axis=-1)
51 | T_0 = np.concatenate([T[1] - T[0], T[5] - T[0],
52 | T[9] - T[0], T[13] - T[0],
53 | T[17] - T[0]], axis=-1)
54 | H = np.matmul(T_0, P_0.T)
55 |
56 | U, S, V_T = np.linalg.svd(H)
57 | V = V_T.T
58 | R0 = np.matmul(V, U.T)
59 |
60 | det0 = np.linalg.det(R0)
61 |
62 | if abs(det0 + 1) < 1e-6:
63 | V_ = V.copy()
64 |
65 | if (abs(S) < 1e-4).sum():
66 | V_[:, 2] = -V_[:, 2]
67 | R0 = np.matmul(V_, U.T)
68 |
69 | R[0] = R0
70 |
71 | # the bone from 1,5,9,13,17 to 0 has same rotations
72 | R[1] = R[0].copy()
73 | R[5] = R[0].copy()
74 | R[9] = R[0].copy()
75 | R[13] = R[0].copy()
76 | R[17] = R[0].copy()
77 |
78 | # compute rotation along kinematics
79 | for k in cfg.kinematic_tree:
80 | pa = cfg.SNAP_PARENT[k]
81 | pa_pa = cfg.SNAP_PARENT[pa]
82 | q[pa] = np.matmul(R[pa], (T[pa] - T[pa_pa])) + q[pa_pa]
83 | delta_p_k = np.matmul(np.linalg.inv(R[pa]), P[k] - q[pa])
84 | delta_p_k = delta_p_k.reshape((3,))
85 | delta_t_k = T[k] - T[pa]
86 | delta_t_k = delta_t_k.reshape((3,))
87 | temp_axis = np.cross(delta_t_k, delta_p_k)
88 | axis = temp_axis / (np.linalg.norm(temp_axis, axis=-1) + 1e-8)
89 | temp = (np.linalg.norm(delta_t_k, axis=0) + 1e-8) * (np.linalg.norm(delta_p_k, axis=0) + 1e-8)
90 | cos_alpha = np.dot(delta_t_k, delta_p_k) / temp
91 |
92 | alpha = np.arccos(cos_alpha)
93 |
94 | twist = delta_t_k
95 | D_sw = transforms3d.axangles.axangle2mat(axis=axis, angle=alpha, is_normalized=False)
96 | D_tw = transforms3d.axangles.axangle2mat(axis=twist, angle=angels0[:, k], is_normalized=False)
97 | R_pa_k[k] = np.matmul(D_sw, D_tw)
98 | R[k] = np.matmul(R[pa], R_pa_k[k])
99 |
100 | pose_R = np.zeros((1, 16, 3, 3))
101 | pose_R[0, 0] = R[0]
102 | for key in cfg.ID2ROT.keys():
103 | value = cfg.ID2ROT[key]
104 | pose_R[0, value] = R_pa_k[key]
105 |
106 | return pose_R
107 |
--------------------------------------------------------------------------------
/utils/LM.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | # import time
3 |
4 | import numpy as np
5 | import torch
6 | from manopth.manolayer import ManoLayer
7 |
8 | from utils import bone
9 |
10 |
11 | class LM_Solver():
12 | def __init__(self, num_Iter=500, th_beta=None, th_pose=None, lb_target=None,
13 | weight=0.01):
14 | self.count = 0
15 | # self.time_start = time.time()
16 | # self.time_in_mano = 0
17 | self.minimal_loss = 9999
18 | self.best_beta = np.zeros([10, 1])
19 | self.num_Iter = num_Iter
20 |
21 | self.th_beta = th_beta
22 | self.th_pose = th_pose
23 |
24 | self.beta = th_beta.numpy()
25 | self.pose = th_pose.numpy()
26 |
27 | self.mano_layer = ManoLayer(side="right",
28 | mano_root='mano/models', use_pca=False, flat_hand_mean=True)
29 |
30 | self.threshold_stop = 10 ** -13
31 | self.weight = weight
32 | self.residual_memory = []
33 |
34 | self.lb = np.zeros(21)
35 |
36 | _, self.joints = self.mano_layer(self.th_pose, self.th_beta)
37 | self.joints = self.joints.numpy().reshape(21, 3)
38 |
39 | self.lb_target = lb_target.reshape(15, 1)
40 | # self.test_time = 0
41 |
42 | def update(self, beta_):
43 | beta = beta_.copy()
44 | self.count += 1
45 | # now = time.time()
46 | my_th_beta = torch.from_numpy(beta).float().reshape(1, 10)
47 | _, joints = self.mano_layer(self.th_pose, my_th_beta)
48 | # self.time_in_mano = time.time() - now
49 |
50 | useful_lb = bone.caculate_length(joints, label="useful")
51 | lb_ref = useful_lb[6]
52 | return useful_lb, lb_ref
53 |
54 | def new_cal_ref_bone(self, _shape):
55 | # now = time.time()
56 | parent_index = [0,
57 | 0, 1, 2,
58 | 0, 4, 5,
59 | 0, 7, 8,
60 | 0, 10, 11,
61 | 0, 13, 14
62 | ]
63 | # index = [0,
64 | # 1, 2, 3, # index
65 | # 4, 5, 6, # middle
66 | # 7, 8, 9, # pinky
67 | # 10, 11, 12, # ring
68 | # 13, 14, 15] # thumb
69 | reoder_index = [
70 | 13, 14, 15,
71 | 1, 2, 3,
72 | 4, 5, 6,
73 | 10, 11, 12,
74 | 7, 8, 9]
75 | shape = torch.Tensor(_shape.reshape((-1, 10)))
76 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs,
77 | shape.transpose(1, 0)).permute(2, 0, 1) \
78 | + self.mano_layer.th_v_template
79 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped)
80 | temp1 = th_j.clone().detach()
81 | temp2 = th_j.clone().detach()[:, parent_index, :]
82 | result = temp1 - temp2
83 | result = torch.norm(result, dim=-1, keepdim=True)
84 | ref_len = result[:, [4]]
85 | result = result / ref_len
86 | # self.time_in_mano = time.time() - now
87 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy()
88 |
89 | def get_residual(self, beta_):
90 | beta = beta_.copy()
91 | lb, lb_ref = self.update(beta)
92 | lb = lb.reshape(45, 1)
93 | return lb / lb_ref - self.lb_target
94 |
95 | def get_count(self):
96 | return self.count
97 |
98 | def get_bones(self, beta_):
99 | beta = beta_.copy()
100 | lb, _ = self.update(beta)
101 | lb = lb.reshape(15, 1)
102 |
103 | return lb
104 |
105 | # Vectorization implementation
106 | def batch_get_l2_loss(self, beta_):
107 | weight = 1e-5
108 | beta = beta_.copy()
109 | temp = self.new_cal_ref_bone(beta)
110 | loss = np.transpose(temp)
111 | loss = np.linalg.norm(loss - self.lb_target, axis=0) ** 2 + \
112 | weight * np.linalg.norm(beta, axis=-1)
113 | return loss
114 |
115 | def new_get_derivative(self, beta_):
116 | # params: beta_ 10*1
117 | # return: 1*10
118 | beta = beta_.copy().reshape((1, 10))
119 | temp_shape = np.zeros((20, beta.shape[1])) # 20*10
120 | step = 0.01
121 | for t2 in range(10): # 位置
122 | t3 = 10 + t2
123 | temp_shape[t2] = beta.copy()
124 | temp_shape[t3] = beta.copy()
125 | temp_shape[t2, t2] += step
126 | temp_shape[t3, t2] -= step
127 |
128 | res = self.batch_get_l2_loss(temp_shape)
129 | d = res[0:10] - res[10:20] # 10*1
130 | d = d.reshape((1, 10)) / (2 * step)
131 | return d
132 |
133 | # LM algorithm
134 | def LM(self):
135 | u = 1e-2
136 | v = 1.5
137 | beta = self.beta.reshape(10, 1)
138 |
139 | out_n = 1
140 | # num_beta = np.shape(beta)[0] # the number of beta
141 | # calculating the init Jocobian matrix
142 | Jacobian = np.zeros([out_n, beta.shape[0]])
143 |
144 | last_update = 0
145 | last_loss = 0
146 | # self.test_time = 0
147 | for i in range(self.num_Iter):
148 | # loss = self.new_get_loss(beta)
149 | loss = self.batch_get_l2_loss(beta)
150 | loss = loss[0]
151 | if loss < self.minimal_loss:
152 | self.minimal_loss = loss
153 | self.best_beta = beta
154 |
155 | if abs(loss - last_loss) < self.threshold_stop:
156 | # self.time_total = time.time() - self.time_start
157 | return beta
158 |
159 | # for k in range(num_beta):
160 | # Jacobian[:, k] = self.get_derivative(beta, k)
161 | Jacobian = self.new_get_derivative(beta)
162 | jtj = np.matmul(Jacobian.T, Jacobian)
163 | jtj = jtj + u * np.eye(jtj.shape[0])
164 |
165 | update = last_loss - loss
166 | delta = (np.matmul(np.linalg.inv(jtj), Jacobian.T) * loss)
167 |
168 | beta -= delta
169 |
170 | if update > last_update and update > 0:
171 | u /= v
172 | else:
173 | u *= v
174 |
175 | last_update = update
176 | last_loss = loss
177 | self.residual_memory.append(loss)
178 |
179 | return beta
180 |
181 | def get_result(self):
182 | return self.residual_memory
183 |
--------------------------------------------------------------------------------
/utils/LM_new.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Hao Meng. All Rights Reserved.
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | from manopth.manolayer import ManoLayer
7 |
8 | from utils import bone
9 |
10 |
11 | class LM_Solver():
12 | def __init__(self, num_Iter=500, th_beta=None, th_pose=None, lb_target=None,
13 | weight=0.01,_mano_root='mano/models'):
14 | self.count = 0
15 | self.time_start = time.time()
16 | self.time_in_mano = 0
17 | self.minimal_loss = 9999
18 | self.best_beta = np.zeros([10, 1])
19 | self.num_Iter = num_Iter
20 |
21 | self.th_beta = th_beta
22 | self.th_pose = th_pose
23 |
24 | self.beta = th_beta.numpy()
25 | self.pose = th_pose.numpy()
26 |
27 | self.mano_layer = ManoLayer(side="right",
28 | mano_root='mano/models', use_pca=False, flat_hand_mean=True)
29 |
30 | self.threshold_stop = 10 ** -13
31 | self.weight = weight
32 | self.residual_memory = []
33 |
34 | self.lb = np.zeros(21)
35 |
36 | _, self.joints = self.mano_layer(self.th_pose, self.th_beta)
37 | self.joints = self.joints.numpy().reshape(21, 3)
38 |
39 | self.lb_target = lb_target.reshape(15, 1)
40 | self.test_time = 0
41 |
42 | def update_target(self, target):
43 | self.lb_target = target.copy().reshape(15, 1)
44 |
45 | def update(self, beta_):
46 | beta = beta_.copy()
47 | self.count += 1
48 | now = time.time()
49 | my_th_beta = torch.from_numpy(beta).float().reshape(1, 10)
50 | _, joints = self.mano_layer(self.th_pose, my_th_beta)
51 | self.time_in_mano = time.time() - now
52 |
53 | useful_lb = bone.caculate_length(joints, label="useful")
54 | lb_ref = useful_lb[6]
55 | return useful_lb, lb_ref
56 |
57 | def new_cal_ref_bone(self, _shape):
58 | now = time.time()
59 | parent_index = [0,
60 | 0, 1, 2,
61 | 0, 4, 5,
62 | 0, 7, 8,
63 | 0, 10, 11,
64 | 0, 13, 14
65 | ]
66 | index = [0,
67 | 1, 2, 3, # index
68 | 4, 5, 6, # middle
69 | 7, 8, 9, # pinky
70 | 10, 11, 12, # ring
71 | 13, 14, 15] # thumb
72 | reoder_index = [
73 | 13, 14, 15,
74 | 1, 2, 3,
75 | 4, 5, 6,
76 | 10, 11, 12,
77 | 7, 8, 9]
78 | shape = torch.Tensor(_shape.reshape((-1, 10)))
79 | th_v_shaped = torch.matmul(self.mano_layer.th_shapedirs,
80 | shape.transpose(1, 0)).permute(2, 0, 1) \
81 | + self.mano_layer.th_v_template
82 | th_j = torch.matmul(self.mano_layer.th_J_regressor, th_v_shaped)
83 | temp1 = th_j.clone().detach()
84 | temp2 = th_j.clone().detach()[:, parent_index, :]
85 | result = temp1 - temp2
86 | result = torch.norm(result, dim=-1, keepdim=True)
87 | ref_len = result[:, [4]]
88 | result = result / ref_len
89 | self.time_in_mano = time.time() - now
90 | return torch.squeeze(result, dim=-1)[:, reoder_index].cpu().numpy()
91 |
92 | def get_residual(self, beta_):
93 | beta = beta_.copy()
94 | lb, lb_ref = self.update(beta)
95 | lb = lb.reshape(45, 1)
96 | return lb / lb_ref - self.lb_target
97 |
98 | def get_count(self):
99 | return self.count
100 |
101 | def get_bones(self, beta_):
102 | beta = beta_.copy()
103 | lb, _ = self.update(beta)
104 | lb = lb.reshape(15, 1)
105 | return lb
106 |
107 | def get_loss(self, beta_):
108 |
109 | beta = beta_.copy()
110 |
111 | lb, lb_ref = self.update(beta)
112 | lb = lb.reshape(15, 1)
113 |
114 | loss = np.linalg.norm(lb / lb_ref - self.lb_target) ** 2 + \
115 | self.weight * np.linalg.norm(beta) ** 2
116 |
117 | return loss
118 |
119 | def new_get_loss(self, beta_):
120 | beta = beta_.copy()
121 | temp = self.new_cal_ref_bone(beta_)
122 | loss = temp.reshape((15, 1))
123 | loss = np.linalg.norm(loss - self.lb_target) ** 2 + \
124 | self.weight * np.linalg.norm(beta_)
125 | return loss
126 |
127 | def get_derivative(self, beta_, n):
128 |
129 | beta = beta_.copy()
130 | params1 = np.array(beta)
131 | params2 = np.array(beta)
132 | step = 0.01
133 | params1[n] += step
134 | params2[n] -= step
135 |
136 | res1 = self.new_get_loss(params1)
137 | res2 = self.new_get_loss(params2)
138 |
139 | d = (res1 - res2) / (2 * step)
140 |
141 | return d.ravel()
142 |
143 | def batch_new_get_loss(self, beta_):
144 | weight = 1e-5
145 | beta = beta_.copy()
146 | temp = self.new_cal_ref_bone(beta)
147 | loss = np.transpose(temp)
148 | loss = np.linalg.norm(loss - self.lb_target, axis=0) ** 2 + \
149 | weight * np.linalg.norm(beta, axis=-1)
150 | return loss
151 |
152 | def new_get_derivative(self, beta_):
153 | # params: beta_ 10*1
154 | # return: 1*10
155 | beta = beta_.copy().reshape((1, 10))
156 | temp_shape = np.zeros((20, beta.shape[1])) # 20*10
157 | step = 0.01
158 | for t2 in range(10): # 位置
159 | t3 = 10 + t2
160 | temp_shape[t2] = beta.copy()
161 | temp_shape[t3] = beta.copy()
162 | temp_shape[t2, t2] += step
163 | temp_shape[t3, t2] -= step
164 |
165 | res = self.batch_new_get_loss(temp_shape)
166 | d = res[0:10] - res[10:20] # 10*1
167 | d = d.reshape((1, 10)) / (2 * step)
168 | return d
169 |
170 | # LM algorithm
171 | def LM(self):
172 | u = 1e-2
173 | v = 1.5
174 | beta = self.beta.reshape(10, 1)
175 |
176 | out_n = 1
177 | num_beta = np.shape(beta)[0] # the number of beta
178 | # calculating the init Jocobian matrix
179 | Jacobian = np.zeros([out_n, beta.shape[0]])
180 |
181 | last_update = 0
182 | last_loss = 0
183 | self.test_time = 0
184 | for i in range(self.num_Iter):
185 | loss = self.new_get_loss(beta)
186 | if loss < self.minimal_loss:
187 | self.minimal_loss = loss
188 | self.best_beta = beta
189 |
190 | if abs(loss - last_loss) < self.threshold_stop:
191 | self.time_total = time.time() - self.time_start
192 | return beta
193 |
194 | # for k in range(num_beta):
195 | # Jacobian[:, k] = self.get_derivative(beta, k)
196 | Jacobian = self.new_get_derivative(beta)
197 | jtj = np.matmul(Jacobian.T, Jacobian)
198 | jtj = jtj + u * np.eye(jtj.shape[0])
199 |
200 | update = last_loss - loss
201 | delta = (np.matmul(np.linalg.inv(jtj), Jacobian.T) * loss)
202 |
203 | beta -= delta
204 |
205 | if update > last_update and update > 0:
206 | u /= v
207 | else:
208 | u *= v
209 |
210 | last_update = update
211 | last_loss = loss
212 | self.residual_memory.append(loss)
213 |
214 | return beta
215 |
216 | def get_result(self):
217 | return self.residual_memory
218 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/AIK.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/AIK.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/LM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/LM.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/align.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/align.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/bone.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/bone.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/func.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/func.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/handutils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/handutils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/heatmaputils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/heatmaputils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/imgutils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/imgutils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/vis.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/__pycache__/vis.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/align.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | def global_align(gtj0, prj0, key):
3 | gtj = gtj0.copy()
4 | prj = prj0.copy()
5 |
6 | if key in ["stb", "rhd"]:
7 | # gtj :B*21*3
8 | # prj :B*21*3
9 | root_idx = 9 # root
10 | ref_bone_link = [0, 9] # mid mcp
11 | pred_align = prj.copy()
12 | for i in range(prj.shape[0]):
13 |
14 | pred_ref_bone_len = np.linalg.norm(prj[i][ref_bone_link[0]] - prj[i][ref_bone_link[1]])
15 | gt_ref_bone_len = np.linalg.norm(gtj[i][ref_bone_link[0]] - gtj[i][ref_bone_link[1]])
16 | scale = gt_ref_bone_len / pred_ref_bone_len
17 |
18 | for j in range(21):
19 | pred_align[i][j] = gtj[i][root_idx] + scale * (prj[i][j] - prj[i][root_idx])
20 |
21 | return gtj, pred_align
22 |
23 | if key in ["do", "eo"]:
24 | # gtj :B*5*3
25 | # prj :B*5*3
26 |
27 | prj_ = prj.copy()[:, [4, 8, 12, 16, 20], :] # B*5*3
28 |
29 | gtj_valid = []
30 | prj_valid_align = []
31 |
32 | for i in range(prj_.shape[0]):
33 | # 5*3
34 | mask = ~(np.isnan(gtj[i][:, 0]))
35 | if mask.sum() < 2:
36 | continue
37 |
38 | prj_mask = prj_[i][mask] # m*3
39 | gtj_mask = gtj[i][mask] # m*3
40 |
41 | gtj_valid_center = np.mean(gtj_mask, 0)
42 | prj_valid_center = np.mean(prj_mask, 0)
43 |
44 | gtj_center_length = np.linalg.norm(gtj_mask - gtj_valid_center, axis=1).mean()
45 | prj_center_length = np.linalg.norm(prj_mask - prj_valid_center, axis=1).mean()
46 | scale = gtj_center_length / prj_center_length
47 |
48 | prj_valid_align_i = gtj_valid_center + scale * (prj_[i][mask] - prj_valid_center)
49 |
50 | gtj_valid.append(gtj_mask)
51 | prj_valid_align.append(prj_valid_align_i)
52 |
53 | return np.array(gtj_valid), np.array(prj_valid_align)
--------------------------------------------------------------------------------
/utils/bone.py:
--------------------------------------------------------------------------------
1 | import config as cfg
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def caculate_length(j3d_, label=None):
7 | if isinstance(j3d_, torch.Tensor):
8 | j3d = j3d_.clone()
9 | j3d = j3d.detach().cpu()
10 | j3d = j3d.numpy()
11 | else:
12 | j3d = j3d_.copy()
13 |
14 | if len(j3d.shape) != 2:
15 | j3d = j3d.squeeze()
16 |
17 | bone = [
18 | j3d[i] - j3d[cfg.SNAP_PARENT[i]]
19 | for i in range(21)
20 | ]
21 | bone_len = np.linalg.norm(
22 | bone, ord=2, axis=-1, keepdims=True # 21*1
23 | )
24 |
25 | if label == "full":
26 | return bone_len
27 | elif label == "useful":
28 | return bone_len[cfg.USEFUL_BONE]
29 | else:
30 | raise ValueError("{} not in ['full'|'useful']".format(label))
31 |
--------------------------------------------------------------------------------
/utils/eval/__pycache__/evalutils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/eval/__pycache__/evalutils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/eval/__pycache__/zimeval.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/handopen/Minimal-Hand-pytorch/c078433136b20c1e4fb3f353254edd409ea65f1b/utils/eval/__pycache__/zimeval.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/eval/evalutils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from random import randint
5 |
6 | from utils.misc import *
7 | from utils.heatmaputils import get_heatmap_pred
8 |
9 |
10 | class AverageMeter(object):
11 | """Computes and stores the average and current value"""
12 |
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0.
18 | self.avg = 0.
19 | self.sum = 0.
20 | self.count = 0.
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def calc_dists(preds, target, normalize, mask):
30 | preds = preds.float()
31 | target = target.float()
32 | dists = torch.zeros(preds.size(1), preds.size(0)) # (njoint, B)
33 | for b in range(preds.size(0)):
34 | for j in range(preds.size(1)):
35 | if mask[b][j] == 0:
36 | dists[j, b] = -1
37 | elif target[b, j, 0] < 1 or target[b, j, 1] < 1:
38 | dists[j, b] = -1
39 | else:
40 | dists[j, b] = torch.dist(preds[b, j, :], target[b, j, :]) / normalize[b]
41 |
42 | return dists
43 |
44 |
45 | def dist_acc(dist, thr=0.5):
46 | """ Return percentage below threshold while ignoring values with a -1 """
47 | dist = dist[dist != -1]
48 | if len(dist) > 0:
49 | return 1.0 * (dist < thr).sum().item() / len(dist)
50 | else:
51 | return -1
52 |
53 |
54 | def accuracy_heatmap(output, target, mask, thr=0.5):
55 | """ Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
56 | First to be returned is average accuracy across 'idxs', Second is individual accuracies
57 | """
58 | preds = get_heatmap_pred(output).float() # (B, njoint, 2)
59 | gts = get_heatmap_pred(target).float()
60 | norm = torch.ones(preds.size(0)) * output.size(3) / 10.0 # (B, ), all 6.4:(1/10 of heatmap side)
61 | dists = calc_dists(preds, gts, norm, mask) # (njoint, B)
62 |
63 | acc = torch.zeros(mask.size(1))
64 | avg_acc = 0
65 | cnt = 0
66 |
67 | for i in range(mask.size(1)): # njoint
68 | acc[i] = dist_acc(dists[i], thr)
69 | if acc[i] >= 0:
70 | avg_acc += acc[i]
71 | cnt += 1
72 |
73 | if cnt != 0:
74 | avg_acc /= cnt
75 |
76 | return avg_acc, acc
77 |
--------------------------------------------------------------------------------
/utils/eval/zimeval.py:
--------------------------------------------------------------------------------
1 | # ColorHandPose3DNetwork - Network for estimating 3D Hand Pose from a single RGB Image
2 | # Copyright (C) 2017 Christian Zimmermann
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 2 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 |
17 | import numpy as np
18 | import torch
19 |
20 |
21 | class EvalUtil:
22 | """ Util class for evaluation networks.
23 | """
24 |
25 | def __init__(self, num_kp=21):
26 | # init empty data storage
27 | self.data = list()
28 | self.num_kp = num_kp
29 | for _ in range(num_kp):
30 | self.data.append(list())
31 |
32 | def feed(self, keypoint_gt, keypoint_pred, keypoint_vis=None):
33 | """
34 | Used to feed data to the class.
35 | Stores the euclidean distance between gt and pred, when it is visible.
36 | """
37 | if isinstance(keypoint_gt, torch.Tensor):
38 | keypoint_gt = keypoint_gt.detach().cpu()
39 | keypoint_gt = keypoint_gt.numpy()
40 | if isinstance(keypoint_pred, torch.Tensor):
41 | keypoint_pred = keypoint_pred.detach().cpu()
42 | keypoint_pred = keypoint_pred.numpy()
43 | keypoint_gt = np.squeeze(keypoint_gt)
44 | keypoint_pred = np.squeeze(keypoint_pred)
45 |
46 | if keypoint_vis is None:
47 | keypoint_vis = np.ones_like(keypoint_gt[:, 0])
48 | keypoint_vis = np.squeeze(keypoint_vis).astype("bool")
49 |
50 | assert len(keypoint_gt.shape) == 2
51 | assert len(keypoint_pred.shape) == 2
52 | assert len(keypoint_vis.shape) == 1
53 |
54 | # calc euclidean distance
55 | diff = keypoint_gt - keypoint_pred
56 | euclidean_dist = np.sqrt(np.sum(np.square(diff), axis=1))
57 |
58 | num_kp = keypoint_gt.shape[0]
59 | for i in range(num_kp):
60 | if keypoint_vis[i]:
61 | self.data[i].append(euclidean_dist[i])
62 |
63 | def _get_pck(self, kp_id, threshold):
64 | """ Returns pck for one keypoint for the given threshold. """
65 | if len(self.data[kp_id]) == 0:
66 | return None
67 |
68 | data = np.array(self.data[kp_id])
69 | pck = np.mean((data <= threshold).astype("float"))
70 | return pck
71 |
72 | def get_pck_all(self, threshold):
73 | pckall = []
74 | for kp_id in range(self.num_kp):
75 | pck = self._get_pck(kp_id, threshold)
76 | pckall.append(pck)
77 | pckall = np.mean(np.array(pckall))
78 | return pckall
79 |
80 | def _get_epe(self, kp_id):
81 | """ Returns end point error for one keypoint. """
82 | if len(self.data[kp_id]) == 0:
83 | return None, None
84 |
85 | data = np.array(self.data[kp_id])
86 | epe_mean = np.mean(data)
87 | epe_median = np.median(data)
88 | return epe_mean, epe_median
89 |
90 | def get_measures(self, val_min, val_max, steps):
91 | """ Outputs the average mean and median error as well as the pck score. """
92 | thresholds = np.linspace(val_min, val_max, steps)
93 | thresholds = np.array(thresholds)
94 | norm_factor = np.trapz(np.ones_like(thresholds), thresholds)
95 |
96 | # init mean measures
97 | epe_mean_all = list()
98 | epe_median_all = list()
99 | auc_all = list()
100 | pck_curve_all = list()
101 |
102 | # Create one plot for each part
103 | for part_id in range(self.num_kp):
104 | # mean/median error
105 | mean, median = self._get_epe(part_id)
106 |
107 | if mean is None:
108 | # there was no valid measurement for this keypoint
109 | continue
110 |
111 | epe_mean_all.append(mean)
112 | epe_median_all.append(median)
113 |
114 | # pck/auc
115 | pck_curve = list()
116 | for t in thresholds:
117 | pck = self._get_pck(part_id, t)
118 | pck_curve.append(pck)
119 |
120 | pck_curve = np.array(pck_curve)
121 | pck_curve_all.append(pck_curve)
122 | auc = np.trapz(pck_curve, thresholds)
123 | auc /= norm_factor
124 | auc_all.append(auc)
125 | # Display error per keypoint
126 | epe_mean_joint = epe_mean_all
127 | epe_mean_all = np.mean(np.array(epe_mean_all))
128 | epe_median_all = np.mean(np.array(epe_median_all))
129 | auc_all = np.mean(np.array(auc_all))
130 | pck_curve_all = np.mean(np.array(pck_curve_all), axis=0) # mean only over keypoints
131 |
132 | return (
133 | epe_mean_all,
134 | epe_mean_joint,
135 | epe_median_all,
136 | auc_all,
137 | pck_curve_all,
138 | thresholds,
139 | )
140 |
141 | # return {
142 | # 'epe_mean_all': epe_mean_all,
143 | # 'epe_mean_joint': epe_mean_joint,
144 | # "epe_median_all": epe_median_all,
145 | # "auc_all": auc_all,
146 | # "pck_curve_all": pck_curve_all,
147 | # "thresholds": thresholds
148 | # }
149 |
--------------------------------------------------------------------------------
/utils/func.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms.functional import *
2 |
3 |
4 | def batch_denormalize(tensor, mean, std, inplace=False):
5 | """Normalize a tensor image with mean and standard deviation.
6 |
7 | .. note::
8 | This transform acts out_testset of place by default, i.e., it does not mutates the input tensor.
9 |
10 | See :class:`~torchvision.transforms.Normalize` for more details.
11 |
12 | Args:
13 | tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized.
14 | mean (sequence): Sequence of means for each channel.
15 | std (sequence): Sequence of standard deviations for each channel.
16 | inplace(bool,optional): Bool to make this operation inplace.
17 |
18 | Returns:
19 | Tensor: Normalized Tensor image.
20 | """
21 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4:
22 | raise TypeError('invalid tensor or tensor channel is not BCHW')
23 |
24 | if not inplace:
25 | tensor = tensor.clone()
26 |
27 | dtype = tensor.dtype
28 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
29 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
30 | tensor.mul_(std[None, :, None, None]).sub_(-1 * mean[None, :, None, None])
31 | return tensor
32 |
33 |
34 | def to_numpy(tensor):
35 | if torch.is_tensor(tensor):
36 | return tensor.detach().cpu().numpy()
37 | elif type(tensor).__module__ != 'numpy':
38 | raise ValueError("Cannot convert {} to numpy array"
39 | .format(type(tensor)))
40 | else:
41 | return tensor
42 |
43 |
44 | def bhwc_2_bchw(tensor):
45 | """
46 | :param x: torch tensor, B x H x W x C
47 | :return: torch tensor, B x C x H x W
48 | """
49 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4:
50 | raise TypeError('invalid tensor or tensor channel is not BCHW')
51 | return tensor.unsqueeze(1).transpose(1, -1).squeeze(-1)
52 |
53 |
54 | def bchw_2_bhwc(tensor):
55 | """
56 | :param x: torch tensor, B x C x H x W
57 | :return: torch tensor, B x H x W x C
58 | """
59 | if not torch.is_tensor(tensor) or tensor.ndimension() != 4:
60 | raise TypeError('invalid tensor or tensor channel is not BCHW')
61 | return tensor.unsqueeze(-1).transpose(1, -1).squeeze(1)
62 |
63 | def initiate(label=None):
64 | if label == "zero":
65 | shape = torch.zeros(10).unsqueeze(0)
66 | pose = torch.zeros(48).unsqueeze(0)
67 | elif label == "uniform":
68 | shape = torch.from_numpy(np.random.normal(size=[1, 10])).float()
69 | pose = torch.from_numpy(np.random.normal(size=[1, 48])).float()
70 | elif label == "01":
71 | shape = torch.rand(1, 10)
72 | pose = torch.rand(1, 48)
73 | else:
74 | raise ValueError("{} not in ['zero'|'uniform'|'01']".format(label))
75 | return pose, shape
76 |
--------------------------------------------------------------------------------
/utils/heatmaputils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Lixin YANG, Jiasen Li. All Rights Reserved.
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def gen_heatmap(img, pt, sigma):
7 | """generate heatmap based on pt coord.
8 |
9 | :param img: original heatmap, zeros
10 | :type img: np (H,W) float32
11 | :param pt: keypoint coord.
12 | :type pt: np (2,) int32
13 | :param sigma: guassian sigma
14 | :type sigma: float
15 | :return
16 | - generated heatmap, np (H, W) each pixel values id a probability
17 | - flag 0 or 1: indicate wheather this heatmap is valid(1)
18 |
19 | """
20 |
21 | pt = pt.astype(np.int32)
22 | # Check that any part of the gaussian is in-bounds
23 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
24 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
25 | if (
26 | ul[0] >= img.shape[1]
27 | or ul[1] >= img.shape[0]
28 | or br[0] < 0
29 | or br[1] < 0
30 | ):
31 | # If not, just return the image as is
32 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
33 | return img, 0
34 |
35 | # Generate gaussian
36 | size = 6 * sigma + 1
37 | x = np.arange(0, size, 1, float)
38 | y = x[:, np.newaxis]
39 | x0 = y0 = size // 2
40 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
41 | # Usable gaussian range
42 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
43 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
44 | # Image range
45 | img_x = max(0, ul[0]), min(br[0], img.shape[1])
46 | img_y = max(0, ul[1]), min(br[1], img.shape[0])
47 |
48 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
49 | return img, 1
50 |
51 |
52 | def get_heatmap_pred(heatmaps):
53 | """ get predictions from heatmaps in torch Tensor
54 | return type: torch.LongTensor
55 | """
56 | assert heatmaps.dim() == 4, 'Score maps should be 4-dim (B, nJoints, H, W)'
57 | maxval, idx = torch.max(heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1), 2)
58 |
59 | maxval = maxval.view(heatmaps.size(0), heatmaps.size(1), 1)
60 | idx = idx.view(heatmaps.size(0), heatmaps.size(1), 1)
61 |
62 | preds = idx.repeat(1, 1, 2).float() # (B, njoint, 2)
63 |
64 | preds[:, :, 0] = (preds[:, :, 0]) % heatmaps.size(3) # + 1
65 | preds[:, :, 1] = torch.floor((preds[:, :, 1]) / heatmaps.size(3)) # + 1
66 |
67 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
68 | preds *= pred_mask
69 | return preds
70 |
--------------------------------------------------------------------------------
/utils/imgutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy as np
4 | import random
5 | import torchvision
6 | import utils.func as func
7 | import config as cfg
8 |
9 |
10 | def get_color_params(brightness=0, contrast=0, saturation=0, hue=0):
11 | if brightness > 0:
12 | brightness_factor = random.uniform(
13 | max(0, 1 - brightness), 1 + brightness)
14 | else:
15 | brightness_factor = None
16 |
17 | if contrast > 0:
18 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
19 | else:
20 | contrast_factor = None
21 |
22 | if saturation > 0:
23 | saturation_factor = random.uniform(
24 | max(0, 1 - saturation), 1 + saturation)
25 | else:
26 | saturation_factor = None
27 |
28 | if hue > 0:
29 | hue_factor = random.uniform(-hue, hue)
30 | else:
31 | hue_factor = None
32 | return brightness_factor, contrast_factor, saturation_factor, hue_factor
33 |
34 |
35 | def color_jitter(img, brightness=0, contrast=0, saturation=0, hue=0):
36 | brightness, contrast, saturation, hue = get_color_params(
37 | brightness=brightness,
38 | contrast=contrast,
39 | saturation=saturation,
40 | hue=hue)
41 |
42 | # Create img transform function sequence
43 | img_transforms = []
44 | if brightness is not None:
45 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
46 | if saturation is not None:
47 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
48 | if hue is not None:
49 | img_transforms.append(
50 | lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
51 | if contrast is not None:
52 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
53 | random.shuffle(img_transforms)
54 |
55 | jittered_img = img
56 | for func in img_transforms:
57 | jittered_img = func(jittered_img)
58 | return jittered_img
59 |
60 |
61 | def batch_with_dep(clrs, deps):
62 | clrs = func.to_numpy(clrs)
63 | if clrs.dtype is not np.uint8:
64 | clrs = (clrs * 255).astype(np.uint8)
65 | assert len(deps.shape) == 4, "deps should have shape (B, 1, H, W)"
66 | deps = func.to_numpy(deps)
67 | deps = deps.swapaxes(1, 2).swapaxes(2, 3)
68 | deps = deps.repeat(3, axis=3)
69 | if deps.dtype is not np.uint8:
70 | deps = (deps * 255).astype(np.uint8)
71 |
72 | batch_size = clrs.shape[0]
73 |
74 | alpha = 0.6
75 | beta = 0.9
76 | gamma = 0
77 |
78 | batch = []
79 | for i in range(16):
80 | if i >= batch_size:
81 | batch.append(np.zeros((64, 64, 3)).astype(np.uint8))
82 | continue
83 | clr = clrs[i]
84 | clr = cv2.resize(clr, (64, 64))
85 | dep = deps[i]
86 | dep_img = cv2.addWeighted(clr, alpha, dep, beta, gamma)
87 | batch.append(dep_img)
88 |
89 | resu = []
90 | for i in range(4):
91 | resu.append(np.concatenate(batch[i * 4: i * 4 + 4], axis=1))
92 | resu = np.concatenate(resu)
93 | return resu
94 |
95 |
96 | def batch_with_joint(clrs, uvds):
97 | clrs = func.to_numpy(clrs)
98 | if clrs.dtype is not np.uint8:
99 | clrs = (clrs * 255).astype(np.uint8)
100 | uvds = func.to_numpy(uvds)
101 |
102 | batch_size = clrs.shape[0]
103 |
104 | batch = []
105 | for i in range(16):
106 | if i >= batch_size:
107 | batch.append(np.zeros((256, 256, 3)).astype(np.uint8))
108 | continue
109 | clr = clrs[i]
110 | uv = (np.array(uvds[i][:, :2]) * clr.shape[0]).astype(np.uint8) # (256)
111 | clr = draw_hand_skeloten(clr, uv, cfg.SNAP_BONES)
112 | batch.append(clr)
113 |
114 | resu = []
115 | for i in range(4):
116 | resu.append(np.concatenate(batch[i * 4: i * 4 + 4], axis=1))
117 | resu = np.concatenate(resu)
118 | return resu
119 |
120 |
121 | def draw_hand_skeloten(clr, uv, bone_links, colors=cfg.JOINT_COLORS):
122 | for i in range(len(bone_links)):
123 | bone = bone_links[i]
124 | for j in bone:
125 | cv2.circle(clr, tuple(uv[j]), 4, colors[i], -1)
126 | for j, nj in zip(bone[:-1], bone[1:]):
127 | cv2.line(clr, tuple(uv[j]), tuple(uv[nj]), colors[i], 2)
128 | return clr
129 |
130 |
131 | def batch_with_heatmap(
132 | inputs,
133 | heatmaps,
134 | num_rows=2,
135 | parts_to_show=None,
136 | n_in_batch=1,
137 | ):
138 | # inputs = func.to_numpy(inputs * 255) # 0~1 -> 0 ~255
139 | heatmaps = func.to_numpy(heatmaps)
140 | batch_img = []
141 | for n in range(min(inputs.shape[0], n_in_batch)):
142 | inp = inputs[n]
143 | batch_img.append(
144 | sample_with_heatmap(
145 | inp,
146 | heatmaps[n],
147 | num_rows=num_rows,
148 | parts_to_show=parts_to_show
149 | )
150 | )
151 | resu = np.concatenate(batch_img)
152 | return resu
153 |
154 |
155 | def sample_with_heatmap(img, heatmap, num_rows=2, parts_to_show=None):
156 | if parts_to_show is None:
157 | parts_to_show = np.arange(heatmap.shape[0]) # 21
158 |
159 | # Generate a single image to display input/output pair
160 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
161 | size = img.shape[0] // num_rows
162 |
163 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
164 | full_img[:img.shape[0], :img.shape[1]] = img
165 |
166 | inp_small = cv2.resize(img, (size, size))
167 |
168 | # Set up heatmap display for each part
169 | for i, part in enumerate(parts_to_show):
170 | part_idx = part
171 | out_resized = cv2.resize(heatmap[part_idx], (size, size))
172 | out_resized = out_resized.astype(float)
173 | out_img = inp_small.copy() * .4
174 | color_hm = color_heatmap(out_resized)
175 | out_img += color_hm * .6
176 |
177 | col_offset = (i % num_cols + num_rows) * size
178 | row_offset = (i // num_cols) * size
179 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img
180 |
181 | return full_img
182 |
183 |
184 | def color_heatmap(x):
185 | color = np.zeros((x.shape[0], x.shape[1], 3))
186 | color[:, :, 0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
187 | color[:, :, 1] = gauss(x, 1, .5, .3)
188 | color[:, :, 2] = gauss(x, 1, .2, .3)
189 | color[color > 1] = 1
190 | color = (color * 255).astype(np.uint8)
191 | return color
192 |
193 |
194 | def gauss(x, a, b, c, d=0):
195 | return a * np.exp(-(x - b) ** 2 / (2 * c ** 2)) + d
196 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import numpy as np
5 | import scipy.io
6 | import torch
7 | from termcolor import colored, cprint
8 |
9 | import utils.func as func
10 | import copy
11 |
12 |
13 | def print_args(args):
14 | opts = vars(args)
15 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow')
16 | for k, v in sorted(opts.items()):
17 | print("{:>30} : {}".format(k, v))
18 | cprint("{:>30} Options {}".format("=" * 15, "=" * 15), 'yellow')
19 |
20 |
21 | def param_count(net):
22 | return sum(p.numel() for p in net.parameters()) / 1e6
23 |
24 |
25 |
26 |
27 | def out_loss_auc(
28 | loss_all_, auc_all_, acc_hm_all_, outpath
29 | ):
30 | loss_all = copy.deepcopy(loss_all_)
31 | acc_hm_all = copy.deepcopy(acc_hm_all_)
32 | auc_all = copy.deepcopy(auc_all_)
33 |
34 | for k, l in zip(loss_all.keys(), loss_all.values()):
35 | np.save(os.path.join(outpath, "{}.npy".format(k)), np.vstack((np.arange(1, len(l) + 1), np.array(l))).T)
36 |
37 | if len(acc_hm_all):
38 | for key ,value in acc_hm_all.items():
39 | acc_hm_all[key]=np.array(value)
40 | np.save(os.path.join(outpath, "acc_hm_all.npy"), acc_hm_all)
41 |
42 |
43 | if len(auc_all):
44 | for key ,value in auc_all.items():
45 | auc_all[key]=np.array(value)
46 | np.save(os.path.join(outpath, "auc_all.npy"), np.array(auc_all))
47 |
48 |
49 | def saveloss(d):
50 | for k, v in zip(d.keys(), d.values()):
51 | mat = np.array(v)
52 | np.save(os.path.join("losses", "{}.npy".format(k)), mat)
53 |
54 |
55 | def save_checkpoint(
56 | state,
57 | checkpoint='checkpoint',
58 | filename='checkpoint.pth',
59 | snapshot=None,
60 | # is_best=False
61 | is_best=None
62 | ):
63 | # preds = to_numpy(preds)
64 | filepath = os.path.join(checkpoint, filename)
65 | fileprefix = filename.split('.')[0]
66 | # torch.save(state, filepath)
67 | torch.save(state['model'].state_dict(), filepath)
68 |
69 | if snapshot and state['epoch'] % snapshot == 0:
70 | shutil.copyfile(
71 | filepath,
72 | os.path.join(
73 | checkpoint,
74 | '{}_{}.pth'.format(fileprefix, state['epoch'])
75 | )
76 | )
77 |
78 | [auc, best_acc] = is_best
79 |
80 | for key in auc.keys():
81 | if auc[key] > best_acc[key]:
82 | shutil.copyfile(
83 | filepath,
84 | os.path.join(
85 | checkpoint,
86 | '{}_{}best.pth'.format(fileprefix, key)
87 | )
88 | )
89 |
90 |
91 | # def load_checkpoint(model, checkpoint):
92 | # name = checkpoint
93 | # checkpoint = torch.load(name)
94 | # pretrain_dict = clean_state_dict(checkpoint['state_dict'])
95 | # model_state = model.state_dict()
96 | # state = {}
97 | # for k, v in pretrain_dict.items():
98 | # if k in model_state:
99 | # state[k] = v
100 | # else:
101 | # print(k, ' is NOT in current model')
102 | # model_state.update(state)
103 | # model.load_state_dict(model_state)
104 | # print(colored('loaded {}'.format(name), 'cyan'))
105 |
106 | def load_checkpoint(model, checkpoint):
107 | name = checkpoint
108 | checkpoint = torch.load(name)
109 | pretrain_dict = clean_state_dict(checkpoint['state_dict'])
110 | model_state = model.state_dict()
111 | state = {}
112 | for k, v in pretrain_dict.items():
113 | if k in model_state:
114 | state[k] = v
115 | else:
116 | print(k, ' is NOT in current model')
117 | model_state.update(state)
118 | model.load_state_dict(model_state)
119 | print(colored('loaded {}'.format(name), 'cyan'))
120 |
121 |
122 | def clean_state_dict(state_dict):
123 | """save a cleaned version of model without dict and DataParallel
124 |
125 | Arguments:
126 | state_dict {collections.OrderedDict} -- [description]
127 |
128 | Returns:
129 | clean_model {collections.OrderedDict} -- [description]
130 | """
131 |
132 | clean_model = state_dict
133 | # create new OrderedDict that does not contain `module.`
134 | from collections import OrderedDict
135 | clean_model = OrderedDict()
136 | if any(key.startswith('module') for key in state_dict):
137 | for k, v in state_dict.items():
138 | name = k[7:] # remove `module.`
139 | clean_model[name] = v
140 | else:
141 | return state_dict
142 |
143 | return clean_model
144 |
145 |
146 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
147 | preds = func.to_numpy(preds)
148 | filepath = os.path.join(checkpoint, filename)
149 | scipy.io.savemat(filepath, mdict={'preds': preds})
150 |
151 |
152 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
153 | """Sets the learning rate to the initial LR decayed by schedule"""
154 | if epoch in schedule:
155 | lr *= gamma
156 | print("adjust learning rate to: %.3e" % lr)
157 | for param_group in optimizer.param_groups:
158 | param_group['lr'] = lr
159 | return lr
160 |
161 |
162 | def adjust_learning_rate_in_group(optimizer, group_id, epoch, lr, schedule, gamma):
163 | """Sets the learning rate to the initial LR decayed by schedule"""
164 | if epoch in schedule:
165 | lr *= gamma
166 | print("adjust learning rate of group %d to: %.3e" % (group_id, lr))
167 | optimizer.param_groups[group_id]['lr'] = lr
168 | return lr
169 |
170 |
171 | def resume_learning_rate(optimizer, epoch, lr, schedule, gamma):
172 | for decay_id in schedule:
173 | if epoch > decay_id:
174 | lr *= gamma
175 | print("adjust learning rate to: %.3e" % lr)
176 | for param_group in optimizer.param_groups:
177 | param_group['lr'] = lr
178 | return lr
179 |
180 |
181 | def resume_learning_rate_in_group(optimizer, group_id, epoch, lr, schedule, gamma):
182 | for decay_id in schedule:
183 | if epoch > decay_id:
184 | lr *= gamma
185 | print("adjust learning rate of group %d to: %.3e" % (group_id, lr))
186 | optimizer.param_groups[group_id]['lr'] = lr
187 | return lr
188 |
--------------------------------------------------------------------------------
/utils/smoother.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LowPassFilter:
5 | def __init__(self):
6 | self.prev_raw_value = None
7 | self.prev_filtered_value = None
8 |
9 | def process(self, value, alpha):
10 | if self.prev_raw_value is None:
11 | s = value
12 | else:
13 | s = alpha * value + (1.0 - alpha) * self.prev_filtered_value
14 | self.prev_raw_value = value
15 | self.prev_filtered_value = s
16 | return s
17 |
18 |
19 | class OneEuroFilter:
20 | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
21 | self.freq = freq
22 | self.mincutoff = mincutoff
23 | self.beta = beta
24 | self.dcutoff = dcutoff
25 | self.x_filter = LowPassFilter()
26 | self.dx_filter = LowPassFilter()
27 |
28 | def compute_alpha(self, cutoff):
29 | te = 1.0 / self.freq
30 | tau = 1.0 / (2 * np.pi * cutoff)
31 | return 1.0 / (1.0 + tau / te)
32 |
33 | def process(self, x):
34 | prev_x = self.x_filter.prev_raw_value
35 | dx = 0.0 if prev_x is None else (x - prev_x) * self.freq
36 | edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff))
37 | cutoff = self.mincutoff + self.beta * np.abs(edx)
38 | return self.x_filter.process(x, self.compute_alpha(cutoff))
39 |
40 |
41 | if __name__ == '__main__':
42 | fliter = OneEuroFilter(4.0, 0.0)
43 | noise = 0.01 * np.random.rand(1000)
44 | x = np.linspace(0, 1, 1000)
45 | X = x + noise
46 | import matplotlib.pyplot as plt
47 |
48 | plt.plot(x)
49 | plt.plot(X)
50 | y = np.zeros((1000,))
51 | for i in range(1000):
52 | y[i] = fliter.process(x[i])
53 | plt.plot(y)
54 | plt.draw()
55 | plt.show()
56 |
--------------------------------------------------------------------------------
/utils/vis.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def plot3d(joints_,ax, title=None):
5 | joints = joints_.copy()
6 | ax.plot(joints[:, 0], joints[:, 1], joints[:, 2], 'yo', label='keypoint')
7 |
8 | ax.plot(joints[:5, 0], joints[:5, 1],
9 | joints[:5, 2],
10 | 'r',
11 | label='thumb')
12 |
13 | ax.plot(joints[[0, 5, 6, 7, 8, ], 0], joints[[0, 5, 6, 7, 8, ], 1],
14 | joints[[0, 5, 6, 7, 8, ], 2],
15 | 'b',
16 | label='index')
17 | ax.plot(joints[[0, 9, 10, 11, 12, ], 0], joints[[0, 9, 10, 11, 12], 1],
18 | joints[[0, 9, 10, 11, 12], 2],
19 | 'b',
20 | label='middle')
21 | ax.plot(joints[[0, 13, 14, 15, 16], 0], joints[[0, 13, 14, 15, 16], 1],
22 | joints[[0, 13, 14, 15, 16], 2],
23 | 'b',
24 | label='ring')
25 | ax.plot(joints[[0, 17, 18, 19, 20], 0], joints[[0, 17, 18, 19, 20], 1],
26 | joints[[0, 17, 18, 19, 20], 2],
27 | 'b',
28 | label='pinky')
29 | # snap convention
30 | ax.plot(joints[4][0], joints[4][1], joints[4][2], 'rD', label='thumb')
31 | ax.plot(joints[8][0], joints[8][1], joints[8][2], 'ro', label='index')
32 | ax.plot(joints[12][0], joints[12][1], joints[12][2], 'ro', label='middle')
33 | ax.plot(joints[16][0], joints[16][1], joints[16][2], 'ro', label='ring')
34 | ax.plot(joints[20][0], joints[20][1], joints[20][2], 'ro', label='pinky')
35 | # plt.plot(joints [1:, 0], joints [1:, 1], joints [1:, 2], 'o')
36 | ax.set_xlabel('x')
37 | ax.set_ylabel('y')
38 | ax.set_zlabel('z')
39 | ax.set_xlim(xmin=-1.0,xmax=1.0)
40 | ax.set_ylim(ymin=-1.0,ymax=1.0)
41 | ax.set_zlim(zmin=-1.0,zmax=1.0)
42 | # plt.legend()
43 | # ax.view_init(330, 110)
44 | ax.view_init(-90, -90)
45 | return ax
46 |
47 |
48 | def multi_plot3d(jointss_, title=None):
49 | jointss = jointss_.copy()
50 | fig = plt.figure(figsize=[50, 50])
51 |
52 | ax = fig.add_subplot(111, projection='3d')
53 |
54 | colors = ['b', 'r', "g"]
55 |
56 | for i in range(len(jointss)):
57 | joints = jointss[i]
58 |
59 | plt.plot(joints[:, 0], joints[:, 1], joints[:, 2], 'yo')
60 |
61 | plt.plot(joints[:5, 0], joints[:5, 1],
62 | joints[:5, 2],
63 | colors[i],
64 | )
65 |
66 | plt.plot(joints[[0, 5, 6, 7, 8, ], 0], joints[[0, 5, 6, 7, 8, ], 1],
67 | joints[[0, 5, 6, 7, 8, ], 2],
68 | colors[i],
69 | )
70 | plt.plot(joints[[0, 9, 10, 11, 12, ], 0], joints[[0, 9, 10, 11, 12], 1],
71 | joints[[0, 9, 10, 11, 12], 2],
72 | colors[i],
73 | )
74 | plt.plot(joints[[0, 13, 14, 15, 16], 0], joints[[0, 13, 14, 15, 16], 1],
75 | joints[[0, 13, 14, 15, 16], 2],
76 | colors[i],
77 | )
78 | plt.plot(joints[[0, 17, 18, 19, 20], 0], joints[[0, 17, 18, 19, 20], 1],
79 | joints[[0, 17, 18, 19, 20], 2],
80 | colors[i],
81 | )
82 |
83 | #######
84 | # plt.plot(joints[:1, 0], joints[:1, 1],
85 | # joints[:1, 2],
86 | # colors[i],
87 | # )
88 | #
89 | # plt.plot(joints[[0, 5, ], 0], joints[[0, 5, ], 1],
90 | # joints[[0, 5, ], 2],
91 | # colors[i],
92 | # )
93 | # plt.plot(joints[[0, 9, ], 0], joints[[0, 9, ], 1],
94 | # joints[[0, 9,], 2],
95 | # colors[i],
96 | # )
97 | # plt.plot(joints[[0, 13, ], 0], joints[[0, 13, ], 1],
98 | # joints[[0, 13, ], 2],
99 | # colors[i],
100 | # )
101 | # plt.plot(joints[[0, 17, ], 0], joints[[0, 17, ], 1],
102 | # joints[[0, 17, ], 2],
103 | # colors[i],
104 | # )
105 |
106 | # snap convention
107 | plt.plot(joints[4][0], joints[4][1], joints[4][2], 'rD')
108 | plt.plot(joints[8][0], joints[8][1], joints[8][2], 'ro', )
109 | plt.plot(joints[12][0], joints[12][1], joints[12][2], 'ro', )
110 | plt.plot(joints[16][0], joints[16][1], joints[16][2], 'ro', )
111 | plt.plot(joints[20][0], joints[20][1], joints[20][2], 'ro', )
112 | # plt.plot(joints [1:, 0], joints [1:, 1], joints [1:, 2], 'o')
113 |
114 | plt.title(title)
115 | ax.set_xlabel('x')
116 | ax.set_ylabel('y')
117 | ax.set_zlabel('z')
118 | plt.legend()
119 | # ax.view_init(330, 110)
120 | ax.view_init(-90, -90)
121 |
122 | if title:
123 | title_ = ""
124 | for i in range(len(title)):
125 | title_ += "{}: {} ".format(colors[i], title[i])
126 |
127 | ax.set_title(title_, fontsize=12, color='black')
128 | else:
129 | ax.set_title("None", fontsize=12, color='black')
130 | plt.show()
131 |
--------------------------------------------------------------------------------